fix: include go.mod in vendored oryx

GitOrigin-RevId: 20365bbe6b2cf95ac7973bcca9056455d2cb3803
This commit is contained in:
Henning Perl 2025-07-07 19:13:04 +02:00 committed by ory-bot
parent 8967cc7e31
commit 7c0d9c6ddc
182 changed files with 6 additions and 17357 deletions

View File

@ -1,2 +1,8 @@
"module name","licenses"
"github.com/arbovm/levenshtein","BSD-3-Clause"
"github.com/ory/x","Apache-2.0"
"github.com/stretchr/testify","MIT"
"go.opentelemetry.io/otel/sdk","Apache-2.0"
"golang.org/x/text","BSD-3-Clause"

1 module name licenses
2 github.com/arbovm/levenshtein BSD-3-Clause
3 github.com/ory/x Apache-2.0
4 github.com/stretchr/testify MIT
5 go.opentelemetry.io/otel/sdk Apache-2.0
6 golang.org/x/text BSD-3-Clause
7
8

View File

View File

@ -1,20 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package assertx
import (
"testing"
"time"
)
func TestEqualAsJSONExcept(t *testing.T) {
a := map[string]interface{}{"foo": "bar", "baz": "bar", "bar": "baz"}
b := map[string]interface{}{"foo": "bar", "baz": "bar", "bar": "not-baz"}
EqualAsJSONExcept(t, a, b, []string{"bar"})
}
func TestTimeDifferenceLess(t *testing.T) {
TimeDifferenceLess(t, time.Now(), time.Now().Add(time.Second), 2)
}

View File

@ -1,57 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package castx
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestToFloatSliceE(t *testing.T) {
tests := []struct {
input interface{}
expect []float64
iserr bool
}{
{[]int{1, 3}, []float64{1, 3}, false},
{[]interface{}{1.2, 3.2}, []float64{1.2, 3.2}, false},
{[]string{"2", "3"}, []float64{2, 3}, false},
{[]string{"2.2", "3.2"}, []float64{2.2, 3.2}, false},
{[2]string{"2", "3"}, []float64{2, 3}, false},
{[2]string{"2.2", "3.2"}, []float64{2.2, 3.2}, false},
// errors
{nil, nil, true},
{testing.T{}, nil, true},
{[]string{"foo", "bar"}, nil, true},
}
for i, test := range tests {
errmsg := fmt.Sprintf("i = %d", i) // assert helper message
v, err := ToFloatSliceE(test.input)
if test.iserr {
assert.Error(t, err, errmsg)
continue
}
assert.NoError(t, err, errmsg)
assert.Equal(t, test.expect, v, errmsg)
// Non-E test
v = ToFloatSlice(test.input)
assert.Equal(t, test.expect, v, errmsg)
}
}
func TestToStringSlice(t *testing.T) {
assert.Equal(t, []string{"foo", "bar"}, ToStringSlice("foo,bar"))
assert.NotEqual(t, []string{"foo bar baz"}, ToStringSlice("foo bar baz,"))
assert.Equal(t, []string{"foo bar baz", ""}, ToStringSlice("foo bar baz,"))
assert.NotEqual(t, []string{"foo", "bar", "baz"}, ToStringSlice("foo bar baz"))
assert.Equal(t, []string{"foo bar baz"}, ToStringSlice("foo bar baz"))
assert.Equal(t, []string{"foo", "bar", "baz,", " asdf"}, ToStringSlice("foo,bar,\"baz,\", asdf"))
assert.Equal(t, []string{"'foo'", "x\"bar", "baz"}, ToStringSlice("'foo',\"x\"\"bar\",baz"))
}

View File

@ -1,95 +0,0 @@
package clidoc
import (
"bytes"
"io/fs"
"os"
"path/filepath"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func noopRun(_ *cobra.Command, _ []string) {}
var (
root = &cobra.Command{Use: "root", Run: noopRun, Long: `A sample text
root
<[some argument]>
`}
child1 = &cobra.Command{Use: "child1", Run: noopRun, Long: `A sample text
child1
<[some argument]>
`, Example: "{{ .CommandPath }} --whatever"}
child2 = &cobra.Command{Use: "child2", Run: noopRun, Long: `A sample text
child2
<[some argument]>
`}
subChild1 = &cobra.Command{Use: "subChild1 <args>", Run: noopRun, Long: `A sample text
subChild1
<[some argument]>
`}
)
func snapshotDir(t *testing.T, path ...string) (assertNoChange func(t *testing.T)) {
var (
as []func(*testing.T)
fps []string
)
require.NoError(t, filepath.WalkDir(filepath.Join(path...), func(path string, d fs.DirEntry, err error) error {
require.NoError(t, err, path)
if !d.IsDir() {
fps = append(fps, path)
as = append(as, snapshotFile(t, path))
}
return nil
}))
return func(t *testing.T) {
fileN := 0
require.NoError(t, filepath.WalkDir(filepath.Join(path...), func(path string, d fs.DirEntry, err error) error {
require.NoError(t, err)
if !d.IsDir() {
assert.Contains(t, fps, path)
fileN++
}
return nil
}))
assert.Equal(t, len(fps), fileN)
for _, a := range as {
a(t)
}
}
}
func snapshotFile(t *testing.T, path ...string) (assertNoChange func(t *testing.T)) {
pre, err := os.ReadFile(filepath.Join(path...))
require.NoError(t, err)
pre = bytes.ReplaceAll(pre, []byte("\r\n"), []byte("\n"))
return func(t *testing.T) {
post, err := os.ReadFile(filepath.Join(path...))
require.NoError(t, err)
assert.Equal(t, string(pre), string(post), "%s", post)
}
}
func init() {
child1.AddCommand(subChild1)
root.AddCommand(child1, child2)
}
func TestGenerate(t *testing.T) {
assertNoChange := snapshotDir(t, "testdata")
require.NoError(t, Generate(root, []string{"testdata"}))
assertNoChange(t)
}

View File

@ -1,14 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package cmdx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestEnvVarExamplesHelpMessage(t *testing.T) {
assert.NotEmpty(t, EnvVarExamplesHelpMessage(""))
}

View File

@ -1,82 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package cmdx
import (
"bytes"
"fmt"
"strings"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConditionalPrinter(t *testing.T) {
const (
msgAlwaysOut = "always out"
msgAlwaysErr = "always err"
msgQuietOut = "quiet out"
msgQuietErr = "quiet err"
msgLoudOut = "loud out"
msgLoudErr = "loud err"
msgArgsSet = "args were set"
)
setup := func() *cobra.Command {
cmd := &cobra.Command{
Use: "test cmd",
Run: func(cmd *cobra.Command, args []string) {
_, _ = fmt.Fprint(cmd.OutOrStdout(), msgAlwaysOut)
_, _ = fmt.Fprint(cmd.ErrOrStderr(), msgAlwaysErr)
_, _ = NewQuietOutPrinter(cmd).Print(msgQuietOut)
_, _ = NewQuietErrPrinter(cmd).Print(msgQuietErr)
_, _ = NewLoudOutPrinter(cmd).Print(msgLoudOut)
_, _ = NewLoudErrPrinter(cmd).Print(msgLoudErr)
_, _ = NewConditionalPrinter(cmd.OutOrStdout(), len(args) > 0).Print(msgArgsSet)
},
}
RegisterNoiseFlags(cmd.Flags())
return cmd
}
for _, tc := range []struct {
stdErrMsg, stdOutMsg, args []string
setQuiet bool
}{
{
stdOutMsg: []string{msgLoudOut},
stdErrMsg: []string{msgLoudErr},
setQuiet: false,
args: []string{},
},
{
stdOutMsg: []string{msgQuietOut},
stdErrMsg: []string{msgQuietErr},
setQuiet: true,
args: []string{},
},
{
stdOutMsg: []string{msgQuietOut, msgArgsSet},
stdErrMsg: []string{msgQuietErr},
setQuiet: true,
args: []string{"foo"},
},
} {
t.Run(fmt.Sprintf("case=quiet:%v", tc.setQuiet), func(t *testing.T) {
cmd := setup()
if tc.setQuiet {
require.NoError(t, cmd.Flags().Set(FlagQuiet, "true"))
}
out, err := &bytes.Buffer{}, &bytes.Buffer{}
cmd.SetOut(out)
cmd.SetErr(err)
cmd.SetArgs(tc.args)
require.NoError(t, cmd.Execute())
assert.Equal(t, strings.Join(append([]string{msgAlwaysOut}, tc.stdOutMsg...), ""), out.String())
assert.Equal(t, strings.Join(append([]string{msgAlwaysErr}, tc.stdErrMsg...), ""), err.String())
})
}
}

View File

@ -1,40 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package cmdx
import (
"bytes"
"io"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPagination(t *testing.T) {
cmd := &cobra.Command{}
cmd.SetErr(io.Discard)
page, perPage, err := ParsePaginationArgs(cmd, "1", "2")
require.NoError(t, err)
assert.EqualValues(t, 1, page)
assert.EqualValues(t, 2, perPage)
_, _, err = ParsePaginationArgs(cmd, "abcd", "")
require.Error(t, err)
}
func TestTokenPagination(t *testing.T) {
var stderr bytes.Buffer
cmd := &cobra.Command{}
cmd.SetErr(&stderr)
RegisterTokenPaginationFlags(cmd)
require.NoError(t, cmd.Flags().Set(FlagPageToken, "1"))
require.NoError(t, cmd.Flags().Set(FlagPageSize, "2"))
page, perPage, err := ParseTokenPaginationArgs(cmd)
require.NoError(t, err, stderr.String())
assert.EqualValues(t, "1", page)
assert.EqualValues(t, 2, perPage)
}

View File

@ -1,363 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package cmdx
import (
"bytes"
"fmt"
"slices"
"strconv"
"testing"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type (
dynamicTable struct {
t [][]string
cs int
}
dynamicIDAbleTable struct {
*dynamicTable
idColumn int
}
dynamicRow []string
dynamicIDAbleRow struct {
dynamicRow
idColumn int
}
)
var (
_ Table = (*dynamicTable)(nil)
_ Table = (*dynamicIDAbleTable)(nil)
_ TableRow = (dynamicRow)(nil)
_ TableRow = (*dynamicIDAbleRow)(nil)
)
func dynamicHeader(l int) []string {
h := make([]string, l)
for i := range h {
h[i] = "C" + strconv.Itoa(i)
}
return h
}
func (d *dynamicTable) Header() []string {
return dynamicHeader(d.cs)
}
func (d *dynamicTable) Table() [][]string {
return d.t
}
func (d *dynamicTable) Interface() interface{} {
return d.t
}
func (d *dynamicIDAbleTable) IDs() []string {
ids := make([]string, d.Len())
for i, row := range d.Table() {
ids[i] = row[d.idColumn]
}
return ids
}
func (d *dynamicTable) Len() int {
return len(d.t)
}
func (d dynamicRow) Header() []string {
return dynamicHeader(len(d))
}
func (d dynamicRow) Columns() []string {
return d
}
func (d dynamicRow) Interface() interface{} {
return d
}
func (d *dynamicIDAbleRow) ID() string {
return d.dynamicRow[d.idColumn]
}
func TestPrinting(t *testing.T) {
t.Run("case=format flags", func(t *testing.T) {
t.Run("format=no value", func(t *testing.T) {
flags := pflag.NewFlagSet("test flags", pflag.ContinueOnError)
RegisterFormatFlags(flags)
require.NoError(t, flags.Parse([]string{}))
f, err := flags.GetString(FlagFormat)
require.NoError(t, err)
assert.Equal(t, FormatDefault, format(f))
})
})
t.Run("method=table row", func(t *testing.T) {
t.Run("case=all formats", func(t *testing.T) {
tr := dynamicRow{"AAA", "BBB", "CCC"}
allFields := append(tr.Header(), tr...)
for _, tc := range []struct {
fArgs []string
contained []string
}{
{
fArgs: []string{"--" + FlagFormat, string(FormatTable)},
contained: allFields,
},
{
fArgs: []string{"--" + FlagQuiet},
contained: []string{tr[0]},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSON)},
contained: tr,
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPretty)},
contained: tr,
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPointer) + "=/0"},
contained: []string{"AAA"},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPointer) + "=/2"},
contained: []string{"CCC"},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPointer) + "=/1"},
contained: []string{"BBB"},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPath) + "=0"},
contained: []string{"AAA"},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPath) + "=2"},
contained: []string{"CCC"},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPath) + "=[0,1]"},
contained: []string{"AAA", "BBB"},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatYAML)},
contained: tr,
},
} {
t.Run(fmt.Sprintf("format=%v", tc.fArgs), func(t *testing.T) {
cmd := &cobra.Command{Use: "x"}
RegisterFormatFlags(cmd.Flags())
out := &bytes.Buffer{}
cmd.SetOut(out)
require.NoError(t, cmd.Flags().Parse(tc.fArgs))
PrintRow(cmd, tr)
for _, s := range tc.contained {
assert.Contains(t, out.String(), s, "%s", out.String())
}
notContained := slices.DeleteFunc(slices.Clone(allFields), func(s string) bool {
return slices.Contains(tc.contained, s)
})
for _, s := range notContained {
assert.NotContains(t, out.String(), s, "%s", out.String())
}
assert.Equal(t, "\n", out.String()[len(out.String())-1:])
})
}
})
t.Run("case=uses ID()", func(t *testing.T) {
tr := &dynamicIDAbleRow{
dynamicRow: []string{"foo", "bar"},
idColumn: 1,
}
cmd := &cobra.Command{Use: "x"}
RegisterFormatFlags(cmd.Flags())
out := &bytes.Buffer{}
cmd.SetOut(out)
require.NoError(t, cmd.Flags().Parse([]string{"--" + FlagQuiet}))
PrintRow(cmd, tr)
assert.Equal(t, tr.dynamicRow[1]+"\n", out.String())
})
})
t.Run("method=table", func(t *testing.T) {
t.Run("case=full table", func(t *testing.T) {
tb := &dynamicTable{
t: [][]string{
{"a0", "b0", "c0"},
{"a1", "b1", "c1"},
},
cs: 3,
}
allFields := append(tb.Header(), append(tb.t[0], tb.t[1]...)...)
for _, tc := range []struct {
fArgs []string
contained []string
}{
{
fArgs: []string{"--" + FlagFormat, string(FormatTable)},
contained: allFields,
},
{
fArgs: []string{"--" + FlagQuiet},
contained: []string{tb.t[0][0], tb.t[1][0]},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSON)},
contained: append(tb.t[0], tb.t[1]...),
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPretty)},
contained: append(tb.t[0], tb.t[1]...),
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPath) + "=1.1"},
contained: []string{tb.t[1][1]},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPointer) + "=/1/1"},
contained: []string{tb.t[1][1]},
},
{
fArgs: []string{"--" + FlagFormat, string(FormatYAML)},
contained: append(tb.t[0], tb.t[1]...),
},
} {
t.Run(fmt.Sprintf("format=%v", tc.fArgs), func(t *testing.T) {
cmd := &cobra.Command{Use: "x"}
RegisterFormatFlags(cmd.Flags())
out := &bytes.Buffer{}
cmd.SetOut(out)
require.NoError(t, cmd.Flags().Parse(tc.fArgs))
PrintTable(cmd, tb)
for _, s := range tc.contained {
assert.Contains(t, out.String(), s, "%s", out.String())
}
notContained := slices.DeleteFunc(slices.Clone(allFields), func(s string) bool {
return slices.Contains(tc.contained, s)
})
for _, s := range notContained {
assert.NotContains(t, out.String(), s, "%s", out.String())
}
assert.Equal(t, "\n", out.String()[len(out.String())-1:])
})
}
})
t.Run("case=empty table", func(t *testing.T) {
tb := &dynamicTable{
t: nil,
cs: 1,
}
for _, tc := range []struct {
fArgs []string
expected string
}{
{
fArgs: []string{"--" + FlagFormat, string(FormatTable)},
expected: "C0\t",
},
{
fArgs: []string{"--" + FlagQuiet},
expected: "",
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSON)},
expected: "null",
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPretty)},
expected: "null",
},
{
fArgs: []string{"--" + FlagFormat, string(FormatJSONPath) + "=foo"},
expected: "null",
},
{
fArgs: []string{"--" + FlagFormat, string(FormatYAML)},
expected: "null",
},
} {
t.Run(fmt.Sprintf("format=%v", tc.fArgs), func(t *testing.T) {
cmd := &cobra.Command{Use: "x"}
RegisterFormatFlags(cmd.Flags())
out := &bytes.Buffer{}
cmd.SetOut(out)
require.NoError(t, cmd.Flags().Parse(tc.fArgs))
PrintTable(cmd, tb)
assert.Equal(t, tc.expected+"\n", out.String())
})
}
})
t.Run("case=uses IDs()", func(t *testing.T) {
tb := &dynamicIDAbleTable{
dynamicTable: &dynamicTable{
t: [][]string{
{"a0", "b0", "c0"},
{"a1", "b1", "c1"},
},
cs: 3,
},
idColumn: 1,
}
cmd := &cobra.Command{Use: "x"}
RegisterFormatFlags(cmd.Flags())
out := &bytes.Buffer{}
cmd.SetOut(out)
require.NoError(t, cmd.Flags().Parse([]string{"--" + FlagQuiet}))
PrintTable(cmd, tb)
assert.Equal(t, tb.t[0][1]+"\n"+tb.t[1][1]+"\n", out.String())
})
})
t.Run("method=jsonable", func(t *testing.T) {
t.Run("case=nil", func(t *testing.T) {
for _, f := range []format{FormatDefault, FormatJSON, FormatJSONPretty, FormatJSONPath, FormatJSONPointer, FormatYAML} {
t.Run("format="+string(f), func(t *testing.T) {
out := &bytes.Buffer{}
cmd := &cobra.Command{}
cmd.SetOut(out)
RegisterJSONFormatFlags(cmd.Flags())
PrintJSONAble(cmd, nil)
assert.Equal(t, "null", out.String())
})
}
})
})
}

View File

@ -1,84 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package cmdx
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
func TestUsageTemplating(t *testing.T) {
root := &cobra.Command{
Use: "root",
Short: "{{ .Name }}",
}
cmdWithTemplate := &cobra.Command{
Use: "with-template",
Long: "{{ .Name }}",
Example: "{{ .Name }}",
}
cmdWithoutTemplate := &cobra.Command{
Use: "without-template",
Long: "{{ .Name }}",
Example: "{{ .Name }}",
}
root.AddCommand(cmdWithTemplate, cmdWithoutTemplate)
EnableUsageTemplating(root)
DisableUsageTemplating(cmdWithoutTemplate)
assert.NotContains(t, root.UsageString(), "{{ .Name }}")
assert.NotContains(t, cmdWithTemplate.UsageString(), "{{ .Name }}")
assert.Contains(t, cmdWithoutTemplate.UsageString(), "{{ .Name }}")
}
func TestAssertUsageTemplates(t *testing.T) {
var cmdsCalled []string
AddUsageTemplateFunc("called", func(use string) string {
cmdsCalled = append(cmdsCalled, use)
return use
})
root := &cobra.Command{
Use: "root",
Short: "{{ called .Use }}",
}
child := &cobra.Command{
Use: "child",
Long: "{{ called .Use }}",
}
otherChild := &cobra.Command{
Use: "other-child",
Example: "{{ called .Use }}",
}
childChild := &cobra.Command{
Use: "child-child",
Example: "{{ called .Use }}",
}
root.AddCommand(child, otherChild)
child.AddCommand(childChild)
EnableUsageTemplating(root)
require.NotPanics(t, func() {
AssertUsageTemplates(&panicT{}, root)
})
assert.ElementsMatch(t, []string{root.Use, child.Use, otherChild.Use, childChild.Use}, cmdsCalled)
}
type panicT struct{}
func (t *panicT) FailNow() {
panic("failing")
}
func (*panicT) Errorf(format string, args ...interface{}) {
panic("erroring: " + fmt.Sprintf(format, args...))
}
var _ require.TestingT = (*panicT)(nil)

View File

@ -1,81 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package cmdx
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAskForConfirmation(t *testing.T) {
t.Run("case=prints question", func(t *testing.T) {
testQuestion := "test-question"
stdin, stdout := new(bytes.Buffer), new(bytes.Buffer)
_, err := stdin.Write([]byte("y\n"))
require.NoError(t, err)
AskForConfirmation(testQuestion, stdin, stdout)
prompt, err := io.ReadAll(stdout)
require.NoError(t, err)
assert.Contains(t, string(prompt), testQuestion)
})
t.Run("case=accept", func(t *testing.T) {
for _, input := range []string{
"y\n",
"yes\n",
} {
stdin := new(bytes.Buffer)
_, err := stdin.Write([]byte(input))
require.NoError(t, err)
confirmed := AskForConfirmation("", stdin, new(bytes.Buffer))
assert.True(t, confirmed)
}
})
t.Run("case=reject", func(t *testing.T) {
for _, input := range []string{
"n\n",
"no\n",
} {
stdin := new(bytes.Buffer)
_, err := stdin.Write([]byte(input))
require.NoError(t, err)
confirmed := AskForConfirmation("", stdin, new(bytes.Buffer))
assert.False(t, confirmed)
}
})
t.Run("case=reprompt on random input", func(t *testing.T) {
testQuestion := "question"
for _, input := range []string{
"foo\ny\n",
"bar\nn\n",
} {
stdin, stdout := new(bytes.Buffer), new(bytes.Buffer)
_, err := stdin.Write([]byte(input))
require.NoError(t, err)
AskForConfirmation(testQuestion, stdin, stdout)
output, err := io.ReadAll(stdout)
require.NoError(t, err)
assert.Equal(t, 2, bytes.Count(output, []byte(testQuestion)))
}
})
}

View File

@ -1,34 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"context"
_ "embed"
"testing"
"github.com/dgraph-io/ristretto/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
//go:embed stub/kratos/config.schema.json
var kratosSchema []byte
func TestNewKoanfEnvCache(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ref, compiler, err := newCompiler(kratosSchema)
require.NoError(t, err)
schema, err := compiler.Compile(ctx, ref)
require.NoError(t, err)
c := *schemaPathCacheConfig
c.Metrics = true
schemaPathCache, _ = ristretto.NewCache(&c)
_, _ = NewKoanfEnv("", kratosSchema, schema)
_, _ = NewKoanfEnv("", kratosSchema, schema)
assert.EqualValues(t, 1, schemaPathCache.Metrics.Hits())
}

View File

@ -1,90 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/ghodss/yaml"
"github.com/pelletier/go-toml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestKoanfFile(t *testing.T) {
setupFile := func(t *testing.T, fn, fc, subKey string) *KoanfFile {
dir := t.TempDir()
fn = filepath.Join(dir, fn)
require.NoError(t, os.WriteFile(fn, []byte(fc), 0600))
kf, err := NewKoanfFileSubKey(fn, subKey)
require.NoError(t, err)
return kf
}
t.Run("case=reads json root file", func(t *testing.T) {
v := map[string]interface{}{
"foo": "bar",
}
encV, err := json.Marshal(v)
require.NoError(t, err)
kf := setupFile(t, "config.json", string(encV), "")
actual, err := kf.Read()
require.NoError(t, err)
assert.Equal(t, v, actual)
})
t.Run("case=reads yaml root file", func(t *testing.T) {
v := map[string]interface{}{
"foo": "yaml string",
}
encV, err := yaml.Marshal(v)
require.NoError(t, err)
kf := setupFile(t, "config.yml", string(encV), "")
actual, err := kf.Read()
require.NoError(t, err)
assert.Equal(t, v, actual)
})
t.Run("case=reads toml root file", func(t *testing.T) {
v := map[string]interface{}{
"foo": "toml string",
}
encV, err := toml.Marshal(v)
require.NoError(t, err)
kf := setupFile(t, "config.toml", string(encV), "")
actual, err := kf.Read()
require.NoError(t, err)
assert.Equal(t, v, actual)
})
t.Run("case=reads json file as subkey", func(t *testing.T) {
v := map[string]interface{}{
"bar": "asdf",
}
encV, err := json.Marshal(v)
require.NoError(t, err)
kf := setupFile(t, "config.json", string(encV), "parent.of.config")
actual, err := kf.Read()
require.NoError(t, err)
assert.Equal(t, map[string]interface{}{
"parent": map[string]interface{}{
"of": map[string]interface{}{
"config": v,
},
},
}, actual)
})
}

View File

@ -1,30 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
stdjson "encoding/json"
"testing"
"github.com/knadh/koanf/parsers/json"
"github.com/knadh/koanf/providers/rawbytes"
"github.com/knadh/koanf/v2"
)
func TestKoanfMergeArray(t *testing.T) {
k := koanf.NewWithConf(koanf.Conf{Delim: Delimiter, StrictMerge: true})
if err := k.Load(rawbytes.Provider([]byte(`{"foo":[{"id":"bar"}]}`)), json.Parser()); err != nil {
t.Fatal(err)
}
if err := k.Load(rawbytes.Provider([]byte(`{"foo":[{"key":"baz"},{"baz":"bar"}]}`)), json.Parser(), koanf.WithMergeFunc(MergeAllTypes)); err != nil {
t.Fatal(err)
}
expected := `{"foo":[{"id":"bar","key":"baz"},{"baz":"bar"}]}`
out, _ := stdjson.Marshal(k.All())
if string(out) != expected {
t.Fatalf("Expected %s but got: %s", expected, out)
}
}

View File

@ -1,30 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/ory/x/assertx"
)
func TestKoanfMemory(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
doc := []byte(`{
"foo": {
"bar": "baz"
}
}`)
kf := NewKoanfMemory(ctx, doc)
actual, err := kf.Read()
require.NoError(t, err)
assertx.EqualAsJSON(t, json.RawMessage(doc), actual)
}

View File

@ -1,43 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"bytes"
"context"
"os"
"path"
"testing"
"github.com/stretchr/testify/require"
"github.com/ory/jsonschema/v3"
"github.com/ory/x/snapshotx"
)
func TestKoanfSchemaDefaults(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schemaPath := path.Join("stub", "domain-aliases", "config.schema.json")
rawSchema, err := os.ReadFile(schemaPath)
require.NoError(t, err)
c := jsonschema.NewCompiler()
require.NoError(t, c.AddResource(schemaPath, bytes.NewReader(rawSchema)))
schema, err := c.Compile(ctx, schemaPath)
require.NoError(t, err)
k, err := newKoanf(ctx, schemaPath, nil)
require.NoError(t, err)
def, err := NewKoanfSchemaDefaults(rawSchema, schema)
require.NoError(t, err)
require.NoError(t, k.Load(def, nil))
snapshotx.SnapshotT(t, k.All())
}

View File

@ -1,128 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"context"
"fmt"
"os"
"path"
"testing"
"github.com/spf13/pflag"
"github.com/dgraph-io/ristretto/v2"
"github.com/stretchr/testify/require"
)
func newKoanf(ctx context.Context, schemaPath string, configPaths []string, modifiers ...OptionModifier) (*Provider, error) {
schema, err := os.ReadFile(schemaPath)
if err != nil {
return nil, err
}
f := pflag.NewFlagSet("config", pflag.ContinueOnError)
f.StringSliceP("config", "c", configPaths, "")
modifiers = append(modifiers, WithFlags(f))
k, err := New(ctx, schema, modifiers...)
if err != nil {
return nil, err
}
return k, nil
}
func setEnvs(t testing.TB, envs [][2]string) {
for _, v := range envs {
require.NoError(t, os.Setenv(v[0], v[1]))
}
t.Cleanup(func() {
for _, v := range envs {
_ = os.Unsetenv(v[0])
}
})
}
func BenchmarkNewKoanf(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
setEnvs(b, [][2]string{{"MUTATORS_HEADER_ENABLED", "true"}})
schemaPath := path.Join("stub/benchmark/schema.config.json")
for i := 0; i < b.N; i++ {
_, err := newKoanf(ctx, schemaPath, []string{}, WithValues(map[string]interface{}{
"dsn": "memory",
}))
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkKoanf(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
setEnvs(b, [][2]string{{"MUTATORS_HEADER_ENABLED", "true"}})
schemaPath := path.Join("stub/benchmark/schema.config.json")
k, err := newKoanf(ctx, schemaPath, []string{"stub/benchmark/benchmark.yaml"})
require.NoError(b, err)
keys := k.Koanf.Keys()
numKeys := len(keys)
b.Run("cache=false", func(b *testing.B) {
var key string
b.ResetTimer()
for i := 0; i < b.N; i++ {
key = keys[i%numKeys]
if k.Koanf.Get(key) == nil {
b.Fatalf("cachedFind returned a nil value for key: %s", key)
}
}
})
b.Run("cache=true", func(b *testing.B) {
for i, c := range []*ristretto.Config[string, any]{
{
NumCounters: int64(numKeys),
MaxCost: 500000,
BufferItems: 64,
},
{
NumCounters: int64(numKeys * 10),
MaxCost: 1000000,
BufferItems: 64,
},
{
NumCounters: int64(numKeys * 10),
MaxCost: 5000000,
BufferItems: 64,
},
} {
cache, err := ristretto.NewCache[string, any](c)
require.NoError(b, err)
b.Run(fmt.Sprintf("config=%d", i), func(b *testing.B) {
b.ResetTimer()
for i := range b.N {
key := keys[i%numKeys]
val, found := cache.Get(key)
if !found {
val = k.Koanf.Get(key)
_ = cache.Set(key, val, 0)
}
if val == nil {
b.Fatalf("cachedFind returned a nil value for key: %s", key)
}
}
})
}
})
}

View File

@ -1,29 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOptions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("case=does not load env if disabled", func(t *testing.T) {
schema := `{"type": "object", "properties": {"path": {"type": "string"}}}`
envP, err := New(ctx, []byte(schema))
require.NoError(t, err)
assert.NotZero(t, envP.String("path"))
nonEnvP, err := New(ctx, []byte(schema), DisableEnvLoading())
require.NoError(t, err)
assert.Nil(t, nonEnvP.Get("path"))
})
}

View File

@ -1,34 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetPerm(t *testing.T) {
f, e := os.CreateTemp("", "test")
require.NoError(t, e)
path := f.Name()
// We cannot test setting owner and group, because we don't know what the
// tester has access to.
_ = (&UnixPermission{
Owner: "",
Group: "",
Mode: 0654,
}).SetPermission(path)
stat, err := f.Stat()
require.NoError(t, err)
assert.Equal(t, os.FileMode(0654), stat.Mode())
require.NoError(t, f.Close())
require.NoError(t, os.Remove(path))
}

View File

@ -1,49 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"context"
"testing"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/jsonschema/v3"
)
func TestPFlagProvider(t *testing.T) {
const schema = `
{
"type": "object",
"properties": {
"foo": {
"type": "string"
}
}
}
`
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := jsonschema.CompileString(ctx, "", schema)
require.NoError(t, err)
t.Run("only parses known flags", func(t *testing.T) {
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
flags.String("foo", "", "")
flags.String("bar", "", "")
require.NoError(t, flags.Parse([]string{"--foo", "x", "--bar", "y"}))
p, err := NewPFlagProvider([]byte(schema), s, flags, nil)
require.NoError(t, err)
values, err := p.Read()
require.NoError(t, err)
assert.Equal(t, map[string]interface{}{
"foo": "x",
}, values)
})
}

View File

