mirror of https://github.com/ory/kratos
chore(kratos): cleanup and improve some tests
GitOrigin-RevId: fe540ee2c43da7ba11c9ed069dc6176f25de4461
This commit is contained in:
parent
9338f3b454
commit
3215ab25b3
|
|
@ -4,6 +4,5 @@
|
|||
"github.com/ory/x","Apache-2.0"
|
||||
"github.com/stretchr/testify","MIT"
|
||||
"go.opentelemetry.io/otel/sdk","Apache-2.0"
|
||||
"go.opentelemetry.io/otel/sdk","BSD-3-Clause"
|
||||
"golang.org/x/text","BSD-3-Clause"
|
||||
|
||||
|
|
|
|||
|
|
|
@ -13,6 +13,7 @@ package client
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// checks if the IdentitySchemaContainer type satisfies the MappedNullable interface at compile time
|
||||
|
|
@ -21,9 +22,9 @@ var _ MappedNullable = &IdentitySchemaContainer{}
|
|||
// IdentitySchemaContainer An Identity JSON Schema Container
|
||||
type IdentitySchemaContainer struct {
|
||||
// The ID of the Identity JSON Schema
|
||||
Id *string `json:"id,omitempty"`
|
||||
Id string `json:"id"`
|
||||
// The actual Identity JSON Schema
|
||||
Schema map[string]interface{} `json:"schema,omitempty"`
|
||||
Schema map[string]interface{} `json:"schema"`
|
||||
AdditionalProperties map[string]interface{}
|
||||
}
|
||||
|
||||
|
|
@ -33,8 +34,10 @@ type _IdentitySchemaContainer IdentitySchemaContainer
|
|||
// This constructor will assign default values to properties that have it defined,
|
||||
// and makes sure properties required by API are set, but the set of arguments
|
||||
// will change when the set of required properties is changed
|
||||
func NewIdentitySchemaContainer() *IdentitySchemaContainer {
|
||||
func NewIdentitySchemaContainer(id string, schema map[string]interface{}) *IdentitySchemaContainer {
|
||||
this := IdentitySchemaContainer{}
|
||||
this.Id = id
|
||||
this.Schema = schema
|
||||
return &this
|
||||
}
|
||||
|
||||
|
|
@ -46,66 +49,50 @@ func NewIdentitySchemaContainerWithDefaults() *IdentitySchemaContainer {
|
|||
return &this
|
||||
}
|
||||
|
||||
// GetId returns the Id field value if set, zero value otherwise.
|
||||
// GetId returns the Id field value
|
||||
func (o *IdentitySchemaContainer) GetId() string {
|
||||
if o == nil || IsNil(o.Id) {
|
||||
if o == nil {
|
||||
var ret string
|
||||
return ret
|
||||
}
|
||||
return *o.Id
|
||||
|
||||
return o.Id
|
||||
}
|
||||
|
||||
// GetIdOk returns a tuple with the Id field value if set, nil otherwise
|
||||
// GetIdOk returns a tuple with the Id field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *IdentitySchemaContainer) GetIdOk() (*string, bool) {
|
||||
if o == nil || IsNil(o.Id) {
|
||||
if o == nil {
|
||||
return nil, false
|
||||
}
|
||||
return o.Id, true
|
||||
return &o.Id, true
|
||||
}
|
||||
|
||||
// HasId returns a boolean if a field has been set.
|
||||
func (o *IdentitySchemaContainer) HasId() bool {
|
||||
if o != nil && !IsNil(o.Id) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetId gets a reference to the given string and assigns it to the Id field.
|
||||
// SetId sets field value
|
||||
func (o *IdentitySchemaContainer) SetId(v string) {
|
||||
o.Id = &v
|
||||
o.Id = v
|
||||
}
|
||||
|
||||
// GetSchema returns the Schema field value if set, zero value otherwise.
|
||||
// GetSchema returns the Schema field value
|
||||
func (o *IdentitySchemaContainer) GetSchema() map[string]interface{} {
|
||||
if o == nil || IsNil(o.Schema) {
|
||||
if o == nil {
|
||||
var ret map[string]interface{}
|
||||
return ret
|
||||
}
|
||||
|
||||
return o.Schema
|
||||
}
|
||||
|
||||
// GetSchemaOk returns a tuple with the Schema field value if set, nil otherwise
|
||||
// GetSchemaOk returns a tuple with the Schema field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *IdentitySchemaContainer) GetSchemaOk() (map[string]interface{}, bool) {
|
||||
if o == nil || IsNil(o.Schema) {
|
||||
if o == nil {
|
||||
return map[string]interface{}{}, false
|
||||
}
|
||||
return o.Schema, true
|
||||
}
|
||||
|
||||
// HasSchema returns a boolean if a field has been set.
|
||||
func (o *IdentitySchemaContainer) HasSchema() bool {
|
||||
if o != nil && !IsNil(o.Schema) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetSchema gets a reference to the given map[string]interface{} and assigns it to the Schema field.
|
||||
// SetSchema sets field value
|
||||
func (o *IdentitySchemaContainer) SetSchema(v map[string]interface{}) {
|
||||
o.Schema = v
|
||||
}
|
||||
|
|
@ -120,12 +107,8 @@ func (o IdentitySchemaContainer) MarshalJSON() ([]byte, error) {
|
|||
|
||||
func (o IdentitySchemaContainer) ToMap() (map[string]interface{}, error) {
|
||||
toSerialize := map[string]interface{}{}
|
||||
if !IsNil(o.Id) {
|
||||
toSerialize["id"] = o.Id
|
||||
}
|
||||
if !IsNil(o.Schema) {
|
||||
toSerialize["schema"] = o.Schema
|
||||
}
|
||||
toSerialize["id"] = o.Id
|
||||
toSerialize["schema"] = o.Schema
|
||||
|
||||
for key, value := range o.AdditionalProperties {
|
||||
toSerialize[key] = value
|
||||
|
|
@ -135,6 +118,28 @@ func (o IdentitySchemaContainer) ToMap() (map[string]interface{}, error) {
|
|||
}
|
||||
|
||||
func (o *IdentitySchemaContainer) UnmarshalJSON(data []byte) (err error) {
|
||||
// This validates that all required properties are included in the JSON object
|
||||
// by unmarshalling the object into a generic map with string keys and checking
|
||||
// that every required field exists as a key in the generic map.
|
||||
requiredProperties := []string{
|
||||
"id",
|
||||
"schema",
|
||||
}
|
||||
|
||||
allProperties := make(map[string]interface{})
|
||||
|
||||
err = json.Unmarshal(data, &allProperties)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, requiredProperty := range requiredProperties {
|
||||
if _, exists := allProperties[requiredProperty]; !exists {
|
||||
return fmt.Errorf("no value given for required property %v", requiredProperty)
|
||||
}
|
||||
}
|
||||
|
||||
varIdentitySchemaContainer := _IdentitySchemaContainer{}
|
||||
|
||||
err = json.Unmarshal(data, &varIdentitySchemaContainer)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
|
@ -23,7 +24,6 @@ import (
|
|||
"github.com/ory/x/jsonnetsecure"
|
||||
"github.com/ory/x/logrusx"
|
||||
"github.com/ory/x/randx"
|
||||
"github.com/ory/x/stringsx"
|
||||
)
|
||||
|
||||
func NewConfigurationWithDefaults(t testing.TB, opts ...configx.OptionModifier) *config.Config {
|
||||
|
|
@ -72,7 +72,7 @@ func NewFastRegistryWithMocks(t *testing.T, opts ...configx.OptionModifier) (*co
|
|||
func NewRegistryDefaultWithDSN(t testing.TB, dsn string, opts ...configx.OptionModifier) (*config.Config, *driver.RegistryDefault) {
|
||||
ctx := context.Background()
|
||||
c := NewConfigurationWithDefaults(t, append([]configx.OptionModifier{configx.WithValues(map[string]interface{}{
|
||||
config.ViperKeyDSN: stringsx.Coalesce(dsn, dbal.NewSQLiteTestDatabase(t)+"&lock=false&max_conns=1"),
|
||||
config.ViperKeyDSN: cmp.Or(dsn, dbal.NewSQLiteTestDatabase(t)+"&lock=false&max_conns=1"),
|
||||
"dev": true,
|
||||
config.ViperKeySecretsCipher: []string{randx.MustString(32, randx.AlphaNum)},
|
||||
config.ViperKeySecretsCookie: []string{randx.MustString(32, randx.AlphaNum)},
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ package client
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// checks if the IdentitySchemaContainer type satisfies the MappedNullable interface at compile time
|
||||
|
|
@ -21,9 +22,9 @@ var _ MappedNullable = &IdentitySchemaContainer{}
|
|||
// IdentitySchemaContainer An Identity JSON Schema Container
|
||||
type IdentitySchemaContainer struct {
|
||||
// The ID of the Identity JSON Schema
|
||||
Id *string `json:"id,omitempty"`
|
||||
Id string `json:"id"`
|
||||
// The actual Identity JSON Schema
|
||||
Schema map[string]interface{} `json:"schema,omitempty"`
|
||||
Schema map[string]interface{} `json:"schema"`
|
||||
AdditionalProperties map[string]interface{}
|
||||
}
|
||||
|
||||
|
|
@ -33,8 +34,10 @@ type _IdentitySchemaContainer IdentitySchemaContainer
|
|||
// This constructor will assign default values to properties that have it defined,
|
||||
// and makes sure properties required by API are set, but the set of arguments
|
||||
// will change when the set of required properties is changed
|
||||
func NewIdentitySchemaContainer() *IdentitySchemaContainer {
|
||||
func NewIdentitySchemaContainer(id string, schema map[string]interface{}) *IdentitySchemaContainer {
|
||||
this := IdentitySchemaContainer{}
|
||||
this.Id = id
|
||||
this.Schema = schema
|
||||
return &this
|
||||
}
|
||||
|
||||
|
|
@ -46,66 +49,50 @@ func NewIdentitySchemaContainerWithDefaults() *IdentitySchemaContainer {
|
|||
return &this
|
||||
}
|
||||
|
||||
// GetId returns the Id field value if set, zero value otherwise.
|
||||
// GetId returns the Id field value
|
||||
func (o *IdentitySchemaContainer) GetId() string {
|
||||
if o == nil || IsNil(o.Id) {
|
||||
if o == nil {
|
||||
var ret string
|
||||
return ret
|
||||
}
|
||||
return *o.Id
|
||||
|
||||
return o.Id
|
||||
}
|
||||
|
||||
// GetIdOk returns a tuple with the Id field value if set, nil otherwise
|
||||
// GetIdOk returns a tuple with the Id field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *IdentitySchemaContainer) GetIdOk() (*string, bool) {
|
||||
if o == nil || IsNil(o.Id) {
|
||||
if o == nil {
|
||||
return nil, false
|
||||
}
|
||||
return o.Id, true
|
||||
return &o.Id, true
|
||||
}
|
||||
|
||||
// HasId returns a boolean if a field has been set.
|
||||
func (o *IdentitySchemaContainer) HasId() bool {
|
||||
if o != nil && !IsNil(o.Id) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetId gets a reference to the given string and assigns it to the Id field.
|
||||
// SetId sets field value
|
||||
func (o *IdentitySchemaContainer) SetId(v string) {
|
||||
o.Id = &v
|
||||
o.Id = v
|
||||
}
|
||||
|
||||
// GetSchema returns the Schema field value if set, zero value otherwise.
|
||||
// GetSchema returns the Schema field value
|
||||
func (o *IdentitySchemaContainer) GetSchema() map[string]interface{} {
|
||||
if o == nil || IsNil(o.Schema) {
|
||||
if o == nil {
|
||||
var ret map[string]interface{}
|
||||
return ret
|
||||
}
|
||||
|
||||
return o.Schema
|
||||
}
|
||||
|
||||
// GetSchemaOk returns a tuple with the Schema field value if set, nil otherwise
|
||||
// GetSchemaOk returns a tuple with the Schema field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *IdentitySchemaContainer) GetSchemaOk() (map[string]interface{}, bool) {
|
||||
if o == nil || IsNil(o.Schema) {
|
||||
if o == nil {
|
||||
return map[string]interface{}{}, false
|
||||
}
|
||||
return o.Schema, true
|
||||
}
|
||||
|
||||
// HasSchema returns a boolean if a field has been set.
|
||||
func (o *IdentitySchemaContainer) HasSchema() bool {
|
||||
if o != nil && !IsNil(o.Schema) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetSchema gets a reference to the given map[string]interface{} and assigns it to the Schema field.
|
||||
// SetSchema sets field value
|
||||
func (o *IdentitySchemaContainer) SetSchema(v map[string]interface{}) {
|
||||
o.Schema = v
|
||||
}
|
||||
|
|
@ -120,12 +107,8 @@ func (o IdentitySchemaContainer) MarshalJSON() ([]byte, error) {
|
|||
|
||||
func (o IdentitySchemaContainer) ToMap() (map[string]interface{}, error) {
|
||||
toSerialize := map[string]interface{}{}
|
||||
if !IsNil(o.Id) {
|
||||
toSerialize["id"] = o.Id
|
||||
}
|
||||
if !IsNil(o.Schema) {
|
||||
toSerialize["schema"] = o.Schema
|
||||
}
|
||||
toSerialize["id"] = o.Id
|
||||
toSerialize["schema"] = o.Schema
|
||||
|
||||
for key, value := range o.AdditionalProperties {
|
||||
toSerialize[key] = value
|
||||
|
|
@ -135,6 +118,28 @@ func (o IdentitySchemaContainer) ToMap() (map[string]interface{}, error) {
|
|||
}
|
||||
|
||||
func (o *IdentitySchemaContainer) UnmarshalJSON(data []byte) (err error) {
|
||||
// This validates that all required properties are included in the JSON object
|
||||
// by unmarshalling the object into a generic map with string keys and checking
|
||||
// that every required field exists as a key in the generic map.
|
||||
requiredProperties := []string{
|
||||
"id",
|
||||
"schema",
|
||||
}
|
||||
|
||||
allProperties := make(map[string]interface{})
|
||||
|
||||
err = json.Unmarshal(data, &allProperties)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, requiredProperty := range requiredProperties {
|
||||
if _, exists := allProperties[requiredProperty]; !exists {
|
||||
return fmt.Errorf("no value given for required property %v", requiredProperty)
|
||||
}
|
||||
}
|
||||
|
||||
varIdentitySchemaContainer := _IdentitySchemaContainer{}
|
||||
|
||||
err = json.Unmarshal(data, &varIdentitySchemaContainer)
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -30,6 +28,7 @@ import (
|
|||
"github.com/ory/kratos/selfservice/flow"
|
||||
"github.com/ory/kratos/selfservice/flow/registration"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
"github.com/ory/x/assertx"
|
||||
"github.com/ory/x/httpx"
|
||||
"github.com/ory/x/ioutilx"
|
||||
|
|
@ -37,15 +36,12 @@ import (
|
|||
)
|
||||
|
||||
func setupServer(t *testing.T, reg *driver.RegistryDefault) *httptest.Server {
|
||||
conf := reg.Config()
|
||||
router := x.NewRouterPublic(reg)
|
||||
admin := x.NewRouterAdmin(reg)
|
||||
|
||||
publicTS, _ := testhelpers.NewKratosServerWithRouters(t, reg, router, admin)
|
||||
publicTS, _ := testhelpers.NewKratosServer(t, reg)
|
||||
redirTS := testhelpers.NewRedirSessionEchoTS(t, reg)
|
||||
ctx := context.Background()
|
||||
conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, redirTS.URL+"/default-return-to")
|
||||
conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+"."+config.DefaultBrowserReturnURL, redirTS.URL+"/registration-return-ts")
|
||||
|
||||
conf := reg.Config()
|
||||
conf.MustSet(t.Context(), config.ViperKeySelfServiceBrowserDefaultReturnTo, redirTS.URL+"/default-return-to") //nolint:staticcheck
|
||||
conf.MustSet(t.Context(), config.ViperKeySelfServiceRegistrationAfter+"."+config.DefaultBrowserReturnURL, redirTS.URL+"/registration-return-ts") //nolint:staticcheck
|
||||
return publicTS
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ func startE2EServerOnly(t *testing.T, configFile string, isTLS bool, configOptio
|
|||
|
||||
t.Log("Starting server...")
|
||||
stdOut, stdErr := &bytes.Buffer{}, &bytes.Buffer{}
|
||||
eg := executor.ExecBackground(nil, stdErr, stdOut, "serve", "--config", configFile, "--watch-courier")
|
||||
eg := executor.ExecBackground(nil, io.MultiWriter(os.Stdout, stdOut), io.MultiWriter(os.Stdout, stdErr), "serve", "--config", configFile, "--watch-courier")
|
||||
|
||||
err = waitTimeout(t, eg, time.Second)
|
||||
if err != nil && tries < 5 {
|
||||
|
|
|
|||
|
|
@ -13,12 +13,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"github.com/urfave/negroni"
|
||||
|
||||
"github.com/ory/kratos/driver"
|
||||
"github.com/ory/kratos/driver/config"
|
||||
|
|
@ -26,7 +24,6 @@ import (
|
|||
kratos "github.com/ory/kratos/internal/httpclient"
|
||||
"github.com/ory/kratos/selfservice/flow/settings"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
"github.com/ory/x/ioutilx"
|
||||
"github.com/ory/x/urlx"
|
||||
)
|
||||
|
|
@ -174,32 +171,6 @@ func NewSettingsLoginAcceptAPIServer(t *testing.T, publicClient *kratos.APIClien
|
|||
return loginTS
|
||||
}
|
||||
|
||||
func NewSettingsAPIServer(t *testing.T, reg *driver.RegistryDefault, ids map[string]*identity.Identity) (*httptest.Server, *httptest.Server, map[string]*http.Client) {
|
||||
ctx := context.Background()
|
||||
public, admin := x.NewRouterPublic(reg), x.NewRouterAdmin(reg)
|
||||
reg.SettingsHandler().RegisterAdminRoutes(admin)
|
||||
|
||||
n := negroni.Classic()
|
||||
n.UseHandler(public)
|
||||
hh := nosurfx.NewTestCSRFHandler(n, reg)
|
||||
reg.WithCSRFHandler(hh)
|
||||
|
||||
reg.SettingsHandler().RegisterPublicRoutes(public)
|
||||
reg.SettingsStrategies(context.Background()).RegisterPublicRoutes(public)
|
||||
reg.LoginHandler().RegisterPublicRoutes(public)
|
||||
reg.LoginHandler().RegisterAdminRoutes(admin)
|
||||
reg.LoginStrategies(context.Background()).RegisterPublicRoutes(public)
|
||||
|
||||
tsp, tsa := httptest.NewServer(hh), httptest.NewServer(admin)
|
||||
t.Cleanup(tsp.Close)
|
||||
t.Cleanup(tsa.Close)
|
||||
|
||||
reg.Config().MustSet(ctx, config.ViperKeyPublicBaseURL, tsp.URL)
|
||||
reg.Config().MustSet(ctx, config.ViperKeyAdminBaseURL, tsa.URL)
|
||||
//#nosec G112
|
||||
return tsp, tsa, AddAndLoginIdentities(t, reg, &httptest.Server{Config: &http.Server{Handler: public}, URL: tsp.URL}, ids)
|
||||
}
|
||||
|
||||
// AddAndLoginIdentities adds the given identities to the store (like a registration flow) and returns http.Clients
|
||||
// which contain their sessions.
|
||||
func AddAndLoginIdentities(t *testing.T, reg *driver.RegistryDefault, public *httptest.Server, ids map[string]*identity.Identity) map[string]*http.Client {
|
||||
|
|
|
|||
|
|
@ -8,15 +8,13 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/urfave/negroni"
|
||||
|
||||
"github.com/ory/kratos/driver"
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
)
|
||||
|
||||
func NewKratosServer(t *testing.T, reg driver.Registry) (public, admin *httptest.Server) {
|
||||
|
|
@ -43,12 +41,12 @@ func NewKratosServerWithCSRFAndRouters(t *testing.T, reg driver.Registry) (publi
|
|||
|
||||
public = httptest.NewServer(nosurfx.NewTestCSRFHandler(rpn, reg))
|
||||
admin = httptest.NewServer(ran)
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
|
||||
// Workaround for:
|
||||
// - https://github.com/golang/go/issues/12610
|
||||
// - https://github.com/golang/go/issues/31054
|
||||
public.URL = strings.Replace(public.URL, "127.0.0.1", "localhost", -1)
|
||||
public.URL = strings.ReplaceAll(public.URL, "127.0.0.1", "localhost")
|
||||
|
||||
if len(reg.Config().GetProvider(ctx).String(config.ViperKeySelfServiceLoginUI)) == 0 {
|
||||
reg.Config().MustSet(ctx, config.ViperKeySelfServiceLoginUI, "http://NewKratosServerWithCSRF/you-forgot-to-set-me/login")
|
||||
|
|
@ -67,14 +65,14 @@ func NewKratosServerWithRouters(t *testing.T, reg driver.Registry, rp *x.RouterP
|
|||
public = httptest.NewServer(rp)
|
||||
admin = httptest.NewServer(ra)
|
||||
|
||||
InitKratosServers(t, reg, public, admin)
|
||||
InitKratosServers(t, reg, public, admin, rp, ra)
|
||||
|
||||
t.Cleanup(public.Close)
|
||||
t.Cleanup(admin.Close)
|
||||
return
|
||||
}
|
||||
|
||||
func InitKratosServers(t *testing.T, reg driver.Registry, public, admin *httptest.Server) {
|
||||
func InitKratosServers(t *testing.T, reg driver.Registry, public, admin *httptest.Server, rp *x.RouterPublic, ra *x.RouterAdmin) {
|
||||
ctx := t.Context()
|
||||
if len(reg.Config().GetProvider(ctx).String(config.ViperKeySelfServiceLoginUI)) == 0 {
|
||||
reg.Config().MustSet(ctx, config.ViperKeySelfServiceLoginUI, "http://NewKratosServerWithRouters/you-forgot-to-set-me/login")
|
||||
|
|
@ -82,15 +80,6 @@ func InitKratosServers(t *testing.T, reg driver.Registry, public, admin *httptes
|
|||
reg.Config().MustSet(ctx, config.ViperKeyPublicBaseURL, public.URL)
|
||||
reg.Config().MustSet(ctx, config.ViperKeyAdminBaseURL, admin.URL)
|
||||
|
||||
reg.RegisterRoutes(context.Background(), public.Config.Handler.(*x.RouterPublic), admin.Config.Handler.(*x.RouterAdmin))
|
||||
}
|
||||
|
||||
func NewKratosServers(t *testing.T, reg driver.Registry) (public, admin *httptest.Server) {
|
||||
public = httptest.NewServer(x.NewRouterPublic(reg))
|
||||
admin = httptest.NewServer(x.NewRouterAdmin(reg))
|
||||
|
||||
public.URL = strings.Replace(public.URL, "127.0.0.1", "localhost", -1)
|
||||
t.Cleanup(public.Close)
|
||||
t.Cleanup(admin.Close)
|
||||
return
|
||||
reg.RegisterPublicRoutes(ctx, rp)
|
||||
reg.RegisterAdminRoutes(ctx, ra)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,10 +31,6 @@ func (p *SessionLifespanProvider) SessionLifespan(context.Context) time.Duration
|
|||
return p.e
|
||||
}
|
||||
|
||||
func NewSessionLifespanProvider(expiresIn time.Duration) *SessionLifespanProvider {
|
||||
return &SessionLifespanProvider{e: expiresIn}
|
||||
}
|
||||
|
||||
func NewSessionClient(t *testing.T, u string) *http.Client {
|
||||
c := NewClientWithCookies(t)
|
||||
MockHydrateCookieClient(t, c, u)
|
||||
|
|
|
|||
|
|
@ -11,15 +11,14 @@ import (
|
|||
)
|
||||
|
||||
//go:embed 404.html
|
||||
var page404HTML []byte
|
||||
var page404HTML string
|
||||
|
||||
//go:embed 404.json
|
||||
var page404JSON []byte
|
||||
var page404JSON string
|
||||
|
||||
// DefaultNotFoundHandler is a default handler for handling 404 errors.
|
||||
var DefaultNotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var contentType string
|
||||
var body []byte
|
||||
var contentType, body string
|
||||
switch httputil.NegotiateContentType(r, []string{
|
||||
"text/html",
|
||||
"text/plain",
|
||||
|
|
@ -27,18 +26,18 @@ var DefaultNotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *htt
|
|||
}, "text/html") {
|
||||
case "text/plain":
|
||||
contentType = "text/plain"
|
||||
body = []byte(`Error 404 - The requested route does not exist. Make sure you are using the right path, domain, and port.`) // #nosec
|
||||
body = "Error 404 - The requested route does not exist. Make sure you are using the right path, domain, and port."
|
||||
case "application/json":
|
||||
contentType = "application/json"
|
||||
body = page404JSON // #nosec
|
||||
case "text/html":
|
||||
fallthrough
|
||||
body = page404JSON
|
||||
default:
|
||||
fallthrough
|
||||
case "text/html":
|
||||
contentType = "text/html"
|
||||
body = page404HTML
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = w.Write(body) // #nosec
|
||||
_, _ = w.Write([]byte(body))
|
||||
})
|
||||
|
|
|
|||
|
|
@ -10,17 +10,17 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
"github.com/ory/kratos/x/redir"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
"github.com/ory/kratos/x/redir"
|
||||
"github.com/ory/x/otelx"
|
||||
"github.com/ory/x/pagination/migrationpagination"
|
||||
)
|
||||
|
|
@ -47,7 +47,10 @@ func NewHandler(r handlerDependencies) *Handler {
|
|||
return &Handler{r: r}
|
||||
}
|
||||
|
||||
const SchemasPath string = "schemas"
|
||||
const (
|
||||
SchemasPath string = "schemas"
|
||||
maxSchemaSize = 1024 * 1024 // 1 MB
|
||||
)
|
||||
|
||||
func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
|
||||
h.r.CSRFHandler().IgnoreGlobs(
|
||||
|
|
@ -68,27 +71,12 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
|
|||
// Raw JSON Schema
|
||||
//
|
||||
// swagger:model identitySchema
|
||||
//
|
||||
//nolint:deadcode,unused
|
||||
//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
|
||||
type identitySchema json.RawMessage
|
||||
|
||||
func (m identitySchema) MarshalJSON() ([]byte, error) {
|
||||
return json.RawMessage(m).MarshalJSON()
|
||||
}
|
||||
|
||||
func (m *identitySchema) UnmarshalJSON(data []byte) error {
|
||||
mm := json.RawMessage(*m)
|
||||
return mm.UnmarshalJSON(data)
|
||||
}
|
||||
type _ json.RawMessage
|
||||
|
||||
// Get Identity JSON Schema Response
|
||||
//
|
||||
// swagger:parameters getIdentitySchema
|
||||
//
|
||||
//nolint:deadcode,unused
|
||||
//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
|
||||
type getIdentitySchema struct {
|
||||
type _ struct {
|
||||
// ID must be set to the ID of schema you want to get
|
||||
//
|
||||
// required: true
|
||||
|
|
@ -136,18 +124,14 @@ func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
src, err := h.ReadSchema(ctx, s)
|
||||
raw, err := h.ReadSchema(ctx, s.URL)
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The file for this JSON Schema ID could not be found or opened. This is a configuration issue.").WithDebugf("%+v", err)))
|
||||
return
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
if _, err := io.Copy(w, src); err != nil {
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The file for this JSON Schema ID could not be found or opened. This is a configuration issue.").WithDebugf("%+v", err)))
|
||||
return
|
||||
}
|
||||
h.r.Writer().Write(w, r, json.RawMessage(raw))
|
||||
}
|
||||
|
||||
// List of Identity JSON Schemas
|
||||
|
|
@ -160,28 +144,24 @@ type IdentitySchemas []identitySchemaContainer
|
|||
// swagger:model identitySchemaContainer
|
||||
type identitySchemaContainer struct {
|
||||
// The ID of the Identity JSON Schema
|
||||
// required: true
|
||||
ID string `json:"id"`
|
||||
// The actual Identity JSON Schema
|
||||
// required: true
|
||||
Schema json.RawMessage `json:"schema"`
|
||||
}
|
||||
|
||||
// List Identity JSON Schemas Response
|
||||
//
|
||||
// swagger:parameters listIdentitySchemas
|
||||
//
|
||||
//nolint:deadcode,unused
|
||||
//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
|
||||
type listIdentitySchemas struct {
|
||||
type _ struct {
|
||||
migrationpagination.RequestParameters
|
||||
}
|
||||
|
||||
// List Identity JSON Schemas Response
|
||||
//
|
||||
// swagger:response identitySchemas
|
||||
//
|
||||
//nolint:deadcode,unused
|
||||
//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
|
||||
type identitySchemasResponse struct {
|
||||
type _ struct {
|
||||
migrationpagination.ResponseHeaderAnnotation
|
||||
|
||||
// in: body
|
||||
|
|
@ -216,53 +196,50 @@ func (h *Handler) getAll(w http.ResponseWriter, r *http.Request) {
|
|||
total := allSchemas.Total()
|
||||
schemas := allSchemas.List(page, itemsPerPage)
|
||||
|
||||
var ss IdentitySchemas
|
||||
for k := range schemas {
|
||||
schema := schemas[k]
|
||||
src, err := h.ReadSchema(ctx, &schema)
|
||||
ss := make(IdentitySchemas, len(schemas))
|
||||
for i, schema := range schemas {
|
||||
raw, err := h.ReadSchema(ctx, schema.URL)
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The file for this JSON Schema ID could not be found or opened. This is a configuration issue.").WithDebugf("%+v", err)))
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The file for a JSON Schema ID could not be found or opened. This is a configuration issue.").WithDebugf("%+v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(io.LimitReader(src, 1024*1024))
|
||||
_ = src.Close()
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The file for this JSON Schema ID could not be found or opened. This is a configuration issue.").WithDebugf("%+v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
ss = append(ss, identitySchemaContainer{
|
||||
ss[i] = identitySchemaContainer{
|
||||
ID: schema.ID,
|
||||
Schema: raw,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
x.PaginationHeader(w, *r.URL, int64(total), page, itemsPerPage)
|
||||
h.r.Writer().Write(w, r, ss)
|
||||
}
|
||||
|
||||
func (h *Handler) ReadSchema(ctx context.Context, schema *Schema) (src io.ReadCloser, err error) {
|
||||
func (h *Handler) ReadSchema(ctx context.Context, uri *url.URL) (data []byte, err error) {
|
||||
ctx, span := h.r.Tracer(ctx).Tracer().Start(ctx, "schema.Handler.ReadSchema")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
if schema.URL.Scheme == "file" {
|
||||
src, err = os.Open(schema.URL.Host + schema.URL.Path)
|
||||
switch uri.Scheme {
|
||||
case "file":
|
||||
data, err = os.ReadFile(uri.Host + uri.Path)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReason("Unable to fetch identity schema."))
|
||||
return nil, errors.WithStack(fmt.Errorf("could not read schema file: %w", err))
|
||||
}
|
||||
} else if schema.URL.Scheme == "base64" {
|
||||
data, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(schema.RawURL, "base64://"))
|
||||
case "base64":
|
||||
data, err = base64.StdEncoding.DecodeString(strings.TrimPrefix(uri.String(), "base64://"))
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReason("Unable to fetch identity schema."))
|
||||
return nil, errors.WithStack(fmt.Errorf("could not decode schema file: %w", err))
|
||||
}
|
||||
src = io.NopCloser(strings.NewReader(string(data)))
|
||||
} else {
|
||||
resp, err := h.r.HTTPClient(ctx).Get(schema.URL.String())
|
||||
default:
|
||||
resp, err := h.r.HTTPClient(ctx).Get(uri.String())
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReason("Unable to fetch identity schema."))
|
||||
return nil, errors.WithStack(fmt.Errorf("could not fetch schema: %w", err))
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errors.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
data, err = io.ReadAll(io.LimitReader(resp.Body, maxSchemaSize))
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(fmt.Errorf("could not read schema response: %w", err))
|
||||
}
|
||||
src = resp.Body
|
||||
}
|
||||
return src, nil
|
||||
return data, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,8 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -22,253 +21,176 @@ import (
|
|||
_ "github.com/ory/jsonschema/v3/fileloader"
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/internal"
|
||||
"github.com/ory/kratos/schema"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/x/configx"
|
||||
"github.com/ory/x/contextx"
|
||||
"github.com/ory/x/urlx"
|
||||
)
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
router := x.NewTestRouterPublic(t)
|
||||
ts := contextx.NewConfigurableTestServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
schemas := map[string]struct {
|
||||
uri string
|
||||
getRaw func() ([]byte, error)
|
||||
}{
|
||||
"default": {
|
||||
uri: "file://./stub/identity.schema.json",
|
||||
getRaw: func() ([]byte, error) { return os.ReadFile("./stub/identity.schema.json") },
|
||||
},
|
||||
"identity2": {
|
||||
uri: "file://./stub/identity-2.schema.json",
|
||||
getRaw: func() ([]byte, error) { return os.ReadFile("./stub/identity-2.schema.json") },
|
||||
},
|
||||
"base64": {
|
||||
uri: "base64://ewogICIkc2NoZW1hIjogImh0dHA6Ly9qc29uLXNjaGVtYS5vcmcvZHJhZnQtMDcvc2NoZW1hIyIsCiAgInR5cGUiOiAib2JqZWN0IiwKICAicHJvcGVydGllcyI6IHsKICAgICJiYXIiOiB7CiAgICAgICJ0eXBlIjogInN0cmluZyIKICAgIH0KICB9LAogICJyZXF1aXJlZCI6IFsKICAgICJiYXIiCiAgXQp9",
|
||||
getRaw: func() ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString("ewogICIkc2NoZW1hIjogImh0dHA6Ly9qc29uLXNjaGVtYS5vcmcvZHJhZnQtMDcvc2NoZW1hIyIsCiAgInR5cGUiOiAib2JqZWN0IiwKICAicHJvcGVydGllcyI6IHsKICAgICJiYXIiOiB7CiAgICAgICJ0eXBlIjogInN0cmluZyIKICAgIH0KICB9LAogICJyZXF1aXJlZCI6IFsKICAgICJiYXIiCiAgXQp9")
|
||||
},
|
||||
},
|
||||
"unreachable": {
|
||||
uri: "http://127.0.0.1:12345/unreachable-schema",
|
||||
getRaw: func() ([]byte, error) {
|
||||
return nil, fmt.Errorf("dial tcp 127.0.0.1:12345: connect: connection refused")
|
||||
},
|
||||
},
|
||||
"no-file": {
|
||||
uri: "file://./stub/does-not-exist.schema.json",
|
||||
getRaw: func() ([]byte, error) { return nil, fmt.Errorf("no such file or directory") },
|
||||
},
|
||||
"directory": {
|
||||
uri: "file://./stub",
|
||||
getRaw: func() ([]byte, error) { return nil, fmt.Errorf("is a directory") },
|
||||
},
|
||||
"preset://email": {
|
||||
uri: "file://./stub/identity-2.schema.json",
|
||||
getRaw: func() ([]byte, error) { return os.ReadFile("./stub/identity-2.schema.json") },
|
||||
},
|
||||
}
|
||||
configSchemas := make(config.Schemas, 0, len(schemas))
|
||||
for id, s := range schemas {
|
||||
configSchemas = append(configSchemas, config.Schema{
|
||||
ID: id,
|
||||
URL: s.uri,
|
||||
})
|
||||
}
|
||||
|
||||
_, reg := internal.NewFastRegistryWithMocks(t, configx.WithValues(map[string]any{
|
||||
config.ViperKeyPublicBaseURL: ts.URL,
|
||||
config.ViperKeyDefaultIdentitySchemaID: "default",
|
||||
config.ViperKeyIdentitySchemas: configSchemas,
|
||||
}))
|
||||
reg.SchemaHandler().RegisterPublicRoutes(router)
|
||||
ts := httptest.NewServer(router)
|
||||
defer ts.Close()
|
||||
|
||||
schemas := schema.Schemas{
|
||||
{
|
||||
ID: "default",
|
||||
URL: urlx.ParseOrPanic("file://./stub/identity.schema.json"),
|
||||
RawURL: "file://./stub/identity.schema.json",
|
||||
},
|
||||
{
|
||||
ID: "identity2",
|
||||
URL: urlx.ParseOrPanic("file://./stub/identity-2.schema.json"),
|
||||
RawURL: "file://./stub/identity-2.schema.json",
|
||||
},
|
||||
{
|
||||
ID: "base64",
|
||||
URL: urlx.ParseOrPanic("base64://ewogICIkc2NoZW1hIjogImh0dHA6Ly9qc29uLXNjaGVtYS5vcmcvZHJhZnQtMDcvc2NoZW1hIyIsCiAgInR5cGUiOiAib2JqZWN0IiwKICAicHJvcGVydGllcyI6IHsKICAgICJiYXIiOiB7CiAgICAgICJ0eXBlIjogInN0cmluZyIKICAgIH0KICB9LAogICJyZXF1aXJlZCI6IFsKICAgICJiYXIiCiAgXQp9"),
|
||||
RawURL: "base64://ewogICIkc2NoZW1hIjogImh0dHA6Ly9qc29uLXNjaGVtYS5vcmcvZHJhZnQtMDcvc2NoZW1hIyIsCiAgInR5cGUiOiAib2JqZWN0IiwKICAicHJvcGVydGllcyI6IHsKICAgICJiYXIiOiB7CiAgICAgICJ0eXBlIjogInN0cmluZyIKICAgIH0KICB9LAogICJyZXF1aXJlZCI6IFsKICAgICJiYXIiCiAgXQp9",
|
||||
},
|
||||
{
|
||||
ID: "unreachable",
|
||||
URL: urlx.ParseOrPanic("http://127.0.0.1:12345/unreachable-schema"),
|
||||
RawURL: "http://127.0.0.1:12345/unreachable-schema",
|
||||
},
|
||||
{
|
||||
ID: "no-file",
|
||||
URL: urlx.ParseOrPanic("file://./stub/does-not-exist.schema.json"),
|
||||
RawURL: "file://./stub/does-not-exist.schema.json",
|
||||
},
|
||||
{
|
||||
ID: "directory",
|
||||
URL: urlx.ParseOrPanic("file://./stub"),
|
||||
RawURL: "file://./stub",
|
||||
},
|
||||
{
|
||||
ID: "preset://email",
|
||||
URL: urlx.ParseOrPanic("file://./stub/identity-2.schema.json"),
|
||||
RawURL: "file://./stub/identity-2.schema.json",
|
||||
},
|
||||
}
|
||||
|
||||
getSchemaById := func(id string) *schema.Schema {
|
||||
s, err := schemas.GetByID(id)
|
||||
require.NoError(t, err)
|
||||
return s
|
||||
}
|
||||
|
||||
getFromTS := func(t *testing.T, url string, expectCode int) []byte {
|
||||
res, err := ts.Client().Get(url)
|
||||
getReq := func(ctx context.Context, t *testing.T, path string, expectCode int) []byte {
|
||||
res, err := ts.Client(ctx).Get(ts.URL + path)
|
||||
require.NoError(t, err)
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
|
||||
require.EqualValues(t, expectCode, res.StatusCode, "%s", body)
|
||||
require.EqualValuesf(t, expectCode, res.StatusCode, "%s", body)
|
||||
return body
|
||||
}
|
||||
|
||||
getFromTSById := func(t *testing.T, id string, expectCode int) []byte {
|
||||
return getFromTS(t, fmt.Sprintf("%s/schemas/%s", ts.URL, id), expectCode)
|
||||
for id, s := range schemas {
|
||||
t.Run(fmt.Sprintf("case=get %s schema", id), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expected, err := s.getRaw()
|
||||
expectedStatus := http.StatusOK
|
||||
if err != nil {
|
||||
expectedStatus = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
actual := getReq(t.Context(), t, fmt.Sprintf("/schemas/%s", url.PathEscape(id)), expectedStatus)
|
||||
|
||||
if expectedStatus == http.StatusOK {
|
||||
require.JSONEq(t, string(expected), string(actual))
|
||||
} else {
|
||||
require.Contains(t, string(actual), "could not be found or opened")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
getFromTSPaginated := func(t *testing.T, page, perPage, expectCode int) []byte {
|
||||
return getFromTS(t, fmt.Sprintf("%s/schemas?page=%d&per_page=%d", ts.URL, page, perPage), expectCode)
|
||||
}
|
||||
t.Run("case=get schema with base64 encoded ID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
getFromFS := func(id string) []byte {
|
||||
schema := getSchemaById(id)
|
||||
expected, err := schemas["preset://email"].getRaw()
|
||||
require.NoError(t, err)
|
||||
|
||||
if schema.URL.Scheme == "file" {
|
||||
raw, err := os.ReadFile(strings.TrimPrefix(schema.RawURL, "file://"))
|
||||
require.NoError(t, err)
|
||||
return raw
|
||||
} else if schema.URL.Scheme == "base64" {
|
||||
data, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(schema.RawURL, "base64://"))
|
||||
require.NoError(t, err)
|
||||
return data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
setSchemas := func(newSchemas schema.Schemas) {
|
||||
schemas = newSchemas
|
||||
var schemasConfig []config.Schema
|
||||
for _, s := range schemas {
|
||||
schemasConfig = append(schemasConfig, config.Schema{
|
||||
ID: s.ID,
|
||||
URL: s.RawURL,
|
||||
})
|
||||
}
|
||||
conf.MustSet(ctx, config.ViperKeyIdentitySchemas, schemasConfig)
|
||||
}
|
||||
|
||||
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, ts.URL)
|
||||
conf.MustSet(ctx, config.ViperKeyDefaultIdentitySchemaID, config.DefaultIdentityTraitsSchemaID)
|
||||
setSchemas(schemas)
|
||||
|
||||
t.Run("case=get default schema", func(t *testing.T) {
|
||||
server := getFromTSById(t, config.DefaultIdentityTraitsSchemaID, http.StatusOK)
|
||||
file := getFromFS(config.DefaultIdentityTraitsSchemaID)
|
||||
require.JSONEq(t, string(file), string(server))
|
||||
})
|
||||
|
||||
t.Run("case=get other schema", func(t *testing.T) {
|
||||
server := getFromTSById(t, "identity2", http.StatusOK)
|
||||
file := getFromFS("identity2")
|
||||
require.JSONEq(t, string(file), string(server))
|
||||
})
|
||||
|
||||
t.Run("case=get base64 schema", func(t *testing.T) {
|
||||
server := getFromTSById(t, "base64", http.StatusOK)
|
||||
file := getFromFS("base64")
|
||||
require.JSONEq(t, string(file), string(server))
|
||||
})
|
||||
|
||||
t.Run("case=get encoded schema", func(t *testing.T) {
|
||||
server := getFromTSById(t, "cHJlc2V0Oi8vZW1haWw", http.StatusOK)
|
||||
file := getFromFS("preset://email")
|
||||
require.JSONEq(t, string(file), string(server))
|
||||
})
|
||||
|
||||
t.Run("case=get unreachable schema", func(t *testing.T) {
|
||||
reason := getFromTSById(t, "unreachable", http.StatusInternalServerError)
|
||||
require.Contains(t, string(reason), "could not be found or opened")
|
||||
})
|
||||
|
||||
t.Run("case=get no-file schema", func(t *testing.T) {
|
||||
reason := getFromTSById(t, "no-file", http.StatusInternalServerError)
|
||||
require.Contains(t, string(reason), "could not be found or opened")
|
||||
})
|
||||
|
||||
t.Run("case=get directory schema", func(t *testing.T) {
|
||||
reason := getFromTSById(t, "directory", http.StatusInternalServerError)
|
||||
require.Contains(t, string(reason), "could not be found or opened")
|
||||
})
|
||||
|
||||
t.Run("case=get not-existing schema", func(t *testing.T) {
|
||||
_ = getFromTSById(t, "not-existing", http.StatusNotFound)
|
||||
actual := getReq(t.Context(), t, "/schemas/"+base64.RawURLEncoding.EncodeToString([]byte("preset://email")), http.StatusOK)
|
||||
require.JSONEq(t, string(expected), string(actual))
|
||||
})
|
||||
|
||||
t.Run("case=get all schemas", func(t *testing.T) {
|
||||
setSchemas(schema.Schemas{
|
||||
{
|
||||
ID: "default",
|
||||
URL: urlx.ParseOrPanic("file://./stub/identity.schema.json"),
|
||||
RawURL: "file://./stub/identity.schema.json",
|
||||
},
|
||||
{
|
||||
ID: "identity2",
|
||||
URL: urlx.ParseOrPanic("file://./stub/identity-2.schema.json"),
|
||||
RawURL: "file://./stub/identity-2.schema.json",
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
|
||||
body := getFromTSPaginated(t, 0, 2, http.StatusOK)
|
||||
defaultSchema, err := configSchemas.FindSchemaByID("default")
|
||||
require.NoError(t, err)
|
||||
identity2Schema, err := configSchemas.FindSchemaByID("identity2")
|
||||
require.NoError(t, err)
|
||||
ctx := contextx.WithConfigValue(t.Context(), config.ViperKeyIdentitySchemas, config.Schemas{*defaultSchema, *identity2Schema})
|
||||
|
||||
getSchemasPaginated := func(t *testing.T, page, perPage, expectCode int) []byte {
|
||||
return getReq(ctx, t, fmt.Sprintf("/schemas?page=%d&per_page=%d", page, perPage), expectCode)
|
||||
}
|
||||
|
||||
body := getSchemasPaginated(t, 0, 10, http.StatusOK)
|
||||
|
||||
var result []client.IdentitySchemaContainer
|
||||
require.NoError(t, json.Unmarshal(body, &result), "%s", body)
|
||||
require.NoErrorf(t, json.Unmarshal(body, &result), "%s", body)
|
||||
|
||||
ids_orig := []string{}
|
||||
for _, s := range schemas {
|
||||
ids_orig = append(ids_orig, s.ID)
|
||||
}
|
||||
ids_list := []string{}
|
||||
var actualIDs []string
|
||||
for _, s := range result {
|
||||
ids_list = append(ids_list, *s.Id)
|
||||
actualIDs = append(actualIDs, s.Id)
|
||||
}
|
||||
for _, id := range ids_orig {
|
||||
require.Contains(t, ids_list, id)
|
||||
assert.Equal(t, []string{defaultSchema.ID, identity2Schema.ID}, actualIDs)
|
||||
|
||||
assertCorrectSchema := func(t *testing.T, r client.IdentitySchemaContainer) {
|
||||
expected, err := schemas[r.Id].getRaw()
|
||||
require.NoError(t, err)
|
||||
actual, err := json.Marshal(r.Schema)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, string(expected), string(actual))
|
||||
}
|
||||
|
||||
for _, s := range schemas {
|
||||
for _, r := range result {
|
||||
if *r.Id == s.ID {
|
||||
j, err := json.Marshal(r.Schema)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, string(getFromFS(s.ID)), string(j))
|
||||
}
|
||||
}
|
||||
for _, r := range result {
|
||||
assertCorrectSchema(t, r)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("case=get paginated schemas", func(t *testing.T) {
|
||||
setSchemas(schema.Schemas{
|
||||
{
|
||||
ID: "default",
|
||||
URL: urlx.ParseOrPanic("file://./stub/identity.schema.json"),
|
||||
RawURL: "file://./stub/identity.schema.json",
|
||||
},
|
||||
{
|
||||
ID: "identity2",
|
||||
URL: urlx.ParseOrPanic("file://./stub/identity-2.schema.json"),
|
||||
RawURL: "file://./stub/identity-2.schema.json",
|
||||
},
|
||||
})
|
||||
for page := range 2 {
|
||||
t.Run(fmt.Sprintf("page=%d", page), func(t *testing.T) {
|
||||
body := getSchemasPaginated(t, page, 1, http.StatusOK)
|
||||
|
||||
body1, body2 := getFromTSPaginated(t, 0, 1, http.StatusOK), getFromTSPaginated(t, 1, 1, http.StatusOK)
|
||||
var result []client.IdentitySchemaContainer
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
|
||||
var result1, result2 schema.IdentitySchemas
|
||||
require.NoError(t, json.Unmarshal(body1, &result1))
|
||||
require.NoError(t, json.Unmarshal(body2, &result2))
|
||||
|
||||
result := append(result1, result2...)
|
||||
|
||||
ids_orig := []string{}
|
||||
for _, s := range schemas {
|
||||
ids_orig = append(ids_orig, s.ID)
|
||||
}
|
||||
ids_list := []string{}
|
||||
for _, s := range result {
|
||||
ids_list = append(ids_list, s.ID)
|
||||
}
|
||||
for _, id := range ids_orig {
|
||||
require.Contains(t, ids_list, id)
|
||||
require.Len(t, result, 1)
|
||||
assert.Equal(t, actualIDs[page], result[0].Id)
|
||||
assertCorrectSchema(t, result[0])
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("case=read schema", func(t *testing.T) {
|
||||
setSchemas(schema.Schemas{
|
||||
{
|
||||
ID: "default",
|
||||
URL: urlx.ParseOrPanic("file://./stub/identity.schema.json"),
|
||||
RawURL: "file://./stub/identity.schema.json",
|
||||
},
|
||||
{
|
||||
ID: "default",
|
||||
URL: urlx.ParseOrPanic(fmt.Sprintf("%s/schemas/default", ts.URL)),
|
||||
RawURL: fmt.Sprintf("%s/schemas/default", ts.URL),
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
|
||||
src, err := reg.SchemaHandler().ReadSchema(ctx, &schemas[0])
|
||||
require.NoError(t, err)
|
||||
defer src.Close()
|
||||
for _, s := range schemas {
|
||||
expected, expectedErr := s.getRaw()
|
||||
|
||||
src, err = reg.SchemaHandler().ReadSchema(ctx, &schemas[1])
|
||||
require.NoError(t, err)
|
||||
defer src.Close()
|
||||
actual, err := reg.SchemaHandler().ReadSchema(t.Context(), urlx.ParseOrPanic(s.uri))
|
||||
if expectedErr == nil {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.ErrorContains(t, err, expectedErr.Error()) // not using error.is because some of the errors are not accessible
|
||||
}
|
||||
|
||||
if expectedErr == nil {
|
||||
require.JSONEq(t, string(expected), string(actual))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -186,8 +186,8 @@ func (s *Strategy) CountActiveMultiFactorCredentials(ctx context.Context, cc map
|
|||
return validAddresses, nil
|
||||
}
|
||||
|
||||
func NewStrategy(deps any) *Strategy {
|
||||
return &Strategy{deps: deps.(strategyDependencies), dx: decoderx.NewHTTP()}
|
||||
func NewStrategy(deps strategyDependencies) *Strategy {
|
||||
return &Strategy{deps: deps, dx: decoderx.NewHTTP()}
|
||||
}
|
||||
|
||||
func (s *Strategy) ID() identity.CredentialsType {
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ type Strategy struct {
|
|||
hd *decoderx.HTTP
|
||||
}
|
||||
|
||||
func NewStrategy(d any) *Strategy {
|
||||
func NewStrategy(d dependencies) *Strategy {
|
||||
return &Strategy{
|
||||
d: d.(dependencies),
|
||||
d: d,
|
||||
v: validator.New(),
|
||||
hd: decoderx.NewHTTP(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,8 +86,8 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
func NewStrategy(d any) *Strategy {
|
||||
return &Strategy{d: d.(strategyDependencies), dx: decoderx.NewHTTP()}
|
||||
func NewStrategy(d strategyDependencies) *Strategy {
|
||||
return &Strategy{d: d, dx: decoderx.NewHTTP()}
|
||||
}
|
||||
|
||||
func (s *Strategy) NodeGroup() node.UiNodeGroup {
|
||||
|
|
|
|||
|
|
@ -76,9 +76,9 @@ type Strategy struct {
|
|||
hd *decoderx.HTTP
|
||||
}
|
||||
|
||||
func NewStrategy(d any) *Strategy {
|
||||
func NewStrategy(d lookupStrategyDependencies) *Strategy {
|
||||
return &Strategy{
|
||||
d: d.(lookupStrategyDependencies),
|
||||
d: d,
|
||||
hd: decoderx.NewHTTP(),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -276,9 +276,9 @@ func (s *Strategy) SetOnConflictingIdentity(t testing.TB, handler ConflictingIde
|
|||
s.conflictingIdentityPolicy = handler
|
||||
}
|
||||
|
||||
func NewStrategy(d any, opts ...NewStrategyOpt) *Strategy {
|
||||
func NewStrategy(d Dependencies, opts ...NewStrategyOpt) *Strategy {
|
||||
s := &Strategy{
|
||||
d: d.(Dependencies),
|
||||
d: d,
|
||||
validator: schema.NewValidator(),
|
||||
credType: identity.CredentialsTypeOIDC,
|
||||
handleUnknownProviderError: func(err error) error { return err },
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ import (
|
|||
"github.com/ory/kratos/ui/node"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
"github.com/ory/x/configx"
|
||||
"github.com/ory/x/contextx"
|
||||
"github.com/ory/x/snapshotx"
|
||||
"github.com/ory/x/sqlxx"
|
||||
|
|
@ -49,17 +50,20 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
t.Skip()
|
||||
}
|
||||
|
||||
var (
|
||||
conf, reg = internal.NewFastRegistryWithMocks(t)
|
||||
subject string
|
||||
claims idTokenClaims
|
||||
scope []string
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t,
|
||||
configx.WithValues(testhelpers.DefaultIdentitySchemaConfig("file://./stub/settings.schema.json")),
|
||||
configx.WithValue(config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh/kratos"),
|
||||
)
|
||||
|
||||
var (
|
||||
subject string
|
||||
claims idTokenClaims
|
||||
scope []string
|
||||
)
|
||||
remoteAdmin, remotePublic, _ := newHydra(t, &subject, &claims, &scope)
|
||||
uiTS := newUI(t, reg)
|
||||
errTS := testhelpers.NewErrorTestServer(t, reg)
|
||||
publicTS, adminTS := testhelpers.NewKratosServers(t, reg)
|
||||
publicTS, _ := testhelpers.NewKratosServer(t, reg)
|
||||
|
||||
viperSetProviderConfig(
|
||||
t,
|
||||
|
|
@ -73,9 +77,6 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
newOIDCProvider(t, publicTS, remotePublic, remoteAdmin, "google"),
|
||||
newOIDCProvider(t, publicTS, remotePublic, remoteAdmin, "github"),
|
||||
)
|
||||
testhelpers.InitKratosServers(t, reg, publicTS, adminTS)
|
||||
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/settings.schema.json")
|
||||
conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh/kratos")
|
||||
|
||||
// Make test data for this test run unique
|
||||
testID := x.NewUUID().String()
|
||||
|
|
@ -133,8 +134,8 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
agents := testhelpers.AddAndLoginIdentities(t, reg, publicTS, users)
|
||||
|
||||
newProfileFlow := func(t *testing.T, client *http.Client, redirectTo string, exp time.Duration) *settings.Flow {
|
||||
req, err := reg.SettingsFlowPersister().GetSettingsFlow(context.Background(),
|
||||
x.ParseUUID(string(testhelpers.InitializeSettingsFlowViaBrowser(t, client, false, publicTS).Id)))
|
||||
req, err := reg.SettingsFlowPersister().GetSettingsFlow(t.Context(),
|
||||
x.ParseUUID(testhelpers.InitializeSettingsFlowViaBrowser(t, client, false, publicTS).Id))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, req.Active)
|
||||
|
||||
|
|
@ -225,7 +226,7 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
} {
|
||||
t.Run("agent="+tc.agent, func(t *testing.T) {
|
||||
rs := nprSDK(t, agents[tc.agent], "", time.Hour)
|
||||
snapshotx.SnapshotTExcept(t, rs.Ui.Nodes, []string{"0.attributes.value", "1.attributes.value"})
|
||||
snapshotx.SnapshotT(t, rs.Ui.Nodes, snapshotx.ExceptPaths("0.attributes.value", "1.attributes.value"))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
|
@ -335,9 +336,9 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
lf, _, err := fa.GetLoginFlow(context.Background()).Id(res.Request.URL.Query()["flow"][0]).Execute()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, node := range lf.Ui.Nodes {
|
||||
if node.Group == "oidc" && node.Attributes.UiNodeInputAttributes.Name == "provider" {
|
||||
assert.Contains(t, []string{"ory", "github"}, node.Attributes.UiNodeInputAttributes.Value)
|
||||
for _, n := range lf.Ui.Nodes {
|
||||
if n.Group == "oidc" && n.Attributes.UiNodeInputAttributes.Name == "provider" {
|
||||
assert.Contains(t, []string{"ory", "github"}, n.Attributes.UiNodeInputAttributes.Value)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -388,7 +389,7 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
assert.Contains(t, gjson.GetBytes(body, "ui.action").String(), publicTS.URL+settings.RouteSubmitFlow+"?flow=")
|
||||
|
||||
// The original options to link google and github are still there
|
||||
snapshotx.SnapshotTExcept(t, json.RawMessage(gjson.GetBytes(body, `ui.nodes`).Raw), []string{"0.attributes.value", "1.attributes.value"})
|
||||
snapshotx.SnapshotT(t, json.RawMessage(gjson.GetBytes(body, `ui.nodes`).Raw), snapshotx.ExceptPaths("0.attributes.value", "1.attributes.value"))
|
||||
|
||||
assert.Contains(t, gjson.GetBytes(body, `ui.messages.0.text`).String(),
|
||||
"can not link unknown or already existing OpenID Connect connection")
|
||||
|
|
@ -462,13 +463,13 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
require.EqualValues(t, flow.StateSuccess, updatedFlowSDK.State)
|
||||
|
||||
t.Run("flow=original", func(t *testing.T) {
|
||||
snapshotx.SnapshotTExcept(t, originalFlow.Ui.Nodes, []string{"0.attributes.value", "1.attributes.value"})
|
||||
snapshotx.SnapshotT(t, originalFlow.Ui.Nodes, snapshotx.ExceptPaths("0.attributes.value", "1.attributes.value"))
|
||||
})
|
||||
t.Run("flow=response", func(t *testing.T) {
|
||||
snapshotx.SnapshotTExcept(t, json.RawMessage(gjson.GetBytes(updatedFlow, "ui.nodes").Raw), []string{"0.attributes.value", "1.attributes.value"})
|
||||
snapshotx.SnapshotT(t, json.RawMessage(gjson.GetBytes(updatedFlow, "ui.nodes").Raw), snapshotx.ExceptPaths("0.attributes.value", "1.attributes.value"))
|
||||
})
|
||||
t.Run("flow=fetch", func(t *testing.T) {
|
||||
snapshotx.SnapshotTExcept(t, updatedFlowSDK.Ui.Nodes, []string{"0.attributes.value", "1.attributes.value"})
|
||||
snapshotx.SnapshotT(t, updatedFlowSDK.Ui.Nodes, snapshotx.ExceptPaths("0.attributes.value", "1.attributes.value"))
|
||||
})
|
||||
|
||||
checkCredentials(t, true, users[agent].ID, provider, subject, true)
|
||||
|
|
@ -516,7 +517,7 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.EqualValues(t, flow.StateSuccess, rs.State)
|
||||
|
||||
snapshotx.SnapshotTExcept(t, rs.Ui.Nodes, []string{"0.attributes.value", "1.attributes.value"})
|
||||
snapshotx.SnapshotT(t, rs.Ui.Nodes, snapshotx.ExceptPaths("0.attributes.value", "1.attributes.value"))
|
||||
|
||||
checkCredentials(t, true, users[agent].ID, provider, subject, true)
|
||||
})
|
||||
|
|
@ -600,9 +601,9 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
lf, _, err := fa.GetLoginFlow(context.Background()).Id(res.Request.URL.Query()["flow"][0]).Execute()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, node := range lf.Ui.Nodes {
|
||||
if node.Group == "oidc" && node.Attributes.UiNodeInputAttributes.Name == "provider" {
|
||||
assert.Contains(t, []string{"ory", "github"}, node.Attributes.UiNodeInputAttributes.Value)
|
||||
for _, n := range lf.Ui.Nodes {
|
||||
if n.Group == "oidc" && n.Attributes.UiNodeInputAttributes.Name == "provider" {
|
||||
assert.Contains(t, []string{"ory", "github"}, n.Attributes.UiNodeInputAttributes.Value)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -77,9 +77,9 @@ type Strategy struct {
|
|||
hd *decoderx.HTTP
|
||||
}
|
||||
|
||||
func NewStrategy(d any) *Strategy {
|
||||
func NewStrategy(d strategyDependencies) *Strategy {
|
||||
return &Strategy{
|
||||
d: d.(strategyDependencies),
|
||||
d: d,
|
||||
hd: decoderx.NewHTTP(),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,9 +82,9 @@ type Strategy struct {
|
|||
hd *decoderx.HTTP
|
||||
}
|
||||
|
||||
func NewStrategy(d any) *Strategy {
|
||||
func NewStrategy(d registrationStrategyDependencies) *Strategy {
|
||||
return &Strategy{
|
||||
d: d.(registrationStrategyDependencies),
|
||||
d: d,
|
||||
v: validator.New(),
|
||||
hd: decoderx.NewHTTP(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,8 +76,8 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
func NewStrategy(d any) *Strategy {
|
||||
return &Strategy{d: d.(strategyDependencies), dc: decoderx.NewHTTP()}
|
||||
func NewStrategy(d strategyDependencies) *Strategy {
|
||||
return &Strategy{d: d, dc: decoderx.NewHTTP()}
|
||||
}
|
||||
|
||||
func (s *Strategy) SettingsStrategyID() string {
|
||||
|
|
|
|||
|
|
@ -76,9 +76,9 @@ type Strategy struct {
|
|||
hd *decoderx.HTTP
|
||||
}
|
||||
|
||||
func NewStrategy(d any) *Strategy {
|
||||
func NewStrategy(d totpStrategyDependencies) *Strategy {
|
||||
return &Strategy{
|
||||
d: d.(totpStrategyDependencies),
|
||||
d: d,
|
||||
hd: decoderx.NewHTTP(),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,9 +77,9 @@ type Strategy struct {
|
|||
hd *decoderx.HTTP
|
||||
}
|
||||
|
||||
func NewStrategy(d any) *Strategy {
|
||||
func NewStrategy(d webauthnStrategyDependencies) *Strategy {
|
||||
return &Strategy{
|
||||
d: d.(webauthnStrategyDependencies),
|
||||
d: d,
|
||||
hd: decoderx.NewHTTP(),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -843,7 +843,7 @@ func (h *Handler) listMySessions(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
page, perPage := x.ParsePagination(r)
|
||||
sess, total, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), s.IdentityID, pointerx.Bool(true), page, perPage, s.ID, ExpandEverything)
|
||||
sess, total, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), s.IdentityID, pointerx.Ptr(true), page, perPage, s.ID, ExpandEverything)
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -18,31 +18,27 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
"github.com/ory/x/sqlxx"
|
||||
|
||||
"github.com/go-faker/faker/v4"
|
||||
"github.com/peterhellberg/link"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/ory/kratos/identity"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/peterhellberg/link"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/kratos/corpx"
|
||||
"github.com/ory/x/pagination/keysetpagination"
|
||||
"github.com/ory/x/sqlcon"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/ory/kratos/corpx"
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/internal"
|
||||
"github.com/ory/kratos/internal/testhelpers"
|
||||
. "github.com/ory/kratos/session"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
"github.com/ory/x/configx"
|
||||
"github.com/ory/x/ioutilx"
|
||||
"github.com/ory/x/pagination/keysetpagination"
|
||||
"github.com/ory/x/sqlcon"
|
||||
"github.com/ory/x/sqlxx"
|
||||
"github.com/ory/x/urlx"
|
||||
)
|
||||
|
||||
|
|
@ -57,6 +53,8 @@ func send(code int) http.HandlerFunc {
|
|||
}
|
||||
|
||||
func TestSessionWhoAmI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
ts, _, r, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
|
||||
ctx := context.Background()
|
||||
|
|
@ -99,17 +97,17 @@ func TestSessionWhoAmI(t *testing.T) {
|
|||
|
||||
t.Run("case=aal requirements", func(t *testing.T) {
|
||||
h1, _ := testhelpers.MockSessionCreateHandlerWithIdentityAndAMR(t, reg,
|
||||
createAAL2Identity(t, reg),
|
||||
newAAL2Identity(),
|
||||
[]identity.CredentialsType{identity.CredentialsTypePassword, identity.CredentialsTypeWebAuthn})
|
||||
r.GET("/set/aal2-aal2", h1)
|
||||
|
||||
h2, _ := testhelpers.MockSessionCreateHandlerWithIdentityAndAMR(t, reg,
|
||||
createAAL2Identity(t, reg),
|
||||
newAAL2Identity(),
|
||||
[]identity.CredentialsType{identity.CredentialsTypePassword})
|
||||
r.GET("/set/aal2-aal1", h2)
|
||||
|
||||
h3, _ := testhelpers.MockSessionCreateHandlerWithIdentityAndAMR(t, reg,
|
||||
createAAL1Identity(t, reg),
|
||||
newAAL1Identity(),
|
||||
[]identity.CredentialsType{identity.CredentialsTypePassword})
|
||||
r.GET("/set/aal1-aal1", h3)
|
||||
|
||||
|
|
@ -241,7 +239,7 @@ func TestSessionWhoAmI(t *testing.T) {
|
|||
setTokenizeConfig(conf, "es256", "jwk.es256.json", "")
|
||||
conf.MustSet(ctx, config.ViperKeySessionWhoAmICaching, true)
|
||||
|
||||
h3, _ := testhelpers.MockSessionCreateHandlerWithIdentityAndAMR(t, reg, createAAL1Identity(t, reg), []identity.CredentialsType{identity.CredentialsTypePassword})
|
||||
h3, _ := testhelpers.MockSessionCreateHandlerWithIdentityAndAMR(t, reg, newAAL1Identity(), []identity.CredentialsType{identity.CredentialsTypePassword})
|
||||
r.GET("/set/tokenize", h3)
|
||||
|
||||
client := testhelpers.NewClientWithCookies(t)
|
||||
|
|
@ -350,6 +348,8 @@ func TestSessionWhoAmI(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestIsNotAuthenticatedSecurecookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
r := x.NewRouterPublic(reg)
|
||||
|
|
@ -378,6 +378,8 @@ func TestIsNotAuthenticatedSecurecookie(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestIsNotAuthenticated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
r := x.NewRouterPublic(reg)
|
||||
|
|
@ -434,6 +436,8 @@ func TestIsNotAuthenticated(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestIsAuthenticated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
reg.WithCSRFHandler(new(nosurfx.FakeCSRFHandler))
|
||||
|
|
@ -487,14 +491,10 @@ func TestIsAuthenticated(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandlerAdminSessionManagement(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
_, ts, _, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
|
||||
t.Parallel()
|
||||
|
||||
// set this intermediate because kratos needs some valid url for CRUDE operations
|
||||
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, "http://example.com")
|
||||
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
|
||||
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, ts.URL)
|
||||
_, reg := internal.NewFastRegistryWithMocks(t, configx.WithValues(testhelpers.DefaultIdentitySchemaConfig("file://./stub/identity.schema.json")))
|
||||
public, ts := testhelpers.NewKratosServer(t, reg)
|
||||
|
||||
t.Run("case=should return 202 after invalidating all sessions", func(t *testing.T) {
|
||||
client := testhelpers.NewClientWithCookies(t)
|
||||
|
|
@ -505,7 +505,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
{Method: identity.CredentialsTypePassword, CompletedAt: time.Now().UTC().Round(time.Second)},
|
||||
{Method: identity.CredentialsTypeOIDC, CompletedAt: time.Now().UTC().Round(time.Second)},
|
||||
}
|
||||
require.NoError(t, reg.Persister().CreateIdentity(ctx, s.Identity))
|
||||
require.NoError(t, reg.Persister().CreateIdentity(t.Context(), s.Identity))
|
||||
|
||||
var expectedSessionDevice Device
|
||||
require.NoError(t, faker.FakeData(&expectedSessionDevice))
|
||||
|
|
@ -513,10 +513,10 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
expectedSessionDevice,
|
||||
}
|
||||
|
||||
assert.Equal(t, uuid.Nil, s.ID)
|
||||
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, s))
|
||||
assert.NotEqual(t, uuid.Nil, s.ID)
|
||||
assert.NotEqual(t, uuid.Nil, s.Identity.ID)
|
||||
assert.Zero(t, s.ID)
|
||||
require.NoError(t, reg.SessionPersister().UpsertSession(t.Context(), s))
|
||||
assert.NotZero(t, s.ID)
|
||||
assert.NotZero(t, s.Identity.ID)
|
||||
|
||||
t.Run("get session", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", ts.URL+"/admin/sessions/"+s.ID.String(), nil)
|
||||
|
|
@ -561,12 +561,13 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
req, _ := http.NewRequest("GET", ts.URL+"/admin/sessions/"+s.ID.String()+tc.expand, nil)
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body := ioutilx.MustReadAll(res.Body)
|
||||
require.Equalf(t, http.StatusOK, res.StatusCode, "%s", body)
|
||||
|
||||
assert.Equal(t, s.ID.String(), gjson.GetBytes(body, "id").String())
|
||||
assert.Equal(t, tc.expectedIdentityId, gjson.GetBytes(body, "identity.id").String())
|
||||
assert.Equal(t, fmt.Sprint(tc.expectedDevices), gjson.GetBytes(body, "devices.#").String())
|
||||
assert.EqualValuesf(t, tc.expectedDevices, gjson.GetBytes(body, "devices.#").Int(), "%s", gjson.GetBytes(body, "devices").Raw)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
|
@ -579,7 +580,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should redirect to public for whoami", func(t *testing.T) {
|
||||
client := testhelpers.NewHTTPClientWithSessionToken(t, ctx, reg, s)
|
||||
client := testhelpers.NewHTTPClientWithSessionToken(t, t.Context(), reg, s)
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
|
@ -588,7 +589,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode)
|
||||
require.Equal(t, ts.URL+"/sessions/whoami", res.Header.Get("Location"))
|
||||
require.Equal(t, public.URL+"/sessions/whoami", res.Header.Get("Location"))
|
||||
})
|
||||
|
||||
assertPageToken := func(t *testing.T, id, linkHeader string) {
|
||||
|
|
@ -709,7 +710,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
client := testhelpers.NewClientWithCookies(t)
|
||||
|
||||
s.ExpiresAt = time.Now().Add(-time.Hour * 1)
|
||||
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, s))
|
||||
require.NoError(t, reg.SessionPersister().UpsertSession(t.Context(), s))
|
||||
|
||||
assert.NotEqual(t, uuid.Nil, s.ID)
|
||||
assert.NotEqual(t, uuid.Nil, s.Identity.ID)
|
||||
|
|
@ -730,10 +731,10 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
require.NoError(t, faker.FakeData(&s1))
|
||||
s1.Active = true
|
||||
s1.Identity.State = identity.StateInactive
|
||||
require.NoError(t, reg.Persister().CreateIdentity(ctx, s1.Identity))
|
||||
require.NoError(t, reg.Persister().CreateIdentity(t.Context(), s1.Identity))
|
||||
|
||||
assert.Equal(t, uuid.Nil, s1.ID)
|
||||
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, s1))
|
||||
require.NoError(t, reg.SessionPersister().UpsertSession(t.Context(), s1))
|
||||
assert.NotEqual(t, uuid.Nil, s1.ID)
|
||||
assert.NotEqual(t, uuid.Nil, s1.Identity.ID)
|
||||
|
||||
|
|
@ -752,7 +753,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusNoContent, res.StatusCode)
|
||||
|
||||
_, err = reg.SessionPersister().GetSession(ctx, s.ID, ExpandNothing)
|
||||
_, err = reg.SessionPersister().GetSession(t.Context(), s.ID, ExpandNothing)
|
||||
require.True(t, errors.Is(err, sqlcon.ErrNoRows))
|
||||
|
||||
t.Run("should not list session", func(t *testing.T) {
|
||||
|
|
@ -790,7 +791,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
client := testhelpers.NewClientWithCookies(t)
|
||||
var i *identity.Identity
|
||||
require.NoError(t, faker.FakeData(&i))
|
||||
require.NoError(t, reg.Persister().CreateIdentity(ctx, i))
|
||||
require.NoError(t, reg.Persister().CreateIdentity(t.Context(), i))
|
||||
|
||||
numSessions := 5
|
||||
numSessionsActive := 2
|
||||
|
|
@ -806,7 +807,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
sess[j].Active = false
|
||||
sess[j].ExpiresAt = time.Now().UTC().Add(-time.Hour)
|
||||
}
|
||||
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, &sess[j]))
|
||||
require.NoError(t, reg.SessionPersister().UpsertSession(t.Context(), &sess[j]))
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
|
|
@ -827,7 +828,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
},
|
||||
} {
|
||||
t.Run(fmt.Sprintf("active=%#v", tc.activeOnly), func(t *testing.T) {
|
||||
sessions, _, _ := reg.SessionPersister().ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, ExpandEverything)
|
||||
sessions, _, _ := reg.SessionPersister().ListSessionsByIdentity(t.Context(), i.ID, nil, 1, 10, uuid.Nil, ExpandEverything)
|
||||
require.Equal(t, 5, len(sessions))
|
||||
assert.True(t, sort.IsSorted(sort.Reverse(byCreatedAt(sessions))))
|
||||
|
||||
|
|
@ -855,6 +856,8 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandlerSelfServiceSessionManagement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
ts, _, r, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
|
||||
|
|
@ -1044,6 +1047,8 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandlerRefreshSessionBySessionID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
publicServer, adminServer, _, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ory/kratos/driver"
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/internal"
|
||||
|
|
@ -35,35 +34,22 @@ type mockCSRFHandler struct {
|
|||
c int
|
||||
}
|
||||
|
||||
func (f *mockCSRFHandler) DisablePath(s string) {
|
||||
}
|
||||
func (f *mockCSRFHandler) DisablePath(string) {}
|
||||
func (f *mockCSRFHandler) DisableGlob(string) {}
|
||||
func (f *mockCSRFHandler) DisableGlobs(...string) {}
|
||||
func (f *mockCSRFHandler) IgnoreGlob(string) {}
|
||||
func (f *mockCSRFHandler) IgnoreGlobs(...string) {}
|
||||
func (f *mockCSRFHandler) ExemptPath(string) {}
|
||||
func (f *mockCSRFHandler) IgnorePath(string) {}
|
||||
func (f *mockCSRFHandler) ServeHTTP(http.ResponseWriter, *http.Request) {}
|
||||
|
||||
func (f *mockCSRFHandler) DisableGlob(s string) {
|
||||
}
|
||||
|
||||
func (f *mockCSRFHandler) DisableGlobs(s ...string) {
|
||||
}
|
||||
|
||||
func (f *mockCSRFHandler) IgnoreGlob(s string) {
|
||||
}
|
||||
|
||||
func (f *mockCSRFHandler) IgnoreGlobs(s ...string) {
|
||||
}
|
||||
|
||||
func (f *mockCSRFHandler) ExemptPath(s string) {}
|
||||
|
||||
func (f *mockCSRFHandler) IgnorePath(s string) {}
|
||||
|
||||
func (f *mockCSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (f *mockCSRFHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) string {
|
||||
func (f *mockCSRFHandler) RegenerateToken(_ http.ResponseWriter, _ *http.Request) string {
|
||||
f.c++
|
||||
return nosurfx.FakeCSRFToken
|
||||
}
|
||||
|
||||
func createAAL2Identity(t *testing.T, reg driver.Registry) *identity.Identity {
|
||||
idAAL2 := identity.Identity{
|
||||
func newAAL2Identity() *identity.Identity {
|
||||
return &identity.Identity{
|
||||
SchemaID: "default",
|
||||
Traits: []byte("{}"),
|
||||
State: identity.StateActive,
|
||||
|
|
@ -80,11 +66,10 @@ func createAAL2Identity(t *testing.T, reg driver.Registry) *identity.Identity {
|
|||
},
|
||||
},
|
||||
}
|
||||
return &idAAL2
|
||||
}
|
||||
|
||||
func createAAL1Identity(t *testing.T, reg driver.Registry) *identity.Identity {
|
||||
idAAL1 := identity.Identity{
|
||||
func newAAL1Identity() *identity.Identity {
|
||||
return &identity.Identity{
|
||||
SchemaID: "default",
|
||||
Traits: []byte("{}"),
|
||||
State: identity.StateActive,
|
||||
|
|
@ -96,7 +81,6 @@ func createAAL1Identity(t *testing.T, reg driver.Registry) *identity.Identity {
|
|||
},
|
||||
},
|
||||
}
|
||||
return &idAAL1
|
||||
}
|
||||
|
||||
func TestManagerHTTP(t *testing.T) {
|
||||
|
|
@ -474,8 +458,8 @@ func TestManagerHTTP(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("identity available AAL is not hydrated", func(t *testing.T) {
|
||||
idAAL2 := createAAL2Identity(t, reg)
|
||||
idAAL1 := createAAL1Identity(t, reg)
|
||||
idAAL2 := newAAL2Identity()
|
||||
idAAL1 := newAAL1Identity()
|
||||
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), idAAL1))
|
||||
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), idAAL2))
|
||||
test(t, idAAL1, idAAL2)
|
||||
|
|
@ -484,7 +468,7 @@ func TestManagerHTTP(t *testing.T) {
|
|||
t.Run("identity available AAL is hydrated and updated in the DB", func(t *testing.T) {
|
||||
// We do not create the identity in the database, proving that we do not need
|
||||
// to do any DB roundtrips in this case.
|
||||
idAAL1 := createAAL2Identity(t, reg)
|
||||
idAAL1 := newAAL2Identity()
|
||||
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), idAAL1))
|
||||
|
||||
s := session.NewInactiveSession()
|
||||
|
|
@ -500,10 +484,10 @@ func TestManagerHTTP(t *testing.T) {
|
|||
t.Run("identity available AAL is hydrated without DB", func(t *testing.T) {
|
||||
// We do not create the identity in the database, proving that we do not need
|
||||
// to do any DB roundtrips in this case.
|
||||
idAAL2 := createAAL2Identity(t, reg)
|
||||
idAAL2 := newAAL2Identity()
|
||||
idAAL2.InternalAvailableAAL = identity.NewNullableAuthenticatorAssuranceLevel(identity.AuthenticatorAssuranceLevel2)
|
||||
|
||||
idAAL1 := createAAL1Identity(t, reg)
|
||||
idAAL1 := newAAL1Identity()
|
||||
idAAL1.InternalAvailableAAL = identity.NewNullableAuthenticatorAssuranceLevel(identity.AuthenticatorAssuranceLevel1)
|
||||
|
||||
test(t, idAAL1, idAAL2)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
|
||||
"github.com/ory/kratos/session"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1287,6 +1287,10 @@
|
|||
"type": "object"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"id",
|
||||
"schema"
|
||||
],
|
||||
"type": "object"
|
||||
},
|
||||
"identitySchemas": {
|
||||
|
|
|
|||
|
|
@ -4636,6 +4636,10 @@
|
|||
"identitySchemaContainer": {
|
||||
"description": "An Identity JSON Schema Container",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"schema"
|
||||
],
|
||||
"properties": {
|
||||
"id": {
|
||||
"description": "The ID of the Identity JSON Schema",
|
||||
|
|
|
|||
|
|
@ -11,16 +11,14 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ory/kratos/x/redir"
|
||||
|
||||
"github.com/ory/x/configx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/internal"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/kratos/x/redir"
|
||||
"github.com/ory/x/configx"
|
||||
)
|
||||
|
||||
func TestRedirectToPublicAdminRoute(t *testing.T) {
|
||||
|
|
@ -38,14 +36,12 @@ func TestRedirectToPublicAdminRoute(t *testing.T) {
|
|||
pub.POST("/privileged", redir.RedirectToAdminRoute(reg))
|
||||
pub.POST("/admin/privileged", redir.RedirectToAdminRoute(reg))
|
||||
adm.POST("/privileged", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_, _ = w.Write(body)
|
||||
_, _ = io.Copy(w, r.Body)
|
||||
})
|
||||
|
||||
adm.POST("/read", redir.RedirectToPublicRoute(reg))
|
||||
pub.POST("/read", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_, _ = w.Write(body)
|
||||
_, _ = io.Copy(w, r.Body)
|
||||
})
|
||||
|
||||
for k, tc := range []struct {
|
||||
|
|
|
|||
Loading…
Reference in New Issue