mirror of https://github.com/ory/kratos
fix: base64encoded schemaURL cannot be resolved
GitOrigin-RevId: 7b224a32afd2ff5ebf62ef96d797258e8de56afd
This commit is contained in:
parent
df866b315f
commit
a86c212198
|
|
@ -115,16 +115,8 @@ func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request) {
|
|||
id := r.PathValue("id")
|
||||
s, err := ss.GetByID(id)
|
||||
if err != nil {
|
||||
// Maybe it is a base64 encoded ID?
|
||||
if dec, err := base64.RawURLEncoding.DecodeString(id); err == nil {
|
||||
id = string(dec)
|
||||
}
|
||||
|
||||
s, err = ss.GetByID(id)
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("Identity schema `%s` could not be found.", id)))
|
||||
return
|
||||
}
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("Identity schema `%s` could not be found.", id)))
|
||||
return
|
||||
}
|
||||
|
||||
raw, err := h.ReadSchema(ctx, s.URL)
|
||||
|
|
|
|||
|
|
@ -74,12 +74,15 @@ type IdentitySchemaList interface {
|
|||
func (s Schemas) GetByID(id string) (*Schema, error) {
|
||||
id = cmp.Or(id, config.DefaultIdentityTraitsSchemaID)
|
||||
|
||||
for _, ss := range s {
|
||||
if ss.ID == id {
|
||||
return &ss, nil
|
||||
}
|
||||
if ss, ok := s.findSchemaByID(id); ok {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
if decodedID, ok := TryDecodeID(id); ok {
|
||||
if ss, ok := s.findSchemaByID(decodedID); ok {
|
||||
return ss, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Unable to find JSON Schema ID: %s", id))
|
||||
}
|
||||
|
||||
|
|
@ -98,8 +101,19 @@ func (s Schemas) List(page, perPage int) Schemas {
|
|||
return s[start:end]
|
||||
}
|
||||
|
||||
var orderedKeyCacheMutex sync.RWMutex
|
||||
var orderedKeyCache map[string][]string
|
||||
func (s Schemas) findSchemaByID(id string) (*Schema, bool) {
|
||||
for _, ss := range s {
|
||||
if ss.ID == id {
|
||||
return &ss, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var (
|
||||
orderedKeyCacheMutex sync.RWMutex
|
||||
orderedKeyCache map[string][]string
|
||||
)
|
||||
|
||||
func init() {
|
||||
orderedKeyCache = make(map[string][]string)
|
||||
|
|
@ -154,3 +168,11 @@ func (s *Schema) SchemaURL(host *url.URL) *url.URL {
|
|||
func IDToURL(host *url.URL, id string) *url.URL {
|
||||
return urlx.AppendPaths(host, SchemasPath, base64.RawURLEncoding.EncodeToString([]byte(id)))
|
||||
}
|
||||
|
||||
func TryDecodeID(encoded string) (string, bool) {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(decoded), true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,108 +1,24 @@
|
|||
// Copyright © 2023 Ory Corp
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package schema
|
||||
package schema_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/x/urlx"
|
||||
"github.com/ory/kratos/internal"
|
||||
"github.com/ory/kratos/schema"
|
||||
schematest "github.com/ory/kratos/schema/test"
|
||||
)
|
||||
|
||||
func TestSchemas_GetByID(t *testing.T) {
|
||||
urlFromID := func(id string) string {
|
||||
return fmt.Sprintf("http://%s.com", id)
|
||||
}
|
||||
|
||||
ss := Schemas{
|
||||
Schema{
|
||||
ID: "foo",
|
||||
},
|
||||
Schema{
|
||||
ID: "bar",
|
||||
},
|
||||
Schema{
|
||||
ID: "foobar",
|
||||
},
|
||||
Schema{
|
||||
ID: config.DefaultIdentityTraitsSchemaID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range ss {
|
||||
s.RawURL = urlFromID(s.ID)
|
||||
s.URL = urlx.ParseOrPanic(s.RawURL)
|
||||
}
|
||||
|
||||
t.Run("case=get first schema", func(t *testing.T) {
|
||||
s, err := ss.GetByID("foo")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &ss[0], s)
|
||||
})
|
||||
|
||||
t.Run("case=get second schema", func(t *testing.T) {
|
||||
s, err := ss.GetByID("bar")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &ss[1], s)
|
||||
})
|
||||
|
||||
t.Run("case=get third schema", func(t *testing.T) {
|
||||
s, err := ss.GetByID("foobar")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &ss[2], s)
|
||||
})
|
||||
|
||||
t.Run("case=get default schema", func(t *testing.T) {
|
||||
s1, err := ss.GetByID("")
|
||||
require.NoError(t, err)
|
||||
s2, err := ss.GetByID(config.DefaultIdentityTraitsSchemaID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &ss[3], s1)
|
||||
assert.Equal(t, &ss[3], s2)
|
||||
})
|
||||
|
||||
t.Run("case=should return error on not existing id", func(t *testing.T) {
|
||||
s, err := ss.GetByID("not existing id")
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, (*Schema)(nil), s)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSchemas_List(t *testing.T) {
|
||||
ss := Schemas{
|
||||
Schema{
|
||||
ID: "foo",
|
||||
},
|
||||
Schema{
|
||||
ID: "bar",
|
||||
},
|
||||
Schema{
|
||||
ID: "foobar",
|
||||
},
|
||||
Schema{
|
||||
ID: config.DefaultIdentityTraitsSchemaID,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("case=get all schemas", func(t *testing.T) {
|
||||
p0 := ss.List(0, 4)
|
||||
assert.Equal(t, ss, p0)
|
||||
})
|
||||
|
||||
t.Run("case=smaller pages", func(t *testing.T) {
|
||||
p0, p1 := ss.List(0, 2), ss.List(1, 2)
|
||||
assert.Equal(t, ss, append(p0, p1...))
|
||||
})
|
||||
|
||||
t.Run("case=indexes out of range", func(t *testing.T) {
|
||||
p0 := ss.List(-1, 10)
|
||||
assert.Equal(t, ss, p0)
|
||||
})
|
||||
func TestDefaultIdentityTraitsProvider(t *testing.T) {
|
||||
_, reg := internal.NewFastRegistryWithMocks(t)
|
||||
schematest.TestIdentitySchemaProvider(t, schema.NewDefaultIdentityTraitsProvider(reg))
|
||||
}
|
||||
|
||||
func TestGetKeysInOrder(t *testing.T) {
|
||||
|
|
@ -112,12 +28,14 @@ func TestGetKeysInOrder(t *testing.T) {
|
|||
path string
|
||||
}{
|
||||
{schemaRef: "file://./stub/identity.schema.json", keys: []string{"bar", "email"}},
|
||||
{schemaRef: "file://./stub/complex.schema.json", keys: []string{"meal.name", "meal.chef", "traits.email",
|
||||
{schemaRef: "file://./stub/complex.schema.json", keys: []string{
|
||||
"meal.name", "meal.chef", "traits.email",
|
||||
"traits.stringy", "traits.numby", "traits.booly", "traits.should_big_number", "traits.should_long_string",
|
||||
"fruits", "vegetables"}},
|
||||
"fruits", "vegetables",
|
||||
}},
|
||||
} {
|
||||
t.Run(fmt.Sprintf("case=%d schemaRef=%s", i, tc.schemaRef), func(t *testing.T) {
|
||||
actual, err := GetKeysInOrder(ctx, tc.schemaRef)
|
||||
actual, err := schema.GetKeysInOrder(context.Background(), tc.schemaRef)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.keys, actual)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright © 2025 Ory Corp
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/schema"
|
||||
"github.com/ory/x/contextx"
|
||||
"github.com/ory/x/urlx"
|
||||
)
|
||||
|
||||
func TestIdentitySchemaProvider(t *testing.T, provider schema.IdentitySchemaProvider) {
|
||||
urlFromID := func(id string) string {
|
||||
return fmt.Sprintf("http://%s.com", id)
|
||||
}
|
||||
|
||||
schemas := schema.Schemas{
|
||||
{ID: "foo"},
|
||||
{ID: "preset://email"},
|
||||
{ID: config.DefaultIdentityTraitsSchemaID},
|
||||
}
|
||||
|
||||
for i := range schemas {
|
||||
raw := urlFromID(schemas[i].ID)
|
||||
schemas[i].RawURL = raw
|
||||
schemas[i].URL = urlx.ParseOrPanic(raw)
|
||||
}
|
||||
|
||||
ctx := contextx.WithConfigValues(t.Context(), map[string]any{
|
||||
config.ViperKeyIdentitySchemas: func() (cs config.Schemas) {
|
||||
for _, s := range schemas {
|
||||
cs = append(cs, config.Schema{
|
||||
ID: s.ID,
|
||||
URL: s.RawURL,
|
||||
})
|
||||
}
|
||||
return cs
|
||||
}(),
|
||||
})
|
||||
|
||||
list, err := provider.IdentityTraitsSchemas(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("GetByID", func(t *testing.T) {
|
||||
t.Run("case=get with raw schemaID", func(t *testing.T) {
|
||||
for _, schema := range schemas {
|
||||
actual, err := list.GetByID(schema.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, schema, *actual)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("case=get with encoded schemaID", func(t *testing.T) {
|
||||
for _, schema := range schemas {
|
||||
encodedID := base64.RawURLEncoding.EncodeToString([]byte(schema.ID))
|
||||
|
||||
actual, err := list.GetByID(encodedID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, schema, *actual)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("case=get default schema", func(t *testing.T) {
|
||||
s1, err := list.GetByID("")
|
||||
require.NoError(t, err)
|
||||
s2, err := list.GetByID(config.DefaultIdentityTraitsSchemaID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &schemas[2], s1)
|
||||
assert.Equal(t, &schemas[2], s2)
|
||||
})
|
||||
|
||||
t.Run("case=should return error on not existing id", func(t *testing.T) {
|
||||
s, err := list.GetByID("not existing id")
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, (*schema.Schema)(nil), s)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
t.Run("case=get all schemas", func(t *testing.T) {
|
||||
p0 := list.List(0, 4)
|
||||
assert.Equal(t, schemas, p0)
|
||||
})
|
||||
|
||||
t.Run("case=smaller pages", func(t *testing.T) {
|
||||
p0, p1 := list.List(0, 2), list.List(1, 2)
|
||||
assert.Equal(t, schemas, append(p0, p1...))
|
||||
})
|
||||
|
||||
t.Run("case=indexes out of range", func(t *testing.T) {
|
||||
p0 := list.List(-1, 10)
|
||||
assert.Equal(t, schemas, p0)
|
||||
})
|
||||
})
|
||||
}
|
||||
Loading…
Reference in New Issue