@ -1,258 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"context"
"os"
"path"
"testing"
"time"
"github.com/inhies/go-bytesize"
"github.com/knadh/koanf/parsers/json"
"github.com/ory/x/urlx"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newProvider(t testing.TB) *Provider {
// Fake some flags
f := pflag.NewFlagSet("config", pflag.ContinueOnError)
f.String("foo-bar-baz", "", "")
f.StringP("b", "b", "", "")
args := []string{"/var/folders/mt/m1dwr59n73zgsq7bk0q2lrmc0000gn/T/go-build533083141/b001/exe/asdf", "aaaa", "-b", "bbbb", "dddd", "eeee", "--foo-bar-baz", "fff"}
require.NoError(t, f.Parse(args[1:]))
RegisterFlags(f)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
p, err := New(ctx, []byte(`{"type": "object", "properties": {"foo-bar-baz": {"type": "string"}, "b": {"type": "string"}}}`), WithFlags(f), WithContext(ctx))
require.NoError(t, err)
return p
}
func TestProviderMethods(t *testing.T) {
p := newProvider(t)
t.Run("check flags", func(t *testing.T) {
assert.Equal(t, "fff", p.String("foo-bar-baz"))
assert.Equal(t, "bbbb", p.String("b"))
})
t.Run("check fallbacks", func(t *testing.T) {
t.Run("type=string", func(t *testing.T) {
require.NoError(t, p.Set("some.string", "bar"))
assert.Equal(t, "bar", p.StringF("some.string", "baz"))
assert.Equal(t, "baz", p.StringF("not.some.string", "baz"))
})
t.Run("type=float", func(t *testing.T) {
require.NoError(t, p.Set("some.float", 123.123))
assert.Equal(t, 123.123, p.Float64F("some.float", 321.321))
assert.Equal(t, 321.321, p.Float64F("not.some.float", 321.321))
})
t.Run("type=int", func(t *testing.T) {
require.NoError(t, p.Set("some.int", 123))
assert.Equal(t, 123, p.IntF("some.int", 123))
assert.Equal(t, 321, p.IntF("not.some.int", 321))
})
t.Run("type=bytesize", func(t *testing.T) {
const key = "some.bytesize"
for _, v := range []interface{}{
bytesize.MB,
float64(1024 * 1024),
"1MB",
} {
require.NoError(t, p.Set(key, v))
assert.Equal(t, bytesize.MB, p.ByteSizeF(key, 0))
}
})
github := urlx.ParseOrPanic("https://github.com/ory")
ory := urlx.ParseOrPanic("https://www.ory.sh/")
t.Run("type=url", func(t *testing.T) {
require.NoError(t, p.Set("some.url", "https://github.com/ory"))
assert.Equal(t, github, p.URIF("some.url", ory))
assert.Equal(t, ory, p.URIF("not.some.url", ory))
})
t.Run("type=request_uri", func(t *testing.T) {
require.NoError(t, p.Set("some.request_uri", "https://github.com/ory"))
assert.Equal(t, github, p.RequestURIF("some.request_uri", ory))
assert.Equal(t, ory, p.RequestURIF("not.some.request_uri", ory))
require.NoError(t, p.Set("invalid.request_uri", "foo"))
assert.Equal(t, ory, p.RequestURIF("invalid.request_uri", ory))
})
})
t.Run("allow integer as duration", func(t *testing.T) {
assert.NoError(t, p.Set("duration.integer1", -1))
assert.NoError(t, p.Set("duration.integer2", "-1"))
assert.Equal(t, -1*time.Nanosecond, p.DurationF("duration.integer1", time.Second))
assert.Equal(t, -1*time.Nanosecond, p.DurationF("duration.integer2", time.Second))
})
t.Run("use complex set operations", func(t *testing.T) {
assert.NoError(t, p.Set("nested", nil))
assert.NoError(t, p.Set("nested.value", "https://www.ory.sh/kratos"))
assert.Equal(t, "https://www.ory.sh/kratos", p.Get("nested.value"))
})
t.Run("use DirtyPatch operations", func(t *testing.T) {
assert.NoError(t, p.DirtyPatch("nested", nil))
assert.NoError(t, p.DirtyPatch("nested.value", "https://www.ory.sh/kratos"))
assert.Equal(t, "https://www.ory.sh/kratos", p.Get("nested.value"))
assert.NoError(t, p.DirtyPatch("duration.integer1", -1))
assert.NoError(t, p.DirtyPatch("duration.integer2", "-1"))
assert.Equal(t, -1*time.Nanosecond, p.DurationF("duration.integer1", time.Second))
assert.Equal(t, -1*time.Nanosecond, p.DurationF("duration.integer2", time.Second))
require.NoError(t, p.DirtyPatch("some.float", 123.123))
assert.Equal(t, 123.123, p.Float64F("some.float", 321.321))
assert.Equal(t, 321.321, p.Float64F("not.some.float", 321.321))
})
}
func TestAdvancedConfigs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, tc := range []struct {
stub string
configs []string
envs [][2]string
ops []OptionModifier
isValid bool
expectedF func(*testing.T, *Provider)
}{
{
stub: "nested-array",
configs: []string{"stub/nested-array/kratos.yaml"},
isValid: true, envs: [][2]string{
{"PROVIDERS_0_CLIENT_ID", "client@example.com"},
{"PROVIDERS_1_CLIENT_ID", "some@example.com"},
},
},
{
stub: "kratos",
configs: []string{"stub/kratos/kratos.yaml"},
isValid: true, envs: [][2]string{
{"SELFSERVICE_METHODS_OIDC_CONFIG_PROVIDERS", `[{"id":"google","provider":"google","mapper_url":"file:///etc/config/kratos/oidc.google.jsonnet","client_id":"client@example.com","client_secret":"secret"}]`},
{"DSN", "sqlite:///var/lib/sqlite/db.sqlite?_fk=true"},
{"SELFSERVICE_FLOWS_REGISTRATION_AFTER_PASSWORD_HOOKS_0_HOOK", "session"},
},
},
{
stub: "multi",
configs: []string{"stub/multi/a.yaml", "stub/multi/b.yaml"},
isValid: true, envs: [][2]string{
{"DSN", "sqlite:///var/lib/sqlite/db.sqlite?_fk=true"},
}},
{
stub: "from-files",
isValid: true, envs: [][2]string{
{"DSN", "sqlite:///var/lib/sqlite/db.sqlite?_fk=true"},
},
ops: []OptionModifier{WithConfigFiles("stub/multi/a.yaml", "stub/multi/b.yaml")}},
{
stub: "hydra",
configs: []string{"stub/hydra/hydra.yaml"},
isValid: true,
envs: [][2]string{
{"DSN", "sqlite:///var/lib/sqlite/db.sqlite?_fk=true"},
{"TRACING_PROVIDER", "jaeger"},
{"TRACING_PROVIDERS_JAEGER_SAMPLING_SERVER_URL", "http://jaeger:5778/sampling"},
{"TRACING_PROVIDERS_JAEGER_LOCAL_AGENT_ADDRESS", "jaeger:6831"},
{"TRACING_PROVIDERS_JAEGER_SAMPLING_TYPE", "const"},
{"TRACING_PROVIDERS_JAEGER_SAMPLING_VALUE", "1"},
},
expectedF: func(t *testing.T, p *Provider) {
assert.Equal(t, "sqlite:///var/lib/sqlite/db.sqlite?_fk=true", p.Get("dsn"))
assert.Equal(t, "jaeger", p.Get("tracing.provider"))
}},
{
stub: "hydra",
configs: []string{"stub/hydra/hydra.yaml"},
isValid: false,
ops: []OptionModifier{WithUserProviders(NewKoanfMemory(ctx, []byte(`{"dsn": null}`)))},
},
{
stub: "hydra",
configs: []string{"stub/hydra/hydra.yaml"},
isValid: true,
ops: []OptionModifier{WithUserProviders(NewKoanfMemory(ctx, []byte(`{"dsn": "invalid"}`)))},
envs: [][2]string{
{"DSN", "sqlite:///var/lib/sqlite/db.sqlite?_fk=true"},
{"TRACING_PROVIDER", "jaeger"},
{"TRACING_PROVIDERS_JAEGER_LOCAL_AGENT_ADDRESS", "jaeger:6831"},
{"TRACING_PROVIDERS_JAEGER_SAMPLING_SERVER_URL", "http://jaeger:5778/sampling"},
{"TRACING_PROVIDERS_JAEGER_SAMPLING_TYPE", "const"},
{"TRACING_PROVIDERS_JAEGER_SAMPLING_VALUE", "1"},
},
},
} {
t.Run("service="+tc.stub, func(t *testing.T) {
setEnvs(t, tc.envs)
expected, err := os.ReadFile(path.Join("stub", tc.stub, "expected.json"))
require.NoError(t, err)
schemaPath := path.Join("stub", tc.stub, "config.schema.json")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
k, err := newKoanf(ctx, schemaPath, tc.configs, append(tc.ops, WithContext(ctx))...)
if !tc.isValid {
require.Error(t, err)
return
}
require.NoError(t, err)
out, err := k.Koanf.Marshal(json.Parser())
require.NoError(t, err)
assert.JSONEq(t, string(expected), string(out), "%s", out)
if tc.expectedF != nil {
tc.expectedF(t, k)
}
})
}
}
func BenchmarkSet(b *testing.B) {
// Benchmark set function
p := newProvider(b)
var err error
for i := 0; i < b.N; i++ {
err = p.Set("nested.value", "https://www.ory.sh/kratos")
if err != nil {
b.Fatalf("Unexpected error: %s", err)
}
}
}
func BenchmarkDirtyPatch(b *testing.B) {
// Benchmark set function
p := newProvider(b)
var err error
for i := 0; i < b.N; i++ {
err = p.DirtyPatch("nested.value", "https://www.ory.sh/kratos")
if err != nil {
b.Fatalf("Unexpected error: %s", err)
}
}
}

View File

@ -1,284 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/x/logrusx"
"github.com/ory/x/watcherx"
)
func tmpConfigFile(t *testing.T, dsn, foo string) (string, string) {
config := fmt.Sprintf("dsn: %s\nfoo: %s\n", dsn, foo)
tdir := t.TempDir()
fn := "config.yml"
watcherx.KubernetesAtomicWrite(t, tdir, fn, config)
return tdir, fn
}
func updateConfigFile(t *testing.T, c <-chan struct{}, dir, name, dsn, foo, bar string) {
config := fmt.Sprintf(`dsn: %s
foo: %s
bar: %s`, dsn, foo, bar)
watcherx.KubernetesAtomicWrite(t, dir, name, config)
<-c // Wait for changes to propagate
time.Sleep(time.Millisecond)
}
func assertNoOpenFDs(t require.TestingT, dir, name string) {
if runtime.GOOS == "windows" {
return
}
var b, be bytes.Buffer
// we are only interested in the file descriptors, so we use the `-F f` option
c := exec.Command("lsof", "-n", "-F", "f", "--", filepath.Join(dir, name))
c.Stdout = &b
c.Stderr = &be
exitErr := new(exec.ExitError)
require.ErrorAsf(t, c.Run(), &exitErr, "File %q has open file descriptor.\nGot stout: %s\nstderr: %s", filepath.Join(dir, name), b.String(), be.String())
assert.Equal(t, 1, exitErr.ExitCode(), "got stout: %s\nstderr: %s", b.String(), be.String())
}
func TestReload(t *testing.T) {
setup := func(t *testing.T, dir, name string, c chan<- struct{}, modifiers ...OptionModifier) (*Provider, *logrusx.Logger) {
l := logrusx.New("configx", "test")
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
modifiers = append(modifiers,
WithLogrusWatcher(l),
WithLogger(l),
AttachWatcher(func(event watcherx.Event, err error) {
fmt.Printf("Received event: %+v error: %+v\n", event, err)
c <- struct{}{}
}),
WithContext(ctx),
)
p, err := newKoanf(ctx, "./stub/watch/config.schema.json", []string{filepath.Join(dir, name)}, modifiers...)
require.NoError(t, err)
return p, l
}
t.Run("case=rejects not validating changes", func(t *testing.T) {
t.Parallel()
dir, name := tmpConfigFile(t, "memory", "bar")
c := make(chan struct{})
p, l := setup(t, dir, name, c)
hook := test.NewLocal(l.Entry.Logger)
assertNoOpenFDs(t, dir, name)
assert.Equal(t, []*logrus.Entry{}, hook.AllEntries())
assert.Equal(t, "memory", p.String("dsn"))
assert.Equal(t, "bar", p.String("foo"))
updateConfigFile(t, c, dir, name, "memory", "not bar", "bar")
entries := hook.AllEntries()
require.False(t, len(entries) > 4, "%+v", entries) // should be 2 but addresses flake https://github.com/ory/x/runs/2332130952
assert.Equal(t, "A change to a configuration file was detected.", entries[0].Message)
assert.Equal(t, "The changed configuration is invalid and could not be loaded. Rolling back to the last working configuration revision. Please address the validation errors before restarting the process.", entries[1].Message)
assert.Equal(t, "memory", p.String("dsn"))
assert.Equal(t, "bar", p.String("foo"))
// but it is still watching the files
updateConfigFile(t, c, dir, name, "memory", "bar", "baz")
assert.Equal(t, "baz", p.String("bar"))
time.Sleep(time.Millisecond * 250)
assertNoOpenFDs(t, dir, name)
})
t.Run("case=rejects to update immutable", func(t *testing.T) {
t.Parallel()
dir, name := tmpConfigFile(t, "memory", "bar")
c := make(chan struct{})
p, l := setup(t, dir, name, c,
WithImmutables("dsn"))
hook := test.NewLocal(l.Entry.Logger)
assertNoOpenFDs(t, dir, name)
assert.Equal(t, []*logrus.Entry{}, hook.AllEntries())
assert.Equal(t, "memory", p.String("dsn"))
assert.Equal(t, "bar", p.String("foo"))
updateConfigFile(t, c, dir, name, "some db", "bar", "baz")
entries := hook.AllEntries()
require.False(t, len(entries) > 4, "%+v", entries) // should be 2 but addresses flake https://github.com/ory/x/runs/2332130952
assert.Equal(t, "A change to a configuration file was detected.", entries[0].Message)
assert.Equal(t, "A configuration value marked as immutable has changed. Rolling back to the last working configuration revision. To reload the values please restart the process.", entries[1].Message)
assert.Equal(t, "memory", p.String("dsn"))
assert.Equal(t, "bar", p.String("foo"))
// but it is still watching the files
updateConfigFile(t, c, dir, name, "memory", "bar", "baz")
assert.Equal(t, "baz", p.String("bar"))
assertNoOpenFDs(t, dir, name)
})
t.Run("case=allows to update excepted immutable", func(t *testing.T) {
t.Parallel()
config := `{"foo": {"bar": "a", "baz": "b"}}`
dir := t.TempDir()
name := "config.json"
watcherx.KubernetesAtomicWrite(t, dir, name, config)
c := make(chan struct{})
p, _ := setup(t, dir, name, c,
WithImmutables("foo"),
WithExceptImmutables("foo.baz"),
SkipValidation())
assert.Equal(t, "a", p.String("foo.bar"))
assert.Equal(t, "b", p.String("foo.baz"))
config = `{"foo": {"bar": "a", "baz": "x"}}`
watcherx.KubernetesAtomicWrite(t, dir, name, config)
<-c
time.Sleep(time.Millisecond)
assert.Equal(t, "x", p.String("foo.baz"))
})
t.Run("case=runs without validation errors", func(t *testing.T) {
t.Parallel()
dir, name := tmpConfigFile(t, "some string", "bar")
c := make(chan struct{})
p, l := setup(t, dir, name, c)
hook := test.NewLocal(l.Entry.Logger)
assert.Equal(t, []*logrus.Entry{}, hook.AllEntries())
assert.Equal(t, "some string", p.String("dsn"))
assert.Equal(t, "bar", p.String("foo"))
})
t.Run("case=runs and reloads", func(t *testing.T) {
t.Parallel()
dir, name := tmpConfigFile(t, "some string", "bar")
c := make(chan struct{})
p, l := setup(t, dir, name, c)
hook := test.NewLocal(l.Entry.Logger)
assert.Equal(t, []*logrus.Entry{}, hook.AllEntries())
assert.Equal(t, "some string", p.String("dsn"))
assert.Equal(t, "bar", p.String("foo"))
updateConfigFile(t, c, dir, name, "memory", "bar", "baz")
assert.Equal(t, "baz", p.String("bar"))
})
t.Run("case=has with validation errors", func(t *testing.T) {
t.Parallel()
dir, name := tmpConfigFile(t, "some string", "not bar")
l := logrusx.New("", "")
hook := test.NewLocal(l.Entry.Logger)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var b bytes.Buffer
_, err := newKoanf(ctx, "./stub/watch/config.schema.json", []string{filepath.Join(dir, name)},
WithStandardValidationReporter(&b),
WithLogrusWatcher(l),
)
require.Error(t, err)
entries := hook.AllEntries()
require.Equal(t, 0, len(entries))
assert.Equal(t, "The configuration contains values or keys which are invalid:\nfoo: not bar\n ^-- value must be \"bar\"\n\n", b.String())
})
t.Run("case=is not leaking open files", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip()
}
dir, name := tmpConfigFile(t, "some string", "bar")
c := make(chan struct{})
p, _ := setup(t, dir, name, c)
assertNoOpenFDs(t, dir, name)
for i := range 30 {
t.Run(fmt.Sprintf("iteration=%d", i), func(t *testing.T) {
expected := []string{"foo", "bar", "baz"}[i%3]
updateConfigFile(t, c, dir, name, "memory", "bar", expected)
assertNoOpenFDs(t, dir, name)
require.EqualValues(t, expected, p.String("bar"))
})
}
assertNoOpenFDs(t, dir, name)
})
t.Run("case=callback can use the provider to get the new value", func(t *testing.T) {
t.Parallel()
dsn := "old"
dir, name := tmpConfigFile(t, dsn, "bar")
c := make(chan struct{})
var p *Provider
p, _ = setup(t, dir, name, c, AttachWatcher(func(watcherx.Event, error) {
dsn = p.String("dsn")
}))
// change dsn
updateConfigFile(t, c, dir, name, "new", "bar", "bar")
assert.Equal(t, "new", dsn)
})
}
type mockTestingT struct {
failed bool
}
func (m *mockTestingT) FailNow() {
m.failed = true
}
func (m *mockTestingT) Errorf(string, ...interface{}) {}
var _ require.TestingT = (*mockTestingT)(nil)
func TestAssertNoOpenFDs(t *testing.T) {
t.Parallel()
mt := &mockTestingT{}
dir := t.TempDir()
f, err := os.Create(filepath.Join(dir, "foo"))
require.NoError(t, err)
assertNoOpenFDs(mt, dir, "foo")
assert.True(t, mt.failed)
mt = &mockTestingT{}
require.NoError(t, f.Close())
assertNoOpenFDs(mt, dir, "foo")
assert.False(t, mt.failed)
}

View File

@ -1,19 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package configx
import (
"testing"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
goleak.IgnoreCurrent(),
// We have the global schema cache that is never closed.
goleak.IgnoreTopFunction("github.com/dgraph-io/ristretto/v2.(*defaultPolicy[...]).processItems"),
goleak.IgnoreTopFunction("github.com/dgraph-io/ristretto/v2.(*Cache[...]).processItems"),
)
}

View File

@ -1,50 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package contextx
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/x/configx"
)
func TestContext(t *testing.T) {
ctx := context.Background()
actual, err := ConfigFromContext(ctx)
require.Error(t, err)
require.Nil(t, actual)
assert.Panics(t, func() {
_ = MustConfigFromContext(ctx)
})
expected := &configx.Provider{}
ctx = WithConfig(ctx, expected)
actual, err = ConfigFromContext(ctx)
require.NoError(t, err)
require.Equal(t, expected, actual)
actual = MustConfigFromContext(ctx)
require.Equal(t, expected, actual)
}
func ExampleConfigFromContext() {
ctx := context.Background()
config, err := configx.New(ctx, []byte(`{"type":"object","properties":{"foo":{"type":"string"}}}`), configx.WithValue("foo", "bar"))
if err != nil {
panic(err)
}
ctx = WithConfig(ctx, config)
fmt.Printf("foo = %s", MustConfigFromContext(ctx).String("foo"))
// Output: foo = bar
}

View File

@ -1,17 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package contextx
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestTreeContext(t *testing.T) {
assert.True(t, IsRootContext(RootContext))
assert.True(t, IsRootContext(context.WithValue(RootContext, "foo", "bar"))) //lint:ignore SA1029 builtin type for context is OK in test
assert.False(t, IsRootContext(context.Background()))
}

View File

@ -1,111 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package corsx
import (
"net/http"
"testing"
"github.com/rs/cors"
"github.com/stretchr/testify/assert"
)
func TestCheckOrigin(t *testing.T) {
for _, tc := range []struct {
name string
allowedOrigins []string
expect, expectOther bool
}{
{
name: "empty",
allowedOrigins: []string{},
expect: true,
expectOther: true,
},
{
name: "wildcard",
allowedOrigins: []string{"https://example.com", "*"},
expect: true,
expectOther: true,
},
{
name: "exact",
allowedOrigins: []string{"https://www.ory.sh"},
expect: true,
},
{
name: "wildcard in the beginning",
allowedOrigins: []string{"*.ory.sh"},
expect: true,
},
{
name: "wildcard in the middle",
allowedOrigins: []string{"https://*.ory.sh"},
expect: true,
},
{
name: "wildcard in the end",
allowedOrigins: []string{"https://www.ory.*"},
expect: true,
},
{
name: "second wildcard is ignored",
allowedOrigins: []string{"https://*.ory.*"},
expect: false,
},
{
name: "multiple exact",
allowedOrigins: []string{"https://example.com", "https://www.ory.sh"},
expect: true,
},
{
name: "multiple wildcard",
allowedOrigins: []string{"https://*.example.com", "https://*.ory.sh"},
expect: true,
},
{
name: "wildcard and exact origins 1",
allowedOrigins: []string{"https://*.example.com", "https://www.ory.sh"},
expect: true,
},
{
name: "wildcard and exact origins 2",
allowedOrigins: []string{"https://example.com", "https://*.ory.sh"},
expect: true,
},
{
name: "multiple unrelated exact",
allowedOrigins: []string{"https://example.com", "https://example.org"},
expect: false,
},
{
name: "multiple unrelated with wildcard",
allowedOrigins: []string{"https://*.example.com", "https://*.example.org"},
expect: false,
},
{
name: "uppercase exact",
allowedOrigins: []string{"https://www.ORY.sh"},
expect: true,
},
{
name: "uppercase wildcard",
allowedOrigins: []string{"https://*.ORY.sh"},
expect: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expect, CheckOrigin(tc.allowedOrigins, "https://www.ory.sh"))
assert.Equal(t, tc.expectOther, CheckOrigin(tc.allowedOrigins, "https://google.com"))
// check for consistency with rs/cors
assert.Equal(t, tc.expect, cors.New(cors.Options{AllowedOrigins: tc.allowedOrigins}).
OriginAllowed(&http.Request{Header: http.Header{"Origin": []string{"https://www.ory.sh"}}}))
assert.Equal(t, tc.expectOther, cors.New(cors.Options{AllowedOrigins: tc.allowedOrigins}).
OriginAllowed(&http.Request{Header: http.Header{"Origin": []string{"https://google.com"}}}))
})
}
}

View File

@ -1,14 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package corsx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHelpMessage(t *testing.T) {
assert.NotEmpty(t, HelpMessage())
}

View File

@ -1,73 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package corsx
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/rs/cors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func TestContextualizedMiddleware(t *testing.T) {
createServer := func(t *testing.T, cb func(ctx context.Context) (cors.Options, bool)) *httptest.Server {
n := negroni.New()
n.UseFunc(ContextualizedMiddleware(cb))
n.UseHandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, _ = rw.Write([]byte("ok"))
})
ts := httptest.NewServer(n)
t.Cleanup(ts.Close)
return ts
}
fetchCORS := func(t *testing.T, origin string, ts *httptest.Server) http.Header {
req, err := http.NewRequest("OPTIONS", ts.URL, nil)
require.NoError(t, err)
req.Header.Set("Origin", origin)
req.Header.Set("Access-Control-Request-Method", "DELETE")
req.Header.Set("Access-Control-Request-Headers", "")
res, err := ts.Client().Do(req)
require.NoError(t, err)
defer res.Body.Close()
return res.Header
}
t.Run("switches enabled on and off", func(t *testing.T) {
var enabled bool
var origins []string
ts := createServer(t, func(ctx context.Context) (cors.Options, bool) {
return cors.Options{
AllowedMethods: []string{"OPTIONS", "DELETE"},
AllowedOrigins: origins,
Debug: true,
}, enabled
})
origins = append(origins, "http://localhost:8080")
actual := fetchCORS(t, "http://localhost:8080", ts)
assert.Empty(t, actual.Get("Access-Control-Allow-Origin"))
enabled = true
actual = fetchCORS(t, "http://localhost:8080", ts)
assert.Equal(t, "http://localhost:8080", actual.Get("Access-Control-Allow-Origin"), actual)
enabled = false
actual = fetchCORS(t, "http://localhost:8080", ts)
assert.Empty(t, actual.Get("Access-Control-Allow-Origin"))
enabled = true
origins = []string{"http://localhost:9090"}
actual = fetchCORS(t, "http://localhost:8080", ts)
assert.Empty(t, actual.Get("Access-Control-Allow-Origin"))
actual = fetchCORS(t, "http://localhost:9090", ts)
assert.Equal(t, "http://localhost:9090", actual.Get("Access-Control-Allow-Origin"), actual)
})
}

View File

@ -1,26 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package corsx
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/x/urlx"
)
func TestNormalizeOrigins(t *testing.T) {
assert.EqualValues(t,
[]string{"https://example.org:1234"},
NormalizeOrigins([]url.URL{*urlx.ParseOrPanic("https://example.org:1234/asdf")}))
}
func TestNormalizeOriginStrings(t *testing.T) {
actual, err := NormalizeOriginStrings([]string{"https://example.org:1234/asdf"})
require.NoError(t, err)
assert.EqualValues(t, []string{"https://example.org:1234"}, actual)
}

View File

@ -1,74 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package crdbx
import (
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ory/x/urlx"
)
func TestConsistencyLevelFromString(t *testing.T) {
assert.Equal(t, ConsistencyLevelUnset, ConsistencyLevelFromString(""))
assert.Equal(t, ConsistencyLevelStrong, ConsistencyLevelFromString("strong"))
assert.Equal(t, ConsistencyLevelEventual, ConsistencyLevelFromString("eventual"))
assert.Equal(t, ConsistencyLevelStrong, ConsistencyLevelFromString("lol"))
}
func TestConsistencyLevelFromRequest(t *testing.T) {
assert.Equal(t, ConsistencyLevelStrong, ConsistencyLevelFromRequest(&http.Request{URL: urlx.ParseOrPanic("/?consistency=strong")}))
assert.Equal(t, ConsistencyLevelEventual, ConsistencyLevelFromRequest(&http.Request{URL: urlx.ParseOrPanic("/?consistency=eventual")}))
assert.Equal(t, ConsistencyLevelStrong, ConsistencyLevelFromRequest(&http.Request{URL: urlx.ParseOrPanic("/?consistency=asdf")}))
assert.Equal(t, ConsistencyLevelUnset, ConsistencyLevelFromRequest(&http.Request{URL: urlx.ParseOrPanic("/?consistency")}))
}
func TestGetTransactionConsistency(t *testing.T) {
for k, tc := range []struct {
in ConsistencyLevel
fallback ConsistencyLevel
dialect string
expected string
}{
{
in: ConsistencyLevelUnset,
fallback: ConsistencyLevelStrong,
dialect: "cockroach",
expected: "",
},
{
in: ConsistencyLevelStrong,
fallback: ConsistencyLevelStrong,
dialect: "cockroach",
expected: "",
},
{
in: ConsistencyLevelStrong,
fallback: ConsistencyLevelEventual,
dialect: "cockroach",
expected: "",
},
{
in: ConsistencyLevelUnset,
fallback: ConsistencyLevelEventual,
dialect: "cockroach",
expected: transactionFollowerReadTimestamp,
},
{
in: ConsistencyLevelEventual,
fallback: ConsistencyLevelEventual,
dialect: "cockroach",
expected: transactionFollowerReadTimestamp,
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
q := getTransactionConsistencyQuery(tc.dialect, tc.in, tc.fallback)
assert.EqualValues(t, tc.expected, q)
})
}
}

View File

@ -1,41 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package dbal
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsMemorySQLite(t *testing.T) {
testCases := map[string]bool{
SQLiteInMemory: true,
SQLiteSharedInMemory: true,
"memory": true,
":memory:": true,
"sqlite://:memory:?_fk=true": true,
"sqlite://file:uniquedb:?_fk=true&mode=memory": true,
"sqlite://file:uniquedb:?_fk=true&mode=memory&cache=shared": true,
"sqlite://file:uniquedb:?_fk=true&cache=shared&mode=memory": true,
"sqlite://file:uniquedb:?mode=memory": true,
"sqlite://file:::uniquedb:?_fk=true&mode=memory": true,
"sqlite://file:memdb1?mode=memory&cache=shared": true,
"sqlite://file:uniquedb:?_fk=true&cache=shared": false,
"sqlite://": false,
"sqlite://file": false,
"sqlite://file:::": false,
"sqlite://?_fk=true&mode=memory": false,
"sqlite://?_fk=true&cache=shared": false,
"sqlite://file::?_fk=true": false,
"sqlite://file:::?_fk=true": false,
"postgresql://username:secret@localhost:5432/database": false,
}
for dsn, expected := range testCases {
t.Run("dsn="+dsn, func(t *testing.T) {
assert.Equal(t, expected, IsMemorySQLite(dsn))
})
}
}

View File

