fix: base64encoded schemaURL cannot be resolved

GitOrigin-RevId: 7b224a32afd2ff5ebf62ef96d797258e8de56afd
This commit is contained in:
Davud Safarov 2025-12-16 13:03:05 +01:00 committed by ory-bot
parent df866b315f
commit a86c212198
4 changed files with 146 additions and 111 deletions

View File

@ -115,16 +115,8 @@ func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
s, err := ss.GetByID(id)
if err != nil {
// Maybe it is a base64 encoded ID?
if dec, err := base64.RawURLEncoding.DecodeString(id); err == nil {
id = string(dec)
}
s, err = ss.GetByID(id)
if err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("Identity schema `%s` could not be found.", id)))
return
}
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("Identity schema `%s` could not be found.", id)))
return
}
raw, err := h.ReadSchema(ctx, s.URL)

View File

@ -74,12 +74,15 @@ type IdentitySchemaList interface {
func (s Schemas) GetByID(id string) (*Schema, error) {
id = cmp.Or(id, config.DefaultIdentityTraitsSchemaID)
for _, ss := range s {
if ss.ID == id {
return &ss, nil
}
if ss, ok := s.findSchemaByID(id); ok {
return ss, nil
}
if decodedID, ok := TryDecodeID(id); ok {
if ss, ok := s.findSchemaByID(decodedID); ok {
return ss, nil
}
}
return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Unable to find JSON Schema ID: %s", id))
}
@ -98,8 +101,19 @@ func (s Schemas) List(page, perPage int) Schemas {
return s[start:end]
}
var orderedKeyCacheMutex sync.RWMutex
var orderedKeyCache map[string][]string
func (s Schemas) findSchemaByID(id string) (*Schema, bool) {
for _, ss := range s {
if ss.ID == id {
return &ss, true
}
}
return nil, false
}
var (
orderedKeyCacheMutex sync.RWMutex
orderedKeyCache map[string][]string
)
func init() {
orderedKeyCache = make(map[string][]string)
@ -154,3 +168,11 @@ func (s *Schema) SchemaURL(host *url.URL) *url.URL {
func IDToURL(host *url.URL, id string) *url.URL {
return urlx.AppendPaths(host, SchemasPath, base64.RawURLEncoding.EncodeToString([]byte(id)))
}
func TryDecodeID(encoded string) (string, bool) {
decoded, err := base64.RawURLEncoding.DecodeString(encoded)
if err != nil {
return "", false
}
return string(decoded), true
}

View File

@ -1,108 +1,24 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package schema
package schema_test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/kratos/driver/config"
"github.com/ory/x/urlx"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/schema"
schematest "github.com/ory/kratos/schema/test"
)
func TestSchemas_GetByID(t *testing.T) {
urlFromID := func(id string) string {
return fmt.Sprintf("http://%s.com", id)
}
ss := Schemas{
Schema{
ID: "foo",
},
Schema{
ID: "bar",
},
Schema{
ID: "foobar",
},
Schema{
ID: config.DefaultIdentityTraitsSchemaID,
},
}
for _, s := range ss {
s.RawURL = urlFromID(s.ID)
s.URL = urlx.ParseOrPanic(s.RawURL)
}
t.Run("case=get first schema", func(t *testing.T) {
s, err := ss.GetByID("foo")
require.NoError(t, err)
assert.Equal(t, &ss[0], s)
})
t.Run("case=get second schema", func(t *testing.T) {
s, err := ss.GetByID("bar")
require.NoError(t, err)
assert.Equal(t, &ss[1], s)
})
t.Run("case=get third schema", func(t *testing.T) {
s, err := ss.GetByID("foobar")
require.NoError(t, err)
assert.Equal(t, &ss[2], s)
})
t.Run("case=get default schema", func(t *testing.T) {
s1, err := ss.GetByID("")
require.NoError(t, err)
s2, err := ss.GetByID(config.DefaultIdentityTraitsSchemaID)
require.NoError(t, err)
assert.Equal(t, &ss[3], s1)
assert.Equal(t, &ss[3], s2)
})
t.Run("case=should return error on not existing id", func(t *testing.T) {
s, err := ss.GetByID("not existing id")
require.Error(t, err)
assert.Equal(t, (*Schema)(nil), s)
})
}
func TestSchemas_List(t *testing.T) {
ss := Schemas{
Schema{
ID: "foo",
},
Schema{
ID: "bar",
},
Schema{
ID: "foobar",
},
Schema{
ID: config.DefaultIdentityTraitsSchemaID,
},
}
t.Run("case=get all schemas", func(t *testing.T) {
p0 := ss.List(0, 4)
assert.Equal(t, ss, p0)
})
t.Run("case=smaller pages", func(t *testing.T) {
p0, p1 := ss.List(0, 2), ss.List(1, 2)
assert.Equal(t, ss, append(p0, p1...))
})
t.Run("case=indexes out of range", func(t *testing.T) {
p0 := ss.List(-1, 10)
assert.Equal(t, ss, p0)
})
func TestDefaultIdentityTraitsProvider(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
schematest.TestIdentitySchemaProvider(t, schema.NewDefaultIdentityTraitsProvider(reg))
}
func TestGetKeysInOrder(t *testing.T) {
@ -112,12 +28,14 @@ func TestGetKeysInOrder(t *testing.T) {
path string
}{
{schemaRef: "file://./stub/identity.schema.json", keys: []string{"bar", "email"}},
{schemaRef: "file://./stub/complex.schema.json", keys: []string{"meal.name", "meal.chef", "traits.email",
{schemaRef: "file://./stub/complex.schema.json", keys: []string{
"meal.name", "meal.chef", "traits.email",
"traits.stringy", "traits.numby", "traits.booly", "traits.should_big_number", "traits.should_long_string",
"fruits", "vegetables"}},
"fruits", "vegetables",
}},
} {
t.Run(fmt.Sprintf("case=%d schemaRef=%s", i, tc.schemaRef), func(t *testing.T) {
actual, err := GetKeysInOrder(ctx, tc.schemaRef)
actual, err := schema.GetKeysInOrder(context.Background(), tc.schemaRef)
require.NoError(t, err)
assert.Equal(t, tc.keys, actual)
})

View File

@ -0,0 +1,103 @@
// Copyright © 2025 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package test
import (
"encoding/base64"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/schema"
"github.com/ory/x/contextx"
"github.com/ory/x/urlx"
)
func TestIdentitySchemaProvider(t *testing.T, provider schema.IdentitySchemaProvider) {
urlFromID := func(id string) string {
return fmt.Sprintf("http://%s.com", id)
}
schemas := schema.Schemas{
{ID: "foo"},
{ID: "preset://email"},
{ID: config.DefaultIdentityTraitsSchemaID},
}
for i := range schemas {
raw := urlFromID(schemas[i].ID)
schemas[i].RawURL = raw
schemas[i].URL = urlx.ParseOrPanic(raw)
}
ctx := contextx.WithConfigValues(t.Context(), map[string]any{
config.ViperKeyIdentitySchemas: func() (cs config.Schemas) {
for _, s := range schemas {
cs = append(cs, config.Schema{
ID: s.ID,
URL: s.RawURL,
})
}
return cs
}(),
})
list, err := provider.IdentityTraitsSchemas(ctx)
require.NoError(t, err)
t.Run("GetByID", func(t *testing.T) {
t.Run("case=get with raw schemaID", func(t *testing.T) {
for _, schema := range schemas {
actual, err := list.GetByID(schema.ID)
require.NoError(t, err)
require.Equal(t, schema, *actual)
}
})
t.Run("case=get with encoded schemaID", func(t *testing.T) {
for _, schema := range schemas {
encodedID := base64.RawURLEncoding.EncodeToString([]byte(schema.ID))
actual, err := list.GetByID(encodedID)
require.NoError(t, err)
require.Equal(t, schema, *actual)
}
})
t.Run("case=get default schema", func(t *testing.T) {
s1, err := list.GetByID("")
require.NoError(t, err)
s2, err := list.GetByID(config.DefaultIdentityTraitsSchemaID)
require.NoError(t, err)
assert.Equal(t, &schemas[2], s1)
assert.Equal(t, &schemas[2], s2)
})
t.Run("case=should return error on not existing id", func(t *testing.T) {
s, err := list.GetByID("not existing id")
require.Error(t, err)
assert.Equal(t, (*schema.Schema)(nil), s)
})
})
t.Run("List", func(t *testing.T) {
t.Run("case=get all schemas", func(t *testing.T) {
p0 := list.List(0, 4)
assert.Equal(t, schemas, p0)
})
t.Run("case=smaller pages", func(t *testing.T) {
p0, p1 := list.List(0, 2), list.List(1, 2)
assert.Equal(t, schemas, append(p0, p1...))
})
t.Run("case=indexes out of range", func(t *testing.T) {
p0 := list.List(-1, 10)
assert.Equal(t, schemas, p0)
})
})
}