// Copyright © 2023 Ory Corp // SPDX-License-Identifier: Apache-2.0 package testhelpers import ( "context" "errors" "io" "net/http" "net/url" "testing" "github.com/gofrs/uuid" "github.com/go-faker/faker/v4" "github.com/gobuffalo/httptest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" "github.com/ory/kratos/driver" "github.com/ory/kratos/driver/config" "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/flow/settings" ) func TestSelfServicePreHook( configKey string, makeRequestPre func(t *testing.T, ts *httptest.Server) (*http.Response, string), newServer func(t *testing.T) *httptest.Server, conf *config.Config, ) func(t *testing.T) { ctx := context.Background() return func(t *testing.T) { t.Run("case=pass without hooks", func(t *testing.T) { t.Cleanup(SelfServiceHookConfigReset(t, conf)) res, _ := makeRequestPre(t, newServer(t)) assert.EqualValues(t, http.StatusOK, res.StatusCode) }) t.Run("case=pass if hooks pass", func(t *testing.T) { t.Cleanup(SelfServiceHookConfigReset(t, conf)) conf.MustSet(ctx, configKey, []config.SelfServiceHook{{Name: "err", Config: []byte(`{}`)}}) res, _ := makeRequestPre(t, newServer(t)) assert.EqualValues(t, http.StatusOK, res.StatusCode) }) t.Run("case=err if hooks err", func(t *testing.T) { t.Cleanup(SelfServiceHookConfigReset(t, conf)) conf.MustSet(ctx, configKey, []config.SelfServiceHook{{Name: "err", Config: []byte(`{"ExecuteLoginPreHook": "err","ExecuteRegistrationPreHook": "err","ExecuteSettingsPreHook": "err","ExecuteVerificationPreHook": "err","ExecuteRecoveryPreHook": "err"}`)}}) res, body := makeRequestPre(t, newServer(t)) assert.EqualValues(t, http.StatusInternalServerError, res.StatusCode, "%s", body) assert.EqualValues(t, "err", body) }) t.Run("case=abort if hooks aborts", func(t *testing.T) { t.Cleanup(SelfServiceHookConfigReset(t, conf)) conf.MustSet(ctx, configKey, []config.SelfServiceHook{{Name: "err", Config: []byte(`{"ExecuteLoginPreHook": "abort","ExecuteRegistrationPreHook": "abort","ExecuteSettingsPreHook": "abort","ExecuteVerificationPreHook": "abort","ExecuteRecoveryPreHook": "abort"}`)}}) res, body := makeRequestPre(t, newServer(t)) assert.EqualValues(t, http.StatusOK, res.StatusCode) assert.Empty(t, body) }) t.Run("case=redirect", func(t *testing.T) { t.Skipf("Skipped because pre-redirect is no longer supported") t.Cleanup(SelfServiceHookConfigReset(t, conf)) conf.MustSet(ctx, configKey, []config.SelfServiceHook{{Name: "redirect", Config: []byte(`{"to": "https://www.ory.sh/"}`)}}) res, _ := makeRequestPre(t, newServer(t)) assert.EqualValues(t, http.StatusOK, res.StatusCode) assert.EqualValues(t, "https://www.ory.sh/", res.Request.URL.String()) }) } } func SelfServiceHookCreateFakeIdentity(t *testing.T, reg driver.Registry) *identity.Identity { i := SelfServiceHookFakeIdentity(t) require.NoError(t, reg.IdentityManager().Create(context.Background(), i)) return i } func SelfServiceHookFakeIdentity(t *testing.T) *identity.Identity { var i identity.Identity require.NoError(t, faker.FakeData(&i)) i.Traits = identity.Traits(`{}`) i.State = identity.StateActive i.NID = uuid.Must(uuid.NewV4()) return &i } func SelfServiceHookConfigReset(t *testing.T, conf *config.Config) func() { ctx := context.Background() return func() { conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter, nil) conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter+".hooks", nil) conf.MustSet(ctx, config.ViperKeySelfServiceLoginBeforeHooks, nil) conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryAfter, nil) conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryAfter+".hooks", nil) conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter, nil) conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+".hooks", nil) conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationBeforeHooks, nil) conf.MustSet(ctx, config.ViperKeySelfServiceSettingsAfter, nil) conf.MustSet(ctx, config.ViperKeySelfServiceSettingsAfter+".hooks", nil) } } func SelfServiceHookSettingsSetDefaultRedirectTo(t *testing.T, conf *config.Config, value string) { ctx := context.Background() conf.MustSet(ctx, config.ViperKeySelfServiceSettingsAfter+"."+config.DefaultBrowserReturnURL, value) } func SelfServiceHookSettingsSetDefaultRedirectToStrategy(t *testing.T, conf *config.Config, strategy, value string) { ctx := context.Background() conf.MustSet(ctx, config.ViperKeySelfServiceSettingsAfter+"."+strategy+"."+config.DefaultBrowserReturnURL, value) } func SelfServiceHookLoginSetDefaultRedirectTo(t *testing.T, conf *config.Config, value string) { ctx := context.Background() conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter+"."+config.DefaultBrowserReturnURL, value) } func SelfServiceHookLoginSetDefaultRedirectToStrategy(t *testing.T, conf *config.Config, strategy, value string) { ctx := context.Background() conf.MustSet(ctx, config.ViperKeySelfServiceLoginAfter+"."+strategy+"."+config.DefaultBrowserReturnURL, value) } func SelfServiceHookRegistrationSetDefaultRedirectTo(t *testing.T, conf *config.Config, value string) { ctx := context.Background() conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+"."+config.DefaultBrowserReturnURL, value) } func SelfServiceHookRegistrationSetDefaultRedirectToStrategy(t *testing.T, conf *config.Config, strategy, value string) { ctx := context.Background() conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+"."+strategy+"."+config.DefaultBrowserReturnURL, value) } func SelfServiceHookLoginViperSetPost(t *testing.T, conf *config.Config, strategy string, c []config.SelfServiceHook) { ctx := context.Background() conf.MustSet(ctx, config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, strategy), c) } func SelfServiceHookRegistrationViperSetPost(t *testing.T, conf *config.Config, strategy string, c []config.SelfServiceHook) { ctx := context.Background() conf.MustSet(ctx, config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, strategy), c) } func SelfServiceHookLoginErrorHandler(t *testing.T, w http.ResponseWriter, r *http.Request, err error) bool { return SelfServiceHookErrorHandler(t, w, r, login.ErrHookAbortFlow, err) } func SelfServiceHookRegistrationErrorHandler(t *testing.T, w http.ResponseWriter, r *http.Request, err error) bool { return SelfServiceHookErrorHandler(t, w, r, registration.ErrHookAbortFlow, err) } func SelfServiceHookSettingsErrorHandler(t *testing.T, w http.ResponseWriter, r *http.Request, err error) bool { return SelfServiceHookErrorHandler(t, w, r, settings.ErrHookAbortFlow, err) } func SelfServiceHookErrorHandler(t *testing.T, w http.ResponseWriter, _ *http.Request, abortErr error, actualErr error) bool { if actualErr != nil { t.Logf("received error: %+v", actualErr) if errors.Is(actualErr, abortErr) { return false } w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte(actualErr.Error())) return false } return true } func SelfServiceMakeLoginPreHookRequest(t *testing.T, ts *httptest.Server) (*http.Response, string) { return SelfServiceMakeHookRequest(t, ts, "/login/pre", false, url.Values{}) } func SelfServiceMakeLoginPostHookRequest(t *testing.T, ts *httptest.Server, asAPI bool, query url.Values) (*http.Response, string) { return SelfServiceMakeHookRequest(t, ts, "/login/post", asAPI, query) } func SelfServiceMakeRegistrationPreHookRequest(t *testing.T, ts *httptest.Server) (*http.Response, string) { return SelfServiceMakeHookRequest(t, ts, "/registration/pre", false, url.Values{}) } func SelfServiceMakeSettingsPreHookRequest(t *testing.T, ts *httptest.Server) (*http.Response, string) { return SelfServiceMakeHookRequest(t, ts, "/settings/pre", false, url.Values{}) } func SelfServiceMakeRecoveryPreHookRequest(t *testing.T, ts *httptest.Server) (*http.Response, string) { return SelfServiceMakeHookRequest(t, ts, "/recovery/pre", false, url.Values{}) } func SelfServiceMakeVerificationPreHookRequest(t *testing.T, ts *httptest.Server) (*http.Response, string) { return SelfServiceMakeHookRequest(t, ts, "/verification/pre", false, url.Values{}) } func SelfServiceMakeRegistrationPostHookRequest(t *testing.T, ts *httptest.Server, asAPI bool, query url.Values) (*http.Response, string) { return SelfServiceMakeHookRequest(t, ts, "/registration/post", asAPI, query) } func SelfServiceMakeSettingsPostHookRequest(t *testing.T, ts *httptest.Server, asAPI bool, query url.Values) (*http.Response, string) { return SelfServiceMakeHookRequest(t, ts, "/settings/post", asAPI, query) } func SelfServiceMakeHookRequest(t *testing.T, ts *httptest.Server, suffix string, asAPI bool, query url.Values) (*http.Response, string) { if len(query) > 0 { suffix += "?" + query.Encode() } req, err := http.NewRequest("GET", ts.URL+suffix, nil) require.NoError(t, err) req.Header.Set("Accept", "text/html") if asAPI { req.Header.Set("Accept", "application/json") } res, err := ts.Client().Do(req) require.NoError(t, err) defer func() { _ = res.Body.Close() }() body, err := io.ReadAll(res.Body) require.NoError(t, err) return res, string(body) } func GetSelfServiceRedirectLocation(t *testing.T, url string) string { c := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } req, err := http.NewRequest("GET", url, nil) require.NoError(t, err) res, err := c.Do(req) require.NoError(t, err) defer func() { _ = res.Body.Close() }() return res.Header.Get("Location") } func AssertMessage(t *testing.T, body []byte, message string) { t.Helper() assert.Len(t, gjson.GetBytes(body, "ui.messages").Array(), 1) assert.Equal(t, message, gjson.GetBytes(body, "ui.messages.0.text").String(), "%v", string(body)) } func AssertFieldMessage(t *testing.T, body []byte, fieldName string, message string) { t.Helper() messages := gjson.GetBytes(body, "ui.nodes.#(attributes.name=="+fieldName+").messages") assert.Len(t, messages.Array(), 1, "expected field %s to have one message, got %s", fieldName, messages) assert.Equal(t, message, messages.Get("0.text").String(), "%v", string(body)) }