@ -1,616 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package decoderx
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"github.com/ory/x/assertx"
"github.com/tidwall/gjson"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/jsonschema/v3"
)
func newRequest(t *testing.T, method, url string, body io.Reader, ct string) *http.Request {
req := httptest.NewRequest(method, url, body)
req.Header.Set("Content-Type", ct)
return req
}
func TestHTTPFormDecoder(t *testing.T) {
for k, tc := range []struct {
d string
request *http.Request
contentType string
options []HTTPDecoderOption
expected string
expectedError string
}{
{
d: "should fail because the method is GET",
request: &http.Request{Header: map[string][]string{}, Method: "GET"},
expectedError: "HTTP Request Method",
},
{
d: "should fail because the body is empty",
request: &http.Request{Header: map[string][]string{}, Method: "POST"},
expectedError: "Content-Length",
},
{
d: "should fail because content type is missing",
request: newRequest(t, "POST", "/", nil, ""),
expectedError: "Content-Length",
},
{
d: "should fail because content type is missing",
request: newRequest(t, "POST", "/", bytes.NewBufferString("foo"), ""),
expectedError: "Content-Type",
},
{
d: "should pass with json without validation",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{"foo":"bar"}`), httpContentTypeJSON),
expected: `{"foo":"bar"}`,
},
{
d: "should fail json if content type is not accepted",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{"foo":"bar"}`), httpContentTypeJSON),
options: []HTTPDecoderOption{HTTPFormDecoder()},
expectedError: "Content-Type: application/json",
},
{
d: "should fail json if validation fails",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{"foo":"bar", "bar":"baz"}`), httpContentTypeJSON),
options: []HTTPDecoderOption{HTTPJSONDecoder(), MustHTTPRawJSONSchemaCompiler([]byte(`{
"$id": "https://example.com/config.schema.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"foo": {
"type": "number"
},
"bar": {
"type": "string"
}
}
}`),
)},
expectedError: "expected number, but got string",
expected: `{ "bar": "baz", "foo": "bar" }`,
},
{
d: "should pass json with validation",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{"foo":"bar"}`), httpContentTypeJSON),
options: []HTTPDecoderOption{HTTPJSONDecoder(), MustHTTPRawJSONSchemaCompiler([]byte(`{
"$id": "https://example.com/config.schema.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"foo": {
"type": "string"
}
}
}`),
),
},
expected: `{"foo":"bar"}`,
},
{
d: "should fail form request when form is used but only json is allowed",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{"foo": {"bar"}}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{HTTPJSONDecoder()},
expectedError: "Content-Type: application/x-www-form-urlencoded",
},
{
d: "should fail form request when schema is missing",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{"foo": {"bar"}}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{},
expectedError: "no validation schema was provided",
},
{
d: "should fail form request when schema does not validate request",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{"bar": {"bar"}}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{HTTPJSONSchemaCompiler("stub/schema.json", nil)},
expectedError: `missing properties: "foo"`,
},
{
d: "should fail for invalid JSON data with unrestricted object",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{"dynamic_object":{"stuff":{"blub":[42,3.14152,"fu":"bar"},"consent":true}}`), httpContentTypeJSON),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPJSONDecoder()},
expectedError: "The request was malformed or contained invalid parameters",
},
{
d: "should fail validation for wrong JSON type with unrestricted object",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{"dynamic_object":[42,3.14152]}`), httpContentTypeJSON),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPJSONDecoder()},
expectedError: "expected object, but got array",
},
{
d: "should accept JSON data with unrestricted object",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{"dynamic_object":{"stuff":{"blub":[42,3.14152],"fu":"bar"},"consent":true}}`), httpContentTypeJSON),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPJSONDecoder()},
expected: `{
"dynamic_object": {
"stuff": {
"blub": [42, 3.14152],
"fu": "bar"
},
"consent": true
}
}`,
},
{
d: "should accept JSON data with unrestricted object and mixed object syntax and query parameter",
request: newRequest(t, "POST", "/?name.last=Horstmann", bytes.NewBufferString(`{"dynamic_object":{"stuff":{"blub":[42,3.14152],"fu":"bar"},"consent":true},"name.first":"Horst"}`), httpContentTypeJSON),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPJSONDecoder(),
HTTPDecoderJSONFollowsFormFormat(),
HTTPDecoderUseQueryAndBody()},
expected: `{
"dynamic_object": {
"stuff": {
"blub": [42, 3.14152],
"fu": "bar"
},
"consent": true
},
"name": {
"first": "Horst",
"last": "Horstmann"
}
}`,
},
{
d: "should accept JSON data with unrestricted object and mixed object syntax",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{"dynamic_object":{"stuff":{"blub":[42,3.14152],"fu":"bar"},"consent":true},"name.first":"Horst","name.last":"Horstmann"}`), httpContentTypeJSON),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPJSONDecoder(),
HTTPDecoderJSONFollowsFormFormat()},
expected: `{
"dynamic_object": {
"stuff": {
"blub": [42, 3.14152],
"fu": "bar"
},
"consent": true
},
"name": {
"first": "Horst",
"last": "Horstmann"
}
}`,
},
{
d: "should fail form data with invalid premarshalled JSON object",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
"dynamic_object": {`{"stuff":{"blub":[42, 3.14152,"fu":"bar"},"consent":true}`},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPFormDecoder()},
expectedError: "The request was malformed or contained invalid parameters",
},
{
d: "should fail validation for form data with wrong premarshalled JSON type",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
"dynamic_object": {`[42, 3.14152]`},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPFormDecoder()},
expectedError: "expected object, but got array",
},
{
d: "should accept form data with premarshalled JSON object",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
"dynamic_object": {`{"stuff":{"blub":[42, 3.14152],"fu":"bar"},"consent":true}`},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPFormDecoder()},
expected: `{
"dynamic_object": {
"stuff": {
"blub": [42, 3.14152],
"fu": "bar"
},
"consent": true
},
"name": {}
}`,
},
{
d: "should accept form data with premarshalled JSON object and mixed object syntax and query parameter",
request: newRequest(t, "POST", "/?name.last=Horstmann", bytes.NewBufferString(url.Values{
"dynamic_object": {`{"stuff":{"blub":[42, 3.14152],"fu":"bar"},"consent":true}`},
"name.first": {"Horst"},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPFormDecoder(),
HTTPDecoderUseQueryAndBody()},
expected: `{
"dynamic_object": {
"stuff": {
"blub": [42, 3.14152],
"fu": "bar"
},
"consent": true
},
"name": {
"first": "Horst",
"last": "Horstmann"
}
}`,
},
{
d: "should accept form data with premarshalled JSON object and mixed object syntax",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
"dynamic_object": {`{"stuff":{"blub":[42, 3.14152],"fu":"bar"},"consent":true}`},
"name.first": {"Horst"},
"name.last": {"Horstmann"},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/dynamic-object.json", nil),
HTTPFormDecoder()},
expected: `{
"dynamic_object": {
"stuff": {
"blub": [42, 3.14152],
"fu": "bar"
},
"consent": true
},
"name": {
"first": "Horst",
"last": "Horstmann"
}
}`,
},
{
d: "should pass form request and type assert data",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
"name.first": {"Aeneas"},
"name.last": {"Rekkas"},
"age": {"29"},
"ratio": {"0.9"},
"consent": {"true"},
// newsletter represents a special case for checkbox input with true/false and raw HTML.
"newsletter": {
"false", // comes from <input type="hidden" name="newsletter" value="false">
"true", // comes from <input type="checkbox" name="newsletter" value="true" checked>
},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{HTTPJSONSchemaCompiler("stub/person.json", nil)},
expected: `{
"name": {"first": "Aeneas", "last": "Rekkas"},
"age": 29,
"newsletter": true,
"consent": true,
"ratio": 0.9
}`,
},
{
d: "should mark the correct fields when nested objects are required",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
// newsletter represents a special case for checkbox input with true/false and raw HTML.
"foo": {"bar"},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/consent.json", nil),
HTTPKeepRequestBody(true),
HTTPDecoderSetValidatePayloads(false),
HTTPDecoderUseQueryAndBody(),
HTTPDecoderAllowedMethods("POST", "GET"),
HTTPDecoderJSONFollowsFormFormat(),
},
expected: `{
"traits": {
"consent": {
"inner": {}
},
"notrequired": {}
}
}`,
},
{
d: "should pass form request with payload in query and type assert data",
request: newRequest(t, "POST", "/?age=29", bytes.NewBufferString(url.Values{
"name.first": {"Aeneas"},
"name.last": {"Rekkas"},
"ratio": {"0.9"},
"consent": {"true"},
// newsletter represents a special case for checkbox input with true/false and raw HTML.
"newsletter": {
"false", // comes from <input type="hidden" name="newsletter" value="false">
"true", // comes from <input type="checkbox" name="newsletter" value="true" checked>
},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{HTTPJSONSchemaCompiler("stub/person.json", nil)},
expected: `{
"name": {"first": "Aeneas", "last": "Rekkas"},
"newsletter": true,
"consent": true,
"ratio": 0.9
}`,
},
{
d: "should pass form request with payload in query and type assert data",
request: newRequest(t, "POST", "/?age=29", bytes.NewBufferString(url.Values{
"name.first": {"Aeneas"},
"name.last": {"Rekkas"},
"ratio": {"0.9"},
"consent": {"true"},
// newsletter represents a special case for checkbox input with true/false and raw HTML.
"newsletter": {
"false", // comes from <input type="hidden" name="newsletter" value="false">
"true", // comes from <input type="checkbox" name="newsletter" value="true" checked>
},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPDecoderUseQueryAndBody(),
HTTPJSONSchemaCompiler("stub/person.json", nil),
},
expected: `{
"name": {"first": "Aeneas", "last": "Rekkas"},
"age": 29,
"newsletter": true,
"consent": true,
"ratio": 0.9
}`,
},
{
d: "should fail form request if empty values are sent because of required fields",
request: newRequest(t, "POST", "/?age=29", bytes.NewBufferString(url.Values{
"name.first": {""},
"name.last": {""},
"name2.first": {""},
"name2.last": {""},
"ratio": {""},
"ratio2": {""},
"age": {""},
"age2": {""},
"consent": {""},
"consent2": {""},
// newsletter represents a special case for checkbox input with true/false and raw HTML.
"newsletter": {""},
"newsletter2": {""},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPDecoderUseQueryAndBody(),
HTTPJSONSchemaCompiler("stub/required-defaults.json", nil),
},
expectedError: `I[#/name2] S[#/properties/name2/required] missing properties: "first"`,
},
{
d: "should fail json request formatted as form if payload is invalid",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{"name.first":"Aeneas", "name.last":"Rekkas","age":"not-a-number"}`), httpContentTypeJSON),
options: []HTTPDecoderOption{HTTPJSONSchemaCompiler("stub/person.json", nil)},
expectedError: "expected integer, but got string",
},
{
d: "should pass JSON request formatted as a form",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`{
"name.first": "Aeneas",
"name.last": "Rekkas",
"age": 29,
"ratio": 0.9,
"consent": false,
"newsletter": true
}`), httpContentTypeJSON),
options: []HTTPDecoderOption{HTTPDecoderJSONFollowsFormFormat(),
HTTPJSONSchemaCompiler("stub/person.json", nil)},
expected: `{
"name": {"first": "Aeneas", "last": "Rekkas"},
"age": 29,
"newsletter": true,
"consent": false,
"ratio": 0.9
}`,
},
{
d: "should pass JSON request formatted as a form",
request: newRequest(t, "POST", "/?age=29", bytes.NewBufferString(`{
"name.first": "Aeneas",
"name.last": "Rekkas",
"ratio": 0.9,
"consent": false,
"newsletter": true
}`), httpContentTypeJSON),
options: []HTTPDecoderOption{HTTPDecoderJSONFollowsFormFormat(),
HTTPJSONSchemaCompiler("stub/person.json", nil)},
expected: `{
"name": {"first": "Aeneas", "last": "Rekkas"},
"newsletter": true,
"consent": false,
"ratio": 0.9
}`,
},
{
d: "should pass JSON request formatted as a JSON even if HTTPDecoderJSONFollowsFormFormat is used",
request: newRequest(t, "POST", "/?age=29", bytes.NewBufferString(`{
"name": {"first": "Aeneas", "last": "Rekkas"},
"ratio": 0.9,
"consent": false,
"newsletter": true
}`), httpContentTypeJSON),
options: []HTTPDecoderOption{HTTPDecoderJSONFollowsFormFormat(),
HTTPJSONSchemaCompiler("stub/person.json", nil)},
expected: `{
"name": {"first": "Aeneas", "last": "Rekkas"},
"newsletter": true,
"consent": false,
"ratio": 0.9
}`,
},
{
d: "should not retry indefinitely if key does not exist",
request: newRequest(t, "POST", "/?age=29", bytes.NewBufferString(`{
"not-foo": "bar"
}`), httpContentTypeJSON),
options: []HTTPDecoderOption{HTTPDecoderJSONFollowsFormFormat(),
HTTPJSONSchemaCompiler("stub/schema.json", nil)},
expectedError: "I[#] S[#/required] missing properties",
},
{
d: "should indicate the true missing fields from nested form",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{"leaf": {"foo"}}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPDecoderUseQueryAndBody(),
HTTPDecoderSetIgnoreParseErrorsStrategy(ParseErrorIgnoreConversionErrors),
HTTPJSONSchemaCompiler("stub/nested.json", nil)},
expectedError: `I[#/node/node/node] S[#/properties/node/properties/node/properties/node/required] missing properties: "leaf"`,
},
{
d: "should pass JSON request formatted as a form",
request: newRequest(t, "POST", "/?age=29", bytes.NewBufferString(`{
"name.first": "Aeneas",
"name.last": "Rekkas",
"ratio": 0.9,
"consent": false,
"newsletter": true
}`), httpContentTypeJSON),
options: []HTTPDecoderOption{
HTTPDecoderUseQueryAndBody(),
HTTPDecoderJSONFollowsFormFormat(),
HTTPJSONSchemaCompiler("stub/person.json", nil)},
expected: `{
"name": {"first": "Aeneas", "last": "Rekkas"},
"age": 29,
"newsletter": true,
"consent": false,
"ratio": 0.9
}`,
},
{
d: "should pass JSON request GET request",
request: newRequest(t, "GET", "/?"+url.Values{
"name.first": {"Aeneas"},
"name.last": {"Rekkas"},
"age": {"29"},
"ratio": {"0.9"},
"consent": {"false"},
"newsletter": {"true"},
}.Encode(), nil, ""),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/person.json", nil),
HTTPDecoderAllowedMethods("GET"),
},
expected: `{
"name": {"first": "Aeneas", "last": "Rekkas"},
"age": 29,
"newsletter": true,
"consent": false,
"ratio": 0.9
}`,
},
{
d: "should fail because json is not an object when using form format",
request: newRequest(t, "POST", "/", bytes.NewBufferString(`[]`), httpContentTypeJSON),
options: []HTTPDecoderOption{HTTPDecoderJSONFollowsFormFormat(),
HTTPJSONSchemaCompiler("stub/person.json", nil)},
expectedError: "be an object",
},
{
d: "should work with ParseErrorIgnoreConversionErrors",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
"ratio": {"foobar"},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{
HTTPJSONSchemaCompiler("stub/person.json", nil),
HTTPDecoderSetIgnoreParseErrorsStrategy(ParseErrorIgnoreConversionErrors),
HTTPDecoderSetValidatePayloads(false),
},
expected: `{"name": {}, "ratio": "foobar"}`,
},
{
d: "should work with ParseErrorIgnoreConversionErrors",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
"ratio": {"foobar"},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{HTTPJSONSchemaCompiler("stub/person.json", nil), HTTPDecoderSetIgnoreParseErrorsStrategy(ParseErrorUseEmptyValueOnConversionErrors)},
expected: `{"name": {}, "ratio": 0.0}`,
},
{
d: "should work with ParseErrorIgnoreConversionErrors",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
"ratio": {"foobar"},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{HTTPJSONSchemaCompiler("stub/person.json", nil), HTTPDecoderSetIgnoreParseErrorsStrategy(ParseErrorReturnOnConversionErrors)},
expectedError: `strconv.ParseFloat: parsing "foobar"`,
},
{
d: "should interpret numbers as string if mandated by the schema",
request: newRequest(t, "POST", "/", bytes.NewBufferString(url.Values{
"name.first": {"12345"},
}.Encode()), httpContentTypeURLEncodedForm),
options: []HTTPDecoderOption{HTTPJSONSchemaCompiler("stub/person.json", nil), HTTPDecoderSetIgnoreParseErrorsStrategy(ParseErrorUseEmptyValueOnConversionErrors)},
expected: `{"name": {"first": "12345"}}`,
},
} {
t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) {
dec := NewHTTP()
var destination json.RawMessage
err := dec.Decode(tc.request, &destination, tc.options...)
if tc.expectedError != "" {
if e, ok := errors.Cause(err).(*jsonschema.ValidationError); ok {
t.Logf("%+v", e)
}
require.Error(t, err)
require.Contains(t, fmt.Sprintf("%+v", err), tc.expectedError)
if len(tc.expected) > 0 {
assert.JSONEq(t, tc.expected, string(destination))
}
return
}
require.NoError(t, err)
assertx.EqualAsJSON(t, json.RawMessage(tc.expected), destination)
})
}
t.Run("description=read body twice", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
dec := NewHTTP()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer wg.Done()
var destination json.RawMessage
require.NoError(t, dec.Decode(r, &destination, HTTPJSONSchemaCompiler("stub/person.json", nil), HTTPKeepRequestBody(true)))
assert.EqualValues(t, "12345", gjson.GetBytes(destination, "name.first").String())
require.NoError(t, dec.Decode(r, &destination, HTTPJSONSchemaCompiler("stub/person.json", nil), HTTPKeepRequestBody(true)))
assert.EqualValues(t, "12345", gjson.GetBytes(destination, "name.first").String())
}))
t.Cleanup(ts.Close)
_, err := ts.Client().PostForm(ts.URL, url.Values{"name.first": {"12345"}})
require.NoError(t, err)
wg.Wait()
})
}

View File

@ -1,21 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package errorsx
import (
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestWithStack(t *testing.T) {
t.Run("case=wrap", func(t *testing.T) {
orig := errors.New("hi")
wrap := WithStack(orig)
assert.EqualValues(t, orig.(StackTracer).StackTrace(), wrap.(StackTracer).StackTrace())
assert.EqualValues(t, orig.(StackTracer).StackTrace(), WithStack(wrap).(StackTracer).StackTrace())
})
}

View File

@ -1,135 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package fetcher
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"net/http"
"os"
"sync/atomic"
"testing"
"time"
"github.com/dgraph-io/ristretto/v2"
"github.com/hashicorp/go-retryablehttp"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFetcher(t *testing.T) {
router := httprouter.New()
router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
_, _ = w.Write([]byte(`{"foo":"bar"}`))
})
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
file, err := os.CreateTemp(os.TempDir(), "source.*.json")
require.NoError(t, err)
_, err = file.WriteString(`{"foo":"baz"}`)
require.NoError(t, err)
require.NoError(t, file.Close())
rClient := retryablehttp.NewClient()
rClient.HTTPClient = ts.Client()
for fc, fetcher := range []*Fetcher{
NewFetcher(WithClient(rClient)),
NewFetcher(),
} {
for k, tc := range []struct {
source string
expect string
}{
{
source: "base64://" + base64.StdEncoding.EncodeToString([]byte(`{"foo":"zab"}`)),
expect: `{"foo":"zab"}`,
},
{
source: "file://" + file.Name(),
expect: `{"foo":"baz"}`,
},
{
source: ts.URL,
expect: `{"foo":"bar"}`,
},
} {
t.Run(fmt.Sprintf("config=%d/case=%d", fc, k), func(t *testing.T) {
actual, err := fetcher.Fetch(tc.source)
require.NoError(t, err)
assert.JSONEq(t, tc.expect, actual.String())
})
}
}
t.Run("case=returns proper error on unknown scheme", func(t *testing.T) {
_, err := NewFetcher().Fetch("unknown-scheme://foo")
assert.ErrorIs(t, err, ErrUnknownScheme)
assert.Contains(t, err.Error(), "unknown-scheme")
})
t.Run("case=FetcherContext cancels the HTTP request", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := NewFetcher().FetchContext(ctx, "https://config.invalid")
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("case=with-limit", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(bytes.Repeat([]byte("test"), 1000))
}))
t.Cleanup(srv.Close)
_, err := NewFetcher(WithMaxHTTPMaxBytes(3999)).Fetch(srv.URL)
assert.ErrorIs(t, err, bytes.ErrTooLarge)
_, err = NewFetcher(WithMaxHTTPMaxBytes(4000)).Fetch(srv.URL)
assert.NoError(t, err)
})
t.Run("case=with-cache", func(t *testing.T) {
var hits int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("toodaloo"))
atomic.AddInt32(&hits, 1)
}))
t.Cleanup(srv.Close)
cache, err := ristretto.NewCache[[]byte, []byte](&ristretto.Config[[]byte, []byte]{
NumCounters: 100 * 10,
MaxCost: 100,
BufferItems: 64,
})
require.NoError(t, err)
f := NewFetcher(WithCache(cache, time.Hour))
res, err := f.Fetch(srv.URL)
require.NoError(t, err)
require.Equal(t, "toodaloo", res.String())
require.EqualValues(t, 1, atomic.LoadInt32(&hits))
f.cache.Wait()
for i := 0; i < 100; i++ {
res2, err := f.Fetch(srv.URL)
require.NoError(t, err)
require.Equal(t, "toodaloo", res2.String())
if &res == &res2 {
t.Fatalf("cache should not return the same pointer")
}
}
require.EqualValues(t, 1, atomic.LoadInt32(&hits))
})
}

View File

@ -1,31 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package flagx
import (
"testing"
"github.com/spf13/cobra"
)
func TestStringToStringCommand(t *testing.T) {
cmd := &cobra.Command{}
cmd.Flags().StringToString("map-value", nil, "test string to string map usage")
cmd.SetArgs([]string{"--map-value", "foo=bar,key=val"})
cmd.Execute()
mapped := MustGetStringToStringMap(cmd, "map-value")
if len(mapped) != 2 {
t.Errorf("expected 2 values in map and got %d", len(mapped))
}
val, ok := mapped["foo"]
if !ok {
t.Errorf("failed to get value 'foo' from flags")
}
if val != "bar" {
t.Errorf("failed to get expected value from map, got %s", val)
}
}

View File

@ -1,123 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package fsx
import (
"testing"
"testing/fstest"
"github.com/laher/mergefs"
"github.com/stretchr/testify/assert"
)
var (
a = fstest.MapFS{
"a": &fstest.MapFile{},
"dir/c": &fstest.MapFile{},
}
b = fstest.MapFS{
"b": &fstest.MapFile{},
"dir/d": &fstest.MapFile{},
}
x = fstest.MapFS{
"x": &fstest.MapFile{},
"dir/y": &fstest.MapFile{},
}
)
func TestMergeFS(t *testing.T) {
assert.NoError(t, fstest.TestFS(
Merge(a, b),
"a",
"b",
"dir",
"dir/c",
"dir/d",
))
assert.NoError(t, fstest.TestFS(
Merge(a, b, x),
"a",
"b",
"dir",
"dir/c",
"dir/d",
"dir/y",
"x",
))
assert.NoError(t, fstest.TestFS(
Merge(x, b, a),
"a",
"b",
"dir",
"dir/c",
"dir/d",
"dir/y",
"x",
))
assert.NoError(t, fstest.TestFS(
Merge(Merge(a, b), x),
"a",
"b",
"dir",
"dir/c",
"dir/d",
"dir/y",
"x",
))
assert.NoError(t, fstest.TestFS(
Merge(Merge(x, b), a),
"a",
"b",
"dir",
"dir/c",
"dir/d",
"dir/y",
"x",
))
}
func TestLaherMergeFS(t *testing.T) {
assert.Error(t, fstest.TestFS(
mergefs.Merge(a, b),
"a",
"b",
"dir",
"dir/c",
"dir/d",
))
t.Skip("laher/mergefs does not handle recursive merges correctly")
assert.NoError(t, fstest.TestFS(
mergefs.Merge(mergefs.Merge(a, b), x),
"a",
"b",
"dir",
"dir/c",
"dir/d",
"dir/y",
"x",
))
assert.NoError(t, fstest.TestFS(
mergefs.Merge(a, mergefs.Merge(b, x)),
"a",
"b",
"dir",
"dir/c",
"dir/d",
"dir/y",
"x",
))
assert.NoError(t, fstest.TestFS(
mergefs.Merge(x, mergefs.Merge(b, a)),
"a",
"b",
"dir",
"dir/c",
"dir/d",
"dir/y",
"x",
))
}

View File

@ -1,275 +0,0 @@
package hasherx_test
import (
"context"
"crypto/rand"
"fmt"
"testing"
"github.com/inhies/go-bytesize"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/ory/x/hasherx"
)
func mkpw(t *testing.T, length int) []byte {
pw := make([]byte, length)
_, err := rand.Read(pw)
require.NoError(t, err)
return pw
}
func TestArgonHasher(t *testing.T) {
c := gomock.NewController(t)
t.Cleanup(c.Finish)
reg := NewMockArgon2Configurator(c)
reg.EXPECT().HasherArgon2Config(gomock.Any()).Return(&hasherx.Argon2Config{
Memory: bytesize.KB,
Iterations: 2,
Parallelism: 1,
SaltLength: 32,
KeyLength: 32,
}).AnyTimes()
for k, pw := range [][]byte{
mkpw(t, 8),
mkpw(t, 16),
mkpw(t, 32),
mkpw(t, 64),
mkpw(t, 128),
} {
k := k
pw := pw
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
t.Parallel()
for kk, h := range []hasherx.Hasher{
hasherx.NewHasherArgon2(reg),
} {
kk := kk
h := h
t.Run(fmt.Sprintf("hasher=%T/password=%d", h, kk), func(t *testing.T) {
t.Parallel()
hs, err := h.Generate(context.Background(), pw)
require.NoError(t, err)
assert.NotEqual(t, pw, hs)
t.Logf("hash: %s", hs)
require.NoError(t, hasherx.CompareArgon2id(context.Background(), pw, hs))
mod := make([]byte, len(pw))
copy(mod, pw)
mod[len(pw)-1] = ^pw[len(pw)-1]
require.Error(t, hasherx.CompareArgon2id(context.Background(), mod, hs))
})
}
})
}
}
func newBCryptRegistry(t *testing.T) *MockBCryptConfigurator {
c := gomock.NewController(t)
t.Cleanup(c.Finish)
reg := NewMockBCryptConfigurator(c)
reg.EXPECT().HasherBcryptConfig(gomock.Any()).Return(&hasherx.BCryptConfig{Cost: 4}).AnyTimes()
return reg
}
func TestBcryptHasherGeneratesErrorWhenPasswordIsLong(t *testing.T) {
hasher := hasherx.NewHasherBcrypt(newBCryptRegistry(t))
password := mkpw(t, 73)
res, err := hasher.Generate(context.Background(), password)
assert.Error(t, err, "password is too long")
assert.Nil(t, res)
}
func TestBcryptHasherGeneratesHash(t *testing.T) {
for k, pw := range [][]byte{
mkpw(t, 8),
mkpw(t, 16),
mkpw(t, 32),
mkpw(t, 64),
mkpw(t, 72),
} {
k := k
pw := pw
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
t.Parallel()
hasher := hasherx.NewHasherBcrypt(newBCryptRegistry(t))
hs, err := hasher.Generate(context.Background(), pw)
assert.Nil(t, err)
assert.True(t, hasher.Understands(hs))
// Valid format: $2a$12$[22 character salt][31 character hash]
assert.Equal(t, 60, len(string(hs)), "invalid bcrypt hash length")
assert.Equal(t, "$2a$04$", string(hs)[:7], "invalid bcrypt identifier")
})
}
}
func TestComparatorBcryptFailsWhenPasswordIsTooLong(t *testing.T) {
password := mkpw(t, 73)
err := hasherx.CompareBcrypt(context.Background(), password, []byte("hash"))
assert.Error(t, err, "password is too long")
}
func TestComparatorBcryptSuccess(t *testing.T) {
for k, pw := range [][]byte{
mkpw(t, 8),
mkpw(t, 16),
mkpw(t, 32),
mkpw(t, 64),
mkpw(t, 72),
} {
k := k
pw := pw
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
t.Parallel()
hasher := hasherx.NewHasherBcrypt(newBCryptRegistry(t))
hs, err := hasher.Generate(context.Background(), pw)
assert.Nil(t, err)
assert.True(t, hasher.Understands(hs))
err = hasherx.CompareBcrypt(context.Background(), pw, hs)
assert.Nil(t, err, "hash validation fails")
})
}
}
func TestComparatorBcryptFail(t *testing.T) {
for k, pw := range [][]byte{
mkpw(t, 8),
mkpw(t, 16),
mkpw(t, 32),
mkpw(t, 64),
mkpw(t, 72),
} {
k := k
pw := pw
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
t.Parallel()
mod := make([]byte, len(pw))
copy(mod, pw)
mod[len(pw)-1] = ^pw[len(pw)-1]
err := hasherx.CompareBcrypt(context.Background(), pw, mod)
assert.Error(t, err)
})
}
}
func TestPbkdf2Hasher(t *testing.T) {
for k, pw := range [][]byte{
mkpw(t, 8),
mkpw(t, 16),
mkpw(t, 32),
mkpw(t, 64),
mkpw(t, 128),
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
t.Parallel()
for kk, config := range []*hasherx.PBKDF2Config{
{
Algorithm: "sha1",
Iterations: 100000,
SaltLength: 32,
KeyLength: 32,
},
{
Algorithm: "sha224",
Iterations: 100000,
SaltLength: 32,
KeyLength: 32,
},
{
Algorithm: "sha256",
Iterations: 100000,
SaltLength: 32,
KeyLength: 32,
},
{
Algorithm: "sha384",
Iterations: 100000,
SaltLength: 32,
KeyLength: 32,
},
{
Algorithm: "sha512",
Iterations: 100000,
SaltLength: 32,
KeyLength: 32,
},
} {
kk := kk
config := config
t.Run(fmt.Sprintf("config=%T/password=%d", config.Algorithm, kk), func(t *testing.T) {
t.Parallel()
c := gomock.NewController(t)
t.Cleanup(c.Finish)
reg := NewMockPBKDF2Configurator(c)
reg.EXPECT().HasherPBKDF2Config(gomock.Any()).Return(config).AnyTimes()
hasher := hasherx.NewHasherPBKDF2(reg)
hs, err := hasher.Generate(context.Background(), pw)
require.NoError(t, err)
assert.NotEqual(t, pw, hs)
t.Logf("hash: %s", hs)
require.NoError(t, hasherx.ComparePbkdf2(context.Background(), pw, hs))
assert.True(t, hasher.Understands(hs))
mod := make([]byte, len(pw))
copy(mod, pw)
mod[len(pw)-1] = ^pw[len(pw)-1]
require.Error(t, hasherx.ComparePbkdf2(context.Background(), mod, hs))
})
}
})
}
}
func TestCompare(t *testing.T) {
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$unknown$12$o6hx.Wog/wvFSkT/Bp/6DOxCtLRTDj7lm9on9suF/WaCGNVHbkfL6")))
assert.Nil(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$2a$12$o6hx.Wog/wvFSkT/Bp/6DOxCtLRTDj7lm9on9suF/WaCGNVHbkfL6")))
assert.Nil(t, hasherx.CompareBcrypt(context.Background(), []byte("test"), []byte("$2a$12$o6hx.Wog/wvFSkT/Bp/6DOxCtLRTDj7lm9on9suF/WaCGNVHbkfL6")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$2a$12$o6hx.Wog/wvFSkT/Bp/6DOxCtLRTDj7lm9on9suF/WaCGNVHbkfL7")))
assert.Nil(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$2a$15$GRvRO2nrpYTEuPQX6AieaOlZ4.7nMGsXpt.QWMev1zrP86JNspZbO")))
assert.Nil(t, hasherx.CompareBcrypt(context.Background(), []byte("test"), []byte("$2a$15$GRvRO2nrpYTEuPQX6AieaOlZ4.7nMGsXpt.QWMev1zrP86JNspZbO")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$2a$15$GRvRO2nrpYTEuPQX6AieaOlZ4.7nMGsXpt.QWMev1zrP86JNspZb1")))
assert.Nil(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$argon2id$v=19$m=32,t=2,p=4$cm94YnRVOW5jZzFzcVE4bQ$MNzk5BtR2vUhrp6qQEjRNw")))
assert.Nil(t, hasherx.CompareArgon2id(context.Background(), []byte("test"), []byte("$argon2id$v=19$m=32,t=2,p=4$cm94YnRVOW5jZzFzcVE4bQ$MNzk5BtR2vUhrp6qQEjRNw")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$argon2id$v=19$m=32,t=2,p=4$cm94YnRVOW5jZzFzcVE4bQ$MNzk5BtR2vUhrp6qQEjRN2")))
assert.Nil(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$argon2i$v=19$m=65536,t=3,p=4$kk51rW/vxIVCYn+EG4kTSg$NyT88uraJ6im6dyha/M5jhXvpqlEdlS/9fEm7ScMb8c")))
assert.Nil(t, hasherx.CompareArgon2i(context.Background(), []byte("test"), []byte("$argon2i$v=19$m=65536,t=3,p=4$kk51rW/vxIVCYn+EG4kTSg$NyT88uraJ6im6dyha/M5jhXvpqlEdlS/9fEm7ScMb8c")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$argon2i$v=19$m=65536,t=3,p=4$pZ+27D6B0bCi0DwSmANF1w$4RNCUu4Uyu7eTIvzIdSuKz+I9idJlX/ykn6J10/W0EU")))
assert.Nil(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$argon2id$v=19$m=32,t=5,p=4$cm94YnRVOW5jZzFzcVE4bQ$fBxypOL0nP/zdPE71JtAV71i487LbX3fJI5PoTN6Lp4")))
assert.Nil(t, hasherx.CompareArgon2id(context.Background(), []byte("test"), []byte("$argon2id$v=19$m=32,t=5,p=4$cm94YnRVOW5jZzFzcVE4bQ$fBxypOL0nP/zdPE71JtAV71i487LbX3fJI5PoTN6Lp4")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$argon2id$v=19$m=32,t=5,p=4$cm94YnRVOW5jZzFzcVE4bQ$fBxypOL0nP/zdPE71JtAV71i487LbX3fJI5PoTN6Lp5")))
assert.Nil(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$pbkdf2-sha256$i=100000,l=32$1jP+5Zxpxgtee/iPxGgOz0RfE9/KJuDElP1ley4VxXc$QJxzfvdbHYBpydCbHoFg3GJEqMFULwskiuqiJctoYpI")))
assert.Nil(t, hasherx.ComparePbkdf2(context.Background(), []byte("test"), []byte("$pbkdf2-sha256$i=100000,l=32$1jP+5Zxpxgtee/iPxGgOz0RfE9/KJuDElP1ley4VxXc$QJxzfvdbHYBpydCbHoFg3GJEqMFULwskiuqiJctoYpI")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$pbkdf2-sha256$i=100000,l=32$1jP+5Zxpxgtee/iPxGgOz0RfE9/KJuDElP1ley4VxXc$QJxzfvdbHYBpydCbHoFg3GJEqMFULwskiuqiJctoYpp")))
assert.Nil(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$pbkdf2-sha512$i=100000,l=32$bdHBpn7OWOivJMVJypy2UqR0UnaD5prQXRZevj/05YU$+wArTfv1a+bNGO1iZrmEdVjhA+lL11wF4/IxpgYfPwc")))
assert.Nil(t, hasherx.ComparePbkdf2(context.Background(), []byte("test"), []byte("$pbkdf2-sha512$i=100000,l=32$bdHBpn7OWOivJMVJypy2UqR0UnaD5prQXRZevj/05YU$+wArTfv1a+bNGO1iZrmEdVjhA+lL11wF4/IxpgYfPwc")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$pbkdf2-sha512$i=100000,l=32$bdHBpn7OWOivJMVJypy2UqR0UnaD5prQXRZevj/05YU$+wArTfv1a+bNGO1iZrmEdVjhA+lL11wF4/IxpgYfPww")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$pbkdf2-sha256$1jP+5Zxpxgtee/iPxGgOz0RfE9/KJuDElP1ley4VxXc$QJxzfvdbHYBpydCbHoFg3GJEqMFULwskiuqiJctoYpI")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$pbkdf2-sha256$aaaa$1jP+5Zxpxgtee/iPxGgOz0RfE9/KJuDElP1ley4VxXc$QJxzfvdbHYBpydCbHoFg3GJEqMFULwskiuqiJctoYpI")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$pbkdf2-sha256$i=100000,l=32$1jP+5Zxpxgtee/iPxGgOz0RfE9/KJuDElP1ley4VxXcc$QJxzfvdbHYBpydCbHoFg3GJEqMFULwskiuqiJctoYpI")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$pbkdf2-sha256$i=100000,l=32$1jP+5Zxpxgtee/iPxGgOz0RfE9/KJuDElP1ley4VxXc$QJxzfvdbHYBpydCbHoFg3GJEqMFULwskiuqiJctoYpII")))
assert.Error(t, hasherx.Compare(context.Background(), []byte("test"), []byte("$pbkdf2-sha512$I=100000,l=32$bdHBpn7OWOivJMVJypy2UqR0UnaD5prQXRZevj/05YU$+wArTfv1a+bNGO1iZrmEdVjhA+lL11wF4/IxpgYfPwc")))
}

View File

@ -1,50 +0,0 @@
package hasherx_test
import (
"context"
"fmt"
"testing"
"time"
"go.uber.org/mock/gomock"
"github.com/ory/x/hasherx"
"github.com/ory/x/randx"
)
func TestPBKDF2Performance(t *testing.T) {
for _, iters := range []uint32{
100, 1000, 10000, 25000, 100000, 1000000,
} {
t.Run(fmt.Sprintf("%d", iters), func(t *testing.T) {
runPBKDF2(t, iters, 100)
})
}
}
func runPBKDF2(t *testing.T, iterations uint32, hashCount uint32) {
c := gomock.NewController(t)
t.Cleanup(c.Finish)
reg := NewMockPBKDF2Configurator(c)
reg.EXPECT().HasherPBKDF2Config(gomock.Any()).Return(&hasherx.PBKDF2Config{
Algorithm: "sha256",
Iterations: iterations,
SaltLength: 32,
KeyLength: 32,
}).AnyTimes()
pw := randx.MustString(16, randx.AlphaLower)
hasher := hasherx.NewHasherPBKDF2(reg)
ctx := context.Background()
var err error
start := time.Now()
for i := uint32(0); i < hashCount; i++ {
if _, err = hasher.Generate(ctx, []byte(pw)); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
end := time.Now()
diff := end.Sub(start).Round(time.Millisecond)
t.Logf("%d hashes in %s with %d iterations, %dms per hash", hashCount, diff, iterations, diff.Milliseconds()/int64(hashCount))
}

View File

@ -1,57 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ory/x/hasherx (interfaces: Argon2Configurator)
//
// Generated by this command:
//
// mockgen -package hasherx_test -destination hasherx/mocks_argon2_test.go github.com/ory/x/hasherx Argon2Configurator
//
// Package hasherx_test is a generated GoMock package.
package hasherx_test
import (
context "context"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
hasherx "github.com/ory/x/hasherx"
)
// MockArgon2Configurator is a mock of Argon2Configurator interface.
type MockArgon2Configurator struct {
ctrl *gomock.Controller
recorder *MockArgon2ConfiguratorMockRecorder
isgomock struct{}
}
// MockArgon2ConfiguratorMockRecorder is the mock recorder for MockArgon2Configurator.
type MockArgon2ConfiguratorMockRecorder struct {
mock *MockArgon2Configurator
}
// NewMockArgon2Configurator creates a new mock instance.
func NewMockArgon2Configurator(ctrl *gomock.Controller) *MockArgon2Configurator {
mock := &MockArgon2Configurator{ctrl: ctrl}
mock.recorder = &MockArgon2ConfiguratorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockArgon2Configurator) EXPECT() *MockArgon2ConfiguratorMockRecorder {
return m.recorder
}
// HasherArgon2Config mocks base method.
func (m *MockArgon2Configurator) HasherArgon2Config(ctx context.Context) *hasherx.Argon2Config {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HasherArgon2Config", ctx)
ret0, _ := ret[0].(*hasherx.Argon2Config)
return ret0
}
// HasherArgon2Config indicates an expected call of HasherArgon2Config.
func (mr *MockArgon2ConfiguratorMockRecorder) HasherArgon2Config(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasherArgon2Config", reflect.TypeOf((*MockArgon2Configurator)(nil).HasherArgon2Config), ctx)
}

View File

@ -1,57 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ory/x/hasherx (interfaces: BCryptConfigurator)
//
// Generated by this command:
//
// mockgen -package hasherx_test -destination hasherx/mocks_bcrypt_test.go github.com/ory/x/hasherx BCryptConfigurator
//
// Package hasherx_test is a generated GoMock package.
package hasherx_test
import (
context "context"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
hasherx "github.com/ory/x/hasherx"
)
// MockBCryptConfigurator is a mock of BCryptConfigurator interface.
type MockBCryptConfigurator struct {
ctrl *gomock.Controller
recorder *MockBCryptConfiguratorMockRecorder
isgomock struct{}
}
// MockBCryptConfiguratorMockRecorder is the mock recorder for MockBCryptConfigurator.
type MockBCryptConfiguratorMockRecorder struct {
mock *MockBCryptConfigurator
}
// NewMockBCryptConfigurator creates a new mock instance.
func NewMockBCryptConfigurator(ctrl *gomock.Controller) *MockBCryptConfigurator {
mock := &MockBCryptConfigurator{ctrl: ctrl}
mock.recorder = &MockBCryptConfiguratorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockBCryptConfigurator) EXPECT() *MockBCryptConfiguratorMockRecorder {
return m.recorder
}
// HasherBcryptConfig mocks base method.
func (m *MockBCryptConfigurator) HasherBcryptConfig(ctx context.Context) *hasherx.BCryptConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HasherBcryptConfig", ctx)
ret0, _ := ret[0].(*hasherx.BCryptConfig)
return ret0
}
// HasherBcryptConfig indicates an expected call of HasherBcryptConfig.
func (mr *MockBCryptConfiguratorMockRecorder) HasherBcryptConfig(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasherBcryptConfig", reflect.TypeOf((*MockBCryptConfigurator)(nil).HasherBcryptConfig), ctx)
}

View File

@ -1,57 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ory/x/hasherx (interfaces: PBKDF2Configurator)
//
// Generated by this command:
//
// mockgen -package hasherx_test -destination hasherx/mocks_pkdbf2_test.go github.com/ory/x/hasherx PBKDF2Configurator
//
// Package hasherx_test is a generated GoMock package.
package hasherx_test
import (
context "context"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
hasherx "github.com/ory/x/hasherx"
)
// MockPBKDF2Configurator is a mock of PBKDF2Configurator interface.
type MockPBKDF2Configurator struct {
ctrl *gomock.Controller
recorder *MockPBKDF2ConfiguratorMockRecorder
isgomock struct{}
}
// MockPBKDF2ConfiguratorMockRecorder is the mock recorder for MockPBKDF2Configurator.
type MockPBKDF2ConfiguratorMockRecorder struct {
mock *MockPBKDF2Configurator
}
// NewMockPBKDF2Configurator creates a new mock instance.
func NewMockPBKDF2Configurator(ctrl *gomock.Controller) *MockPBKDF2Configurator {
mock := &MockPBKDF2Configurator{ctrl: ctrl}
mock.recorder = &MockPBKDF2ConfiguratorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPBKDF2Configurator) EXPECT() *MockPBKDF2ConfiguratorMockRecorder {
return m.recorder
}
// HasherPBKDF2Config mocks base method.
func (m *MockPBKDF2Configurator) HasherPBKDF2Config(ctx context.Context) *hasherx.PBKDF2Config {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HasherPBKDF2Config", ctx)
ret0, _ := ret[0].(*hasherx.PBKDF2Config)
return ret0
}
// HasherPBKDF2Config indicates an expected call of HasherPBKDF2Config.
func (mr *MockPBKDF2ConfiguratorMockRecorder) HasherPBKDF2Config(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasherPBKDF2Config", reflect.TypeOf((*MockPBKDF2Configurator)(nil).HasherPBKDF2Config), ctx)
}

View File

@ -1,191 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package healthx
import (
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/herodot"
)
func TestHealth(t *testing.T) {
const mockHeaderKey = "middleware-header"
const mockHeaderValue = "test-header-value"
const mockVersion = "test version"
// middlware to run an assert function on the requested handler
testMiddleware := func(t *testing.T, assertFunc func(*testing.T, http.ResponseWriter, *http.Request)) func(next http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
writer.Header().Add(mockHeaderKey, mockHeaderValue)
assertFunc(t, writer, req)
h.ServeHTTP(writer, req)
})
}
}
assertAliveCheck := func(t *testing.T, endpoint string, handler *Handler) *http.Response {
var healthBody swaggerHealthStatus
c := http.DefaultClient
response, err := c.Get(endpoint)
require.NoError(t, err)
require.EqualValues(t, http.StatusOK, response.StatusCode)
require.NoError(t, json.NewDecoder(response.Body).Decode(&healthBody))
assert.EqualValues(t, "ok", healthBody.Status)
return response
}
assertVersionResponse := func(t *testing.T, endpoint string, handler *Handler) *http.Response {
var versionBody swaggerVersion
c := http.DefaultClient
response, err := c.Get(endpoint)
require.NoError(t, err)
require.EqualValues(t, http.StatusOK, response.StatusCode)
require.NoError(t, json.NewDecoder(response.Body).Decode(&versionBody))
require.EqualValues(t, mockVersion, versionBody.Version)
return response
}
assertReadyCheckNotAlive := func(t *testing.T, endpoint string, handler *Handler) *http.Response {
handler.ReadyChecks = map[string]ReadyChecker{
"test": func(r *http.Request) error {
return errors.New("not alive")
},
}
c := http.DefaultClient
response, err := c.Get(endpoint)
require.NoError(t, err)
require.EqualValues(t, http.StatusServiceUnavailable, response.StatusCode)
out, err := io.ReadAll(response.Body)
require.NoError(t, err)
assert.Equal(t, "{\"error\":{\"code\":500,\"status\":\"Internal Server Error\",\"message\":\"not alive\"}}", strings.TrimSpace(string(out)))
return response
}
assertReadyCheck := func(t *testing.T, endpoint string, handler *Handler) *http.Response {
var healthCheck swaggerHealthStatus
c := http.DefaultClient
response, err := c.Get(endpoint)
require.NoError(t, err)
require.EqualValues(t, http.StatusOK, response.StatusCode)
require.NoError(t, json.NewDecoder(response.Body).Decode(&healthCheck))
require.EqualValues(t, swaggerHealthStatus{Status: "ok"}, healthCheck)
return response
}
testCases := []struct {
description string
url func(mockServerURL string) string
test func(t *testing.T, endpoint string, handler *Handler) *http.Response
}{
{
description: "ready check should return status ok",
url: func(mockServerURL string) string {
return mockServerURL + ReadyCheckPath
},
test: assertReadyCheck,
},
{
description: "ready check should return error",
url: func(mockServerURL string) string {
return mockServerURL + ReadyCheckPath
},
test: assertReadyCheckNotAlive,
},
{
description: "alive check should return status ok",
url: func(mockServerURL string) string {
return mockServerURL + AliveCheckPath
},
test: assertAliveCheck,
},
{
description: "version should return",
url: func(mockServerURL string) string {
return mockServerURL + VersionPath
},
test: assertVersionResponse,
},
}
t.Run("case=without middleware", func(t *testing.T) {
router := httprouter.New()
handler := &Handler{
H: herodot.NewJSONWriter(nil),
VersionString: mockVersion,
ReadyChecks: map[string]ReadyChecker{
"test": func(r *http.Request) error {
return nil
},
},
}
ts := httptest.NewServer(router)
defer ts.Close()
handler.SetHealthRoutes(router, true)
handler.SetVersionRoutes(router)
for _, tc := range testCases {
t.Run("case="+tc.description, func(t *testing.T) {
tc.test(t, tc.url(ts.URL), handler)
})
}
})
t.Run("case=with middleware", func(t *testing.T) {
router := httprouter.New()
var alive error
handler := &Handler{
H: herodot.NewJSONWriter(nil),
VersionString: mockVersion,
ReadyChecks: map[string]ReadyChecker{
"test": func(r *http.Request) error {
return alive
},
},
}
ts := httptest.NewServer(router)
defer ts.Close()
// set the health handlers with middleware
handler.SetHealthRoutes(router, true, WithMiddleware(
testMiddleware(t, func(t *testing.T, rw http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
}),
))
handler.SetVersionRoutes(router, WithMiddleware(
testMiddleware(t, func(t *testing.T, rw http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
}),
))
for _, tc := range testCases {
t.Run("case="+tc.description, func(t *testing.T) {
handler.ReadyChecks = map[string]ReadyChecker{
"test": func(r *http.Request) error {
return nil
},
}
response := tc.test(t, tc.url(ts.URL), handler)
assert.EqualValues(t, mockHeaderValue, response.Header.Get(mockHeaderKey))
})
}
})
}

View File

@ -1,67 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package httprouterx_test
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
x "github.com/ory/x/httprouterx"
"github.com/ory/x/urlx"
)
func TestRedirectToPublicAdminRoute(t *testing.T) {
var ts *httptest.Server
router := x.NewRouterAdminWithPrefix("/admin", func(ctx context.Context) *url.URL {
return urlx.ParseOrPanic(ts.URL)
})
ts = httptest.NewServer(router)
t.Cleanup(ts.Close)
router.POST("/privileged", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
})
router.POST("/read", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
})
for _, tc := range []struct {
source string
dest string
}{
{
source: ts.URL + "/admin/privileged?foo=bar",
dest: ts.URL + "/admin/privileged?foo=bar",
},
{
source: ts.URL + "/privileged?foo=bar",
dest: ts.URL + "/admin/privileged?foo=bar",
},
} {
t.Run(fmt.Sprintf("source=%s", tc.source), func(t *testing.T) {
id := uuid.Must(uuid.NewV4()).String()
res, err := ts.Client().Post(tc.source, "", strings.NewReader(id))
require.NoError(t, err)
assert.EqualValues(t, http.StatusOK, res.StatusCode)
assert.Equal(t, tc.dest, res.Request.URL.String())
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, id, string(body))
})
}
}

View File

@ -1,81 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package httprouterx
import (
"context"
"net/http"
"net/url"
"testing"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewRouterAdmin(t *testing.T) {
require.NotEmpty(t, NewRouterAdmin())
require.NotEmpty(t, NewRouterPublic())
}
func TestCacheHandling(t *testing.T) {
router := NewRouterPublic()
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
router.GET("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
router.DELETE("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
router.POST("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
router.PUT("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
router.PATCH("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
for _, method := range []string{} {
req, _ := http.NewRequest(method, ts.URL+"/foo", nil)
res, err := ts.Client().Do(req)
require.NoError(t, err)
assert.EqualValues(t, "0", res.Header.Get("Cache-Control"))
}
}
func TestAdminPrefix(t *testing.T) {
router := NewRouterAdminWithPrefix("/admin", func(ctx context.Context) *url.URL {
return &url.URL{Path: "https://www.ory.sh/"}
})
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
router.GET("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
router.DELETE("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
router.POST("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
router.PUT("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
router.PATCH("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNoContent)
})
for _, method := range []string{} {
req, _ := http.NewRequest(method, ts.URL+"/admin/foo", nil)
res, err := ts.Client().Do(req)
require.NoError(t, err)
assert.EqualValues(t, http.StatusNoContent, res.StatusCode)
}
}

View File

@ -1,32 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package httpx
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestChanHandler(t *testing.T) {
h, c := NewChanHandler(1)
s := httptest.NewServer(h)
c <- func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(555)
}
resp, err := s.Client().Get(s.URL)
require.NoError(t, err)
assert.Equal(t, 555, resp.StatusCode)
c <- func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(337)
}
resp, err = s.Client().Get(s.URL)
require.NoError(t, err)
assert.Equal(t, 337, resp.StatusCode)
}

View File

@ -1,101 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package httpx
import (
"context"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIgnoresInternalIPs(t *testing.T) {
input := "54.155.246.232,10.145.1.10"
res, err := GetClientIPAddressesWithoutInternalIPs(strings.Split(input, ","))
require.NoError(t, err)
assert.Equal(t, "54.155.246.232", res)
}
func TestEmptyInputArray(t *testing.T) {
res, err := GetClientIPAddressesWithoutInternalIPs([]string{})
require.NoError(t, err)
assert.Equal(t, "", res)
}
func TestClientIP(t *testing.T) {
req := http.Request{
RemoteAddr: "1.0.0.4",
Header: http.Header{},
}
req.Header.Add("true-client-ip", "1.0.0.1")
req.Header.Add("cf-connecting-ip", "1.0.0.2")
req.Header.Add("x-real-ip", "1.0.0.3")
req.Header.Add("x-forwarded-for", "192.168.1.1,1.0.0.3,10.0.0.1")
t.Run("true-client-ip", func(t *testing.T) {
req := req.Clone(context.Background())
assert.Equal(t, "1.0.0.1", ClientIP(req))
})
t.Run("cf-connecting-ip", func(t *testing.T) {
req := req.Clone(context.Background())
req.Header.Del("true-client-ip")
assert.Equal(t, "1.0.0.2", ClientIP(req))
})
t.Run("x-real-ip", func(t *testing.T) {
req := req.Clone(context.Background())
req.Header.Del("true-client-ip")
req.Header.Del("cf-connecting-ip")
assert.Equal(t, "1.0.0.3", ClientIP(req))
})
t.Run("x-forwarded-for", func(t *testing.T) {
req := req.Clone(context.Background())
req.Header.Del("true-client-ip")
req.Header.Del("cf-connecting-ip")
req.Header.Del("x-real-ip")
assert.Equal(t, "1.0.0.3", ClientIP(req))
})
t.Run("remote-addr", func(t *testing.T) {
req := req.Clone(context.Background())
req.Header.Del("true-client-ip")
req.Header.Del("cf-connecting-ip")
req.Header.Del("x-real-ip")
req.Header.Del("x-forwarded-for")
assert.Equal(t, "1.0.0.4", ClientIP(req))
})
}
func TestClientGeoLocation(t *testing.T) {
req := http.Request{
Header: http.Header{},
}
req.Header.Add("cf-ipcity", "Berlin")
req.Header.Add("cf-ipcountry", "Germany")
req.Header.Add("cf-region-code", "BE")
t.Run("cf-ipcity", func(t *testing.T) {
req := req.Clone(context.Background())
assert.Equal(t, "Berlin", ClientGeoLocation(req).City)
})
t.Run("cf-ipcountry", func(t *testing.T) {
req := req.Clone(context.Background())
assert.Equal(t, "Germany", ClientGeoLocation(req).Country)
})
t.Run("cf-region-code", func(t *testing.T) {
req := req.Clone(context.Background())
assert.Equal(t, "BE", ClientGeoLocation(req).Region)
})
t.Run("empty", func(t *testing.T) {
req := req.Clone(context.Background())
req.Header.Del("cf-ipcity")
req.Header.Del("cf-ipcountry")
req.Header.Del("cf-region-code")
assert.Equal(t, GeoLocation{}, *ClientGeoLocation(req))
})
}

View File

@ -1,23 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package httpx
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHasContentType(t *testing.T) {
assert.True(t, HasContentType(&http.Request{Header: map[string][]string{}}, "application/octet-stream"))
assert.False(t, HasContentType(&http.Request{Header: map[string][]string{}}, "not-application/octet-stream"))
assert.True(t, HasContentType(&http.Request{Header: map[string][]string{"Content-Type": {"application/octet-stream"}}}, "application/octet-stream"))
// Invalid conent types
assert.False(t, HasContentType(&http.Request{Header: map[string][]string{"Content-Type": {"application/octet-stream, not-application/application"}}}, "not-application/application"))
assert.False(t, HasContentType(&http.Request{Header: map[string][]string{"Content-Type": {"application/octet-stream,not-application/application"}}}, "not-application/application"))
assert.False(t, HasContentType(&http.Request{Header: map[string][]string{"Content-Type": {"application/octet-stream, application/not-application"}}}, "not-application/not-octet-stream"))
assert.False(t, HasContentType(&http.Request{Header: map[string][]string{"Content-Type": {"a"}}}, "not-application/not-octet-stream"))
}

View File

@ -1,54 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package httpx
import (
"bytes"
gzip2 "compress/gzip"
"encoding/json"
"net/http"
"testing"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func makeRequest(t *testing.T, data string, ts *httptest.Server) {
var buf bytes.Buffer
gzip := gzip2.NewWriter(&buf)
_, err := gzip.Write([]byte(data))
require.NoError(t, err)
require.NoError(t, gzip.Close())
c := http.Client{}
req, err := http.NewRequest("POST", ts.URL, &buf)
req.Header.Set("Content-Encoding", "gzip")
require.NoError(t, err)
res, err := c.Do(req)
require.NoError(t, err)
res.Body.Close()
assert.EqualValues(t, http.StatusNoContent, res.StatusCode)
}
func TestGZipServer(t *testing.T) {
router := httprouter.New()
router.POST("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
var f json.RawMessage
require.NoError(t, json.NewDecoder(r.Body).Decode(&f))
t.Logf("%s", f)
w.WriteHeader(http.StatusNoContent)
})
n := negroni.New(NewCompressionRequestReader(func(w http.ResponseWriter, r *http.Request, err error) {
require.NoError(t, err)
}))
n.UseHandler(router)
ts := httptest.NewServer(n)
defer ts.Close()
makeRequest(t, "true", ts)
}

View File

@ -1,107 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package httpx
import (
"net/http"
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsAssociatedIPAllowed(t *testing.T) {
for _, disallowed := range []string{
"localhost",
"https://localhost/foo?bar=baz#zab",
"127.0.0.0",
"127.255.255.255",
"172.16.0.0",
"172.31.255.255",
"192.168.0.0",
"192.168.255.255",
"10.0.0.0",
"0.0.0.0",
"10.255.255.255",
"::1",
"100::1",
"fe80::1",
"169.254.169.254", // AWS instance metadata service
} {
t.Run("case="+disallowed, func(t *testing.T) {
assert.Error(t, DisallowIPPrivateAddresses(disallowed))
})
}
}
func TestDisallowLocalIPAddressesWhenSet(t *testing.T) {
require.NoError(t, DisallowIPPrivateAddresses(""))
require.Error(t, DisallowIPPrivateAddresses("127.0.0.1"))
require.ErrorAs(t, DisallowIPPrivateAddresses("127.0.0.1"), new(ErrPrivateIPAddressDisallowed))
}
type noOpRoundTripper struct{}
func (n noOpRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
return &http.Response{}, nil
}
var _ http.RoundTripper = new(noOpRoundTripper)
type errRoundTripper struct{ err error }
var errNotOnWhitelist = errors.New("OK")
var errOnWhitelist = errors.New("OK (on whitelist)")
func (n errRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
return nil, n.err
}
var _ http.RoundTripper = new(errRoundTripper)
// TestInternalRespectsRoundTripper tests if the RoundTripper picks the correct
// underlying transport for two allowed requests.
func TestInternalRespectsRoundTripper(t *testing.T) {
rt := &noInternalIPRoundTripper{
onWhitelist: &errRoundTripper{errOnWhitelist},
notOnWhitelist: &errRoundTripper{errNotOnWhitelist},
internalIPExceptions: []string{
"https://127.0.0.1/foo",
}}
req, err := http.NewRequest("GET", "https://google.com/foo", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, errNotOnWhitelist)
req, err = http.NewRequest("GET", "https://127.0.0.1/foo", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, errOnWhitelist)
}
func TestAllowExceptions(t *testing.T) {
rt := noInternalIPRoundTripper{
onWhitelist: &errRoundTripper{errOnWhitelist},
notOnWhitelist: &errRoundTripper{errNotOnWhitelist},
internalIPExceptions: []string{
"http://localhost/asdf",
}}
req, err := http.NewRequest("GET", "http://localhost/asdf", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, errOnWhitelist)
req, err = http.NewRequest("GET", "http://localhost/not-asdf", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, errNotOnWhitelist)
req, err = http.NewRequest("GET", "http://127.0.0.1", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, errNotOnWhitelist)
}

View File

@ -1,130 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package httpx
import (
"context"
"net"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"net/netip"
"net/url"
"sync/atomic"
"testing"
"github.com/hashicorp/go-retryablehttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNoPrivateIPs(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("Hello, world!"))
}))
t.Cleanup(ts.Close)
target, err := url.ParseRequestURI(ts.URL)
require.NoError(t, err)
_, port, err := net.SplitHostPort(target.Host)
require.NoError(t, err)
allowedURL := "http://localhost:" + port + "/foobar"
allowedGlob := "http://localhost:" + port + "/glob/*"
c := NewResilientClient(
ResilientClientWithMaxRetry(1),
ResilientClientDisallowInternalIPs(),
ResilientClientAllowInternalIPRequestsTo(allowedURL, allowedGlob),
)
for i := 0; i < 10; i++ {
for destination, passes := range map[string]bool{
"http://127.0.0.1:" + port: false,
"http://localhost:" + port: false,
"http://192.168.178.5:" + port: false,
allowedURL: true,
"http://localhost:" + port + "/glob/bar": true,
"http://localhost:" + port + "/glob/bar/baz": false,
"http://localhost:" + port + "/FOOBAR": false,
} {
_, err := c.Get(destination)
if !passes {
require.Errorf(t, err, "dest = %s", destination)
assert.Containsf(t, err.Error(), "is not a permitted destination", "dest = %s", destination)
} else {
require.NoErrorf(t, err, "dest = %s", destination)
}
}
}
}
func TestNoIPV6(t *testing.T) {
for _, tc := range []struct {
name string
c *retryablehttp.Client
}{
{
"internal IPs allowed",
NewResilientClient(
ResilientClientWithMaxRetry(1),
ResilientClientNoIPv6(),
),
}, {
"internal IPs disallowed",
NewResilientClient(
ResilientClientWithMaxRetry(1),
ResilientClientDisallowInternalIPs(),
ResilientClientNoIPv6(),
),
},
} {
t.Run(tc.name, func(t *testing.T) {
var connectDone int32
ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
DNSDone: func(dnsInfo httptrace.DNSDoneInfo) {
for _, ip := range dnsInfo.Addrs {
netIP, ok := netip.AddrFromSlice(ip.IP)
assert.True(t, ok)
assert.Truef(t, netIP.Is4(), "ip = %s", ip)
}
},
ConnectDone: func(network, addr string, err error) {
atomic.AddInt32(&connectDone, 1)
assert.NoError(t, err)
assert.Equalf(t, "tcp4", network, "network = %s addr = %s", network, addr)
},
})
// Dual stack
req, err := retryablehttp.NewRequestWithContext(ctx, "GET", "http://dual.tlund.se/", nil)
require.NoError(t, err)
atomic.StoreInt32(&connectDone, 0)
res, err := tc.c.Do(req)
require.GreaterOrEqual(t, int32(1), atomic.LoadInt32(&connectDone))
require.NoError(t, err)
t.Cleanup(func() { _ = res.Body.Close() })
require.EqualValues(t, http.StatusOK, res.StatusCode)
// IPv4 only
req, err = retryablehttp.NewRequestWithContext(ctx, "GET", "http://ipv4.tlund.se/", nil)
require.NoError(t, err)
atomic.StoreInt32(&connectDone, 0)
res, err = tc.c.Do(req)
require.EqualValues(t, 1, atomic.LoadInt32(&connectDone))
require.NoError(t, err)
t.Cleanup(func() { _ = res.Body.Close() })
require.EqualValues(t, http.StatusOK, res.StatusCode)
// IPv6 only
req, err = retryablehttp.NewRequestWithContext(ctx, "GET", "http://ipv6.tlund.se/", nil)
require.NoError(t, err)
atomic.StoreInt32(&connectDone, 0)
_, err = tc.c.Do(req)
require.EqualValues(t, 0, atomic.LoadInt32(&connectDone))
require.ErrorContains(t, err, "no such host")
})
}
}

View File

@ -1,28 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package httpx_test
import (
"crypto/tls"
"net/http"
"testing"
"github.com/ory/x/httpx"
"github.com/stretchr/testify/assert"
"github.com/ory/x/urlx"
)
func TestIncomingRequestURL(t *testing.T) {
assert.EqualValues(t, httpx.IncomingRequestURL(&http.Request{
URL: urlx.ParseOrPanic("/foo"), Host: "foobar", TLS: &tls.ConnectionState{},
}).String(), "https://foobar/foo")
assert.EqualValues(t, httpx.IncomingRequestURL(&http.Request{
URL: urlx.ParseOrPanic("/foo"), Host: "foobar",
}).String(), "http://foobar/foo")
assert.EqualValues(t, httpx.IncomingRequestURL(&http.Request{
URL: urlx.ParseOrPanic("/foo"), Host: "foobar", Header: http.Header{"X-Forwarded-Host": []string{"notfoobar"}, "X-Forwarded-Proto": {"https"}},
}).String(), "https://notfoobar/foo")
}

View File

@ -1,43 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package ipx
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsAssociatedIPAllowed(t *testing.T) {
for _, disallowed := range []string{
"localhost",
"https://localhost/foo?bar=baz#zab",
"127.0.0.0",
"127.255.255.255",
"172.16.0.0",
"172.31.255.255",
"192.168.0.0",
"192.168.255.255",
"10.0.0.0",
"10.255.255.255",
"::1",
} {
t.Run("case="+disallowed, func(t *testing.T) {
require.Error(t, IsAssociatedIPAllowed(disallowed))
})
}
// Do not error if invalid data is used
require.NoError(t, IsAssociatedIPAllowed("idonotexist"))
require.NoError(t, IsAssociatedIPAllowedWhenSet(""))
require.NoError(t, AreAllAssociatedIPsAllowed(map[string]string{
"foo": "https://google.com",
"bar": "microsoft.com",
}))
require.Error(t, AreAllAssociatedIPsAllowed(map[string]string{
"foo": "https://google.com",
"bar": "microsoft.com",
"baz": "localhost",
}))
}

View File

@ -1,386 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jsonnetsecure
import (
"bufio"
"errors"
"fmt"
"math/rand"
"os/exec"
"runtime"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/google/go-jsonnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
func ensureChildProcessStoppedEarly(t testing.TB, err error) {
t.Helper()
require.Error(t, err)
// The actual string is OS-specific and our tests run on all major ones.
// Additionally the child process may have stopped/been stopped for a variety of reasons,
// depending on which limit was hit first.
errStr := err.Error()
require.True(t,
// Killed by the parent or the OS (due to hitting the memory limit).
strings.Contains(errStr, "reached limits") ||
strings.Contains(errStr, "killed") ||
// The Go runtime hit the memory limit and quit.
strings.Contains(errStr, "cannot allocate memory") ||
strings.Contains(errStr, "out of memory") ||
// Invalid input.
strings.Contains(errStr, "encountered an error") ||
// Timeout.
strings.Contains(errStr, "deadline exceeded") ||
// Too much output (this error comes from `bufio.Scanner` which has its own internal limit).
strings.Contains(errStr, "token too long"),
errStr,
)
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
assert.NotEqual(t, exitErr.ProcessState.ExitCode(), 0)
}
}
func TestSecureVM(t *testing.T) {
testBinary := JsonnetTestBinary(t)
for _, optCase := range []struct {
name string
opts []Option
}{
{"none", []Option{}},
{"process pool vm", []Option{
WithProcessPool(procPool),
WithJsonnetBinary(testBinary),
}},
} {
t.Run("options="+optCase.name, func(t *testing.T) {
for i, contents := range []string{
"local contents = importstr 'jsonnet.go'; { contents: contents }",
"local contents = import 'stub/import.jsonnet'; { contents: contents }",
`{user_id: ` + strings.Repeat("a", jsonnetErrLimit*5),
} {
t.Run(fmt.Sprintf("case=%d", i), func(t *testing.T) {
vm := MakeSecureVM(optCase.opts...)
result, err := vm.EvaluateAnonymousSnippet("test", contents)
require.Error(t, err, "%s", result)
})
}
})
}
// Test that all VM behave the same for sane input
t.Run("suite=feature parity", func(t *testing.T) {
t.Run("case=simple input", func(t *testing.T) {
// from https://jsonnet.org/learning/tutorial.html
snippet := `
/* A C-style comment. */
# A Python-style comment.
{
cocktails: {
// Ingredient quantities are in fl oz.
'Tom Collins': {
ingredients: [
{ kind: "Farmer's Gin", qty: 1.5 },
{ kind: 'Lemon', qty: 1 },
{ kind: 'Simple Syrup', qty: 0.5 },
{ kind: 'Soda', qty: 2 },
{ kind: 'Angostura', qty: 'dash' },
],
garnish: 'Maraschino Cherry',
served: 'Tall',
description: |||
The Tom Collins is essentially gin and
lemonade. The bitters add complexity.
|||,
},
Manhattan: {
ingredients: [
{ kind: 'Rye', qty: 2.5 },
{ kind: 'Sweet Red Vermouth', qty: 1 },
{ kind: 'Angostura', qty: 'dash' },
],
garnish: 'Maraschino Cherry',
served: 'Straight Up',
description: @'A clear \ red drink.',
},
},
}`
assertEqualVMOutput(t, func(factory func(t *testing.T) VM) string {
vm := factory(t)
out, err := vm.EvaluateAnonymousSnippet("test", snippet)
assert.NoError(t, err)
return out
})
})
t.Run("case=ext variables", func(t *testing.T) {
assertEqualVMOutput(t, func(factory func(t *testing.T) VM) string {
vm := factory(t)
vm.ExtVar("one", "1")
vm.ExtVar("two", "2")
vm.ExtCode("bool", "true")
vm.TLAVar("oneArg", "1")
vm.TLAVar("twoArg", "2")
vm.TLACode("boolArg", "false")
out, err := vm.EvaluateAnonymousSnippet(
"test",
`function (oneArg, twoArg, boolArg) {
one: std.extVar("one"), two: std.extVar("two"), bool: std.extVar("bool"),
oneTLA: oneArg, twoTLA: twoArg, boolTLA: boolArg,
}`)
assert.NoError(t, err)
return out
})
})
})
t.Run("case=stack overflow pool", func(t *testing.T) {
snippet := "local f(x) = if x == 0 then [] else [f(x - 1), f(x - 1)]; f(100)"
vm := MakeSecureVM(
WithJsonnetBinary(testBinary),
WithProcessPool(procPool),
)
result, err := vm.EvaluateAnonymousSnippet("test", snippet)
ensureChildProcessStoppedEarly(t, err)
assert.Empty(t, result)
})
t.Run("case=stdout too lengthy pool", func(t *testing.T) {
// This script outputs more than the limit.
snippet := `{user_id: std.repeat("a", ` + strconv.FormatUint(jsonnetOutputLimit, 10) + `)}`
vm := MakeSecureVM(
WithProcessPool(procPool),
WithJsonnetBinary(testBinary),
)
_, err := vm.EvaluateAnonymousSnippet("test", snippet)
ensureChildProcessStoppedEarly(t, err)
})
t.Run("case=importbin", func(t *testing.T) {
// importbin does not exist in the current version, but is already merged on the main branch:
// https://github.com/google/go-jsonnet/commit/856bd58872418eee1cede0badea5b7b462c429eb
vm := MakeSecureVM()
result, err := vm.EvaluateAnonymousSnippet(
"test",
"local contents = importbin 'stub/import.jsonnet'; { contents: contents }")
require.Error(t, err, "%s", result)
})
}
func standardVM(t *testing.T) VM {
t.Helper()
return jsonnet.MakeVM()
}
func secureVM(t *testing.T) VM {
t.Helper()
return MakeSecureVM()
}
func poolVM(t *testing.T) VM {
t.Helper()
pool := NewProcessPool(10)
t.Cleanup(pool.Close)
return MakeSecureVM(
WithProcessPool(pool),
WithJsonnetBinary(JsonnetTestBinary(t)))
}
func assertEqualVMOutput(t *testing.T, run func(factory func(t *testing.T) VM) string) {
t.Helper()
expectedOut := run(standardVM)
secureOut := run(secureVM)
poolOut := run(poolVM)
assert.Equal(t, expectedOut, secureOut, "secure output incorrect")
assert.Equal(t, expectedOut, poolOut, "pool output incorrect")
}
func TestStressTestOnlyValid(t *testing.T) {
wg := errgroup.Group{}
testBinary := JsonnetTestBinary(t)
count := 100
procPool := NewProcessPool(runtime.GOMAXPROCS(0))
defer procPool.Close()
snippet := `{a:1}`
for range count {
wg.Go(func() error {
vm := MakeSecureVM(
WithProcessPool(procPool),
WithJsonnetBinary(testBinary),
)
out, err := vm.EvaluateAnonymousSnippet("test", snippet)
require.NoError(t, err)
require.NotEmpty(t, out)
return err
})
}
require.NoError(t, wg.Wait())
}
func TestStressTest(t *testing.T) {
wg := errgroup.Group{}
testBinary := JsonnetTestBinary(t)
count := 100
cases := []string{
`{a:1}`, // Correct.
`{a: std.repeat("a",1000000)}`, // Correct but output is too lengthy.
`{a:`, // Incorrect syntax (will print on stderr).
`{a:` + strings.Repeat("a", 1024*1024), // Big script which will be printed to stderr.
}
for i := range count {
wg.Go(func() error {
vm := MakeSecureVM(
WithProcessPool(procPool),
WithJsonnetBinary(testBinary),
)
snippet := cases[i%len(cases)]
// Due to the documented edge cases, we cannot really assert anything about
// the result and error in the presence of misbehaving scripts.
vm.EvaluateAnonymousSnippet("test", snippet)
return nil
})
}
require.NoError(t, wg.Wait())
}
func TestMain(m *testing.M) {
procPool = NewProcessPool(runtime.GOMAXPROCS(0))
defer procPool.Close()
m.Run()
}
var (
procPool Pool
snippet = "{a:std.extVar('a')}"
)
func BenchmarkIsolatedVM(b *testing.B) {
binary := JsonnetTestBinary(b)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
vm := MakeSecureVM(
WithJsonnetBinary(binary),
)
i := rand.Int()
vm.ExtCode("a", strconv.Itoa(i))
res, err := vm.EvaluateAnonymousSnippet("test", snippet)
require.NoError(b, err)
require.JSONEq(b, fmt.Sprintf(`{"a": %d}`, i), res)
}
})
}
func BenchmarkProcessPoolVM(b *testing.B) {
binary := JsonnetTestBinary(b)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
vm := MakeSecureVM(
WithJsonnetBinary(binary),
WithProcessPool(procPool),
)
i := rand.Int()
vm.ExtCode("a", strconv.Itoa(i))
res, err := vm.EvaluateAnonymousSnippet("test", snippet)
require.NoError(b, err)
require.JSONEq(b, fmt.Sprintf(`{"a": %d}`, i), res)
}
})
}
func BenchmarkRegularVM(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
vm := MakeSecureVM()
i := rand.Int()
vm.ExtCode("a", strconv.Itoa(i))
res, err := vm.EvaluateAnonymousSnippet("test", snippet)
require.NoError(b, err)
require.JSONEq(b, fmt.Sprintf(`{"a": %d}`, i), res)
}
})
}
func BenchmarkReusableProcessVM(b *testing.B) {
var (
binary = JsonnetTestBinary(b)
cmd = exec.Command(binary, "-0")
inputs = make(chan struct{})
stderr strings.Builder
eg errgroup.Group
count int32 = 0
)
stdin, err := cmd.StdinPipe()
require.NoError(b, err)
stdout, err := cmd.StdoutPipe()
require.NoError(b, err)
cmd.Stderr = &stderr
require.NoError(b, cmd.Start())
b.Cleanup(func() {
close(inputs)
assert.NoError(b, stdin.Close())
assert.NoError(b, eg.Wait())
assert.NoError(b, cmd.Wait())
assert.Empty(b, stderr.String())
})
eg.Go(func() error {
scanner := bufio.NewScanner(stdout)
scanner.Split(splitNull)
for scanner.Scan() {
c := atomic.AddInt32(&count, 1)
require.JSONEq(b, fmt.Sprintf(`{"a": %d}`, c), scanner.Text())
}
return scanner.Err()
})
eg.Go(func() error {
a := 1
for range inputs {
pp := processParameters{Snippet: snippet, ExtCodes: []kv{{"a", strconv.Itoa(a)}}}
a++
require.NoError(b, pp.EncodeTo(stdin))
_, err := stdin.Write([]byte{0})
require.NoError(b, err)
}
return nil
})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
inputs <- struct{}{}
}
})
for atomic.LoadInt32(&count) != int32(b.N) {
time.Sleep(1 * time.Millisecond)
}
}

View File

@ -1,305 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jsonschemax
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"testing"
"github.com/ory/x/snapshotx"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"github.com/ory/jsonschema/v3"
)
const recursiveSchema = `{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "test.json",
"definitions": {
"foo": {
"type": "object",
"properties": {
"bars": {
"type": "string",
"format": "email",
"pattern": ".*"
},
"bar": {
"$ref": "#/definitions/bar"
}
},
"required":["bars"]
},
"bar": {
"type": "object",
"properties": {
"foos": {
"type": "string",
"minLength": 1,
"maxLength": 10
},
"foo": {
"$ref": "#/definitions/foo"
}
}
}
},
"type": "object",
"properties": {
"bar": {
"$ref": "#/definitions/bar"
}
}
}`
func readFile(t *testing.T, path string) string {
schema, err := os.ReadFile(path)
require.NoError(t, err)
return string(schema)
}
const fooExtensionName = "fooExtension"
type (
extensionConfig struct {
NotAJSONSchemaKey string `json:"not-a-json-schema-key"`
}
)
func fooExtensionCompile(_ jsonschema.CompilerContext, m map[string]interface{}) (interface{}, error) {
if raw, ok := m[fooExtensionName]; ok {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(raw); err != nil {
return nil, errors.WithStack(err)
}
var e extensionConfig
if err := json.NewDecoder(&b).Decode(&e); err != nil {
return nil, errors.WithStack(err)
}
return &e, nil
}
return nil, nil
}
func fooExtensionValidate(_ jsonschema.ValidationContext, _, _ interface{}) error {
return nil
}
func (ec *extensionConfig) EnhancePath(p Path) map[string]interface{} {
if ec.NotAJSONSchemaKey != "" {
fmt.Printf("enhancing path: %s with custom property %s\n", p.Name, ec.NotAJSONSchemaKey)
return map[string]interface{}{
ec.NotAJSONSchemaKey: p.Name,
}
}
return nil
}
func TestListPathsWithRecursion(t *testing.T) {
for k, tc := range []struct {
recursion uint8
expected interface{}
}{
{
recursion: 5,
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
c := jsonschema.NewCompiler()
require.NoError(t, c.AddResource("test.json", bytes.NewBufferString(recursiveSchema)))
actual, err := ListPathsWithRecursion(context.Background(), "test.json", c, tc.recursion)
require.NoError(t, err)
snapshotx.SnapshotT(t, actual)
})
}
}
func TestListPaths(t *testing.T) {
for k, tc := range []struct {
schema string
expectErr bool
extension *jsonschema.Extension
}{
{
schema: readFile(t, "./stub/.oathkeeper.schema.json"),
},
{
schema: readFile(t, "./stub/nested-simple-array.schema.json"),
},
{
schema: readFile(t, "./stub/config.schema.json"),
},
{
schema: readFile(t, "./stub/nested-array.schema.json"),
},
{
// this should fail because of recursion
schema: recursiveSchema,
expectErr: true,
},
{
schema: `{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "test.json",
"oneOf": [
{
"type": "object",
"properties": {
"list": {
"type": "array",
"items": {
"type": "string"
}
},
"foo": {
"default": false,
"type": "boolean"
},
"bar": {
"type": "boolean",
"default": "asdf",
"readOnly": true
}
}
},
{
"type": "object",
"properties": {
"foo": {
"type": "boolean"
}
}
}
]
}`,
},
{
schema: `{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "test.json",
"type": "object",
"required": ["foo"],
"properties": {
"foo": {
"type": "boolean"
},
"bar": {
"type": "string",
"fooExtension": {
"not-a-json-schema-key": "foobar"
}
}
}
}`,
extension: &jsonschema.Extension{
Meta: nil,
Compile: fooExtensionCompile,
Validate: fooExtensionValidate,
},
},
{
schema: `{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "test.json",
"type": "object",
"definitions": {
"foo": {
"type": "string"
}
},
"properties": {
"bar": {
"type": "array",
"items": {
"$ref": "#/definitions/foo"
}
}
}
}`,
},
{
schema: `{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "test.json",
"type": "object",
"definitions": {
"foo": {
"type": "string"
},
"bar": {
"type": "array",
"items": {
"$ref": "#/definitions/foo"
},
"required": ["foo"]
}
},
"properties": {
"baz": {
"type": "array",
"items": {
"$ref": "#/definitions/bar"
}
}
}
}`,
},
{
schema: `{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "test.json",
"type": "object",
"definitions": {
"foo": {
"type": "string"
},
"bar": {
"type": "object",
"properties": {
"foo": {
"$ref": "#/definitions/foo"
}
},
"required": ["foo"]
}
},
"properties": {
"baz": {
"type": "array",
"items": {
"$ref": "#/definitions/bar"
}
}
}
}`,
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
c := jsonschema.NewCompiler()
if tc.extension != nil {
c.Extensions[fooExtensionName] = *tc.extension
}
require.NoError(t, c.AddResource("test.json", bytes.NewBufferString(tc.schema)))
actual, err := ListPathsWithArraysIncluded(context.Background(), "test.json", c)
if tc.expectErr {
require.Error(t, err, "%+v", actual)
return
}
require.NoError(t, err)
snapshotx.SnapshotT(t, actual)
})
}
}

View File

@ -1,31 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jsonschemax
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestJSONPointerToDotNotation(t *testing.T) {
for k, tc := range [][]string{
{"#/foo/bar/baz", "foo.bar.baz"},
{"#/baz", "baz"},
{"#/properties/ory.sh~1kratos/type", "properties.ory\\.sh/kratos.type"},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
path, err := JSONPointerToDotNotation(tc[0])
require.NoError(t, err)
require.Equal(t, tc[1], path)
})
}
_, err := JSONPointerToDotNotation("http://foo/#/bar")
require.Error(t, err, "should fail because remote pointers are not supported")
_, err = JSONPointerToDotNotation("http://foo/#/bar%zz")
require.Error(t, err, "should fail because %3b is not a valid escaped path.")
}

View File

@ -1,125 +0,0 @@
// Copyright © 2025 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jsonx_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/ory/x/jsonx"
)
func TestJSONShape(t *testing.T) {
for _, tc := range []struct {
name string
in string
expected string
}{{
name: "user patch",
in: `{
"schemas" : [ "urn:ietf:params:scim:schemas:core:2.0:User" ],
"id" : "d4b4f9db-2361-4845-a4cd-51e12527b92e",
"externalId" : "00uo3xq5f75s2KCOE5d7",
"userName" : "henning.perl@ory.sh",
"name" : {
"familyName" : "Perl",
"givenName" : "Henning"
},
"displayName" : "Henning Perl",
"locale" : "en-US",
"active" : true,
"emails" : [ {
"value" : "henning.perl@ory.sh",
"primary" : true,
"type" : "work"
} ],
"groups" : [ {
"value" : "21c5f2f9-8fb0-45b3-9bb6-61ecd1090549",
"display" : "Developers",
"type" : "direct"
}, {
"value" : "a37d499d-739c-4e08-8273-c124f85172fe",
"display" : "SCIM pros",
"type" : "direct"
} ],
"meta" : {
"resourceType" : "User",
"created" : "2025-04-25T07:53:43Z",
"lastModified" : "2025-04-25T08:31:23Z"
},
"roles" : [ "foo", "bar" ]
}`,
expected: `{
"active": "boolean",
"displayName": "string",
"emails": [
{
"primary": "boolean",
"type": "string",
"value": "string"
}
],
"externalId": "string",
"groups": [
{
"display": "string",
"type": "string",
"value": "string"
},
{
"display": "string",
"type": "string",
"value": "string"
}
],
"id": "d4b4f9db-2361-4845-a4cd-51e12527b92e",
"locale": "string",
"meta": {
"created": "string",
"lastModified": "string",
"resourceType": "string"
},
"name": {
"familyName": "string",
"givenName": "string"
},
"roles": [
"string",
"string"
],
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User"
],
"userName": "string"
}`,
}, {
name: "invalid JSON",
in: `{`,
expected: `{"error": "invalid JSON", "message": "unexpected end of JSON input"}`,
}, {
name: "different types",
in: `{
"float": 0.42,
"int": 42,
"string": "foo",
"bool": true,
"null": null,
"array": [1, "2", 0]
}`,
expected: `{
"float": "number",
"int": "number",
"string": "string",
"bool": "boolean",
"null": "null",
"array": ["number", "string", "number"]
}`,
}} {
t.Run(tc.name, func(t *testing.T) {
actual := string(jsonx.Anonymize([]byte(tc.in), "id", "schemas"))
assert.JSONEq(t, tc.expected, actual, actual)
})
}
}

View File

@ -1,63 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jsonx
import (
"io/fs"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/x/snapshotx"
)
func TestEmbedSources(t *testing.T) {
t.Run("fixtures", func(t *testing.T) {
require.NoError(t, filepath.Walk("fixture/embed", func(p string, i fs.FileInfo, err error) error {
if err != nil {
return err
}
if i.IsDir() {
return nil
}
t.Run("fixture="+i.Name(), func(t *testing.T) {
t.Parallel()
input, err := os.ReadFile(p)
require.NoError(t, err)
actual, err := EmbedSources(input, WithIgnoreKeys(
"ignore_this_key",
))
require.NoError(t, err)
snapshotx.SnapshotT(t, actual)
})
return nil
}))
})
t.Run("only embeds base64", func(t *testing.T) {
actual, err := EmbedSources([]byte(`{"key":"https://foobar.com", "bar":"base64://YXNkZg=="}`), WithOnlySchemes(
"base64",
))
require.NoError(t, err)
snapshotx.SnapshotT(t, actual)
})
t.Run("fails on invalid source", func(t *testing.T) {
expected := []byte(`{"foo":"base64://invalid}`)
actual, err := EmbedSources(expected)
require.NoError(t, err)
assert.Equal(t, string(expected), string(actual))
})
}

View File

@ -1,42 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jsonx
import (
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFlatten(t *testing.T) {
f, err := os.ReadFile("./stub/random.json")
require.NoError(t, err)
for k, tc := range []struct {
raw []byte
expected map[string]interface{}
}{
{
raw: f,
expected: map[string]interface{}{"fall": "to", "floating.0": -1.273085434e+09, "floating.1": 9.53442581e+08, "floating.2.gray.buy": true, "floating.2.gray.hold.0.0": 1.81518765e+08, "floating.2.gray.hold.0.1.0.flies": -1.571371799e+09, "floating.2.gray.hold.0.1.0.leather": "across", "floating.2.gray.hold.0.1.0.over": 5.12666854e+08, "floating.2.gray.hold.0.1.0.shaking": true, "floating.2.gray.hold.0.1.0.steam.ago": true, "floating.2.gray.hold.0.1.0.steam.appropriate": 1.249911539e+09, "floating.2.gray.hold.0.1.0.steam.box": false, "floating.2.gray.hold.0.1.0.steam.cry": 1.463961818e+09, "floating.2.gray.hold.0.1.0.steam.entirely": -8.51427469e+08, "floating.2.gray.hold.0.1.0.steam.through": 6.95239749e+08, "floating.2.gray.hold.0.1.0.thank": true, "floating.2.gray.hold.0.1.1": "hit", "floating.2.gray.hold.0.1.2": -6.481787444899056e+08, "floating.2.gray.hold.0.1.3": 1.225027271e+09, "floating.2.gray.hold.0.1.4": -1.481507228e+09, "floating.2.gray.hold.0.1.5": true, "floating.2.gray.hold.0.2": -2.114582277e+09, "floating.2.gray.hold.0.3": 1.3900602049360588e+09, "floating.2.gray.hold.0.4": 1.6156026309049141e+09, "floating.2.gray.hold.0.5": "darkness", "floating.2.gray.hold.1": 6.3427197713988304e+07, "floating.2.gray.hold.2": -5.80344963961421e+08, "floating.2.gray.hold.3": "stems", "floating.2.gray.hold.4": 1.016960217612642e+09, "floating.2.gray.hold.5": 1.240918909e+09, "floating.2.gray.parent": "pull", "floating.2.gray.shore": -7.38396277e+08, "floating.2.gray.usually": 1.050049449e+09, "floating.2.gray.wonder": false, "floating.2.joy": "difference", "floating.2.little": "cloud", "floating.2.probably": -4.13625494e+08, "floating.2.ready": "silent", "floating.2.worker": "situation", "floating.3": "grade", "floating.4": false, "floating.5": "thou", "product": "whale", "shop": 1.294397217e+09, "spend": "greatest", "wagon": -1.722583702e+09},
},
{raw: []byte(`{"foo":"bar"}`), expected: map[string]interface{}{"foo": "bar"}},
{raw: []byte(`{"foo":["bar",{"foo":"bar"}]}`), expected: map[string]interface{}{"foo.0": "bar", "foo.1.foo": "bar"}},
{raw: []byte(`{"foo":"bar","baz":{"bar":"foo"}}`), expected: map[string]interface{}{"foo": "bar", "baz.bar": "foo"}},
{
raw: []byte(`{"foo":"bar","baz":{"bar":"foo"},"bar":["foo","bar","baz"]}`),
expected: map[string]interface{}{"bar.0": "foo", "bar.1": "bar", "bar.2": "baz", "baz.bar": "foo", "foo": "bar"},
},
{raw: []byte(`[]`), expected: nil},
{raw: []byte(`null`), expected: nil},
{raw: []byte(`"bar"`), expected: nil},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
assert.EqualValues(t, tc.expected, Flatten(tc.raw))
})
}
}

View File

@ -1,141 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jsonx
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestGetJSONKeys(t *testing.T) {
type A struct {
B string
}
for _, tc := range []struct {
name string
input interface{}
expected []string
}{
{
name: "simple struct",
input: struct {
A, B string
}{},
expected: []string{"A", "B"},
},
{
name: "struct with json tags",
input: struct {
A string `json:"a"`
B string `json:"b"`
}{},
expected: []string{"a", "b"},
},
{
name: "struct with unexported field",
input: struct {
A, b string
C string `json:"c"`
}{},
expected: []string{"A", "c"},
},
{
name: "struct with omitempty",
input: struct {
A string `json:"a"`
B string `json:"b,omitempty"`
}{
B: "we have to set this to a non-empty value because gjson keys collection will not work otherwise",
},
expected: []string{"a", "b"},
},
{
name: "pointer to struct",
input: &struct {
A string
}{},
expected: []string{"A"},
},
{
name: "embedded struct",
input: struct {
A
}{},
expected: []string{"B"},
},
{
name: "nested structs",
input: struct {
A struct {
B string `json:"b"`
} `json:"a"`
}{},
expected: []string{"a.b"},
},
} {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, AllValidJSONKeys(tc.input))
// collect keys with gjson, which only works reliably for non-omitempty fields
var collectKeys func(gjson.Result) []string
collectKeys = func(res gjson.Result) []string {
var keys []string
res.ForEach(func(key, value gjson.Result) bool {
if value.IsObject() {
childKeys := collectKeys(value)
for _, k := range childKeys {
keys = append(keys, key.String()+"."+k)
}
} else {
keys = append(keys, key.String())
}
return true
})
return keys
}
assert.ElementsMatch(t, tc.expected, collectKeys(gjson.Parse(TestMarshalJSONString(t, tc.input))))
})
}
}
func TestResultGetValidKey(t *testing.T) {
t.Run("case=fails on invalid key", func(t *testing.T) {
r := ParseEnsureKeys(struct{ A string }{}, []byte("{}"))
assert.Panics(t, func() {
r.GetRequireValidKey(&panicFail{}, "b")
})
})
t.Run("case=does not fail on valid key", func(t *testing.T) {
r := ParseEnsureKeys(struct{ A string }{}, []byte(`{"A":"a"}`))
var v string
require.NotPanics(t, func() {
v = r.GetRequireValidKey(&panicFail{}, "A").Str
})
assert.Equal(t, "a", v)
})
t.Run("case=nested key", func(t *testing.T) {
r := ParseEnsureKeys(struct{ A struct{ B string } }{}, []byte(`{"A":{"B":"b"}}`))
var v string
require.NotPanics(t, func() {
v = r.GetRequireValidKey(&panicFail{}, "A.B").Str
})
assert.Equal(t, "b", v)
})
}
var _ require.TestingT = (*panicFail)(nil)
type panicFail struct{}
func (*panicFail) Errorf(string, ...interface{}) {}
func (*panicFail) FailNow() {
panic("failing")
}

View File

@ -1,183 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jsonx
import (
"testing"
"github.com/mohae/deepcopy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type TestType struct {
Field1 string
Field2 []string
Field3 struct {
Field1 bool
Field2 []int
}
FieldNull *struct {
Field1 any
}
OmitEmptyField string `json:"OmitEmptyField,omitempty"`
}
func TestApplyJSONPatch(t *testing.T) {
object := TestType{
Field1: "foo",
Field2: []string{
"foo",
"bar",
"baz",
"kaz",
},
Field3: struct {
Field1 bool
Field2 []int
}{
Field1: true,
Field2: []int{
1,
2,
3,
},
},
}
t.Run("case=empty patch", func(t *testing.T) {
rawPatch := []byte(`[]`)
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj))
require.Equal(t, object, obj)
})
t.Run("case=field replace", func(t *testing.T) {
rawPatch := []byte(`[{"op": "replace", "path": "/Field1", "value": "boo"}]`)
expected := deepcopy.Copy(object).(TestType)
expected.Field1 = "boo"
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj))
require.Equal(t, expected, obj)
})
t.Run("case=array replace", func(t *testing.T) {
rawPatch := []byte(`[{"op": "replace", "path": "/Field2/0", "value": "boo"}]`)
expected := deepcopy.Copy(object).(TestType)
expected.Field2[0] = "boo"
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj))
require.Equal(t, expected, obj)
})
t.Run("case=array append", func(t *testing.T) {
rawPatch := []byte(`[{"op": "add", "path": "/Field2/-", "value": "boo"}]`)
expected := deepcopy.Copy(object).(TestType)
expected.Field2 = append(expected.Field2, "boo")
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj))
require.Equal(t, expected, obj)
})
t.Run("case=array remove", func(t *testing.T) {
rawPatch := []byte(`[{"op": "remove", "path": "/Field2/0"}]`)
expected := deepcopy.Copy(object).(TestType)
expected.Field2 = expected.Field2[1:]
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj))
require.Equal(t, expected, obj)
})
t.Run("case=nested field replace", func(t *testing.T) {
rawPatch := []byte(`[{"op": "replace", "path": "/Field3/Field1", "value": false}]`)
expected := deepcopy.Copy(object).(TestType)
expected.Field3.Field1 = false
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj))
require.Equal(t, expected, obj)
})
t.Run("case=nested array append", func(t *testing.T) {
rawPatch := []byte(`[{"op": "add", "path": "/Field3/Field2/-", "value": 4}]`)
expected := deepcopy.Copy(object).(TestType)
expected.Field3.Field2 = append(expected.Field3.Field2, 4)
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj))
require.Equal(t, expected, obj)
})
t.Run("case=nested array remove", func(t *testing.T) {
rawPatch := []byte(`[{"op": "remove", "path": "/Field3/Field2/2"}]`)
expected := deepcopy.Copy(object).(TestType)
expected.Field3.Field2 = expected.Field3.Field2[:2]
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj))
require.Equal(t, expected, obj)
})
t.Run("case=patch denied path", func(t *testing.T) {
for _, path := range []string{
"/Field1",
"/field1",
"/fIeld1",
"/FIELD1",
} {
t.Run("path="+path, func(t *testing.T) {
rawPatch := []byte(`[{"op": "replace", "path": "/Field1", "value": "bar"}]`)
obj := deepcopy.Copy(object).(TestType)
assert.Error(t, ApplyJSONPatch(rawPatch, &obj, path))
require.Equal(t, object, obj)
})
}
})
t.Run("case=patch denied sub-path", func(t *testing.T) {
rawPatch := []byte(`[{"op": "replace", "path": "/Field3/Field1", "value": true}]`)
obj := deepcopy.Copy(object).(TestType)
err := ApplyJSONPatch(rawPatch, &obj, "/Field3/**", "/Field1/*/Unknown")
require.Error(t, err)
require.Equal(t, object, obj)
})
t.Run("case=patch allowed path", func(t *testing.T) {
rawPatch := []byte(`[{"op": "add", "path": "/Field2/-", "value": "bar"}]`)
expected := deepcopy.Copy(object).(TestType)
expected.Field2 = append(expected.Field2, "bar")
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj, "/Field1"))
require.Equal(t, expected, obj)
})
t.Run("case=patch object field when object null", func(t *testing.T) {
rawPatch := []byte(`[{"op": "add", "path": "/FieldNull/Field1", "value": "bar"}]`)
expected := deepcopy.Copy(object).(TestType)
expected.FieldNull = &struct{ Field1 any }{Field1: "bar"}
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj, "/Field1"))
require.Equal(t, expected, obj)
})
t.Run("case=replace non-existing path adds value", func(t *testing.T) {
rawPatch := []byte(`[{"op": "replace", "path": "/OmitEmptyField", "value": "boo"}]`)
expected := deepcopy.Copy(object).(TestType)
expected.OmitEmptyField = "boo"
obj := deepcopy.Copy(object).(TestType)
require.NoError(t, ApplyJSONPatch(rawPatch, &obj))
require.Equal(t, expected, obj)
})
t.Run("suite=invalid patches", func(t *testing.T) {
cases := []struct {
name string
patch []byte
}{{
name: "test",
patch: []byte(`[{"op": "test", "path": "/"}]`),
}, {
name: "add",
patch: []byte(`[{"op": "add", "path": "/"}]`),
}, {
name: "remove",
patch: []byte(`[{"op": "remove"}]`),
}, {
name: "replace",
patch: []byte(`[{"op": "replace", "path": "/"}]`),
}}
for _, tc := range cases {
t.Run("case="+tc.name, func(t *testing.T) {
obj := &TestType{}
assert.Error(t, ApplyJSONPatch(tc.patch, &obj))
})
}
})
}

View File

@ -1,55 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jwksx
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
keys = `{
"keys": [
{
"use": "sig",
"kty": "oct",
"kid": "7d5f5ad0674ec2f2960b1a34f33370a0f71471fa0e3ef0c0a692977d276dafe8",
"alg": "HS256",
"k": "Y2hhbmdlbWVjaGFuZ2VtZWNoYW5nZW1lY2hhbmdlbWU"
}
]
}`
secret = "changemechangemechangemechangeme"
)
func TestFetcher(t *testing.T) {
var called int
var h http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
called++
w.Write([]byte(keys))
}
ts := httptest.NewServer(h)
defer ts.Close()
f := NewFetcher(ts.URL)
k, err := f.GetKey("7d5f5ad0674ec2f2960b1a34f33370a0f71471fa0e3ef0c0a692977d276dafe8")
require.NoError(t, err)
assert.EqualValues(t, secret, fmt.Sprintf("%s", k.Key))
assert.Equal(t, 1, called)
k, err = f.GetKey("7d5f5ad0674ec2f2960b1a34f33370a0f71471fa0e3ef0c0a692977d276dafe8")
require.NoError(t, err)
assert.EqualValues(t, secret, fmt.Sprintf("%s", k.Key))
assert.Equal(t, 1, called)
_, err = f.GetKey("does-not-exist")
require.Error(t, err)
assert.Equal(t, 2, called)
}

View File

@ -1,212 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jwksx
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/lestrrat-go/jwx/jwk"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
"github.com/dgraph-io/ristretto/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/x/snapshotx"
)
const (
multiKeys = `{
"keys": [
{
"use": "sig",
"kty": "oct",
"kid": "7d5f5ad0674ec2f2960b1a34f33370a0f71471fa0e3ef0c0a692977d276dafe8",
"alg": "HS256",
"k": "Y2hhbmdlbWVjaGFuZ2VtZWNoYW5nZW1lY2hhbmdlbWU"
},
{
"use": "sig",
"kty": "oct",
"kid": "8d5f5ad0674ec2f2960b1a34f33370a0f71471fa0e3ef0c0a692977d276dafe8",
"alg": "HS256",
"k": "Y2hhbmdlbWVjaGFuZ2VtZWNoYW5nZW1lY2hhbmdlbWU"
},
{
"use": "sig",
"kty": "oct",
"kid": "9d5f5ad0674ec2f2960b1a34f33370a0f71471fa0e3ef0c0a692977d276dafe8",
"alg": "HS256",
"k": "Y2hhbmdlbWVjaGFuZ2VtZWNoYW5nZW1lY2hhbmdlbWU"
}
]
}`
)
type brokenTransport struct{}
var _ http.RoundTripper = new(brokenTransport)
var errBroken = errors.New("broken")
func (b brokenTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
return nil, errBroken
}
func TestFetcherNext(t *testing.T) {
ctx := context.Background()
cache, err := ristretto.NewCache[[]byte, jwk.Set](&ristretto.Config[[]byte, jwk.Set]{
NumCounters: 100 * 10,
MaxCost: 100,
BufferItems: 64,
Metrics: true,
IgnoreInternalCost: true,
Cost: func(jwk.Set) int64 {
return 1
},
})
require.NoError(t, err)
f := NewFetcherNext(cache)
createRemoteProvider := func(called *int, payload string) *httptest.Server {
cache.Clear()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
*called++
_, _ = w.Write([]byte(payload))
}))
t.Cleanup(ts.Close)
return ts
}
t.Run("case=resolve multiple source urls", func(t *testing.T) {
t.Run("case=fails without forced kid", func(t *testing.T) {
var called int
ts1 := createRemoteProvider(&called, keys)
ts2 := createRemoteProvider(&called, multiKeys)
_, err := f.ResolveKeyFromLocations(ctx, []string{ts1.URL, ts2.URL})
require.Error(t, err)
})
t.Run("case=succeeds with forced kid", func(t *testing.T) {
var called int
ts1 := createRemoteProvider(&called, keys)
ts2 := createRemoteProvider(&called, multiKeys)
k, err := f.ResolveKeyFromLocations(ctx, []string{ts1.URL, ts2.URL}, WithForceKID("8d5f5ad0674ec2f2960b1a34f33370a0f71471fa0e3ef0c0a692977d276dafe8"))
require.NoError(t, err)
snapshotx.SnapshotT(t, k)
})
})
t.Run("case=resolve single source url", func(t *testing.T) {
t.Run("case=with forced key", func(t *testing.T) {
var called int
ts := createRemoteProvider(&called, keys)
k, err := f.ResolveKey(ctx, ts.URL, WithForceKID("7d5f5ad0674ec2f2960b1a34f33370a0f71471fa0e3ef0c0a692977d276dafe8"))
require.NoError(t, err)
snapshotx.SnapshotT(t, k)
})
t.Run("case=forced key is not found", func(t *testing.T) {
var called int
ts := createRemoteProvider(&called, keys)
_, err := f.ResolveKey(ctx, ts.URL, WithForceKID("not-found"))
require.Error(t, err)
})
t.Run("case=no key in remote", func(t *testing.T) {
var called int
ts := createRemoteProvider(&called, "{}")
_, err := f.ResolveKey(ctx, ts.URL)
require.Error(t, err)
})
t.Run("case=remote not returning JSON", func(t *testing.T) {
var called int
ts := createRemoteProvider(&called, "lol")
_, err := f.ResolveKey(ctx, ts.URL)
require.Error(t, err)
})
t.Run("case=without cache", func(t *testing.T) {
var called int
ts := createRemoteProvider(&called, keys)
k, err := f.ResolveKey(ctx, ts.URL)
require.NoError(t, err)
snapshotx.SnapshotT(t, k)
assert.Equal(t, called, 1)
cache.Wait()
_, err = f.ResolveKey(ctx, ts.URL)
require.NoError(t, err)
assert.Equal(t, called, 2)
})
t.Run("case=with cache", func(t *testing.T) {
var called int
ts := createRemoteProvider(&called, keys)
k, err := f.ResolveKey(ctx, ts.URL, WithCacheEnabled())
require.NoError(t, err)
assert.Equal(t, called, 1)
cache.Wait()
k, err = f.ResolveKey(ctx, ts.URL, WithCacheEnabled())
require.NoError(t, err)
assert.Equal(t, called, 1)
snapshotx.SnapshotT(t, k)
})
t.Run("case=with cache and TTL", func(t *testing.T) {
var called int
ts := createRemoteProvider(&called, keys)
waitTime := time.Millisecond * 100
k, err := f.ResolveKey(ctx, ts.URL, WithCacheEnabled(), WithCacheTTL(waitTime))
require.NoError(t, err)
assert.Equal(t, called, 1)
cache.Wait()
k, err = f.ResolveKey(ctx, ts.URL, WithCacheEnabled())
require.NoError(t, err)
assert.Equal(t, called, 1)
time.Sleep(waitTime)
cache.Wait()
k, err = f.ResolveKey(ctx, ts.URL, WithCacheEnabled())
require.NoError(t, err)
assert.Equal(t, called, 2)
snapshotx.SnapshotT(t, k)
})
t.Run("case=with broken HTTP client", func(t *testing.T) {
var called int
ts := createRemoteProvider(&called, keys)
broken := retryablehttp.NewClient()
broken.RetryMax = 0
broken.HTTPClient.Transport = new(brokenTransport)
_, err := f.ResolveKey(ctx, ts.URL, WithHTTPClient(broken))
require.ErrorIs(t, err, errBroken)
})
})
}

View File

@ -1,37 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jwksx
import (
"fmt"
"testing"
"github.com/go-jose/go-jose/v3"
"github.com/stretchr/testify/require"
)
func TestGenerateSigningKeys(t *testing.T) {
for _, alg := range GenerateSigningKeysAvailableAlgorithms() {
t.Run(fmt.Sprintf("alg=%s", alg), func(t *testing.T) {
key, err := GenerateSigningKeys("", alg, 0)
require.NoError(t, err)
t.Logf("%+v", key)
})
}
for _, tc := range []struct {
alg jose.SignatureAlgorithm
bits int
}{
{alg: jose.HS256, bits: 128}, // should fail because minimum 256 bit
{alg: jose.HS384, bits: 256}, // should fail because minimum 384 bit
{alg: jose.HS512, bits: 384}, // should fail because minimum 512 bit
{alg: jose.HS512, bits: 555}, // should fail because not modulo 8
} {
t.Run(fmt.Sprintf("alg=%s/bit=%d", tc.alg, tc.bits), func(t *testing.T) {
_, err := GenerateSigningKeys("", string(tc.alg), tc.bits)
require.Error(t, err)
})
}
}

View File

@ -1,176 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jwtmiddleware_test
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/tidwall/gjson"
"github.com/golang-jwt/jwt/v5"
"github.com/rakutentech/jwk-go/jwk"
"github.com/stretchr/testify/assert"
"github.com/ory/x/jwtmiddleware"
_ "embed"
"github.com/tidwall/sjson"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func mustString(s string, err error) string {
if err != nil {
panic(err)
}
return s
}
var key *jwk.KeySpec
//go:embed stub/jwks.json
var rawKey []byte
func init() {
key = &jwk.KeySpec{}
if err := json.Unmarshal(rawKey, key); err != nil {
panic(err)
}
}
func newKeyServer(t *testing.T) string {
public, err := key.PublicOnly()
require.NoError(t, err)
keys, err := json.Marshal(map[string]interface{}{
"keys": []interface{}{
public,
},
})
require.NoError(t, err)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(keys)
}))
t.Cleanup(ts.Close)
return ts.URL
}
func TestSessionFromRequest(t *testing.T) {
ks := newKeyServer(t)
router := httprouter.New()
router.GET("/anonymous", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.Write([]byte("ok"))
})
router.GET("/me", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
s, err := jwtmiddleware.SessionFromContext(r.Context())
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(s))
})
n := negroni.New()
n.Use(jwtmiddleware.NewMiddleware(ks, jwtmiddleware.MiddlewareExcludePaths("/anonymous")))
n.UseHandler(router)
ts := httptest.NewServer(n)
defer ts.Close()
for k, tc := range []struct {
token string
expectedStatusCode int
expectedErrorReason string
expectedResponse string
}{
// token without token
{
token: "",
expectedStatusCode: 401,
expectedErrorReason: "Authorization header format must be Bearer {token}",
},
// token without kid
{
token: func() string {
c := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{})
delete(c.Header, "kid")
s, err := c.SignedString(key.Key)
require.NoError(t, err)
return s
}(),
expectedStatusCode: 401,
expectedErrorReason: "token is unverifiable: error while executing keyfunc: jwt from authorization HTTP header is missing value for \"kid\" in token header",
},
// token with int kid
{
token: func() string {
c := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{})
c.Header["kid"] = 42
s, err := c.SignedString(key.Key)
require.NoError(t, err)
return s
}(),
expectedStatusCode: 401,
expectedErrorReason: "token is unverifiable: error while executing keyfunc: jwt from authorization HTTP header is expecting string value for \"kid\" in tokenWithoutKid header but got: float64",
},
// token with unknown kid
{
token: func() string {
c := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{})
c.Header["kid"] = "not " + key.KeyID
s, err := c.SignedString(key.Key)
require.NoError(t, err)
return s
}(),
expectedStatusCode: 401,
expectedErrorReason: "token is unverifiable: error while executing keyfunc: unable to find JSON Web Key with ID: not b71ff5bd-a016-4ac0-9f3f-a172552578ea",
},
// token with valid kid
{
token: func() string {
c := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
"identity": map[string]interface{}{"email": "foo@bar.com"},
})
c.Header["kid"] = key.KeyID
s, err := c.SignedString(key.Key)
require.NoError(t, err)
return s
}(),
expectedStatusCode: 200,
expectedResponse: mustString(sjson.SetRaw("{}", "identity.email", `"foo@bar.com"`)),
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
req, err := http.NewRequest("GET", ts.URL+"/me", nil)
require.NoError(t, err)
req.Header.Set("Authorization", "bearer "+tc.token)
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, tc.expectedStatusCode, res.StatusCode, string(body))
assert.Equal(t, tc.expectedErrorReason, gjson.GetBytes(body, "error.reason").String())
if tc.expectedResponse != "" {
assert.JSONEq(t, tc.expectedResponse, string(body))
}
})
}
res, err := http.Get(ts.URL + "/anonymous")
require.NoError(t, err)
assert.Equal(t, 200, res.StatusCode)
}

View File

@ -1,64 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jwtx
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseMapStringInterfaceClaims(t *testing.T) {
assert.EqualValues(t, &Claims{
JTI: "jti",
Subject: "sub",
Issuer: "iss",
Audience: []string{"aud"},
ExpiresAt: time.Unix(1234, 0),
IssuedAt: time.Unix(1234, 0),
NotBefore: time.Unix(1234, 0),
}, ParseMapStringInterfaceClaims(map[string]interface{}{
"jti": "jti",
"aud": "aud",
"iss": "iss",
"sub": "sub",
"exp": 1234,
"iat": 1234,
"nbf": 1234,
}))
assert.EqualValues(t, &Claims{
Audience: []string{"aud", "dua"},
ExpiresAt: time.Unix(1234, 0),
IssuedAt: time.Unix(1234, 0),
NotBefore: time.Unix(1234, 0),
}, ParseMapStringInterfaceClaims(map[string]interface{}{
"aud": []string{"aud", "dua"},
"exp": 1234,
"iat": 1234,
"nbf": 1234,
}))
out, err := json.Marshal(map[string]interface{}{
"aud": []string{"aud", "dua"},
"exp": 1234,
"iat": 1234,
"nbf": 1234,
})
require.NoError(t, err)
var in map[string]interface{}
require.NoError(t, json.Unmarshal(out, &in))
assert.EqualValues(t, &Claims{
Audience: []string{"aud", "dua"},
ExpiresAt: time.Unix(1234, 0),
IssuedAt: time.Unix(1234, 0),
NotBefore: time.Unix(1234, 0),
}, ParseMapStringInterfaceClaims(in))
}

View File

@ -1,74 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package logrusx
import (
"context"
"testing"
"github.com/sirupsen/logrus/hooks/test"
"github.com/knadh/koanf/parsers/json"
"github.com/knadh/koanf/providers/rawbytes"
"github.com/knadh/koanf/v2"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"
"github.com/ory/jsonschema/v3"
)
func TestConfigSchema(t *testing.T) {
config := func(t *testing.T, vals map[string]interface{}) []byte {
rawConfig, err := sjson.Set("{}", "log", vals)
require.NoError(t, err)
return []byte(rawConfig)
}
t.Run("case=basic validation and retrieval", func(t *testing.T) {
c := jsonschema.NewCompiler()
require.NoError(t, AddConfigSchema(c))
schema, err := c.Compile(context.Background(), ConfigSchemaID)
require.NoError(t, err)
logConfig := map[string]interface{}{
"level": "trace",
"format": "json_pretty",
"leak_sensitive_values": true,
"additional_redacted_headers": []interface{}{
"custom_header_1",
"custom_header_2",
},
}
assert.NoError(t, schema.ValidateInterface(logConfig))
k := koanf.New(".")
require.NoError(t, k.Load(rawbytes.Provider(config(t, logConfig)), json.Parser()))
l := New("foo", "bar", WithConfigurator(k))
assert.True(t, l.leakSensitive)
assert.Equal(t, logrus.TraceLevel, l.Logger.Level)
assert.Contains(t, l.additionalRedactedHeaders, "custom_header_1")
assert.Contains(t, l.additionalRedactedHeaders, "custom_header_2")
assert.IsType(t, &logrus.JSONFormatter{}, l.Logger.Formatter)
})
t.Run("case=warns on unknown format", func(t *testing.T) {
h := &test.Hook{}
New("foo", "bar", WithHook(h), ForceFormat("unknown"))
require.Len(t, h.Entries, 1)
assert.Contains(t, h.LastEntry().Message, "got unknown \"log.format\", falling back to \"text\"")
})
t.Run("case=does not warn on text format", func(t *testing.T) {
h := &test.Hook{}
New("foo", "bar", WithHook(h), ForceFormat("text"))
assert.Len(t, h.Entries, 0)
})
}

View File

@ -1,287 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package logrusx_test
import (
"bytes"
"net/http"
"net/url"
"strconv"
"strings"
"testing"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/ory/herodot"
. "github.com/ory/x/logrusx"
)
var fakeRequest = &http.Request{
Method: "GET",
URL: &url.URL{Path: "/foo/bar", RawQuery: "bar=foo"},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{
"User-Agent": {"Go-http-client/1.1"},
"Accept-Encoding": {"gzip"},
"X-Request-Id": {"id1234"},
"Accept": {"application/json"},
"Set-Cookie": {"kratos_session=2198ef09ac09d09ff098dd123ab128353"},
"Cookie": {"kratos_cookie=2198ef09ac09d09ff098dd123ab128353"},
"X-Session-Token": {"2198ef09ac09d09ff098dd123ab128353"},
"X-Custom-Header": {"2198ef09ac09d09ff098dd123ab128353"},
"Authorization": {"Bearer 2198ef09ac09d09ff098dd123ab128353"},
},
Body: nil,
Host: "127.0.0.1:63232",
RemoteAddr: "127.0.0.1:63233",
RequestURI: "/foo/bar?bar=foo",
}
func TestOptions(t *testing.T) {
logger := New("", "", ForceLevel(logrus.DebugLevel))
assert.EqualValues(t, logrus.DebugLevel, logger.Logger.Level)
}
func TestJSONFormatter(t *testing.T) {
t.Run("pretty=true", func(t *testing.T) {
l := New("logrusx-audit", "v0.0.0", ForceFormat("json_pretty"), ForceLevel(logrus.DebugLevel))
var b bytes.Buffer
l.Logrus().Out = &b
l.Info("foo bar")
assert.True(t, strings.Count(b.String(), "\n") > 1)
assert.Contains(t, b.String(), " ")
})
t.Run("pretty=false", func(t *testing.T) {
l := New("logrusx-audit", "v0.0.0", ForceFormat("json"), ForceLevel(logrus.DebugLevel))
var b bytes.Buffer
l.Logrus().Out = &b
l.Info("foo bar")
assert.EqualValues(t, 1, strings.Count(b.String(), "\n"))
assert.NotContains(t, b.String(), " ")
})
}
func TestGelfFormatter(t *testing.T) {
t.Run("gelf formatter", func(t *testing.T) {
l := New("logrusx-audit", "v0.0.0", ForceFormat("gelf"), ForceLevel(logrus.DebugLevel))
var b bytes.Buffer
l.Logrus().Out = &b
l.Info("foo bar")
assert.Contains(t, b.String(), "_pid")
assert.Contains(t, b.String(), "level")
assert.Contains(t, b.String(), "short_message")
})
}
func TestTextLogger(t *testing.T) {
audit := NewAudit("logrusx-audit", "v0.0.0", ForceFormat("text"), ForceLevel(logrus.TraceLevel))
tracer := New("logrusx-app", "v0.0.0", ForceFormat("text"), ForceLevel(logrus.TraceLevel))
debugger := New("logrusx-server", "v0.0.1", ForceFormat("text"), ForceLevel(logrus.DebugLevel))
warner := New("logrusx-server", "v0.0.1", ForceFormat("text"), ForceLevel(logrus.WarnLevel))
for k, tc := range []struct {
l *Logger
expect []string
notExpect []string
call func(l *Logger)
}{
{
l: audit,
expect: []string{"logrus_test.go", "logrusx_test.TestTextLogger",
"audience=audit", "service_name=logrusx-audit", "service_version=v0.0.0",
"An error occurred.", "message:some error", "trace", "testing.tRunner"},
call: func(l *Logger) {
l.WithError(errors.New("some error")).Error("An error occurred.")
},
},
{
l: tracer,
expect: []string{"logrus_test.go", "logrusx_test.TestTextLogger",
"audience=application", "service_name=logrusx-app", "service_version=v0.0.0",
"An error occurred.", "message:some error", "trace", "testing.tRunner"},
call: func(l *Logger) {
l.WithError(errors.New("some error")).Error("An error occurred.")
},
},
{
l: tracer,
expect: []string{"logrus_test.go", "logrusx_test.TestTextLogger",
"audience=application", "service_name=logrusx-app", "service_version=v0.0.0",
"An error occurred.", "headers:map[", "accept:application/json", "accept-encoding:gzip",
"user-agent:Go-http-client/1.1", "x-request-id:id1234", "host:127.0.0.1:63232", "method:GET",
"query:Value is sensitive and has been redacted. To see the value set config key \"log.leak_sensitive_values = true\" or environment variable \"LOG_LEAK_SENSITIVE_VALUES=true\".",
"remote:127.0.0.1:63233", "scheme:http", "path:/foo/bar",
},
notExpect: []string{"testing.tRunner", "bar=foo"},
call: func(l *Logger) {
l.WithRequest(fakeRequest).Error("An error occurred.")
},
},
{
l: New("logrusx-app", "v0.0.0", ForceFormat("text"), ForceLevel(logrus.TraceLevel), RedactionText("redacted")),
expect: []string{"logrus_test.go", "logrusx_test.TestTextLogger",
"audience=application", "service_name=logrusx-app", "service_version=v0.0.0",
"An error occurred.", "headers:map[", "accept:application/json", "accept-encoding:gzip",
"user-agent:Go-http-client/1.1", "x-request-id:id1234", "host:127.0.0.1:63232", "method:GET",
"query:redacted",
},
notExpect: []string{"testing.tRunner", "bar=foo"},
call: func(l *Logger) {
l.WithRequest(fakeRequest).Error("An error occurred.")
},
},
{
l: New("logrusx-server", "v0.0.1", ForceFormat("text"), LeakSensitive(), ForceLevel(logrus.DebugLevel)),
expect: []string{
"audience=application", "service_name=logrusx-server", "service_version=v0.0.1",
"An error occurred.",
"headers:map[", "accept:application/json", "accept-encoding:gzip",
"user-agent:Go-http-client/1.1", "x-request-id:id1234", "host:127.0.0.1:63232", "method:GET",
"query:bar=foo",
"remote:127.0.0.1:63233", "scheme:http", "path:/foo/bar",
},
notExpect: []string{"logrus_test.go", "logrusx_test.TestTextLogger", "testing.tRunner", "?bar=foo"},
call: func(l *Logger) {
l.WithRequest(fakeRequest).Error("An error occurred.")
},
},
{
l: tracer,
expect: []string{"logrus_test.go", "logrusx_test.TestTextLogger",
"audience=application", "service_name=logrusx-app", "service_version=v0.0.0",
"An error occurred.", "message:The requested resource could not be found", "reason:some reason",
"status:Not Found", "status_code:404", "debug:some debug", "trace", "testing.tRunner"},
call: func(l *Logger) {
l.WithError(errors.WithStack(herodot.ErrNotFound.WithReason("some reason").WithDebug("some debug"))).Error("An error occurred.")
},
},
{
l: debugger,
expect: []string{"audience=application", "service_name=logrusx-server", "service_version=v0.0.1",
"An error occurred.", "message:some error"},
call: func(l *Logger) {
l.WithError(errors.New("some error")).Error("An error occurred.")
},
},
{
l: warner,
expect: []string{"audience=application", "service_name=logrusx-server", "service_version=v0.0.1",
"An error occurred.", "message:some error"},
notExpect: []string{"logrus_test.go", "logrusx_test.TestTextLogger", "trace", "testing.tRunner"},
call: func(l *Logger) {
l.WithError(errors.New("some error")).Error("An error occurred.")
},
},
{
l: debugger,
expect: []string{"audience=application", "service_name=logrusx-server", "service_version=v0.0.1", "baz!", "foo=bar"},
notExpect: []string{"logrus_test.go", "logrusx_test.TestTextLogger"},
call: func(l *Logger) {
l.WithField("foo", "bar").Info("baz!")
},
},
{
l: New("logrusx-server", "v0.0.1", ForceFormat("text"), ForceLevel(logrus.DebugLevel)),
expect: []string{
"set-cookie:Value is sensitive and has been redacted. To see the value set config key \"log.leak_sensitive_values = true\" or environment variable \"LOG_LEAK_SENSITIVE_VALUES=true\".",
`cookie:Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`,
`x-session-token:Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`,
`authorization:Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`,
"x-custom-header:2198ef09ac09d09ff098dd123ab128353",
},
notExpect: []string{
"set-cookie:kratos_session=2198ef09ac09d09ff098dd123ab128353",
"cookie:kratos_cookie=2198ef09ac09d09ff098dd123ab128353",
"x-session-token:2198ef09ac09d09ff098dd123ab128353",
"authorization:Bearer 2198ef09ac09d09ff098dd123ab128353",
},
call: func(l *Logger) {
l.WithRequest(fakeRequest).Debug()
},
},
{
l: New("logrusx-server", "v0.0.1", ForceFormat("text"), WithAdditionalRedactedHeaders([]string{"x-custom-header"}), ForceLevel(logrus.DebugLevel)),
expect: []string{
"set-cookie:Value is sensitive and has been redacted. To see the value set config key \"log.leak_sensitive_values = true\" or environment variable \"LOG_LEAK_SENSITIVE_VALUES=true\".",
`cookie:Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`,
`x-session-token:Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`,
`authorization:Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`,
`x-custom-header:Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`,
},
notExpect: []string{
"set-cookie:kratos_session=2198ef09ac09d09ff098dd123ab128353",
"cookie:kratos_cookie=2198ef09ac09d09ff098dd123ab128353",
"x-session-token:2198ef09ac09d09ff098dd123ab128353",
"authorization:Bearer 2198ef09ac09d09ff098dd123ab128353",
"x-custom-header:2198ef09ac09d09ff098dd123ab128353",
},
call: func(l *Logger) {
l.WithRequest(fakeRequest).Debug()
},
},
{
l: tracer,
notExpect: []string{"?bar=foo"},
call: func(l *Logger) {
l.Printf("%s", fakeRequest.URL)
},
},
{
l: New("logrusx-app", "v0.0.0", ForceFormat("text"), ForceLevel(logrus.TraceLevel), LeakSensitive()),
expect: []string{"?bar=foo"},
call: func(l *Logger) {
l.Printf("%s", fakeRequest.URL)
},
},
{
l: tracer,
notExpect: []string{"RawQuery:bar=foo"},
call: func(l *Logger) {
l.Printf("%+v", *fakeRequest.URL)
},
},
{
l: New("logrusx-app", "v0.0.0", ForceFormat("text"), ForceLevel(logrus.TraceLevel), LeakSensitive()),
expect: []string{"RawQuery:bar=foo"},
call: func(l *Logger) {
l.Printf("%+v", *fakeRequest.URL)
},
},
} {
t.Run("case="+strconv.Itoa(k), func(t *testing.T) {
var b bytes.Buffer
tc.l.Logrus().Out = &b
tc.call(tc.l)
t.Log(b.String())
for _, expect := range tc.expect {
assert.Contains(t, b.String(), expect)
}
for _, expect := range tc.notExpect {
assert.NotContains(t, b.String(), expect)
}
})
}
}
func TestLogger(t *testing.T) {
l := New("logrus test", "test")
t.Run("case=does not panic on nil error", func(t *testing.T) {
defer func() {
assert.Nil(t, recover())
}()
l.WithError(nil)
})
}

View File

@ -1,171 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package mapx
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetString(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": 1234}
v, err := GetString(m, "foo")
require.NoError(t, err)
assert.EqualValues(t, "bar", v)
_, err = GetString(m, "bar")
require.Error(t, err)
_, err = GetString(m, "baz")
require.Error(t, err)
}
func TestGetStringSlice(t *testing.T) {
m := map[interface{}]interface{}{"foo": []string{"foo", "bar"}, "baz": "bar"}
v, err := GetStringSlice(m, "foo")
require.NoError(t, err)
assert.EqualValues(t, []string{"foo", "bar"}, v)
_, err = GetStringSlice(m, "bar")
require.Error(t, err)
_, err = GetStringSlice(m, "baz")
require.Error(t, err)
}
func TestGetStringSliceDefault(t *testing.T) {
m := map[interface{}]interface{}{"foo": []string{"foo", "bar"}, "baz": "bar"}
assert.EqualValues(t, []string{"foo", "bar"}, GetStringSliceDefault(m, "foo", []string{"default"}))
assert.EqualValues(t, []string{"default"}, GetStringSliceDefault(m, "baz", []string{"default"}))
assert.EqualValues(t, []string{"default"}, GetStringSliceDefault(m, "bar", []string{"default"}))
}
func TestGetStringDefault(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": 1234}
assert.EqualValues(t, "bar", GetStringDefault(m, "foo", "default"))
assert.EqualValues(t, "default", GetStringDefault(m, "baz", "default"))
assert.EqualValues(t, "default", GetStringDefault(m, "bar", "default"))
}
func TestGetFloat32(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": float32(1234)}
v, err := GetFloat32(m, "baz")
require.NoError(t, err)
assert.EqualValues(t, float32(1234), v)
_, err = GetFloat32(m, "foo")
require.Error(t, err)
_, err = GetFloat32(m, "bar")
require.Error(t, err)
}
func TestGetFloat64(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": float64(1234)}
v, err := GetFloat64(m, "baz")
require.NoError(t, err)
assert.EqualValues(t, float64(1234), v)
_, err = GetFloat64(m, "foo")
require.Error(t, err)
_, err = GetFloat64(m, "bar")
require.Error(t, err)
}
func TestGetGetFloat64Default(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": float64(1234)}
v := GetFloat64Default(m, "baz", 0)
assert.EqualValues(t, float64(1234), v)
v = GetFloat64Default(m, "foo", float64(1))
assert.EqualValues(t, float64(1), v)
v = GetFloat64Default(m, "bar", float64(2))
assert.EqualValues(t, float64(2), v)
}
func TestGetGetFloat32Default(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": float32(1234)}
v := GetFloat32Default(m, "baz", 0)
assert.EqualValues(t, float32(1234), v)
v = GetFloat32Default(m, "foo", float32(1))
assert.EqualValues(t, float32(1), v)
v = GetFloat32Default(m, "bar", float32(2))
assert.EqualValues(t, float32(2), v)
}
func TestGetGetInt32Default(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": int32(1234)}
v := GetInt32Default(m, "baz", 0)
assert.EqualValues(t, int32(1234), v)
v = GetInt32Default(m, "foo", int32(1))
assert.EqualValues(t, int32(1), v)
v = GetInt32Default(m, "bar", int32(2))
assert.EqualValues(t, int32(2), v)
}
func TestGetGetInt64Default(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": int64(1234)}
v := GetInt64Default(m, "baz", 0)
assert.EqualValues(t, int64(1234), v)
v = GetInt64Default(m, "foo", int64(1))
assert.EqualValues(t, int64(1), v)
v = GetInt64Default(m, "bar", int64(2))
assert.EqualValues(t, int64(2), v)
}
func TestGetGetIntDefault(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": int(1234)}
v := GetIntDefault(m, "baz", 0)
assert.EqualValues(t, int(1234), v)
v = GetIntDefault(m, "foo", int(1))
assert.EqualValues(t, int(1), v)
v = GetIntDefault(m, "bar", int(2))
assert.EqualValues(t, int(2), v)
}
func TestGetInt64(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": int64(1234)}
v, err := GetInt64(m, "baz")
require.NoError(t, err)
assert.EqualValues(t, int64(1234), v)
_, err = GetInt64(m, "foo")
require.Error(t, err)
_, err = GetInt64(m, "bar")
require.Error(t, err)
}
func TestGetInt32(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": int32(1234), "baz2": int(1234)}
v, err := GetInt32(m, "baz")
require.NoError(t, err)
assert.EqualValues(t, int32(1234), v)
v, err = GetInt32(m, "baz2")
require.NoError(t, err)
assert.EqualValues(t, int32(1234), v)
_, err = GetInt32(m, "foo")
require.Error(t, err)
_, err = GetInt32(m, "bar")
require.Error(t, err)
}
func TestKeyStringToInterface(t *testing.T) {
assert.EqualValues(t, map[interface{}]interface{}{"foo": "bar", "baz": 1234, "baz2": int32(1234)}, KeyStringToInterface(map[string]interface{}{"foo": "bar", "baz": 1234, "baz2": int32(1234)}))
}
func TestGetInt(t *testing.T) {
m := map[interface{}]interface{}{"foo": "bar", "baz": 1234, "baz2": int32(1234)}
v, err := GetInt32(m, "baz")
require.NoError(t, err)
assert.EqualValues(t, int32(1234), v)
_, err = GetInt32(m, "foo")
require.Error(t, err)
_, err = GetInt32(m, "bar")
require.Error(t, err)
}
func TestToJSONMap(t *testing.T) {
assert.EqualValues(t, map[string]interface{}{"baz": []interface{}{map[string]interface{}{"bar": "bar"}}, "foo": "bar"}, ToJSONMap(map[string]interface{}{
"foo": "bar",
"baz": []interface{}{
map[interface{}]interface{}{
"bar": "bar",
},
},
}))
}

View File

@ -1,38 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package metricsx
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAnonymizePath(t *testing.T) {
m := &Service{
o: &Options{WhitelistedPaths: []string{"/keys"}},
}
assert.Equal(t, "/keys", m.anonymizePath("/keys/1234/sub-path"))
assert.Equal(t, "/keys", m.anonymizePath("/keys/1234"))
assert.Equal(t, "/keys", m.anonymizePath("/keys"))
assert.Equal(t, "/", m.anonymizePath("/not-keys"))
}
func TestAnonymizeQuery(t *testing.T) {
m := &Service{}
assert.EqualValues(t, "foo=2ec879270efe890972d975251e9d454f4af49df1f07b4317fd5b6ae90de4c774&foo=1864a573566eba1b9ddab79d8f4bab5a39c938918a21b80a64ae1c9c12fa9aa2&foo2=186084f6bd8e222bedade9439d6ae69ed274b954eeebe9b54fd5f47e54dd7675&foo2=1ee7158281cc3b5a27de4c337e07987e8677f5f687a4671ca369b79c653d379d", m.anonymizeQuery(url.Values{
"foo": []string{"bar", "baz"},
"foo2": []string{"bar2", "baz2"},
}, "somesupersaltysalt"))
assert.EqualValues(t, "", m.anonymizeQuery(url.Values{
"foo": []string{},
}, "somesupersaltysalt"))
assert.EqualValues(t, "foo=", m.anonymizeQuery(url.Values{
"foo": []string{""},
}, "somesupersaltysalt"))
assert.EqualValues(t, "", m.anonymizeQuery(url.Values{}, "somesupersaltysalt"))
}

View File

@ -1,104 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package modx
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const stub = `module github.com/ory/x
// remove once https://github.com/seatgeek/logrus-gelf-formatter/pull/5 is merged
replace github.com/seatgeek/logrus-gelf-formatter => github.com/zepatrik/logrus-gelf-formatter v0.0.0-20210305135027-b8b3731dba10
require (
github.com/DataDog/datadog-go v4.0.0+incompatible // indirect
github.com/bmatcuk/doublestar/v2 v2.0.3
github.com/containerd/containerd v1.4.3 // indirect
github.com/dgraph-io/ristretto v0.0.2
github.com/docker/distribution v2.7.1+incompatible // indirect
github.com/docker/docker v17.12.0-ce-rc1.0.20201201034508-7d75c1d40d88+incompatible
github.com/fatih/structs v1.1.0
github.com/fsnotify/fsnotify v1.4.9
github.com/ghodss/yaml v1.0.0
github.com/go-bindata/go-bindata v3.1.1+incompatible
github.com/go-openapi/errors v0.20.0 // indirect
github.com/go-openapi/runtime v0.19.26
github.com/go-sql-driver/mysql v1.5.0
github.com/gobuffalo/fizz v1.10.0
github.com/gobuffalo/httptest v1.0.2
github.com/gobuffalo/packr v1.22.0
github.com/ory/pop/v5 v5.3.1
github.com/golang/mock v1.3.1
github.com/google/go-jsonnet v0.16.0
github.com/google/uuid v1.1.2
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/go-retryablehttp v0.6.8
github.com/inhies/go-bytesize v0.0.0-20201103132853-d0aed0d254f8
github.com/jackc/pgconn v1.6.0
github.com/jackc/pgx/v4 v4.6.0
github.com/jandelgado/gcov2lcov v1.0.4-0.20210120124023-b83752c6dc08
github.com/jmoiron/sqlx v1.2.0
github.com/julienschmidt/httprouter v1.2.0
github.com/knadh/koanf v0.14.1-0.20201201075439-e0853799f9ec
github.com/lib/pq v1.3.0
github.com/markbates/pkger v0.17.1
github.com/morikuni/aec v1.0.0 // indirect
github.com/opentracing/opentracing-go v1.2.0
github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5
github.com/openzipkin/zipkin-go v0.2.2
github.com/ory/analytics-go/v5 v5.0.0
github.com/ory/dockertest/v3 v3.6.3
github.com/ory/go-acc v0.2.6
github.com/ory/herodot v0.9.2
github.com/ory/jsonschema/v3 v3.0.1
github.com/pborman/uuid v1.2.0
github.com/pelletier/go-toml v1.8.0
github.com/philhofer/fwd v1.0.0 // indirect
github.com/pkg/errors v0.9.1
github.com/pkg/profile v1.2.1
github.com/rs/cors v1.6.0
github.com/rubenv/sql-migrate v0.0.0-20190212093014-1007f53448d7
github.com/seatgeek/logrus-gelf-formatter v0.0.0-20210219220335-367fa274be2c
github.com/sirupsen/logrus v1.6.0
github.com/spf13/cast v1.3.2-0.20200723214538-8d17101741c8
github.com/spf13/cobra v1.0.0
github.com/spf13/pflag v1.0.5
github.com/go-jose/go-jose/v3 v3.0.0-20200630053402-0a67ce9b0693
github.com/stretchr/testify v1.6.1
github.com/tidwall/gjson v1.3.2
github.com/tidwall/sjson v1.0.4
github.com/uber/jaeger-client-go v2.22.1+incompatible
github.com/urfave/negroni v1.0.0
go.elastic.co/apm v1.8.0
go.elastic.co/apm/module/apmot v1.8.0
go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.13.0
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37
gonum.org/v1/plot v0.0.0-20200111075622-4abb28f724d5
google.golang.org/grpc v1.36.0
gopkg.in/DataDog/dd-trace-go.v1 v1.27.0
gopkg.in/square/go-jose.v2 v2.2.2
)
go 1.16
`
func TestVersion(t *testing.T) {
for _, tc := range [][]string{
{"google.golang.org/grpc", "v1.36.0"},
{"golang.org/x/crypto", "v0.0.0-20200510223506-06a226fb4e37"},
} {
v, err := FindVersion([]byte(stub), tc[0])
require.NoError(t, err)
assert.Equal(t, tc[1], v)
}
_, err := FindVersion([]byte(stub), "notgithub.com/idonot/exist")
require.Error(t, err)
}

View File

@ -1,25 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package networkx
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddressIsUnixSocket(t *testing.T) {
for k, tc := range []struct {
a string
e bool
}{
{a: "unix:/var/baz", e: true},
{a: "https://foo", e: false},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
assert.EqualValues(t, tc.e, AddressIsUnixSocket(tc.a))
})
}
}

View File

@ -1,40 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package networkx
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/pop/v6"
"github.com/ory/x/dbal"
"github.com/ory/x/logrusx"
)
func TestManager(t *testing.T) {
ctx := context.Background()
c, err := pop.NewConnection(&pop.ConnectionDetails{URL: dbal.SQLiteInMemory})
require.NoError(t, err)
require.NoError(t, c.Open())
l := logrusx.New("", "")
m := NewManager(c, l, nil)
require.NoError(t, m.MigrateUp(ctx))
first, err := m.Determine(ctx)
require.NoError(t, err)
assert.NotNil(t, first.ID)
second, err := m.Determine(ctx)
require.NoError(t, err)
assert.EqualValues(t, first.ID, second.ID)
}

File diff suppressed because one or more lines are too long

View File

@ -1,57 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package otelx
import (
"bytes"
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"
"github.com/ory/jsonschema/v3"
)
const rootSchema = `{
"properties": {
"tracing": {
"$ref": "%s"
}
}
}
`
func TestConfigSchema(t *testing.T) {
t.Run("func=AddConfigSchema", func(t *testing.T) {
c := jsonschema.NewCompiler()
require.NoError(t, AddConfigSchema(c))
conf := Config{
ServiceName: "Ory X",
Provider: "jaeger",
Providers: ProvidersConfig{
Jaeger: JaegerConfig{
LocalAgentAddress: "localhost:6831",
Sampling: JaegerSampling{
ServerURL: "http://localhost:5778/sampling",
TraceIdRatio: 1,
},
},
},
}
rawConfig, err := sjson.Set("{}", "otelx", &conf)
require.NoError(t, err)
require.NoError(t, c.AddResource("config", bytes.NewBufferString(fmt.Sprintf(rootSchema, ConfigSchemaID))))
schema, err := c.Compile(context.Background(), "config")
require.NoError(t, err)
assert.NoError(t, schema.Validate(bytes.NewBufferString(rawConfig)))
})
}

View File

@ -1,93 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package otelx
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/urfave/negroni"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
)
func TestShouldNotTraceHealthEndpoint(t *testing.T) {
testCases := []struct {
path string
testDescription string
}{
{
path: "health/ready",
testDescription: "health",
},
{
path: "admin/alive",
testDescription: "adminHealth",
},
{
path: "foo/bar",
testDescription: "notHealth",
},
}
for _, test := range testCases {
t.Run(test.testDescription, func(t *testing.T) {
recorder := tracetest.NewSpanRecorder()
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
req := httptest.NewRequest(http.MethodGet, "https://api.example.com/"+test.path, nil)
h := NewHandler(negroni.New(), "test op", otelhttp.WithTracerProvider(tp))
h.ServeHTTP(negroni.NewResponseWriter(httptest.NewRecorder()), req)
spans := recorder.Ended()
if strings.Contains(test.path, "health") {
assert.Len(t, spans, 0)
} else {
assert.Len(t, spans, 1)
}
})
}
}
func TestTraceHandlerSpanName(t *testing.T) {
testCases := []struct {
path string
expectedName string
opts []otelhttp.Option
}{
{
path: "testPath",
expectedName: "/testPath",
opts: []otelhttp.Option{},
},
{
path: "testPath",
expectedName: "/overwritten/name",
opts: []otelhttp.Option{
otelhttp.WithSpanNameFormatter(func(operation string, r *http.Request) string {
return "/overwritten/name"
}),
},
},
}
for _, test := range testCases {
recorder := tracetest.NewSpanRecorder()
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
opts := append([]otelhttp.Option{
otelhttp.WithTracerProvider(tp),
}, test.opts...)
req := httptest.NewRequest(http.MethodGet, "https://api.example.com/"+test.path, nil)
h := TraceHandler(negroni.New(), opts...)
h.ServeHTTP(negroni.NewResponseWriter(httptest.NewRecorder()), req)
spans := recorder.Ended()
assert.Len(t, spans, 1)
assert.Equal(t, test.expectedName, spans[0].Name())
}
}

View File

@ -1,285 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package otelx
import (
"compress/gzip"
"compress/zlib"
"context"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
tracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
"github.com/ory/x/logrusx"
)
const testTracingComponent = "github.com/ory/x/otelx"
func decodeResponseBody(t *testing.T, r *http.Request) []byte {
var reader io.ReadCloser
switch r.Header.Get("Content-Encoding") {
case "gzip":
var err error
reader, err = gzip.NewReader(r.Body)
if err != nil {
t.Fatal(err)
}
case "deflate":
var err error
reader, err = zlib.NewReader(r.Body)
if err != nil {
t.Fatal(err)
}
default:
reader = r.Body
}
respBody, err := io.ReadAll(reader)
require.NoError(t, err)
require.NoError(t, reader.Close())
return respBody
}
type zipkinSpanRequest struct {
Id string
TraceId string
Timestamp uint64
Name string
LocalEndpoint struct {
ServiceName string
}
Tags map[string]string
}
// runTestJaegerAgent starts a mock server listening on a random port for Jaeger spans sent over UDP.
func runTestJaegerAgent(t *testing.T, errs *errgroup.Group, done chan<- struct{}) net.Conn {
addr := "127.0.0.1:0"
udpAddr, err := net.ResolveUDPAddr("udp", addr)
require.NoError(t, err)
srv, err := net.ListenUDP("udp", udpAddr)
require.NoError(t, err)
errs.Go(func() error {
t.Logf("Starting test UDP server for Jaeger spans on %s", srv.LocalAddr().String())
for {
buf := make([]byte, 2048)
_, conn, err := srv.ReadFromUDP(buf)
if err != nil {
return err
}
if conn == nil {
continue
}
if len(buf) != 0 {
t.Log("received span!")
done <- struct{}{}
}
break
}
return nil
})
return srv
}
func TestJaegerTracer(t *testing.T) {
done := make(chan struct{})
errs := errgroup.Group{}
srv := runTestJaegerAgent(t, &errs, done)
jt, err := New(testTracingComponent, logrusx.New("ory/x", "1"), &Config{
ServiceName: "Ory X",
Provider: "jaeger",
Providers: ProvidersConfig{
Jaeger: JaegerConfig{
LocalAgentAddress: srv.LocalAddr().String(),
Sampling: JaegerSampling{
TraceIdRatio: 1,
},
},
},
})
require.NoError(t, err)
trc := jt.Tracer()
_, span := trc.Start(context.Background(), "testSpan")
span.SetAttributes(attribute.Bool("testAttribute", true))
span.End()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatalf("Test server did not receive spans")
}
require.NoError(t, errs.Wait())
}
func TestJaegerTracerRespectsParentSamplingDecision(t *testing.T) {
done := make(chan struct{})
errs := errgroup.Group{}
srv := runTestJaegerAgent(t, &errs, done)
jt, err := New(testTracingComponent, logrusx.New("ory/x", "1"), &Config{
ServiceName: "Ory X",
Provider: "jaeger",
Providers: ProvidersConfig{
Jaeger: JaegerConfig{
LocalAgentAddress: srv.LocalAddr().String(),
Sampling: JaegerSampling{
// Effectively disable local sampling.
TraceIdRatio: 0,
},
},
},
})
require.NoError(t, err)
traceId := strings.Repeat("a", 32)
spanId := strings.Repeat("b", 16)
sampledFlag := "1"
traceHeaders := map[string]string{"uber-trace-id": traceId + ":" + spanId + ":0:" + sampledFlag}
ctx := otel.GetTextMapPropagator().Extract(context.Background(), propagation.MapCarrier(traceHeaders))
spanContext := trace.SpanContextFromContext(ctx)
assert.True(t, spanContext.IsValid())
assert.True(t, spanContext.IsSampled())
assert.True(t, spanContext.IsRemote())
trc := jt.Tracer()
_, span := trc.Start(ctx, "testSpan", trace.WithLinks(trace.Link{SpanContext: spanContext}))
span.SetAttributes(attribute.Bool("testAttribute", true))
span.End()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatalf("Test server did not receive spans")
}
require.NoError(t, errs.Wait())
}
func TestZipkinTracer(t *testing.T) {
done := make(chan struct{})
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(done)
body, err := io.ReadAll(r.Body)
assert.NoError(t, err)
var spans []zipkinSpanRequest
err = json.Unmarshal(body, &spans)
assert.NoError(t, err)
assert.NotEmpty(t, spans[0].Id)
assert.NotEmpty(t, spans[0].TraceId)
assert.Equal(t, "testspan", spans[0].Name)
assert.Equal(t, "ory x", spans[0].LocalEndpoint.ServiceName)
assert.NotNil(t, spans[0].Tags["testTag"])
assert.Equal(t, "true", spans[0].Tags["testTag"])
}))
defer ts.Close()
zt, err := New(testTracingComponent, logrusx.New("ory/x", "1"), &Config{
ServiceName: "Ory X",
Provider: "zipkin",
Providers: ProvidersConfig{
Zipkin: ZipkinConfig{
ServerURL: ts.URL,
Sampling: ZipkinSampling{
SamplingRatio: 1,
},
},
},
})
assert.NoError(t, err)
trc := zt.Tracer()
_, span := trc.Start(context.Background(), "testspan")
span.SetAttributes(attribute.Bool("testTag", true))
span.End()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatalf("Test server did not receive spans")
}
}
func TestOTLPTracer(t *testing.T) {
done := make(chan struct{})
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body := decodeResponseBody(t, r)
var res tracepb.ExportTraceServiceRequest
err := proto.Unmarshal(body, &res)
require.NoError(t, err, "must be able to unmarshal traces")
resourceSpans := res.GetResourceSpans()
spans := resourceSpans[0].GetScopeSpans()[0].GetSpans()
assert.Equal(t, len(spans), 1)
assert.NotEmpty(t, spans[0].GetSpanId())
assert.NotEmpty(t, spans[0].GetTraceId())
assert.Equal(t, "testSpan", spans[0].GetName())
assert.Equal(t, "testAttribute", spans[0].Attributes[0].Key)
close(done)
}))
defer ts.Close()
tsu, err := url.Parse(ts.URL)
require.NoError(t, err)
ot, err := New(testTracingComponent, logrusx.New("ory/x", "1"), &Config{
ServiceName: "ORY X",
Provider: "otel",
Providers: ProvidersConfig{
OTLP: OTLPConfig{
ServerURL: tsu.Host,
Insecure: true,
Sampling: OTLPSampling{
SamplingRatio: 1,
},
},
},
})
assert.NoError(t, err)
trc := ot.Tracer()
_, span := trc.Start(context.Background(), "testSpan")
span.SetAttributes(attribute.Bool("testAttribute", true))
span.End()
select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatalf("Test server did not receive spans")
}
}

View File

@ -1,43 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package semconv
import (
"context"
"testing"
"github.com/gofrs/uuid"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/attribute"
"github.com/ory/x/httpx"
)
func TestAttributesFromContext(t *testing.T) {
ctx := context.Background()
assert.Len(t, AttributesFromContext(ctx), 0)
nid, wsID := uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4())
ctx = ContextWithAttributes(ctx, AttrNID(nid), AttrWorkspace(wsID))
assert.Len(t, AttributesFromContext(ctx), 2)
uid1, uid2 := uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4())
location := httpx.GeoLocation{
City: "Berlin",
Country: "Germany",
Region: "BE",
}
ctx = ContextWithAttributes(ctx, append(AttrGeoLocation(location), AttrIdentityID(uid1), AttrClientIP("127.0.0.1"), AttrIdentityID(uid2))...)
attrs := AttributesFromContext(ctx)
assert.Len(t, attrs, 7, "should deduplicate")
assert.Equal(t, []attribute.KeyValue{
attribute.String(AttributeKeyNID.String(), nid.String()),
attribute.String(AttributeKeyWorkspace.String(), wsID.String()),
attribute.String(AttributeKeyGeoLocationCity.String(), "Berlin"),
attribute.String(AttributeKeyGeoLocationCountry.String(), "Germany"),
attribute.String(AttributeKeyGeoLocationRegion.String(), "BE"),
attribute.String(AttributeKeyClientIP.String(), "127.0.0.1"),
attribute.String(AttributeKeyIdentityID.String(), uid2.String()),
}, attrs, "last duplicate attribute wins")
}

View File

@ -1,144 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package otelx
import (
"context"
"errors"
"fmt"
"slices"
"testing"
pkgerrors "github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
)
var errPanic = errors.New("panic-error")
type errWithReason struct {
error
}
func (*errWithReason) Reason() string {
return "some interesting error reason"
}
func (errWithReason) Debug() string {
return "verbose debugging information"
}
func TestWithSpan(t *testing.T) {
tracer := noop.NewTracerProvider().Tracer("test")
ctx, span := tracer.Start(context.Background(), "parent")
defer span.End()
assert.NoError(t, WithSpan(ctx, "no-error", func(ctx context.Context) error { return nil }))
assert.Error(t, WithSpan(ctx, "error", func(ctx context.Context) error { return errors.New("some-error") }))
assert.PanicsWithError(t, errPanic.Error(), func() {
WithSpan(ctx, "panic", func(ctx context.Context) error {
panic(errPanic)
})
})
assert.PanicsWithValue(t, errPanic, func() {
WithSpan(ctx, "panic", func(ctx context.Context) error {
panic(errPanic)
})
})
assert.PanicsWithValue(t, "panic-string", func() {
WithSpan(ctx, "panic", func(ctx context.Context) error {
panic("panic-string")
})
})
}
func returnsNormally(ctx context.Context) (err error) {
_, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "returnsNormally")
defer End(span, &err)
return nil
}
func returnsError(ctx context.Context) (err error) {
_, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "returnsError")
defer End(span, &err)
return fmt.Errorf("wrapped: %w", &errWithReason{errors.New("error from returnsError()")})
}
func returnsStackTracer(ctx context.Context) (err error) {
_, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "returnsStackTracer")
defer End(span, &err)
return pkgerrors.WithStack(errors.New("error from returnsStackTracer()"))
}
func returnsNamedError(ctx context.Context) (err error) {
_, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "returnsNamedError")
defer End(span, &err)
err2 := fmt.Errorf("%w", errWithReason{errors.New("err2 message")})
return err2
}
func panics(ctx context.Context) (err error) {
_, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "panics")
defer End(span, &err)
panic(errors.New("panic from panics()"))
}
func TestEnd(t *testing.T) {
recorder := tracetest.NewSpanRecorder()
tracer := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)).Tracer("test")
ctx, span := tracer.Start(context.Background(), "parent")
defer span.End()
assert.NoError(t, returnsNormally(ctx))
require.NotEmpty(t, recorder.Ended())
assert.Equal(t, last(recorder).Name(), "returnsNormally")
assert.Equal(t, last(recorder).Status(), sdktrace.Status{codes.Unset, ""})
assert.Error(t, returnsError(ctx))
require.NotEmpty(t, recorder.Ended())
assert.Equal(t, last(recorder).Name(), "returnsError")
assert.Equal(t, last(recorder).Status(), sdktrace.Status{codes.Error, "wrapped: error from returnsError()"})
assert.Contains(t, last(recorder).Attributes(), attribute.String("error.reason", "some interesting error reason"))
assert.Errorf(t, returnsNamedError(ctx), "err2 message")
require.NotEmpty(t, recorder.Ended())
assert.Equal(t, last(recorder).Name(), "returnsNamedError")
assert.Equal(t, last(recorder).Status(), sdktrace.Status{codes.Error, "err2 message"})
assert.Contains(t, last(recorder).Attributes(), attribute.String("error.debug", "verbose debugging information"))
assert.Errorf(t, returnsStackTracer(ctx), "error from returnsStackTracer()")
require.NotEmpty(t, recorder.Ended())
assert.Equal(t, last(recorder).Name(), "returnsStackTracer")
assert.Equal(t, last(recorder).Status(), sdktrace.Status{codes.Error, "error from returnsStackTracer()"})
stackIdx := slices.IndexFunc(last(recorder).Attributes(), func(kv attribute.KeyValue) bool { return kv.Key == "error.stack" })
require.GreaterOrEqual(t, stackIdx, 0)
assert.Contains(t, last(recorder).Attributes()[stackIdx].Value.AsString(), "github.com/ory/x/otelx.returnsStackTracer")
assert.PanicsWithError(t, "panic from panics()", func() { panics(ctx) })
require.NotEmpty(t, recorder.Ended())
assert.Equal(t, last(recorder).Name(), "panics")
assert.Equal(t, last(recorder).Status(), sdktrace.Status{codes.Error, "panic: panic from panics()"})
stackIdx = slices.IndexFunc(last(recorder).Attributes(), func(kv attribute.KeyValue) bool { return kv.Key == "error.stack" })
require.GreaterOrEqual(t, stackIdx, 0)
assert.Contains(t, last(recorder).Attributes()[stackIdx].Value.AsString(), "github.com/ory/x/otelx.panics")
span.End()
require.NotEmpty(t, recorder.Ended())
assert.Equal(t, last(recorder).Name(), "parent")
assert.Equal(t, last(recorder).Status(), sdktrace.Status{codes.Unset, ""})
}
func last(r *tracetest.SpanRecorder) sdktrace.ReadOnlySpan {
ended := r.Ended()
if len(ended) == 0 {
return nil
}
return ended[len(ended)-1]
}

View File

@ -1,107 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package pagination
import (
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHeader(t *testing.T) {
u, err := url.Parse("http://example.com")
if err != nil {
t.Fatal(err)
}
t.Run("Create previous and first but not next or last if at the end", func(t *testing.T) {
r := httptest.NewRecorder()
Header(r, u, 120, 50, 100)
expect := strings.Join([]string{
"<http://example.com?limit=50&offset=0>; rel=\"first\"",
"<http://example.com?limit=50&offset=50>; rel=\"prev\"",
}, ",")
assert.EqualValues(t, expect, r.Result().Header.Get("Link"))
assert.EqualValues(t, "120", r.Result().Header.Get("X-Total-Count"))
})
t.Run("Create next and last, but not previous or first if at the beginning", func(t *testing.T) {
r := httptest.NewRecorder()
Header(r, u, 120, 50, 0)
expect := strings.Join([]string{
"<http://example.com?limit=50&offset=50>; rel=\"next\"",
"<http://example.com?limit=50&offset=100>; rel=\"last\"",
}, ",")
assert.EqualValues(t, expect, r.Result().Header.Get("Link"))
})
t.Run("Create next and last, but not previous or first if on the first page", func(t *testing.T) {
r := httptest.NewRecorder()
Header(r, u, 120, 50, 10)
expect := strings.Join([]string{
"<http://example.com?limit=50&offset=50>; rel=\"next\"",
"<http://example.com?limit=50&offset=100>; rel=\"last\"",
}, ",")
assert.EqualValues(t, expect, r.Result().Header.Get("Link"))
})
t.Run("Create previous, next, first, and last if in the middle", func(t *testing.T) {
r := httptest.NewRecorder()
Header(r, u, 300, 50, 150)
expect := strings.Join([]string{
"<http://example.com?limit=50&offset=0>; rel=\"first\"",
"<http://example.com?limit=50&offset=200>; rel=\"next\"",
"<http://example.com?limit=50&offset=100>; rel=\"prev\"",
"<http://example.com?limit=50&offset=250>; rel=\"last\"",
}, ",")
assert.EqualValues(t, expect, r.Result().Header.Get("Link"))
})
t.Run("Header should default limit to 1 no limit was provided", func(t *testing.T) {
r := httptest.NewRecorder()
Header(r, u, 100, 0, 20)
expect := strings.Join([]string{
"<http://example.com?limit=1&offset=0>; rel=\"first\"",
"<http://example.com?limit=1&offset=21>; rel=\"next\"",
"<http://example.com?limit=1&offset=19>; rel=\"prev\"",
"<http://example.com?limit=1&offset=99>; rel=\"last\"",
}, ",")
assert.EqualValues(t, expect, r.Result().Header.Get("Link"))
})
t.Run("Create previous, next, first, but not last if in the middle and no total was provided", func(t *testing.T) {
r := httptest.NewRecorder()
Header(r, u, 0, 50, 150)
expect := strings.Join([]string{
"<http://example.com?limit=50&offset=0>; rel=\"first\"",
"<http://example.com?limit=50&offset=200>; rel=\"next\"",
"<http://example.com?limit=50&offset=100>; rel=\"prev\"",
}, ",")
assert.EqualValues(t, expect, r.Result().Header.Get("Link"))
})
t.Run("Create only first if the limits provided exceeds the number of clients found", func(t *testing.T) {
r := httptest.NewRecorder()
Header(r, u, 5, 50, 0)
expect := "<http://example.com?limit=5&offset=0>; rel=\"first\""
assert.EqualValues(t, expect, r.Result().Header.Get("Link"))
})
}

View File

@ -1,16 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package pagination
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMaxItemsPerPage(t *testing.T) {
assert.Equal(t, 0, MaxItemsPerPage(100, 0))
assert.Equal(t, 10, MaxItemsPerPage(100, 10))
assert.Equal(t, 100, MaxItemsPerPage(100, 110))
}

View File

@ -1,48 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package keysetpagination
import (
"net/http/httptest"
"net/url"
"testing"
"github.com/peterhellberg/link"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHeader(t *testing.T) {
p := &Paginator{
defaultToken: StringPageToken("default"),
token: StringPageToken("next"),
size: 2,
}
u, err := url.Parse("http://ory.sh/")
require.NoError(t, err)
r := httptest.NewRecorder()
Header(r, u, p)
assert.Len(t, r.Result().Header.Values("link"), 1, "make sure we send one header with multiple comma-separated values rather than multiple headers")
links := link.ParseResponse(r.Result())
assert.Contains(t, links, "first")
assert.Contains(t, links["first"].URI, "page_token=default")
assert.Contains(t, links, "next")
assert.Contains(t, links["next"].URI, "page_token=next")
p.isLast = true
r = httptest.NewRecorder()
Header(r, u, p)
links = link.ParseResponse(r.Result())
assert.Contains(t, links, "first")
assert.Contains(t, links["first"].URI, "page_token=default")
assert.NotContains(t, links, "next")
}

View File

@ -1,328 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package keysetpagination
import (
"net/url"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/pop/v6"
)
type testItem struct {
ID string `db:"pk"`
CreatedAt string `db:"created_at"`
}
// Both value and pointer receiver implementations should work with this test:
// func (t testItem) PageToken() PageToken {
func (t *testItem) PageToken() PageToken {
return StringPageToken(t.ID)
}
func TestPaginator(t *testing.T) {
t.Run("paginates correctly", func(t *testing.T) {
c, err := pop.NewConnection(&pop.ConnectionDetails{
URL: "postgres://foo.bar",
})
require.NoError(t, err)
q := pop.Q(c)
paginator := GetPaginator(WithSize(10), WithToken(StringPageToken("token")))
q = q.Scope(Paginate[testItem](paginator))
sql, args := q.ToSQL(&pop.Model{Value: new(testItem)})
assert.Equal(t, `SELECT test_items.created_at, test_items.pk FROM test_items AS test_items WHERE "test_items"."pk" > $1 ORDER BY "test_items"."pk" ASC LIMIT 11`, sql)
assert.Equal(t, []interface{}{"token"}, args)
})
t.Run("paginates correctly with negative size", func(t *testing.T) {
c, err := pop.NewConnection(&pop.ConnectionDetails{
URL: "postgres://foo.bar",
})
require.NoError(t, err)
q := pop.Q(c)
paginator := GetPaginator(WithSize(-1), WithDefaultSize(10), WithToken(StringPageToken("token")))
q = q.Scope(Paginate[testItem](paginator))
sql, args := q.ToSQL(&pop.Model{Value: new(testItem)})
assert.Equal(t, `SELECT test_items.created_at, test_items.pk FROM test_items AS test_items WHERE "test_items"."pk" > $1 ORDER BY "test_items"."pk" ASC LIMIT 11`, sql)
assert.Equal(t, []interface{}{"token"}, args)
})
t.Run("paginates correctly mysql", func(t *testing.T) {
c, err := pop.NewConnection(&pop.ConnectionDetails{
URL: "mysql://user:pass@(host:1337)/database",
})
require.NoError(t, err)
q := pop.Q(c)
paginator := GetPaginator(WithSize(10), WithToken(StringPageToken("token")))
q = q.Scope(Paginate[testItem](paginator))
sql, args := q.ToSQL(&pop.Model{Value: new(testItem)})
assert.Equal(t, "SELECT test_items.created_at, test_items.pk FROM test_items AS test_items WHERE `test_items`.`pk` > ? ORDER BY `test_items`.`pk` ASC LIMIT 11", sql)
assert.Equal(t, []interface{}{"token"}, args)
})
t.Run("returns correct result", func(t *testing.T) {
items := []testItem{
{ID: "1"},
{ID: "2"},
{ID: "3"},
{ID: "4"},
{ID: "5"},
{ID: "6"},
{ID: "7"},
{ID: "8"},
{ID: "9"},
{ID: "10"},
{ID: "11"},
}
paginator := GetPaginator(WithDefaultSize(10), WithToken(StringPageToken("token")))
items, nextPage := Result(items, paginator)
assert.Len(t, items, 10)
assert.Equal(t, StringPageToken("10"), nextPage.Token())
assert.Equal(t, 10, nextPage.Size())
})
t.Run("returns correct size and token", func(t *testing.T) {
for _, tc := range []struct {
name string
opts []Option
expectedSize int
expectedToken PageToken
}{
{
name: "default",
opts: nil,
expectedSize: 100,
},
{
name: "default max size",
opts: []Option{WithSize(1000)},
expectedSize: DefaultMaxSize,
},
{
name: "with size and token",
opts: []Option{WithSize(10), WithToken(StringPageToken("token"))},
expectedSize: 10,
expectedToken: StringPageToken("token"),
},
{
name: "with custom defaults",
opts: []Option{WithDefaultSize(10), WithDefaultToken(StringPageToken("token"))},
expectedSize: 10,
expectedToken: StringPageToken("token"),
},
{
name: "with custom defaults and size and token",
opts: []Option{WithDefaultSize(10), WithDefaultToken(StringPageToken("token")), WithSize(20), WithToken(StringPageToken("token2"))},
expectedSize: 20,
expectedToken: StringPageToken("token2"),
},
{
name: "with size and custom default and max size",
opts: []Option{WithSize(10), WithDefaultSize(20), WithMaxSize(5)},
expectedSize: 5,
},
{
name: "with negative size",
opts: []Option{WithSize(-1), WithDefaultSize(20), WithMaxSize(100)},
expectedSize: 20,
},
} {
t.Run(tc.name, func(t *testing.T) {
paginator := GetPaginator(tc.opts...)
assert.Equal(t, tc.expectedSize, paginator.Size())
assert.Equal(t, tc.expectedToken, paginator.Token())
})
}
})
}
func TestParse(t *testing.T) {
for _, tc := range []struct {
name string
q url.Values
expectedSize int
expectedToken PageToken
f PageTokenConstructor
}{
{
name: "with page token",
q: url.Values{"page_token": {"token3"}},
expectedSize: 100,
expectedToken: StringPageToken("token3"),
f: NewStringPageToken,
},
{
name: "with page size",
q: url.Values{"page_size": {"123"}},
expectedSize: 123,
f: NewStringPageToken,
},
{
name: "with page size and page token",
q: url.Values{"page_size": {"123"}, "page_token": {"token5"}},
expectedSize: 123,
expectedToken: StringPageToken("token5"),
f: NewStringPageToken,
},
{
name: "with page size and page token",
q: url.Values{"page_size": {"123"}, "page_token": {"cGs9dG9rZW41"}},
expectedSize: 123,
expectedToken: MapPageToken{"pk": "token5"},
f: NewMapPageToken,
},
} {
t.Run(tc.name, func(t *testing.T) {
opts, err := Parse(tc.q, tc.f)
require.NoError(t, err)
paginator := GetPaginator(opts...)
assert.Equal(t, tc.expectedSize, paginator.Size())
assert.Equal(t, tc.expectedToken, paginator.Token())
})
}
t.Run("invalid page size leads to err", func(t *testing.T) {
_, err := Parse(url.Values{"page_size": {"invalid-int"}}, NewStringPageToken)
require.ErrorIs(t, err, strconv.ErrSyntax)
})
t.Run("empty tokens and page sizes work as if unset, empty values are skipped", func(t *testing.T) {
opts, err := Parse(url.Values{}, NewStringPageToken)
require.NoError(t, err)
paginator := GetPaginator(append(opts, WithDefaultToken(StringPageToken("default")))...)
assert.Equal(t, "default", paginator.Token().Encode())
assert.Equal(t, 100, paginator.Size())
opts, err = Parse(url.Values{"page_token": {""}, "page_size": {""}}, NewStringPageToken)
require.NoError(t, err)
paginator = GetPaginator(append(opts, WithDefaultToken(StringPageToken("default2")))...)
assert.Equal(t, "default2", paginator.Token().Encode())
assert.Equal(t, 100, paginator.Size())
opts, err = Parse(url.Values{"page_token": {"", "foo", ""}, "page_size": {"", "123", ""}}, NewStringPageToken)
require.NoError(t, err)
paginator = GetPaginator(append(opts, WithDefaultToken(StringPageToken("default3")))...)
assert.Equal(t, "foo", paginator.Token().Encode())
assert.Equal(t, 123, paginator.Size())
})
}
func TestPaginateWithAdditionalColumn(t *testing.T) {
c, err := pop.NewConnection(&pop.ConnectionDetails{
URL: "postgres://foo.bar",
})
require.NoError(t, err)
for _, tc := range []struct {
d string
opts []Option
e string
args []interface{}
}{
{
d: "with sort by created_at DESC",
opts: []Option{WithToken(MapPageToken{"pk": "token_value", "created_at": "timestamp"}), WithColumn("created_at", "DESC")},
e: `WHERE ("test_items"."created_at" < $1 OR ("test_items"."created_at" = $2 AND "test_items"."pk" > $3)) ORDER BY "test_items"."created_at" DESC, "test_items"."pk" ASC`,
args: []interface{}{"timestamp", "timestamp", "token_value"},
},
{
d: "with sort by created_at ASC",
opts: []Option{WithToken(MapPageToken{"pk": "token_value", "created_at": "timestamp"}), WithColumn("created_at", "ASC")},
e: `WHERE ("test_items"."created_at" > $1 OR ("test_items"."created_at" = $2 AND "test_items"."pk" > $3)) ORDER BY "test_items"."created_at" ASC, "test_items"."pk" ASC`,
args: []interface{}{"timestamp", "timestamp", "token_value"},
},
{
d: "with unknown column",
opts: []Option{WithToken(MapPageToken{"pk": "token_value", "created_at": "timestamp"}), WithColumn("unknown_column", "ASC")},
e: `WHERE "test_items"."pk" > $1 ORDER BY "test_items"."pk"`,
args: []interface{}{"token_value"},
},
{
d: "with no token value",
opts: []Option{WithToken(MapPageToken{"pk": "token_value"}), WithColumn("created_at", "ASC")},
e: `WHERE "test_items"."pk" > $1 ORDER BY "test_items"."pk"`,
args: []interface{}{"token_value"},
},
{
d: "with unknown order",
opts: []Option{WithToken(MapPageToken{"pk": "token_value", "created_at": "timestamp"}), WithColumn("created_at", Order("unknown order"))},
e: `WHERE "test_items"."pk" > $1 ORDER BY "test_items"."pk"`,
args: []interface{}{"token_value"},
},
} {
t.Run("case="+tc.d, func(t *testing.T) {
opts := append(tc.opts, WithSize(10))
paginator := GetPaginator(opts...)
sql, args := pop.Q(c).
Scope(Paginate[testItem](paginator)).
ToSQL(&pop.Model{Value: new(testItem)})
assert.Contains(t, sql, tc.e)
assert.Contains(t, sql, "LIMIT 11")
assert.Equal(t, tc.args, args)
})
}
}
func TestOptions(t *testing.T) {
for _, tc := range []struct {
name string
opts []Option
expectedToken PageToken
expectedSize int
}{
{
name: "no options",
opts: nil,
expectedToken: nil,
expectedSize: DefaultSize,
},
{
name: "with token",
opts: []Option{WithToken(StringPageToken("token"))},
expectedToken: StringPageToken("token"),
expectedSize: DefaultSize,
},
{
name: "with size",
opts: []Option{WithSize(10)},
expectedToken: nil,
expectedSize: 10,
},
{
name: "with all options",
opts: []Option{
WithToken(StringPageToken("token")),
WithDefaultToken(StringPageToken("default")),
WithSize(20),
WithDefaultSize(30),
WithMaxSize(50),
WithColumn("created_at", "DESC"),
withIsLast(true),
},
expectedToken: StringPageToken("token"),
expectedSize: 20,
},
{
name: "with explicit defaults",
opts: []Option{WithMaxSize(DefaultMaxSize), WithDefaultSize(DefaultSize)},
expectedToken: nil,
expectedSize: DefaultSize,
},
} {
t.Run(tc.name, func(t *testing.T) {
paginator := GetPaginator(tc.opts...)
assert.Equal(t, tc.expectedToken, paginator.Token())
assert.Equal(t, tc.expectedSize, paginator.Size())
assert.Equal(t, paginator, GetPaginator(paginator.ToOptions()...))
})
}
}

View File

@ -1,49 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package keysetpagination
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseHeader(t *testing.T) {
u, err := url.Parse("https://www.ory.sh/")
require.NoError(t, err)
t.Run("has next page", func(t *testing.T) {
p := &Paginator{
defaultToken: StringPageToken("default"),
token: StringPageToken("next"),
size: 2,
}
r := httptest.NewRecorder()
Header(r, u, p)
result := ParseHeader(&http.Response{Header: r.Header()})
assert.Equal(t, "next", result.NextToken, r.Header())
assert.Equal(t, "default", result.FirstToken, r.Header())
})
t.Run("is last page", func(t *testing.T) {
p := &Paginator{
defaultToken: StringPageToken("default"),
size: 1,
isLast: true,
}
r := httptest.NewRecorder()
Header(r, u, p)
result := ParseHeader(&http.Response{Header: r.Header()})
assert.Equal(t, "", result.NextToken, r.Header())
assert.Equal(t, "default", result.FirstToken, r.Header())
})
}

View File

@ -1,67 +0,0 @@
// Copyright © 2025 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package keysetpagination
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPageToken(t *testing.T) {
t.Parallel()
t.Run("json idempotency", func(t *testing.T) {
token := NewPageToken(Column{Name: "id", Value: "token"}, Column{Name: "name", Order: OrderDescending, Value: "My Name"})
raw, err := token.MarshalJSON()
require.NoError(t, err)
var decodedToken PageToken
require.NoError(t, decodedToken.UnmarshalJSON(raw))
assert.Equal(t, token, decodedToken)
})
t.Run("checks expiration", func(t *testing.T) {
now := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
token := NewPageToken(Column{Name: "id", Value: "token"})
token.testNow = func() time.Time { return now }
raw, err := token.MarshalJSON()
require.NoError(t, err)
decodedToken := PageToken{
testNow: func() time.Time { return now.Add(2 * time.Hour) },
}
assert.ErrorIs(t, decodedToken.UnmarshalJSON(raw), ErrPageTokenExpired)
})
}
func TestPageToken_Encrypt(t *testing.T) {
t.Parallel()
keys := [][32]byte{{1, 2, 3}, {4, 5, 6}}
token := NewPageToken(Column{Name: "id", Value: "token"})
t.Run("encrypts with the first key", func(t *testing.T) {
encrypted := token.Encrypt(keys)
decrypted, err := ParsePageToken(keys[:1], encrypted)
require.NoError(t, err)
assert.Equal(t, token, decrypted)
_, err = ParsePageToken(keys[1:], encrypted)
assert.ErrorContains(t, err, "decrypt token")
})
t.Run("uses fallback key", func(t *testing.T) {
for _, encrypted := range []string{token.Encrypt(nil), token.Encrypt([][32]byte{})} {
decrypted, err := ParsePageToken([][32]byte{*fallbackEncryptionKey}, encrypted)
require.NoError(t, err)
assert.Equal(t, token, decrypted)
}
})
}

View File

@ -1,198 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package keysetpagination
import (
"strconv"
"testing"
"github.com/stretchr/testify/assert"
)
type testItem struct {
ID int `db:"pk"`
Name string `db:"name"`
CreatedAt string `db:"created_at"`
}
func nTestItems(n int) []testItem {
items := make([]testItem, n)
for i := range items {
items[i] = testItem{
ID: i + 1,
Name: "item" + strconv.Itoa(i+1),
CreatedAt: "2023-01-01T00:00:00Z",
}
}
return items
}
func TestResult(t *testing.T) {
t.Parallel()
defaultToken := NewPageToken(Column{Name: "pk", Value: 0}, Column{Name: "name", Order: OrderDescending, Value: ""})
paginator := NewPaginator(WithSize(10), WithDefaultToken(defaultToken))
t.Run("not last page", func(t *testing.T) {
items := nTestItems(11)
croppedItems, nextPage := Result(items, paginator)
assert.Len(t, croppedItems, 10)
assert.Equal(t, 10, nextPage.Size())
assert.False(t, nextPage.IsLast())
assert.Equal(t, NewPageToken(
Column{Name: "pk", Value: 10},
Column{Name: "name", Order: OrderDescending, Value: items[9].Name},
), nextPage.PageToken())
assert.NotContains(t, croppedItems, items[10], "last item should not be included in the result")
assert.Equal(t, croppedItems, items[:10], "cropped items should match the first 10 items")
})
t.Run("last page is full", func(t *testing.T) {
items := nTestItems(10)
croppedItems, nextPage := Result(items, paginator)
assert.Len(t, croppedItems, 10)
assert.Equal(t, 10, nextPage.Size())
assert.True(t, nextPage.IsLast())
assert.Equal(t, defaultToken, nextPage.PageToken())
assert.Equal(t, croppedItems, items)
})
t.Run("last page not full", func(t *testing.T) {
items := nTestItems(2)
croppedItems, nextPage := Result(items, paginator)
assert.Len(t, croppedItems, 2)
assert.Equal(t, 10, nextPage.Size())
assert.True(t, nextPage.IsLast())
assert.Equal(t, defaultToken, nextPage.PageToken())
assert.Equal(t, croppedItems, items)
})
}
func TestPaginator_Size(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
opts []Option
expected int
}{
{
name: "default",
opts: nil,
expected: DefaultSize,
},
{
name: "enforced default max size",
opts: []Option{WithSize(2 * DefaultMaxSize)},
expected: DefaultMaxSize,
},
{
name: "with size",
opts: []Option{WithSize(10)},
expected: 10,
},
{
name: "with custom default",
opts: []Option{WithDefaultSize(10)},
expected: 10,
},
{
name: "with custom default and size",
opts: []Option{WithDefaultSize(10), WithSize(20)},
expected: 20,
},
{
name: "with size and default bigger than max",
opts: []Option{WithSize(10), WithDefaultSize(20), WithMaxSize(5)},
expected: 5,
},
{
name: "with negative size",
opts: []Option{WithSize(-1), WithDefaultSize(20), WithMaxSize(100)},
expected: 20,
},
} {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, NewPaginator(tc.opts...).Size())
})
}
}
func TestPaginator_Token(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
opts []Option
expected PageToken
}{
{
name: "no options",
opts: nil,
expected: PageToken{},
},
{
name: "with token",
opts: []Option{WithToken(NewPageToken(Column{Name: "id", Value: "token"}))},
expected: NewPageToken(Column{Name: "id", Value: "token"}),
},
{
name: "with default token",
opts: []Option{WithDefaultToken(NewPageToken(Column{Name: "id", Value: "default"}))},
expected: NewPageToken(Column{Name: "id", Value: "default"}),
},
{
name: "with both tokens",
opts: []Option{WithToken(NewPageToken(Column{Name: "id", Value: "token"})), WithDefaultToken(NewPageToken(Column{Name: "id", Value: "default"}))},
expected: NewPageToken(Column{Name: "id", Value: "token"}),
},
} {
t.Run(tc.name, func(t *testing.T) {
paginator := NewPaginator(tc.opts...)
assert.Equal(t, tc.expected, paginator.PageToken())
})
}
}
func TestOptions(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
opts []Option
}{
{
name: "no options",
opts: nil,
},
{
name: "with token",
opts: []Option{WithToken(NewPageToken(Column{Name: "id", Value: "token"}))},
},
{
name: "with size",
opts: []Option{WithSize(10)},
},
{
name: "with all options",
opts: []Option{
WithSize(20),
WithDefaultSize(30),
WithMaxSize(50),
WithToken(NewPageToken(Column{Name: "id", Value: 123})),
WithDefaultToken(NewPageToken(Column{Name: "id", Value: 456})),
withIsLast(true),
},
},
{
name: "with explicit defaults",
opts: []Option{WithMaxSize(DefaultMaxSize), WithDefaultSize(DefaultSize)},
},
} {
t.Run(tc.name, func(t *testing.T) {
paginator := NewPaginator(tc.opts...)
assert.Equal(t, paginator, NewPaginator(paginator.ToOptions()...))
})
}
}

View File

@ -1,55 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package keysetpagination
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseHeader(t *testing.T) {
t.Parallel()
u, err := url.Parse("https://www.ory.sh/")
require.NoError(t, err)
keys := [][32]byte{{1, 2, 3}}
defaultToken, nextToken := NewPageToken(Column{Name: "id", Value: "default"}), NewPageToken(Column{Name: "id", Value: "next"})
t.Run("has next page", func(t *testing.T) {
p := NewPaginator(WithSize(2), WithDefaultToken(defaultToken), WithToken(nextToken))
r := httptest.NewRecorder()
SetLinkHeader(r, keys, u, p)
first, next, isLast := ParseHeader(&http.Response{Header: r.Header()})
require.NotEqual(t, first, next, r.Header())
assert.False(t, isLast)
parsedFirst, err := ParsePageToken(keys, first)
require.NoErrorf(t, err, "raw token %q", first)
assert.Equal(t, defaultToken, parsedFirst, r.Header())
parsedNext, err := ParsePageToken(keys, next)
require.NoErrorf(t, err, "raw token %q", next)
assert.Equal(t, nextToken, parsedNext, r.Header())
})
t.Run("is last page", func(t *testing.T) {
p := NewPaginator(WithSize(2), WithDefaultToken(defaultToken), WithToken(nextToken), withIsLast(true))
r := httptest.NewRecorder()
SetLinkHeader(r, keys, u, p)
first, next, isLast := ParseHeader(&http.Response{Header: r.Header()})
assert.Empty(t, next, r.Header())
assert.True(t, isLast)
parsedFirst, err := ParsePageToken(keys, first)
require.NoErrorf(t, err, "raw token %q", first)
assert.Equal(t, defaultToken, parsedFirst, r.Header())
})
}

View File

@ -1,122 +0,0 @@
// Copyright © 2025 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package keysetpagination
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/pop/v6"
)
func TestBuildWhereAndOrder(t *testing.T) {
t.Parallel()
tests := []struct {
name string
parts []Column
expectedWhere string
expectedArgs []any
expectedOrderBy string
}{
{
name: "single part ascending",
parts: []Column{
{Name: "id", Order: OrderAscending, Value: "first"},
},
expectedWhere: "(id > ?)",
expectedArgs: []any{"first"},
expectedOrderBy: "id ASC",
},
{
name: "single part descending",
parts: []Column{
{Name: "id", Order: OrderDescending, Value: 1},
},
expectedWhere: "(id < ?)",
expectedArgs: []any{1},
expectedOrderBy: "id DESC",
},
{
name: "two cols",
parts: []Column{
{Name: "id", Order: OrderAscending, Value: 1},
{Name: "name", Order: OrderDescending, Value: "test"},
},
expectedWhere: "(id > ?) OR (id = ? AND name < ?)",
expectedArgs: []any{1, 1, "test"},
expectedOrderBy: "id ASC, name DESC",
},
{
name: "many cols",
parts: []Column{
{Name: "id", Order: OrderAscending, Value: 1},
{Name: "name", Order: OrderAscending, Value: "test"},
{Name: "created_at", Order: OrderDescending, Value: "2023-01-01"},
{Name: "owner_id", Order: OrderDescending, Value: "owner123"},
},
expectedWhere: "(id > ?) OR (id = ? AND name > ?) OR (id = ? AND name = ? AND created_at < ?) OR (id = ? AND name = ? AND created_at = ? AND owner_id < ?)",
expectedArgs: []any{1, 1, "test", 1, "test", "2023-01-01", 1, "test", "2023-01-01", "owner123"},
expectedOrderBy: "id ASC, name ASC, created_at DESC, owner_id DESC",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
where, args, order := BuildWhereAndOrder(tc.parts, func(s string) string { return s })
assert.Equal(t, tc.expectedWhere, where)
assert.Equal(t, tc.expectedArgs, args)
assert.Equal(t, tc.expectedOrderBy, order)
})
}
}
func TestPaginate(t *testing.T) {
t.Parallel()
t.Run("paginates correctly", func(t *testing.T) {
c, err := pop.NewConnection(&pop.ConnectionDetails{
URL: "postgres://foo.bar",
})
require.NoError(t, err)
q := pop.Q(c)
paginator := NewPaginator(WithSize(10), WithToken(NewPageToken(Column{Name: "pk", Value: 666})))
q = q.Scope(Paginate[testItem](paginator))
sql, args := q.ToSQL(&pop.Model{Value: new(testItem)})
assert.Equal(t, `SELECT test_items.created_at, test_items.name, test_items.pk FROM test_items AS test_items WHERE ("test_items"."pk" > $1) ORDER BY "test_items"."pk" ASC LIMIT 11`, sql)
assert.Equal(t, []interface{}{666}, args)
})
t.Run("paginates correctly with negative size", func(t *testing.T) {
c, err := pop.NewConnection(&pop.ConnectionDetails{
URL: "postgres://foo.bar",
})
require.NoError(t, err)
q := pop.Q(c)
paginator := NewPaginator(WithSize(-1), WithDefaultSize(10), WithToken(NewPageToken(Column{Name: "pk", Value: 123})))
q = q.Scope(Paginate[testItem](paginator))
sql, args := q.ToSQL(&pop.Model{Value: new(testItem)})
assert.Equal(t, `SELECT test_items.created_at, test_items.name, test_items.pk FROM test_items AS test_items WHERE ("test_items"."pk" > $1) ORDER BY "test_items"."pk" ASC LIMIT 11`, sql)
assert.Equal(t, []interface{}{123}, args)
})
t.Run("paginates correctly mysql", func(t *testing.T) {
c, err := pop.NewConnection(&pop.ConnectionDetails{
URL: "mysql://user:pass@(host:1337)/database",
})
require.NoError(t, err)
q := pop.Q(c)
q = q.Scope(Paginate[testItem](NewPaginator(WithSize(10), WithToken(NewPageToken(Column{Name: "pk", Value: 666})))))
sql, args := q.ToSQL(&pop.Model{Value: new(testItem)})
assert.Equal(t, "SELECT test_items.created_at, test_items.name, test_items.pk FROM test_items AS test_items WHERE (`test_items`.`pk` > ?) ORDER BY `test_items`.`pk` ASC LIMIT 11", sql)
assert.Equal(t, []interface{}{666}, args)
})
}

View File

@ -1,179 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package keysetpagination
import (
"net/http/httptest"
"net/url"
"strconv"
"testing"
"github.com/peterhellberg/link"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetLinkHeader(t *testing.T) {
t.Parallel()
keys := [][32]byte{{1, 2, 3}}
defaultToken, nextToken := NewPageToken(Column{Name: "id", Value: "default"}), NewPageToken(Column{Name: "id", Value: "next"})
opts := []Option{WithSize(2), WithDefaultToken(defaultToken), WithToken(nextToken)}
u, err := url.Parse("https://ory.sh/")
require.NoError(t, err)
getParsedToken := func(t *testing.T, uri string) PageToken {
u, err := url.Parse(uri)
require.NoError(t, err)
assert.Equal(t, "https", u.Scheme)
assert.Equal(t, "ory.sh", u.Host)
raw := u.Query().Get("page_token")
token, err := ParsePageToken(keys, raw)
require.NoError(t, err)
return token
}
t.Run("case=not last page", func(t *testing.T) {
r := httptest.NewRecorder()
p := NewPaginator(opts...)
SetLinkHeader(r, keys, u, p)
assert.Len(t, r.Result().Header.Values("link"), 1, "make sure we send one header with multiple comma-separated values rather than multiple headers")
links := link.ParseResponse(r.Result())
require.Contains(t, links, "first")
assert.Equal(t, defaultToken, getParsedToken(t, links["first"].URI))
require.Contains(t, links, "next")
assert.Equal(t, nextToken, getParsedToken(t, links["next"].URI))
})
t.Run("case=last page", func(t *testing.T) {
r := httptest.NewRecorder()
p := NewPaginator(append(opts, withIsLast(true))...)
SetLinkHeader(r, keys, u, p)
assert.Len(t, r.Result().Header.Values("link"), 1, "make sure we send one header with multiple comma-separated values rather than multiple headers")
links := link.ParseResponse(r.Result())
require.Contains(t, links, "first")
assert.Equal(t, defaultToken, getParsedToken(t, links["first"].URI))
assert.NotContains(t, links, "next")
})
}
func TestParsePageToken(t *testing.T) {
t.Parallel()
keys := [][32]byte{{1, 2, 3}, {4, 5, 6}}
expectedToken := NewPageToken(Column{Name: "id", Value: "token"}, Column{Name: "name", Order: OrderDescending, Value: "test"})
encryptedToken := expectedToken.Encrypt(keys)
t.Run("with valid key", func(t *testing.T) {
token, err := ParsePageToken(keys, encryptedToken)
require.NoError(t, err)
assert.Equal(t, expectedToken, token)
})
t.Run("with rotated key", func(t *testing.T) {
encryptedToken := expectedToken.Encrypt(keys[1:])
token, err := ParsePageToken(keys, encryptedToken)
require.NoError(t, err)
assert.Equal(t, expectedToken, token)
})
t.Run("with invalid key", func(t *testing.T) {
token, err := ParsePageToken([][32]byte{{7, 8, 9}}, encryptedToken)
require.ErrorContains(t, err, "decrypt token")
assert.Zero(t, token)
})
t.Run("uses fallback key", func(t *testing.T) {
fallbackEncryptedToken := expectedToken.Encrypt(nil)
for _, noKeys := range [][][32]byte{nil, {}} {
token, err := ParsePageToken(noKeys, fallbackEncryptedToken)
require.NoError(t, err)
assert.Equal(t, expectedToken, token)
}
})
}
func TestParse(t *testing.T) {
t.Parallel()
keys := [][32]byte{{1, 2, 3}}
token := NewPageToken(Column{Name: "id", Value: "token"}, Column{Name: "name", Order: OrderDescending, Value: "test"})
defaultToken := NewPageToken(Column{Name: "id", Value: "default"}, Column{Name: "name", Order: OrderDescending, Value: "default name"})
encryptedToken := token.Encrypt(keys)
for _, tc := range []struct {
name string
q url.Values
expectedSize int
expectedToken PageToken
}{
{
name: "no query parameters",
q: url.Values{},
expectedSize: DefaultSize,
expectedToken: defaultToken,
},
{
name: "with page token",
q: url.Values{"page_token": {encryptedToken}},
expectedSize: DefaultSize,
expectedToken: token,
},
{
name: "with page size",
q: url.Values{"page_size": {"123"}},
expectedSize: 123,
expectedToken: defaultToken,
},
{
name: "with page size and page token",
q: url.Values{"page_size": {"123"}, "page_token": {encryptedToken}},
expectedSize: 123,
expectedToken: token,
},
} {
t.Run(tc.name, func(t *testing.T) {
opts, err := ParseQueryParams(keys, tc.q)
require.NoError(t, err)
paginator := NewPaginator(append(opts, WithDefaultToken(defaultToken))...)
assert.Equal(t, tc.expectedSize, paginator.Size())
assert.Equal(t, tc.expectedToken, paginator.PageToken())
})
}
t.Run("invalid page size leads to err", func(t *testing.T) {
_, err := ParseQueryParams(keys, url.Values{"page_size": {"invalid-int"}})
require.ErrorIs(t, err, strconv.ErrSyntax)
})
t.Run("empty tokens and page sizes work as if unset, empty values are skipped", func(t *testing.T) {
opts, err := ParseQueryParams(keys, url.Values{})
require.NoError(t, err)
paginator := NewPaginator(append(opts, WithDefaultToken(defaultToken))...)
assert.Equal(t, defaultToken, paginator.PageToken())
assert.Equal(t, DefaultSize, paginator.Size())
opts, err = ParseQueryParams(keys, url.Values{"page_token": {""}, "page_size": {""}})
require.NoError(t, err)
paginator = NewPaginator(append(opts, WithDefaultToken(defaultToken))...)
assert.Equal(t, defaultToken, paginator.PageToken())
assert.Equal(t, DefaultSize, paginator.Size())
opts, err = ParseQueryParams(keys, url.Values{"page_token": {"", encryptedToken, ""}, "page_size": {"", "123", ""}})
require.NoError(t, err)
paginator = NewPaginator(append(opts, WithDefaultToken(defaultToken))...)
assert.Equal(t, token, paginator.PageToken())
assert.Equal(t, 123, paginator.Size())
})
}

View File

@ -1,74 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package pagination
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIndex(t *testing.T) {
for k, c := range []struct {
s []string
offset int
limit int
e []string
}{
{
s: []string{"a", "b", "c"},
offset: 0,
limit: 100,
e: []string{"a", "b", "c"},
},
{
s: []string{"a", "b", "c"},
offset: 0,
limit: 2,
e: []string{"a", "b"},
},
{
s: []string{"a", "b", "c"},
offset: 1,
limit: 10,
e: []string{"b", "c"},
},
{
s: []string{"a", "b", "c"},
offset: 1,
limit: 2,
e: []string{"b", "c"},
},
{
s: []string{"a", "b", "c"},
offset: 2,
limit: 2,
e: []string{"c"},
},
{
s: []string{"a", "b", "c"},
offset: 3,
limit: 10,
e: []string{},
},
{
s: []string{"a", "b", "c"},
offset: 2,
limit: 10,
e: []string{"c"},
},
{
s: []string{"a", "b", "c"},
offset: 1,
limit: 10,
e: []string{"b", "c"},
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
start, end := Index(c.limit, c.offset, len(c.s))
assert.EqualValues(t, c.e, c.s[start:end])
})
}
}

Some files were not shown because too many files have changed in this diff Show More