feat: use stdlib HTTP router in Kratos

GitOrigin-RevId: 799513e99acbf43a05fe3113ffda45d2fff2a9e0
This commit is contained in:
Henning Perl 2025-07-23 13:12:42 +02:00 committed by ory-bot
parent 9e94951d9e
commit acfa6ef2ec
76 changed files with 468 additions and 512 deletions

View File

@ -7,6 +7,7 @@ import (
"context"
"net/http"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/cobra"
"github.com/urfave/negroni"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
@ -14,9 +15,9 @@ import (
"github.com/ory/graceful"
"github.com/ory/kratos/driver"
"github.com/ory/kratos/x"
"github.com/ory/x/configx"
"github.com/ory/x/otelx"
"github.com/ory/x/prometheusx"
"github.com/ory/x/reqlog"
)
@ -58,9 +59,9 @@ func ServeMetrics(ctx context.Context, r driver.Registry, port int) error {
l := r.Logger()
n := negroni.New()
router := x.NewRouterAdmin()
router := http.NewServeMux()
r.MetricsHandler().SetRoutes(router.Router)
router.Handle(prometheusx.MetricsPrometheusPath, promhttp.Handler())
n.Use(reqlog.NewMiddlewareFromLogger(l, "admin#"+cfg.BaseURL.String()))
n.Use(r.PrometheusManager())

View File

@ -94,7 +94,6 @@ func servePublic(ctx context.Context, r *driver.RegistryDefault, cmd *cobra.Comm
csrf.DisablePath(prometheus.MetricsPrometheusPath)
r.RegisterPublicRoutes(ctx, router)
r.PrometheusManager().RegisterRouter(router.Router)
var handler http.Handler = n
if tracer := r.Tracer(ctx); tracer.IsLoaded() {
@ -163,7 +162,6 @@ func serveAdmin(ctx context.Context, r *driver.RegistryDefault, cmd *cobra.Comma
router := x.NewRouterAdmin()
r.RegisterAdminRoutes(ctx, router)
r.PrometheusManager().RegisterRouter(router.Router)
n.UseHandler(http.MaxBytesHandler(router, 5*1024*1024 /* 5 MB */))

View File

@ -20,7 +20,6 @@ import (
"github.com/ory/x/ioutilx"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -56,22 +55,22 @@ func TestManager(t *testing.T) {
newServer := func(t *testing.T, p continuity.Manager, tc *persisterTestCase) *httptest.Server {
writer := herodot.NewJSONWriter(logrusx.New("", ""))
router := httprouter.New()
router.PUT("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if err := p.Pause(r.Context(), w, r, ps.ByName("name"), tc.ro...); err != nil {
router := http.NewServeMux()
router.HandleFunc("PUT /{name}", func(w http.ResponseWriter, r *http.Request) {
if err := p.Pause(r.Context(), w, r, r.PathValue("name"), tc.ro...); err != nil {
writer.WriteError(w, r, err)
return
}
w.WriteHeader(http.StatusNoContent)
})
router.POST("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if err := p.Pause(r.Context(), w, r, ps.ByName("name"), tc.ro...); err != nil {
router.HandleFunc("POST /{name}", func(w http.ResponseWriter, r *http.Request) {
if err := p.Pause(r.Context(), w, r, r.PathValue("name"), tc.ro...); err != nil {
writer.WriteError(w, r, err)
return
}
c, err := p.Continue(r.Context(), w, r, ps.ByName("name"), tc.wo...)
c, err := p.Continue(r.Context(), w, r, r.PathValue("name"), tc.wo...)
if err != nil {
writer.WriteError(w, r, err)
return
@ -79,8 +78,8 @@ func TestManager(t *testing.T) {
writer.Write(w, r, c)
})
router.GET("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
c, err := p.Continue(r.Context(), w, r, ps.ByName("name"), tc.ro...)
router.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
c, err := p.Continue(r.Context(), w, r, r.PathValue("name"), tc.ro...)
if err != nil {
writer.WriteError(w, r, err)
return
@ -88,8 +87,8 @@ func TestManager(t *testing.T) {
writer.Write(w, r, c)
})
router.DELETE("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
err := p.Abort(r.Context(), w, r, ps.ByName("name"))
router.HandleFunc("DELETE /{name}", func(w http.ResponseWriter, r *http.Request) {
err := p.Abort(r.Context(), w, r, r.PathValue("name"))
if err != nil {
writer.WriteError(w, r, err)
return

View File

@ -16,8 +16,6 @@ import (
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/pagination/migrationpagination"
"github.com/julienschmidt/httprouter"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/x"
)
@ -25,7 +23,7 @@ import (
const (
AdminRouteCourier = "/courier"
AdminRouteListMessages = AdminRouteCourier + "/messages"
AdminRouteGetMessage = AdminRouteCourier + "/messages/:msgID"
AdminRouteGetMessage = AdminRouteCourier + "/messages/{msgID}"
)
type (
@ -113,7 +111,7 @@ type ListCourierMessagesParameters struct {
// 200: listCourierMessages
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) listCourierMessages(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) listCourierMessages(w http.ResponseWriter, r *http.Request) {
filter, paginator, err := parseMessagesFilter(r)
if err != nil {
h.r.Writer().WriteErrorCode(w, r, http.StatusBadRequest, err)
@ -193,10 +191,10 @@ type getCourierMessage struct {
// 200: message
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) getCourierMessage(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
msgID, err := uuid.FromString(ps.ByName("msgID"))
func (h *Handler) getCourierMessage(w http.ResponseWriter, r *http.Request) {
msgID, err := uuid.FromString(r.PathValue("msgID"))
if err != nil {
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebugf("could not parse parameter {id} as UUID, got %s", ps.ByName("id")))
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebugf("could not parse parameter {id} as UUID, got %s", r.PathValue("id")))
return
}

View File

@ -14,8 +14,6 @@ import (
"testing"
"time"
"github.com/julienschmidt/httprouter"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/internal"
@ -145,11 +143,11 @@ func TestLoadTextTemplate(t *testing.T) {
})
t.Run("case=http resource", func(t *testing.T) {
router := httprouter.New()
router.Handle("GET", "/html", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /html", func(writer http.ResponseWriter, request *http.Request) {
http.ServeFile(writer, request, "courier/builtin/templates/test_stub/email.body.html.en_US.gotmpl")
})
router.Handle("GET", "/plaintext", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
router.HandleFunc("GET /plaintext", func(writer http.ResponseWriter, request *http.Request) {
http.ServeFile(writer, request, "courier/builtin/templates/test_stub/email.body.plaintext.gotmpl")
})
ts := httptest.NewServer(router)

View File

@ -14,7 +14,6 @@ import (
"github.com/ory/kratos/courier/template/email"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -94,9 +93,9 @@ func TestRemoteTemplates(t *testing.T, basePath string, tmplType template.Templa
t.Run("case=http resource", func(t *testing.T) {
t.Parallel()
router := httprouter.New()
router.Handle("GET", "/:filename", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
http.ServeFile(writer, request, path.Join(basePath, params.ByName("filename")))
router := http.NewServeMux()
router.HandleFunc("GET /{filename}", func(writer http.ResponseWriter, request *http.Request) {
http.ServeFile(writer, request, path.Join(basePath, request.PathValue("filename")))
})
ts := httptest.NewServer(router)
defer ts.Close()

View File

@ -8,16 +8,15 @@ import (
"fmt"
"net/http"
"github.com/julienschmidt/httprouter"
"github.com/knadh/koanf/parsers/json"
)
type router interface {
GET(path string, handle httprouter.Handle)
HandlerFunc(method, path string, handler http.HandlerFunc)
}
func NewConfigHashHandler(c Provider, router router) {
router.GET("/health/config", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandlerFunc("GET", "/health/config", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
if revision := c.Config().GetProvider(r.Context()).String("revision"); len(revision) > 0 {
_, _ = fmt.Fprintf(w, "%s", revision)

View File

@ -8,12 +8,12 @@ import (
"io"
"testing"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/x"
"github.com/ory/x/contextx"
)
@ -28,7 +28,7 @@ func (c *configProvider) Config() *config.Config {
func TestNewConfigHashHandler(t *testing.T) {
ctx := context.Background()
cfg := internal.NewConfigurationWithDefaults(t)
router := httprouter.New()
router := x.NewRouterPublic()
config.NewConfigHashHandler(&configProvider{cfg: cfg}, router)
ts := contextx.NewConfigurableTestServer(router)
t.Cleanup(ts.Close)

View File

@ -197,7 +197,7 @@ func (m *RegistryDefault) RegisterPublicRoutes(ctx context.Context, router *x.Ro
m.VerificationHandler().RegisterPublicRoutes(router)
m.AllVerificationStrategies().RegisterPublicRoutes(router)
m.HealthHandler(ctx).SetHealthRoutes(router.Router, false)
m.HealthHandler(ctx).SetHealthRoutes(router, false)
}
func (m *RegistryDefault) RegisterAdminRoutes(ctx context.Context, router *x.RouterAdmin) {
@ -222,7 +222,7 @@ func (m *RegistryDefault) RegisterAdminRoutes(ctx context.Context, router *x.Rou
m.HealthHandler(ctx).SetHealthRoutes(router, true)
m.HealthHandler(ctx).SetVersionRoutes(router)
m.MetricsHandler().SetRoutes(router)
m.MetricsHandler().SetMuxRoutes(router)
config.NewConfigHashHandler(m, router)
}

View File

@ -30,7 +30,6 @@ import (
"github.com/ory/herodot"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/ory/x/decoderx"
@ -44,8 +43,8 @@ import (
const (
RouteCollection = "/identities"
RouteItem = RouteCollection + "/:id"
RouteCredentialItem = RouteItem + "/credentials/:type"
RouteItem = RouteCollection + "/{id}"
RouteCredentialItem = RouteItem + "/credentials/{type}"
BatchPatchIdentitiesLimit = 1000
BatchPatchIdentitiesWithPasswordLimit = 200
@ -266,7 +265,7 @@ func parseListIdentitiesParameters(r *http.Request) (params ListIdentityParamete
// Responses:
// 200: listIdentities
// default: errorGeneric
func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) list(w http.ResponseWriter, r *http.Request) {
params, err := parseListIdentitiesParameters(r)
if err != nil {
h.r.Writer().WriteError(w, r, err)
@ -355,8 +354,8 @@ type getIdentity struct {
// 200: identity
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) get(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
i, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), x.ParseUUID(ps.ByName("id")))
func (h *Handler) get(w http.ResponseWriter, r *http.Request) {
i, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), x.ParseUUID(r.PathValue("id")))
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
@ -577,7 +576,7 @@ type AdminCreateIdentityImportCredentialsSAMLProvider struct {
// 400: errorGeneric
// 409: errorGeneric
// default: errorGeneric
func (h *Handler) create(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) create(w http.ResponseWriter, r *http.Request) {
var cr CreateIdentityBody
if err := jsonx.NewStrictDecoder(r.Body).Decode(&cr); err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error())))
@ -688,7 +687,7 @@ func (h *Handler) identityFromCreateIdentityBody(ctx context.Context, cr *Create
// 400: errorGeneric
// 409: errorGeneric
// default: errorGeneric
func (h *Handler) batchPatchIdentities(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) batchPatchIdentities(w http.ResponseWriter, r *http.Request) {
var (
req BatchPatchIdentitiesBody
res batchPatchIdentitiesResponse
@ -845,7 +844,7 @@ type UpdateIdentityBody struct {
// 404: errorGeneric
// 409: errorGeneric
// default: errorGeneric
func (h *Handler) update(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) update(w http.ResponseWriter, r *http.Request) {
var ur UpdateIdentityBody
if err := h.dx.Decode(r, &ur,
decoderx.HTTPJSONDecoder()); err != nil {
@ -853,7 +852,7 @@ func (h *Handler) update(w http.ResponseWriter, r *http.Request, ps httprouter.P
return
}
id := x.ParseUUID(ps.ByName("id"))
id := x.ParseUUID(r.PathValue("id"))
identity, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), id)
if err != nil {
h.r.Writer().WriteError(w, r, err)
@ -933,8 +932,8 @@ type deleteIdentity struct {
// 204: emptyResponse
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) delete(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if err := h.r.PrivilegedIdentityPool().DeleteIdentity(r.Context(), x.ParseUUID(ps.ByName("id"))); err != nil {
func (h *Handler) delete(w http.ResponseWriter, r *http.Request) {
if err := h.r.PrivilegedIdentityPool().DeleteIdentity(r.Context(), x.ParseUUID(r.PathValue("id"))); err != nil {
h.r.Writer().WriteError(w, r, err)
return
}
@ -983,14 +982,14 @@ type patchIdentity struct {
// 404: errorGeneric
// 409: errorGeneric
// default: errorGeneric
func (h *Handler) patch(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) patch(w http.ResponseWriter, r *http.Request) {
requestBody, err := io.ReadAll(r.Body)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}
id := x.ParseUUID(ps.ByName("id"))
id := x.ParseUUID(r.PathValue("id"))
identity, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), id)
if err != nil {
h.r.Writer().WriteError(w, r, err)
@ -1095,17 +1094,17 @@ type _ struct {
// 204: emptyResponse
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) deleteIdentityCredentials(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) deleteIdentityCredentials(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
identity, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(ctx, x.ParseUUID(ps.ByName("id")))
identity, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(ctx, x.ParseUUID(r.PathValue("id")))
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}
cred, ok := identity.GetCredentials(CredentialsType(ps.ByName("type")))
cred, ok := identity.GetCredentials(CredentialsType(r.PathValue("type")))
if !ok {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("You tried to remove a %s but this user have no %s set up.", ps.ByName("type"), ps.ByName("type"))))
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("You tried to remove a %s but this user have no %s set up.", r.PathValue("type"), r.PathValue("type"))))
return
}

View File

@ -19,7 +19,6 @@ import (
"github.com/ory/x/httpx"
"github.com/golang/mock/gomock"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/require"
"github.com/ory/kratos/driver/config"
@ -38,10 +37,10 @@ func TestSchemaValidatorDisallowsInternalNetworkRequests(t *testing.T) {
v := NewValidator(reg)
n := negroni.New(x.HTTPLoaderContextMiddleware(reg))
router := httprouter.New()
router.GET("/:id", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /{id}", func(w http.ResponseWriter, r *http.Request) {
i := &Identity{
SchemaID: ps.ByName("id"),
SchemaID: r.PathValue("id"),
Traits: Traits(`{ "firstName": "first-name", "lastName": "last-name", "age": 1 }`),
}
_, _ = w.Write([]byte(fmt.Sprintf("%+v", v.Validate(r.Context(), i))))
@ -75,8 +74,8 @@ func TestSchemaValidator(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
router := httprouter.New()
router.GET("/schema/:name", func(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /schema/{name}", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{
"$id": "https://example.com/person.schema.json",
"$schema": "http://json-schema.org/draft-07/schema#",
@ -86,7 +85,7 @@ func TestSchemaValidator(t *testing.T) {
"traits": {
"type": "object",
"properties": {
"` + ps.ByName("name") + `": {
"` + r.PathValue("name") + `": {
"type": "string",
"description": "The person's first name."
},

View File

@ -14,7 +14,6 @@ import (
"github.com/go-faker/faker/v4"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -33,8 +32,8 @@ type mockDeps interface {
config.Provider
}
func MockSetSession(t *testing.T, reg mockDeps, conf *config.Config) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func MockSetSession(t *testing.T, reg mockDeps, conf *config.Config) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
i.NID = uuid.Must(uuid.NewV4())
require.NoError(t, i.SetCredentialsWithConfig(
@ -46,12 +45,12 @@ func MockSetSession(t *testing.T, reg mockDeps, conf *config.Config) httprouter.
json.RawMessage(`{"hashed_password":"$"}`)))
require.NoError(t, reg.IdentityManager().Create(context.Background(), i))
MockSetSessionWithIdentity(t, reg, conf, i)(w, r, ps)
MockSetSessionWithIdentity(t, reg, conf, i)(w, r)
}
}
func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, _ *config.Config, i *identity.Identity) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, _ *config.Config, i *identity.Identity) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
activeSession, err := NewActiveSession(r, reg, i, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
require.NoError(t, err)
if aal := r.URL.Query().Get("set_aal"); len(aal) > 0 {
@ -63,20 +62,24 @@ func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, _ *config.Config, i
}
}
func MockMakeAuthenticatedRequest(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request) ([]byte, *http.Response) {
type router interface {
HandleFunc(pattern string, handler http.HandlerFunc)
}
func MockMakeAuthenticatedRequest(t *testing.T, reg mockDeps, conf *config.Config, router router, req *http.Request) ([]byte, *http.Response) {
return MockMakeAuthenticatedRequestWithClient(t, reg, conf, router, req, NewClientWithCookies(t))
}
func MockMakeAuthenticatedRequestWithClient(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request, client *http.Client) ([]byte, *http.Response) {
func MockMakeAuthenticatedRequestWithClient(t *testing.T, reg mockDeps, conf *config.Config, router router, req *http.Request, client *http.Client) ([]byte, *http.Response) {
return MockMakeAuthenticatedRequestWithClientAndID(t, reg, conf, router, req, client, nil)
}
func MockMakeAuthenticatedRequestWithClientAndID(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request, client *http.Client, id *identity.Identity) ([]byte, *http.Response) {
func MockMakeAuthenticatedRequestWithClientAndID(t *testing.T, reg mockDeps, conf *config.Config, router router, req *http.Request, client *http.Client, id *identity.Identity) ([]byte, *http.Response) {
set := "/" + uuid.Must(uuid.NewV4()).String() + "/set"
if id == nil {
router.GET(set, MockSetSession(t, reg, conf))
router.HandleFunc("GET "+set, MockSetSession(t, reg, conf))
} else {
router.GET(set, MockSetSessionWithIdentity(t, reg, conf, id))
router.HandleFunc("GET "+set, MockSetSessionWithIdentity(t, reg, conf, id))
}
MockHydrateCookieClient(t, client, "http://"+req.URL.Host+set+"?"+req.URL.Query().Encode())
@ -128,11 +131,11 @@ func MockHydrateCookieClient(t *testing.T, c *http.Client, u string) *http.Cooki
return sessionCookie
}
func MockSessionCreateHandlerWithIdentity(t *testing.T, reg mockDeps, i *identity.Identity) (httprouter.Handle, *session.Session) {
func MockSessionCreateHandlerWithIdentity(t *testing.T, reg mockDeps, i *identity.Identity) (http.HandlerFunc, *session.Session) {
return MockSessionCreateHandlerWithIdentityAndAMR(t, reg, i, []identity.CredentialsType{"password"})
}
func MockSessionCreateHandlerWithIdentityAndAMR(t *testing.T, reg mockDeps, i *identity.Identity, methods []identity.CredentialsType) (httprouter.Handle, *session.Session) {
func MockSessionCreateHandlerWithIdentityAndAMR(t *testing.T, reg mockDeps, i *identity.Identity, methods []identity.CredentialsType) (http.HandlerFunc, *session.Session) {
var sess session.Session
require.NoError(t, faker.FakeData(&sess))
// require AuthenticatedAt to be time.Now() as we always compare it to the current time
@ -159,12 +162,12 @@ func MockSessionCreateHandlerWithIdentityAndAMR(t *testing.T, reg mockDeps, i *i
require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), &sess))
require.Len(t, inserted.Credentials, len(i.Credentials))
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
return func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, reg.SessionManager().IssueCookie(context.Background(), w, r, &sess))
}, &sess
}
func MockSessionCreateHandler(t *testing.T, reg mockDeps) (httprouter.Handle, *session.Session) {
func MockSessionCreateHandler(t *testing.T, reg mockDeps) (http.HandlerFunc, *session.Session) {
return MockSessionCreateHandlerWithIdentity(t, reg, &identity.Identity{
ID: x.NewUUID(), State: identity.StateActive, Traits: identity.Traits(`{"baz":"bar","foo":true,"bar":2.5}`)})
}

View File

@ -13,7 +13,7 @@ import (
"time"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -111,11 +111,11 @@ func ExpectURL(isAPI bool, api, browser string) string {
}
func NewSettingsUITestServer(t *testing.T, conf *config.Config) *httptest.Server {
router := httprouter.New()
router.GET("/settings", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /settings", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
})
ts := httptest.NewServer(router)
@ -129,14 +129,14 @@ func NewSettingsUITestServer(t *testing.T, conf *config.Config) *httptest.Server
}
func NewSettingsUIEchoServer(t *testing.T, reg *driver.RegistryDefault) *httptest.Server {
router := httprouter.New()
router.GET("/settings", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /settings", func(w http.ResponseWriter, r *http.Request) {
res, err := reg.SettingsFlowPersister().GetSettingsFlow(r.Context(), x.ParseUUID(r.URL.Query().Get("flow")))
require.NoError(t, err)
reg.Writer().Write(w, r, res)
})
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
})
ts := httptest.NewServer(router)
@ -211,9 +211,9 @@ func AddAndLoginIdentities(t *testing.T, reg *driver.RegistryDefault, public *ht
location := "/sessions/set/" + tid
if router, ok := public.Config.Handler.(*x.RouterPublic); ok {
router.Router.GET(location, route)
} else if router, ok := public.Config.Handler.(*httprouter.Router); ok {
router.GET(location, route)
} else if router, ok := public.Config.Handler.(*http.ServeMux); ok {
router.Handle("GET "+location, route)
} else if router, ok := public.Config.Handler.(*x.RouterAdmin); ok {
router.GET(location, route)
} else {

View File

@ -42,6 +42,15 @@ func (h *Handler) SetRoutes(r router) {
r.GET(MetricsPrometheusPath, h.Metrics)
}
type muxrouter interface {
GET(path string, handle http.HandlerFunc)
}
// SetMuxRoutes registers this handler's routes on a ServeMux.
func (h *Handler) SetMuxRoutes(mux muxrouter) {
mux.GET(MetricsPrometheusPath, promhttp.Handler().ServeHTTP)
}
// Metrics outputs prometheus metrics
//
// swagger:route GET /metrics/prometheus metadata prometheus

View File

@ -5,6 +5,7 @@ package prometheusx
import (
"net/http"
"regexp"
"strings"
"sync"
@ -61,7 +62,16 @@ func (pmm *MetricsManager) RegisterRouter(router *httprouter.Router) {
pmm.routers.data = append(pmm.routers.data, router)
}
var paramPlaceHolderRE = regexp.MustCompile(`\{[a-zA-Z0-9_-]+\}`)
func (pmm *MetricsManager) getLabelForPath(r *http.Request) string {
// If the request came through a http.ServeMux, it already has a pattern that we
// can use as a label. We just need to replace all path parameters with a generic
// placeholder and remove the trailing slash pattern.
if p := r.Pattern; p != "" {
return paramPlaceHolderRE.ReplaceAllString(strings.TrimSuffix(p, "/{$}"), "{param}")
}
// looking for a match in one of registered routers
pmm.routers.Lock()
defer pmm.routers.Unlock()

1
package-lock.json generated
View File

@ -4,7 +4,6 @@
"requires": true,
"packages": {
"": {
"name": "kratos-oss",
"dependencies": {
"@openapitools/openapi-generator-cli": "2.20.0",
"yamljs": "0.3.0"

View File

@ -16,7 +16,6 @@ import (
"github.com/ory/kratos/x/nosurfx"
"github.com/ory/kratos/x/redir"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/ory/herodot"
@ -55,14 +54,14 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
"/"+SchemasPath+"/*",
x.AdminPrefix+"/"+SchemasPath+"/*",
)
public.GET(fmt.Sprintf("/%s/:id", SchemasPath), h.getIdentitySchema)
public.GET(fmt.Sprintf("/%s/{id}", SchemasPath), h.getIdentitySchema)
public.GET(fmt.Sprintf("/%s", SchemasPath), h.getAll)
public.GET(fmt.Sprintf("%s/%s/:id", x.AdminPrefix, SchemasPath), h.getIdentitySchema)
public.GET(fmt.Sprintf("%s/%s/{id}", x.AdminPrefix, SchemasPath), h.getIdentitySchema)
public.GET(fmt.Sprintf("%s/%s", x.AdminPrefix, SchemasPath), h.getAll)
}
func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
admin.GET(fmt.Sprintf("/%s/:id", SchemasPath), redir.RedirectToPublicRoute(h.r))
admin.GET(fmt.Sprintf("/%s/{id}", SchemasPath), redir.RedirectToPublicRoute(h.r))
admin.GET(fmt.Sprintf("/%s", SchemasPath), redir.RedirectToPublicRoute(h.r))
}
@ -112,7 +111,7 @@ type getIdentitySchema struct {
// 200: identitySchema
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request) {
ctx, span := h.r.Tracer(r.Context()).Tracer().Start(r.Context(), "schema.Handler.getIdentitySchema")
defer span.End()
@ -122,7 +121,7 @@ func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request, ps h
return
}
id := ps.ByName("id")
id := r.PathValue("id")
s, err := ss.GetByID(id)
if err != nil {
// Maybe it is a base64 encoded ID?
@ -203,7 +202,7 @@ type identitySchemasResponse struct {
// Responses:
// 200: identitySchemas
// default: errorGeneric
func (h *Handler) getAll(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) getAll(w http.ResponseWriter, r *http.Request) {
ctx, span := h.r.Tracer(r.Context()).Tracer().Start(r.Context(), "schema.Handler.getAll")
defer span.End()

View File

@ -14,16 +14,15 @@ import (
"github.com/ory/jsonschema/v3/httploader"
"github.com/ory/x/httpx"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/require"
"github.com/ory/x/stringsx"
)
func TestSchemaValidator(t *testing.T) {
router := httprouter.New()
router := http.NewServeMux()
fs := http.StripPrefix("/schema", http.FileServer(http.Dir("stub/validator")))
router.GET("/schema/:name", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("/schema/{name}", func(w http.ResponseWriter, r *http.Request) {
fs.ServeHTTP(w, r)
})
ts := httptest.NewServer(router)

View File

@ -14,8 +14,6 @@ import (
"github.com/ory/kratos/driver/config"
"github.com/julienschmidt/httprouter"
"github.com/ory/herodot"
"github.com/ory/kratos/x"
"github.com/ory/nosurf"
@ -92,7 +90,7 @@ type getFlowError struct {
// 403: errorGeneric
// 404: errorGeneric
// 500: errorGeneric
func (h *Handler) publicFetchError(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) publicFetchError(w http.ResponseWriter, r *http.Request) {
if err := h.fetchError(w, r); err != nil {
h.r.Writer().WriteError(w, r, err)
return

View File

@ -16,7 +16,6 @@ import (
"github.com/ory/x/assertx"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -39,11 +38,11 @@ func TestHandler(t *testing.T) {
ns := nosurfx.NewTestCSRFHandler(router, reg)
h.RegisterPublicRoutes(router)
router.GET("/regen", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /regen", func(w http.ResponseWriter, r *http.Request) {
ns.RegenerateToken(w, r)
w.WriteHeader(http.StatusNoContent)
})
router.GET("/set-error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /set-error", func(w http.ResponseWriter, r *http.Request) {
id, err := reg.SelfServiceErrorPersister().CreateErrorContainer(context.Background(), nosurf.Token(r), herodot.ErrNotFound.WithReason("foobar"))
require.NoError(t, err)
_, _ = w.Write([]byte(id.String()))

View File

@ -18,7 +18,7 @@ import (
"github.com/ory/kratos/ui/node"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -45,7 +45,7 @@ func TestHandleError(t *testing.T) {
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/password.schema.json")
router := httprouter.New()
router := http.NewServeMux()
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
@ -58,7 +58,7 @@ func TestHandleError(t *testing.T) {
var loginFlow *login.Flow
var flowError error
var ct node.UiNodeGroup
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
h.WriteFlowError(w, r, loginFlow, ct, flowError)
})

View File

@ -10,7 +10,6 @@ import (
"time"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
@ -378,7 +377,7 @@ type createNativeLoginFlow struct {
// 200: loginFlow
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) createNativeLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) createNativeLoginFlow(w http.ResponseWriter, r *http.Request) {
var err error
ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.createNativeLoginFlow")
r = r.WithContext(ctx)
@ -497,7 +496,7 @@ type createBrowserLoginFlow struct {
// 303: emptyResponse
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) createBrowserLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) createBrowserLoginFlow(w http.ResponseWriter, r *http.Request) {
var err error
ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.createBrowserLoginFlow")
r = r.WithContext(ctx)
@ -655,7 +654,7 @@ type getLoginFlow struct {
// 404: errorGeneric
// 410: errorGeneric
// default: errorGeneric
func (h *Handler) getLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) getLoginFlow(w http.ResponseWriter, r *http.Request) {
var err error
ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.getLoginFlow")
r = r.WithContext(ctx)
@ -797,7 +796,7 @@ type updateLoginFlowBody struct{}
// 410: errorGeneric
// 422: errorBrowserLocationChangeRequired
// default: errorGeneric
func (h *Handler) updateLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) updateLoginFlow(w http.ResponseWriter, r *http.Request) {
var err error
ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.updateLoginFlow")
ctx = semconv.ContextWithAttributes(ctx, attribute.String(events.AttributeKeySelfServiceStrategyUsed.String(), "login"))

View File

@ -16,7 +16,6 @@ import (
"github.com/ory/kratos/x/nosurfx"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/ory/x/urlx"
@ -88,7 +87,7 @@ func TestFlowLifecycle(t *testing.T) {
}
req := testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+route, nil)
req.URL.RawQuery = extQuery.Encode()
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
if isAPI {
assert.Len(t, res.Header.Get("Set-Cookie"), 0)
}
@ -354,7 +353,7 @@ func TestFlowLifecycle(t *testing.T) {
require.NoError(t, err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
return string(body), res
}
@ -458,7 +457,7 @@ func TestFlowLifecycle(t *testing.T) {
})
require.NoError(t, reg.IdentityManager().Update(context.Background(), id, identity.ManagerAllowWriteProtectedTraits))
h := func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
h := func(w http.ResponseWriter, r *http.Request) {
sess, err := testhelpers.NewActiveSession(r, reg, id, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
require.NoError(t, err)
sess.AuthenticatorAssuranceLevel = identity.AuthenticatorAssuranceLevel1

View File

@ -16,7 +16,7 @@ import (
"github.com/ory/x/configx"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -50,9 +50,9 @@ func TestLoginExecutor(t *testing.T) {
_ = testhelpers.NewLoginUIFlowEchoServer(t, reg)
newServer := func(t *testing.T, ft flow.Type, useIdentity *identity.Identity, flowCallback ...func(*login.Flow)) *httptest.Server {
router := httprouter.New()
router := http.NewServeMux()
router.GET("/login/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /login/pre", func(w http.ResponseWriter, r *http.Request) {
loginFlow, err := login.NewFlow(conf, time.Minute, "", r, ft)
require.NoError(t, err)
if testhelpers.SelfServiceHookLoginErrorHandler(t, w, r, reg.LoginHookExecutor().PreLoginHook(w, r, loginFlow)) {
@ -60,7 +60,7 @@ func TestLoginExecutor(t *testing.T) {
}
})
router.GET("/login/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /login/post", func(w http.ResponseWriter, r *http.Request) {
loginFlow, err := login.NewFlow(conf, time.Minute, "", r, ft)
require.NoError(t, err)
loginFlow.Active = strategy
@ -79,7 +79,7 @@ func TestLoginExecutor(t *testing.T) {
reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, ""))
})
router.GET("/login/post2fa", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /login/post2fa", func(w http.ResponseWriter, r *http.Request) {
loginFlow, err := login.NewFlow(conf, time.Minute, "", r, ft)
require.NoError(t, err)
loginFlow.Active = strategy

View File

@ -22,8 +22,6 @@ import (
"github.com/ory/x/sqlcon"
"github.com/ory/x/urlx"
"github.com/julienschmidt/httprouter"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/selfservice/errorx"
"github.com/ory/kratos/session"
@ -141,7 +139,7 @@ type createBrowserLogoutFlow struct {
// 400: errorGeneric
// 401: errorGeneric
// 500: errorGeneric
func (h *Handler) createBrowserLogoutFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) createBrowserLogoutFlow(w http.ResponseWriter, r *http.Request) {
sess, err := h.d.SessionManager().FetchFromRequest(r.Context(), r)
if err != nil {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
@ -232,7 +230,7 @@ type performNativeLogoutBody struct {
// 204: emptyResponse
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) performNativeLogout(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) performNativeLogout(w http.ResponseWriter, r *http.Request) {
var p performNativeLogoutBody
if err := h.dx.Decode(r, &p,
decoderx.HTTPJSONDecoder(),
@ -323,7 +321,7 @@ type updateLogoutFlow struct {
// 303: emptyResponse
// 204: emptyResponse
// default: errorGeneric
func (h *Handler) updateLogoutFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) updateLogoutFlow(w http.ResponseWriter, r *http.Request) {
expected := r.URL.Query().Get("token")
if len(expected) == 0 {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReason("Please include a token in the URL query.")))

View File

@ -17,7 +17,6 @@ import (
"github.com/ory/kratos/session"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -38,8 +37,10 @@ func TestLogout(t *testing.T) {
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
public, _, publicRouter, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
publicRouter.GET("/session/browser/set", testhelpers.MockSetSession(t, reg, conf))
publicRouter.GET("/session/browser/get", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
publicRouter.GET("/session/browser/set", func(writer http.ResponseWriter, request *http.Request) {
testhelpers.MockSetSession(t, reg, conf)(writer, request)
})
publicRouter.HandleFunc("GET /session/browser/get", func(w http.ResponseWriter, r *http.Request) {
sess, err := reg.SessionManager().FetchFromRequest(r.Context(), r)
if err != nil {
reg.Writer().WriteError(w, r, err)
@ -47,7 +48,7 @@ func TestLogout(t *testing.T) {
}
reg.Writer().Write(w, r, sess)
})
publicRouter.POST("/csrf/check", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
publicRouter.HandleFunc("POST /csrf/check", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
conf.MustSet(ctx, config.ViperKeySelfServiceLogoutBrowserDefaultReturnTo, public.URL+"/session/browser/get")

View File

@ -21,7 +21,7 @@ import (
"github.com/ory/kratos/ui/node"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -49,7 +49,7 @@ func TestHandleError(t *testing.T) {
public, _ := testhelpers.NewKratosServer(t, reg)
router := httprouter.New()
router := http.NewServeMux()
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
@ -62,7 +62,7 @@ func TestHandleError(t *testing.T) {
var recoveryFlow *recovery.Flow
var flowError error
var methodName node.UiNodeGroup
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
h.WriteFlowError(w, r, recoveryFlow, methodName, flowError)
})
@ -307,7 +307,7 @@ func TestHandleError_WithContinueWith(t *testing.T) {
public, _ := testhelpers.NewKratosServer(t, reg)
router := httprouter.New()
router := http.NewServeMux()
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
@ -320,7 +320,7 @@ func TestHandleError_WithContinueWith(t *testing.T) {
var recoveryFlow *recovery.Flow
var flowError error
var methodName node.UiNodeGroup
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
h.WriteFlowError(w, r, recoveryFlow, methodName, flowError)
})

View File

@ -20,7 +20,6 @@ import (
"github.com/ory/herodot"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/ory/x/urlx"
@ -73,11 +72,11 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
h.d.CSRFHandler().IgnorePath(RouteSubmitFlow)
redirect := session.RedirectOnAuthenticated(h.d)
public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsNotAuthenticated(h.createBrowserRecoveryFlow, func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsNotAuthenticated(h.createBrowserRecoveryFlow, func(w http.ResponseWriter, r *http.Request) {
if x.IsJSONRequest(r) {
h.d.Writer().WriteError(w, r, errors.WithStack(ErrAlreadyLoggedIn))
} else {
redirect(w, r, ps)
redirect(w, r)
}
}))
@ -122,7 +121,7 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
// 200: recoveryFlow
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) createNativeRecoveryFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) createNativeRecoveryFlow(w http.ResponseWriter, r *http.Request) {
if !h.d.Config().SelfServiceFlowRecoveryEnabled(r.Context()) {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
return
@ -187,7 +186,7 @@ type createBrowserRecoveryFlow struct {
// 303: emptyResponse
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) createBrowserRecoveryFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) createBrowserRecoveryFlow(w http.ResponseWriter, r *http.Request) {
if !h.d.Config().SelfServiceFlowRecoveryEnabled(r.Context()) {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
return
@ -277,7 +276,7 @@ type getRecoveryFlow struct {
// 404: errorGeneric
// 410: errorGeneric
// default: errorGeneric
func (h *Handler) getRecoveryFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) getRecoveryFlow(w http.ResponseWriter, r *http.Request) {
if !h.d.Config().SelfServiceFlowRecoveryEnabled(r.Context()) {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
return
@ -402,7 +401,7 @@ type updateRecoveryFlowBody struct{}
// 410: errorGeneric
// 422: errorBrowserLocationChangeRequired
// default: errorGeneric
func (h *Handler) updateRecoveryFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) updateRecoveryFlow(w http.ResponseWriter, r *http.Request) {
rid, err := flow.GetFlowID(r)
if err != nil {
h.d.RecoveryFlowErrorHandler().WriteFlowError(w, r, nil, node.DefaultGroup, err)

View File

@ -50,13 +50,13 @@ func TestHandlerRedirectOnAuthenticated(t *testing.T) {
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
t.Run("does redirect to default on authenticated request", func(t *testing.T) {
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+recovery.RouteInitBrowserFlow, nil))
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+recovery.RouteInitBrowserFlow, nil))
assert.Contains(t, res.Request.URL.String(), redirTS.URL, "%+v", res)
assert.EqualValues(t, "already authenticated", string(body))
})
t.Run("does redirect to default on authenticated request", func(t *testing.T) {
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+recovery.RouteInitAPIFlow, nil))
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+recovery.RouteInitAPIFlow, nil))
assert.Contains(t, res.Request.URL.String(), recovery.RouteInitAPIFlow)
assert.EqualValues(t, text.ErrIDAlreadyLoggedIn, gjson.GetBytes(body, "error.id").Str)
assertx.EqualAsJSON(t, recovery.ErrAlreadyLoggedIn, json.RawMessage(gjson.GetBytes(body, "error").Raw))
@ -96,7 +96,7 @@ func TestInitFlow(t *testing.T) {
if isSPA {
req.Header.Set("Accept", "application/json")
}
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
if isAPI {
assert.Len(t, res.Header.Get("Set-Cookie"), 0)
}

View File

@ -16,7 +16,7 @@ import (
"github.com/ory/kratos/selfservice/strategy/code"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -35,8 +35,8 @@ func TestRecoveryExecutor(t *testing.T) {
s := code.NewStrategy(reg)
newServer := func(t *testing.T, i *identity.Identity, ft flow.Type) *httptest.Server {
router := httprouter.New()
router.GET("/recovery/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /recovery/pre", func(w http.ResponseWriter, r *http.Request) {
a, err := recovery.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, s, ft)
require.NoError(t, err)
if testhelpers.SelfServiceHookErrorHandler(t, w, r, recovery.ErrHookAbortFlow, reg.RecoveryExecutor().PreRecoveryHook(w, r, a)) {
@ -44,7 +44,7 @@ func TestRecoveryExecutor(t *testing.T) {
}
})
router.GET("/recovery/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /recovery/post", func(w http.ResponseWriter, r *http.Request) {
a, err := recovery.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, s, ft)
require.NoError(t, err)
s, err := testhelpers.NewActiveSession(r,

View File

@ -20,7 +20,7 @@ import (
"github.com/ory/kratos/ui/node"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -48,7 +48,7 @@ func TestHandleError(t *testing.T) {
public, _ := testhelpers.NewKratosServer(t, reg)
router := httprouter.New()
router := http.NewServeMux()
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
@ -61,7 +61,7 @@ func TestHandleError(t *testing.T) {
var registrationFlow *registration.Flow
var flowError error
var group node.UiNodeGroup
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
h.WriteFlowError(w, r, registrationFlow, group, flowError)
})

View File

@ -11,7 +11,6 @@ import (
"github.com/ory/kratos/x/nosurfx"
"github.com/ory/kratos/x/redir"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
@ -88,13 +87,13 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
public.GET(RouteSubmitFlow, h.d.SessionHandler().IsNotAuthenticated(h.updateRegistrationFlow, h.onAuthenticated))
}
func (h *Handler) onAuthenticated(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) onAuthenticated(w http.ResponseWriter, r *http.Request) {
handler := session.RedirectOnAuthenticated(h.d)
if x.IsJSONRequest(r) {
handler = session.RespondWithJSONErrorOnAuthenticated(h.d.Writer(), ErrAlreadyLoggedIn)
}
handler(w, r, ps)
handler(w, r)
}
func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
@ -224,7 +223,7 @@ func (h *Handler) FromOldFlow(w http.ResponseWriter, r *http.Request, of Flow) (
// 200: registrationFlow
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) createNativeRegistrationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) createNativeRegistrationFlow(w http.ResponseWriter, r *http.Request) {
a, err := h.NewRegistrationFlow(w, r, flow.TypeAPI)
if err != nil {
h.d.Writer().WriteError(w, r, err)
@ -336,7 +335,7 @@ type createBrowserRegistrationFlow struct {
// 200: registrationFlow
// 303: emptyResponse
// default: errorGeneric
func (h *Handler) createBrowserRegistrationFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) createBrowserRegistrationFlow(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
a, err := h.NewRegistrationFlow(w, r, flow.TypeBrowser)
@ -501,7 +500,7 @@ type getRegistrationFlow struct {
// 404: errorGeneric
// 410: errorGeneric
// default: errorGeneric
func (h *Handler) getRegistrationFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) getRegistrationFlow(w http.ResponseWriter, r *http.Request) {
if !h.d.Config().SelfServiceFlowRegistrationEnabled(r.Context()) {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(ErrRegistrationDisabled))
return
@ -638,7 +637,7 @@ type updateRegistrationFlowBody struct{}
// 410: errorGeneric
// 422: errorBrowserLocationChangeRequired
// default: errorGeneric
func (h *Handler) updateRegistrationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) updateRegistrationFlow(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = semconv.ContextWithAttributes(ctx, attribute.String(events.AttributeKeySelfServiceStrategyUsed.String(), "registration"))
r = r.WithContext(ctx)

View File

@ -65,19 +65,19 @@ func TestHandlerRedirectOnAuthenticated(t *testing.T) {
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
t.Run("does redirect to default on authenticated request", func(t *testing.T) {
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow, nil))
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow, nil))
assert.Contains(t, res.Request.URL.String(), redirTS.URL)
assert.EqualValues(t, "already authenticated", string(body))
})
t.Run("does redirect to default on authenticated request", func(t *testing.T) {
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitAPIFlow, nil))
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitAPIFlow, nil))
assert.Contains(t, res.Request.URL.String(), registration.RouteInitAPIFlow)
assertx.EqualAsJSON(t, registration.ErrAlreadyLoggedIn, json.RawMessage(gjson.GetBytes(body, "error").Raw))
})
t.Run("does redirect to return_to url on authenticated request", func(t *testing.T) {
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?return_to="+returnToTS.URL, nil))
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?return_to="+returnToTS.URL, nil))
assert.Contains(t, res.Request.URL.String(), returnToTS.URL)
assert.EqualValues(t, "return_to", string(body))
})
@ -92,7 +92,7 @@ func TestHandlerRedirectOnAuthenticated(t *testing.T) {
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?login_challenge="+hydra.FakeValidLoginChallenge, nil), client)
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?login_challenge="+hydra.FakeValidLoginChallenge, nil), client)
assert.Contains(t, res.Header.Get("location"), login.RouteInitBrowserFlow)
})
@ -106,7 +106,7 @@ func TestHandlerRedirectOnAuthenticated(t *testing.T) {
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?login_challenge="+hydra.FakeValidLoginChallenge, nil), client)
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?login_challenge="+hydra.FakeValidLoginChallenge, nil), client)
assert.Contains(t, res.Header.Get("location"), hydra.FakePostLoginURL)
})
}
@ -142,7 +142,7 @@ func TestInitFlow(t *testing.T) {
if isSPA {
req.Header.Set("Accept", "application/json")
}
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
if isAPI {
assert.Len(t, res.Header.Get("Set-Cookie"), 0)
}

View File

@ -15,7 +15,7 @@ import (
"github.com/gobuffalo/httptest"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -47,10 +47,10 @@ func TestRegistrationExecutor(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, returnToServer.URL)
newServer := func(t *testing.T, i *identity.Identity, ft flow.Type, flowCallbacks ...func(*registration.Flow)) *httptest.Server {
router := httprouter.New()
router := http.NewServeMux()
handleErr := testhelpers.SelfServiceHookRegistrationErrorHandler
router.GET("/registration/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /registration/pre", func(w http.ResponseWriter, r *http.Request) {
f, err := registration.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, ft)
require.NoError(t, err)
if handleErr(t, w, r, reg.RegistrationHookExecutor().PreRegistrationHook(w, r, f)) {
@ -58,7 +58,7 @@ func TestRegistrationExecutor(t *testing.T) {
}
})
router.GET("/registration/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /registration/post", func(w http.ResponseWriter, r *http.Request) {
if i == nil {
i = testhelpers.SelfServiceHookFakeIdentity(t)
}

View File

@ -16,7 +16,7 @@ import (
"github.com/go-faker/faker/v4"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -45,7 +45,7 @@ func TestHandleError(t *testing.T) {
public, _ := testhelpers.NewKratosServer(t, reg)
router := httprouter.New()
router := http.NewServeMux()
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
@ -65,11 +65,11 @@ func TestHandleError(t *testing.T) {
id.State = identity.StateActive
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &id))
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
h.WriteFlowError(ctx, w, r, flowMethod, settingsFlow, &id, flowError)
})
router.GET("/fake-redirect", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
router.HandleFunc("GET /fake-redirect", func(w http.ResponseWriter, r *http.Request) {
reg.LoginHandler().NewLoginFlow(w, r, flow.TypeBrowser)
})

View File

@ -14,7 +14,6 @@ import (
"github.com/ory/x/otelx"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/ory/herodot"
@ -96,7 +95,7 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
h.d.CSRFHandler().IgnorePath(RouteInitAPIFlow)
h.d.CSRFHandler().IgnorePath(RouteSubmitFlow)
public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsAuthenticated(h.createBrowserSettingsFlow, func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsAuthenticated(h.createBrowserSettingsFlow, func(w http.ResponseWriter, r *http.Request) {
if x.IsJSONRequest(r) {
h.d.Writer().WriteError(w, r, session.NewErrNoActiveSessionFound())
} else {
@ -222,7 +221,7 @@ type createNativeSettingsFlow struct {
// 200: settingsFlow
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) createNativeSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) createNativeSettingsFlow(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
if err != nil {
@ -306,7 +305,7 @@ type createBrowserSettingsFlow struct {
// 401: errorGeneric
// 403: errorGeneric
// default: errorGeneric
func (h *Handler) createBrowserSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) createBrowserSettingsFlow(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
if err != nil {
@ -405,7 +404,7 @@ type getSettingsFlow struct {
// 404: errorGeneric
// 410: errorGeneric
// default: errorGeneric
func (h *Handler) getSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) getSettingsFlow(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
rid := x.ParseUUID(r.URL.Query().Get("id"))
pr, err := h.d.SettingsFlowPersister().GetSettingsFlow(ctx, rid)
@ -567,7 +566,7 @@ type updateSettingsFlowBody struct{}
// 410: errorGeneric
// 422: errorBrowserLocationChangeRequired
// default: errorGeneric
func (h *Handler) updateSettingsFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) updateSettingsFlow(w http.ResponseWriter, r *http.Request) {
var (
err error
ctx = r.Context()

View File

@ -13,7 +13,7 @@ import (
"github.com/tidwall/gjson"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -47,9 +47,9 @@ func TestSettingsExecutor(t *testing.T) {
newServer := func(t *testing.T, i *identity.Identity, ft flow.Type) *httptest.Server {
t.Helper()
router := httprouter.New()
router := http.NewServeMux()
handleErr := testhelpers.SelfServiceHookSettingsErrorHandler
router.GET("/settings/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /settings/pre", func(w http.ResponseWriter, r *http.Request) {
if i == nil {
i = testhelpers.SelfServiceHookCreateFakeIdentity(t, reg)
}
@ -62,7 +62,7 @@ func TestSettingsExecutor(t *testing.T) {
}
})
router.GET("/settings/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /settings/post", func(w http.ResponseWriter, r *http.Request) {
if i == nil {
i = testhelpers.SelfServiceHookCreateFakeIdentity(t, reg)
}

View File

@ -9,7 +9,6 @@ import (
"time"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/ory/herodot"
@ -102,13 +101,13 @@ func GetFlowID(r *http.Request) (uuid.UUID, error) {
func OnUnauthenticated(reg interface {
config.Provider
x.WriterProvider
}) func(http.ResponseWriter, *http.Request, httprouter.Params) {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
}) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
handler := session.RedirectOnUnauthenticated(reg.Config().SelfServiceFlowLoginUI(r.Context()).String())
if x.IsJSONRequest(r) {
handler = session.RespondWithJSONErrorOnAuthenticated(reg.Writer(), herodot.ErrUnauthorized.WithReasonf("A valid Ory Session Cookie or Ory Session Token is missing."))
}
handler(w, r, ps)
handler(w, r)
}
}

View File

@ -19,7 +19,7 @@ import (
"github.com/ory/kratos/ui/node"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -46,7 +46,7 @@ func TestHandleError(t *testing.T) {
public, _ := testhelpers.NewKratosServer(t, reg)
router := httprouter.New()
router := http.NewServeMux()
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
@ -59,7 +59,7 @@ func TestHandleError(t *testing.T) {
var verificationFlow *verification.Flow
var flowError error
var methodName node.UiNodeGroup
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
h.WriteFlowError(w, r, verificationFlow, methodName, flowError)
})

View File

@ -20,7 +20,6 @@ import (
"github.com/ory/herodot"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/ory/x/urlx"
@ -162,7 +161,7 @@ type createNativeVerificationFlow struct {
// 200: verificationFlow
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) createNativeVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) createNativeVerificationFlow(w http.ResponseWriter, r *http.Request) {
if !h.d.Config().SelfServiceFlowVerificationEnabled(r.Context()) {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification is not allowed because it was disabled.")))
return
@ -209,7 +208,7 @@ type createBrowserVerificationFlow struct {
// 200: verificationFlow
// 303: emptyResponse
// default: errorGeneric
func (h *Handler) createBrowserVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) createBrowserVerificationFlow(w http.ResponseWriter, r *http.Request) {
if !h.d.Config().SelfServiceFlowVerificationEnabled(r.Context()) {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification is not allowed because it was disabled.")))
return
@ -284,7 +283,7 @@ type getVerificationFlow struct {
// 403: errorGeneric
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) getVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) getVerificationFlow(w http.ResponseWriter, r *http.Request) {
if !h.d.Config().SelfServiceFlowVerificationEnabled(r.Context()) {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification is not allowed because it was disabled.")))
return
@ -408,7 +407,7 @@ type updateVerificationFlowBody struct{}
// 400: verificationFlow
// 410: errorGeneric
// default: errorGeneric
func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request) {
rid, err := flow.GetFlowID(r)
if err != nil {
h.d.VerificationFlowErrorHandler().WriteFlowError(w, r, nil, node.DefaultGroup, err)

View File

@ -15,7 +15,7 @@ import (
"github.com/ory/kratos/selfservice/flow/verification"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -32,8 +32,8 @@ func TestVerificationExecutor(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
newServer := func(t *testing.T, i *identity.Identity, ft flow.Type) *httptest.Server {
router := httprouter.New()
router.GET("/verification/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /verification/pre", func(w http.ResponseWriter, r *http.Request) {
strategy, err := reg.GetActiveVerificationStrategy(r.Context())
require.NoError(t, err)
a, err := verification.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, strategy, ft)
@ -43,7 +43,7 @@ func TestVerificationExecutor(t *testing.T) {
}
})
router.GET("/verification/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /verification/post", func(w http.ResponseWriter, r *http.Request) {
strategy, err := reg.GetActiveVerificationStrategy(r.Context())
require.NoError(t, err)
a, err := verification.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, strategy, ft)

View File

@ -20,7 +20,6 @@ import (
"testing"
"time"
"github.com/julienschmidt/httprouter"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -103,8 +102,8 @@ func TestWebHooks(t *testing.T) {
Method string
}
webHookEndPoint := func(whr *WebHookRequest) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
webHookEndPoint := func(whr *WebHookRequest) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
@ -116,14 +115,14 @@ func TestWebHooks(t *testing.T) {
}
}
webHookHttpCodeEndPoint := func(code int) httprouter.Handle {
return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
webHookHttpCodeEndPoint := func(code int) http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(code)
}
}
webHookHttpCodeWithBodyEndPoint := func(t *testing.T, code int, body []byte) httprouter.Handle {
return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
webHookHttpCodeWithBodyEndPoint := func(t *testing.T, code int, body []byte) http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(code)
_, err := w.Write(body)
assert.NoError(t, err, "error while returning response from webHookHttpCodeWithBodyEndPoint")
@ -131,17 +130,17 @@ func TestWebHooks(t *testing.T) {
}
path := "/web_hook"
newServer := func(f httprouter.Handle) *httptest.Server {
r := httprouter.New()
newServer := func(f http.HandlerFunc) *httptest.Server {
r := http.NewServeMux()
r.Handle("CONNECT", path, f)
r.DELETE(path, f)
r.GET(path, f)
r.OPTIONS(path, f)
r.PATCH(path, f)
r.POST(path, f)
r.PUT(path, f)
r.Handle("TRACE", path, f)
r.HandleFunc("CONNECT "+path, f)
r.HandleFunc("DELETE "+path, f)
r.HandleFunc("GET "+path, f)
r.HandleFunc("OPTIONS "+path, f)
r.HandleFunc("PATCH "+path, f)
r.HandleFunc("POST "+path, f)
r.HandleFunc("PUT "+path, f)
r.HandleFunc("TRACE "+path, f)
ts := httptest.NewServer(r)
t.Cleanup(ts.Close)
@ -917,7 +916,7 @@ func TestWebHooks(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
waitTime := time.Millisecond * 100
ts := newServer(func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
ts := newServer(func(w http.ResponseWriter, _ *http.Request) {
defer wg.Done()
time.Sleep(waitTime)
w.WriteHeader(http.StatusBadRequest)
@ -956,7 +955,7 @@ func TestWebHooks(t *testing.T) {
var wg sync.WaitGroup
wg.Add(3) // HTTP client does 3 attempts
ts := newServer(func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
ts := newServer(func(w http.ResponseWriter, r *http.Request) {
defer wg.Done()
w.WriteHeader(500)
_, _ = w.Write([]byte(`{"error":"some error"}`))

View File

@ -145,9 +145,9 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F
if _, err := s.deps.SessionManager().FetchFromRequest(ctx, r); err == nil {
// User is already logged in
if x.IsJSONRequest(r) {
session.RespondWithJSONErrorOnAuthenticated(s.deps.Writer(), recovery.ErrAlreadyLoggedIn)(w, r, nil)
session.RespondWithJSONErrorOnAuthenticated(s.deps.Writer(), recovery.ErrAlreadyLoggedIn)(w, r)
} else {
session.RedirectOnAuthenticated(s.deps)(w, r, nil)
session.RedirectOnAuthenticated(s.deps)(w, r)
}
return errors.WithStack(flow.ErrCompletedByStrategy)
}

View File

@ -10,7 +10,6 @@ import (
"time"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/trace"
@ -138,7 +137,7 @@ type recoveryCodeForIdentity struct {
// 400: errorGeneric
// 404: errorGeneric
// default: errorGeneric
func (s *Strategy) createRecoveryCodeForIdentity(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (s *Strategy) createRecoveryCodeForIdentity(w http.ResponseWriter, r *http.Request) {
var p createRecoveryCodeForIdentityBody
if err := s.dx.Decode(r, &p, decoderx.HTTPJSONDecoder()); err != nil {
s.deps.Writer().WriteError(w, r, err)

View File

@ -6,8 +6,6 @@ package strategy
import (
"net/http"
"github.com/julienschmidt/httprouter"
"github.com/ory/herodot"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/x"
@ -20,32 +18,16 @@ type disabledChecker interface {
x.WriterProvider
}
func disabledWriter(c disabledChecker, enabled bool, wrap httprouter.Handle, w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func disabledWriter(c disabledChecker, enabled bool, wrap http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if !enabled {
c.Writer().WriteError(w, r, herodot.ErrNotFound.WithReason(EndpointDisabledMessage))
return
}
wrap(w, r, ps)
wrap(w, r)
}
func IsDisabled(c disabledChecker, strategy string, wrap httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
disabledWriter(c, c.Config().SelfServiceStrategy(r.Context(), strategy).Enabled, wrap, w, r, ps)
}
}
func IsRecoveryDisabled(c disabledChecker, strategy string, wrap httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
disabledWriter(c,
c.Config().SelfServiceStrategy(r.Context(), strategy).Enabled && c.Config().SelfServiceFlowRecoveryEnabled(r.Context()),
wrap, w, r, ps)
}
}
func IsVerificationDisabled(c disabledChecker, strategy string, wrap httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
disabledWriter(c,
c.Config().SelfServiceStrategy(r.Context(), strategy).Enabled && c.Config().SelfServiceFlowVerificationEnabled(r.Context()),
wrap, w, r, ps)
func IsDisabled(c disabledChecker, strategy string, wrap http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
disabledWriter(c, c.Config().SelfServiceStrategy(r.Context(), strategy).Enabled, wrap, w, r)
}
}

View File

@ -127,7 +127,7 @@ func TestCompleteLogin(t *testing.T) {
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
assert.Contains(t, res.Request.URL.String(), publicTS.URL+login.RouteSubmitFlow)
assert.Equal(t, text.NewErrorValidationLoginNoStrategyFound().Text, gjson.GetBytes(actual, "ui.messages.0.text").String())
})

View File

@ -13,7 +13,6 @@ import (
"github.com/ory/kratos/x/redir"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
@ -148,7 +147,7 @@ type recoveryLinkForIdentity struct {
// 400: errorGeneric
// 404: errorGeneric
// default: errorGeneric
func (s *Strategy) createRecoveryLinkForIdentity(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (s *Strategy) createRecoveryLinkForIdentity(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var p createRecoveryLinkForIdentityBody
@ -275,9 +274,9 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F
if _, err := s.d.SessionManager().FetchFromRequest(r.Context(), r); err == nil {
if x.IsJSONRequest(r) {
session.RespondWithJSONErrorOnAuthenticated(s.d.Writer(), recovery.ErrAlreadyLoggedIn)(w, r, nil)
session.RespondWithJSONErrorOnAuthenticated(s.d.Writer(), recovery.ErrAlreadyLoggedIn)(w, r)
} else {
session.RedirectOnAuthenticated(s.d)(w, r, nil)
session.RedirectOnAuthenticated(s.d)(w, r)
}
return errors.WithStack(flow.ErrCompletedByStrategy)
}

View File

@ -680,7 +680,7 @@ func TestRecovery(t *testing.T) {
}
check(t, expectSuccess(t, nil, false, false, values), email, testhelpers.NewClientWithCookies(t), func(cl *http.Client, req *http.Request) (*http.Response, error) {
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, publicRouter.Router, req, cl)
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, publicRouter, req, cl)
return res, nil
})
})
@ -692,7 +692,7 @@ func TestRecovery(t *testing.T) {
cl := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, reg, id)
check(t, expectSuccess(t, nil, false, false, values), email, cl, func(_ *http.Client, req *http.Request) (*http.Response, error) {
_, res := testhelpers.MockMakeAuthenticatedRequestWithClientAndID(t, reg, conf, publicRouter.Router, req, cl, id)
_, res := testhelpers.MockMakeAuthenticatedRequestWithClientAndID(t, reg, conf, publicRouter, req, cl, id)
return res, nil
})
})

View File

@ -146,12 +146,12 @@ func (p Configuration) Redir(public *url.URL) string {
if p.OrganizationID != "" {
route := RouteOrganizationCallback
route = strings.Replace(route, ":provider", p.ID, 1)
route = strings.Replace(route, ":organization", p.OrganizationID, 1)
route = strings.Replace(route, "{provider}", p.ID, 1)
route = strings.Replace(route, "{organization}", p.OrganizationID, 1)
return urlx.AppendPaths(public, route).String()
}
return urlx.AppendPaths(public, strings.Replace(RouteCallback, ":provider", p.ID, 1)).String()
return urlx.AppendPaths(public, strings.Replace(RouteCallback, "{provider}", p.ID, 1)).String()
}
type ConfigurationCollection struct {

View File

@ -5,6 +5,7 @@ package oidc
import (
"bytes"
"cmp"
"context"
"encoding/json"
"maps"
@ -20,7 +21,6 @@ import (
"github.com/ory/kratos/x/redir"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/tidwall/gjson"
"go.opentelemetry.io/otel/attribute"
@ -57,10 +57,10 @@ import (
const (
RouteBase = "/self-service/methods/oidc"
RouteAuth = RouteBase + "/auth/:flow"
RouteCallback = RouteBase + "/callback/:provider"
RouteAuth = RouteBase + "/auth/{flow}"
RouteCallback = RouteBase + "/callback/{provider}"
RouteCallbackGeneric = RouteBase + "/callback"
RouteOrganizationCallback = RouteBase + "/organization/:organization/callback/:provider"
RouteOrganizationCallback = RouteBase + "/organization/{organization}/callback/{provider}"
)
var _ identity.ActiveCredentialsCounter = new(Strategy)
@ -198,15 +198,15 @@ func (s *Strategy) CountActiveMultiFactorCredentials(_ context.Context, _ map[id
func (s *Strategy) setRoutes(r *x.RouterPublic) {
wrappedHandleCallback := strategy.IsDisabled(s.d, s.ID().String(), s.HandleCallback)
if handle, _, _ := r.Lookup("GET", RouteCallback); handle == nil {
if !r.HasRoute("GET", RouteCallback) {
r.GET(RouteCallback, wrappedHandleCallback)
}
if handle, _, _ := r.Lookup("GET", RouteCallbackGeneric); handle == nil {
if !r.HasRoute("GET", RouteCallbackGeneric) {
r.GET(RouteCallbackGeneric, wrappedHandleCallback)
}
// Apple can use the POST request method when calling the callback
if handle, _, _ := r.Lookup("POST", RouteCallback); handle == nil {
if !r.HasRoute("POST", RouteCallback) {
// Apple is the only (known) provider that sometimes does a form POST to the callback URL.
// This is a workaround to handle this case.
// But since the URL contains the `id` of the provider, we just allow all OIDC provider callbacks to bypass CSRF.
@ -221,7 +221,7 @@ func (s *Strategy) setRoutes(r *x.RouterPublic) {
}
// Redirect POST request to GET rewriting form fields to query params.
func (s *Strategy) redirectToGET(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (s *Strategy) redirectToGET(w http.ResponseWriter, r *http.Request) {
publicUrl := s.d.Config().SelfPublicURL(r.Context())
dest := *r.URL
dest.Host = publicUrl.Host
@ -330,7 +330,7 @@ func (s *Strategy) validateFlow(ctx context.Context, r *http.Request, rid uuid.U
return ar, err // this must return the error
}
func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request, ps httprouter.Params) (flow.Flow, *oidcv1.State, *AuthCodeContainer, error) {
func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flow.Flow, *oidcv1.State, *AuthCodeContainer, error) {
var (
codeParam = stringsx.Coalesce(r.URL.Query().Get("code"), r.URL.Query().Get("authCode"))
stateParam = r.URL.Query().Get("state")
@ -345,7 +345,7 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request, ps h
return nil, nil, nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`Unable to complete OpenID Connect flow because the state parameter is invalid.`))
}
if providerFromURL := ps.ByName("provider"); providerFromURL != "" {
if providerFromURL := r.PathValue("provider"); providerFromURL != "" {
// We're serving an OIDC callback URL with provider in the URL.
if state.ProviderId == "" {
// provider in URL, but not in state: compatiblity mode, remove this fallback later
@ -442,18 +442,18 @@ func (s *Strategy) alreadyAuthenticated(ctx context.Context, w http.ResponseWrit
return false, nil
}
func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request) {
var (
code = stringsx.Coalesce(r.URL.Query().Get("code"), r.URL.Query().Get("authCode"))
code = cmp.Or(r.URL.Query().Get("code"), r.URL.Query().Get("authCode"))
err error
)
ctx := context.WithValue(r.Context(), httprouter.ParamsKey, ps)
ctx := r.Context()
ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "strategy.oidc.HandleCallback")
defer otelx.End(span, &err)
r = r.WithContext(ctx)
req, state, cntnr, err := s.ValidateCallback(w, r, ps)
req, state, cntnr, err := s.ValidateCallback(w, r)
if err != nil {
if req != nil {
s.forwardError(ctx, w, r, req, s.HandleError(ctx, w, r, req, state.ProviderId, nil, err))

View File

@ -20,7 +20,7 @@ import (
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/julienschmidt/httprouter"
"github.com/phayes/freeport"
"github.com/pkg/errors"
"github.com/rakutentech/jwk-go/jwk"
@ -135,7 +135,7 @@ func createClient(t *testing.T, remote string, redir []string) (id, secret strin
}
func newHydraIntegration(t *testing.T, remote *string, subject *string, claims *idTokenClaims, scope *[]string, addr string) (*http.Server, string) {
router := httprouter.New()
router := http.NewServeMux()
type p struct {
Subject string `json:"subject,omitempty"`
@ -164,7 +164,7 @@ func newHydraIntegration(t *testing.T, remote *string, subject *string, claims *
http.Redirect(w, r, response.RedirectTo, http.StatusSeeOther)
}
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
require.NotEmpty(t, *remote)
require.NotEmpty(t, *subject)
@ -179,7 +179,7 @@ func newHydraIntegration(t *testing.T, remote *string, subject *string, claims *
do(w, r, href, &b)
})
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
require.NotEmpty(t, *remote)
require.NotNil(t, *scope)

View File

@ -402,9 +402,9 @@ func TestSettingsStrategy(t *testing.T) {
t.Run("case=should not be able to link a connection already linked by another identity", func(t *testing.T) {
// While this theoretically allows for account enumeration - because we see an error indicator if an
// oidc connection is being linked that exists already - it would require the attacker to already
// OIDC connection is being linked that exists already - it would require the attacker to already
// have control over the social profile, in which case account enumeration is the least of our worries.
// Instead of using the oidc profile for enumeration, the attacker would use it for account takeover.
// Instead of using the OIDC profile for enumeration, the attacker would use it for account takeover.
// This is the multiuser login id for google
subject = "hackerman+multiuser+" + testID

View File

@ -397,7 +397,7 @@ func TestStrategy(t *testing.T) {
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, routerP.Router, req)
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, routerP, req)
assert.Contains(t, res.Request.URL.String(), ts.URL+login.RouteSubmitFlow)
assert.Equal(t, text.NewErrorValidationLoginNoStrategyFound().Text, gjson.GetBytes(actual, "ui.messages.0.text").String())
})
@ -1677,6 +1677,7 @@ func TestStrategy(t *testing.T) {
subject = email2
t.Run("step=should fail login if existing identity identifier doesn't match", func(t *testing.T) {
require.NotNil(t, linkingLoginFlow.ID)
require.NotEmpty(t, linkingLoginFlow.ID)
res, body := loginWithOIDC(t, client, uuid.Must(uuid.FromString(linkingLoginFlow.ID)), "valid")
assertUIError(t, res, body, "Linked credentials do not match.")
})

View File

@ -155,7 +155,7 @@ func TestCompleteLogin(t *testing.T) {
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
assert.Contains(t, res.Request.URL.String(), publicTS.URL+login.RouteSubmitFlow)
assert.Equal(t, text.NewErrorValidationLoginNoStrategyFound().Text, gjson.GetBytes(actual, "ui.messages.0.text").String())
})

View File

@ -15,7 +15,6 @@ import (
"github.com/ory/kratos/x/nosurfx"
"github.com/julienschmidt/httprouter"
"github.com/tidwall/gjson"
"github.com/urfave/negroni"
"golang.org/x/oauth2"
@ -53,7 +52,7 @@ func TestOAuth2Provider(t *testing.T) {
errTS := testhelpers.NewErrorTestServer(t, reg)
redirTS := testhelpers.NewRedirSessionEchoTS(t, reg)
router.GET("/login-ts", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /login-ts", func(w http.ResponseWriter, r *http.Request) {
t.Log("[loginTS] navigated to the login ui")
c := r.Context().Value(TestUIConfig).(*testConfig)
*c.callTrace = append(*c.callTrace, LoginUI)
@ -115,7 +114,7 @@ func TestOAuth2Provider(t *testing.T) {
}
})
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
router.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
t.Log("[consentTS] navigated to the consent ui")
c := r.Context().Value(TestUIConfig).(*testConfig)
*c.callTrace = append(*c.callTrace, Consent)

View File

@ -15,7 +15,6 @@ import (
"golang.org/x/oauth2"
"github.com/julienschmidt/httprouter"
"github.com/urfave/negroni"
hydraclientgo "github.com/ory/hydra-client-go/v2"
@ -50,7 +49,7 @@ func TestOAuth2ProviderRegistration(t *testing.T) {
TestOAuthClientState contextKey = "test-oauth-client-state"
)
router.GET("/login-ts", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /login-ts", func(w http.ResponseWriter, r *http.Request) {
t.Log("[loginTS] navigated to the login ui")
c := r.Context().Value(TestUIConfig).(*testConfig)
*c.callTrace = append(*c.callTrace, LoginUI)
@ -76,7 +75,7 @@ func TestOAuth2ProviderRegistration(t *testing.T) {
}
})
router.GET("/registration-ts", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /registration-ts", func(w http.ResponseWriter, r *http.Request) {
t.Log("[registrationTS] navigated to the registration ui")
c := r.Context().Value(TestUIConfig).(*testConfig)
*c.callTrace = append(*c.callTrace, RegistrationUI)
@ -144,7 +143,7 @@ func TestOAuth2ProviderRegistration(t *testing.T) {
}
})
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
router.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
t.Log("[consentTS] navigated to the consent ui")
c := r.Context().Value(TestUIConfig).(*testConfig)
*c.callTrace = append(*c.callTrace, Consent)

View File

@ -34,7 +34,6 @@ import (
"github.com/ory/kratos/corpx"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@ -512,8 +511,8 @@ func TestStrategyTraits(t *testing.T) {
setPrivileged(t)
var returned bool
router := httprouter.New()
router.GET("/return-ts", func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /return-ts", func(w http.ResponseWriter, r *http.Request) {
returned = true
})
rts := httptest.NewServer(router)

View File

@ -6,13 +6,11 @@ package session
import (
"net/http"
"github.com/julienschmidt/httprouter"
"github.com/ory/herodot"
)
func RespondWithJSONErrorOnAuthenticated(h herodot.Writer, err error) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func RespondWithJSONErrorOnAuthenticated(h herodot.Writer, err error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.WriteError(w, r, err)
}
}

View File

@ -21,7 +21,6 @@ import (
"github.com/ory/x/pointerx"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
"github.com/ory/x/decoderx"
@ -66,12 +65,12 @@ const (
RouteCollection = "/sessions"
RouteExchangeCodeForSessionToken = RouteCollection + "/token-exchange" // #nosec G101
RouteWhoami = RouteCollection + "/whoami"
RouteSession = RouteCollection + "/:id"
RouteSession = RouteCollection + "/{id}"
)
const (
AdminRouteIdentity = "/identities"
AdminRouteIdentitiesSessions = AdminRouteIdentity + "/:id/sessions"
AdminRouteIdentitiesSessions = AdminRouteIdentity + "/{id}/sessions"
AdminRouteSessionExtendId = RouteSession + "/extend"
)
@ -212,7 +211,7 @@ type toSession struct {
// 401: errorGeneric
// 403: errorGeneric
// default: errorGeneric
func (h *Handler) whoami(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) whoami(w http.ResponseWriter, r *http.Request) {
ctx, span := h.r.Tracer(r.Context()).Tracer().Start(r.Context(), "sessions.Handler.whoami")
defer span.End()
@ -307,8 +306,8 @@ type deleteIdentitySessions struct {
// 401: errorGeneric
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) deleteIdentitySessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
iID, err := uuid.FromString(ps.ByName("id"))
func (h *Handler) deleteIdentitySessions(w http.ResponseWriter, r *http.Request) {
iID, err := uuid.FromString(r.PathValue("id"))
if err != nil {
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID"))
return
@ -386,7 +385,7 @@ type listSessionsResponse struct {
// 200: listSessions
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) adminListSessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) adminListSessions(w http.ResponseWriter, r *http.Request) {
activeRaw := r.URL.Query().Get("active")
activeBool, err := strconv.ParseBool(activeRaw)
if activeRaw != "" && err != nil {
@ -470,14 +469,14 @@ type getSession struct {
// 200: session
// 400: errorGeneric
// default: errorGeneric
func (h *Handler) getSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if ps.ByName("id") == "whoami" {
func (h *Handler) getSession(w http.ResponseWriter, r *http.Request) {
if r.PathValue("id") == "whoami" {
// for /admin/sessions/whoami redirect to the public route
redir.RedirectToPublicRoute(h.r)(w, r, ps)
redir.RedirectToPublicRoute(h.r)(w, r)
return
}
sID, err := uuid.FromString(ps.ByName("id"))
sID, err := uuid.FromString(r.PathValue("id"))
if err != nil {
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID"))
return
@ -539,8 +538,8 @@ type disableSession struct {
// 400: errorGeneric
// 401: errorGeneric
// default: errorGeneric
func (h *Handler) disableSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
sID, err := uuid.FromString(ps.ByName("id"))
func (h *Handler) disableSession(w http.ResponseWriter, r *http.Request) {
sID, err := uuid.FromString(r.PathValue("id"))
if err != nil {
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID"))
return
@ -605,8 +604,8 @@ type listIdentitySessionsResponse struct {
// 400: errorGeneric
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) listIdentitySessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
iID, err := uuid.FromString(ps.ByName("id"))
func (h *Handler) listIdentitySessions(w http.ResponseWriter, r *http.Request) {
iID, err := uuid.FromString(r.PathValue("id"))
if err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID")))
return
@ -679,7 +678,7 @@ type disableMyOtherSessions struct {
// 400: errorGeneric
// 401: errorGeneric
// default: errorGeneric
func (h *Handler) deleteMySessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) deleteMySessions(w http.ResponseWriter, r *http.Request) {
s, err := h.r.SessionManager().FetchFromRequest(r.Context(), r)
if err != nil {
h.r.Audit().WithRequest(r).WithError(err).Info("No valid session cookie found.")
@ -738,11 +737,11 @@ type disableMySession struct {
// 400: errorGeneric
// 401: errorGeneric
// default: errorGeneric
func (h *Handler) deleteMySession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
sid := ps.ByName("id")
func (h *Handler) deleteMySession(w http.ResponseWriter, r *http.Request) {
sid := r.PathValue("id")
if sid == "whoami" {
// Special case where we actually want to handle the whoami endpoint.
h.whoami(w, r, ps)
h.whoami(w, r)
return
}
@ -822,7 +821,7 @@ type listMySessionsResponse struct {
// 400: errorGeneric
// 401: errorGeneric
// default: errorGeneric
func (h *Handler) listMySessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) listMySessions(w http.ResponseWriter, r *http.Request) {
s, err := h.r.SessionManager().FetchFromRequest(r.Context(), r)
if err != nil {
h.r.Audit().WithRequest(r).WithError(err).Info("No valid session cookie found.")
@ -860,13 +859,13 @@ const (
sessionInContextKey sessionInContext = iota
)
func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) IsAuthenticated(wrap http.HandlerFunc, onUnauthenticated http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
sess, err := h.r.SessionManager().FetchFromRequest(ctx, r)
if err != nil {
if onUnauthenticated != nil {
onUnauthenticated(w, r, ps)
onUnauthenticated(w, r)
return
}
@ -874,7 +873,7 @@ func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated http
return
}
wrap(w, r.WithContext(context.WithValue(ctx, sessionInContextKey, sess)), ps)
wrap(w, r.WithContext(context.WithValue(ctx, sessionInContextKey, sess)))
}
}
@ -917,8 +916,8 @@ type extendSession struct {
// 400: errorGeneric
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
id, err := uuid.FromString(ps.ByName("id"))
func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request) {
id, err := uuid.FromString(r.PathValue("id"))
if err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID")))
return
@ -945,11 +944,11 @@ func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request, ps
h.r.Writer().Write(w, r, s)
}
func (h *Handler) IsNotAuthenticated(wrap httprouter.Handle, onAuthenticated httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (h *Handler) IsNotAuthenticated(wrap http.HandlerFunc, onAuthenticated http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if _, err := h.r.SessionManager().FetchFromRequest(r.Context(), r); err != nil {
if e := new(ErrNoActiveSessionFound); errors.As(err, &e) {
wrap(w, r, ps)
wrap(w, r)
return
}
h.r.Writer().WriteError(w, r, err)
@ -957,7 +956,7 @@ func (h *Handler) IsNotAuthenticated(wrap httprouter.Handle, onAuthenticated htt
}
if onAuthenticated != nil {
onAuthenticated(w, r, ps)
onAuthenticated(w, r)
return
}
@ -965,8 +964,8 @@ func (h *Handler) IsNotAuthenticated(wrap httprouter.Handle, onAuthenticated htt
}
}
func RedirectOnAuthenticated(d interface{ config.Provider }) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func RedirectOnAuthenticated(d interface{ config.Provider }) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
returnTo, err := redir.SecureRedirectTo(r, d.Config().SelfServiceBrowserDefaultReturnTo(ctx), redir.SecureRedirectAllowSelfServiceURLs(d.Config().SelfPublicURL(ctx)))
if err != nil {
@ -978,14 +977,14 @@ func RedirectOnAuthenticated(d interface{ config.Provider }) httprouter.Handle {
}
}
func RedirectOnUnauthenticated(to string) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func RedirectOnUnauthenticated(to string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, to, http.StatusFound)
}
}
func RespondWitherrorGenericOnAuthenticated(h herodot.Writer, err error) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func RespondWitherrorGenericOnAuthenticated(h herodot.Writer, err error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.WriteError(w, r, err)
}
}
@ -1048,7 +1047,7 @@ type CodeExchangeResponse struct {
// 404: errorGeneric
// 410: errorGeneric
// default: errorGeneric
func (h *Handler) exchangeCode(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func (h *Handler) exchangeCode(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
initCode = r.URL.Query().Get("init_code")

View File

@ -33,7 +33,6 @@ import (
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/sqlcon"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -50,8 +49,8 @@ func init() {
corpx.RegisterFakes()
}
func send(code int) httprouter.Handle {
return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
func send(code int) http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(code)
}
}
@ -530,7 +529,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
}{
{
description: "expand Identity",
expand: "/?expand=Identity",
expand: "?expand=Identity",
expectedIdentityId: s.Identity.ID.String(),
expectedDevices: 0,
},
@ -859,9 +858,9 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
// we limit the scope of the channels, so you cannot accidentally mess up a test case
ident := make(chan *identity.Identity, 1)
sess := make(chan *Session, 1)
r.GET("/set", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
r.GET("/set", func(w http.ResponseWriter, r *http.Request) {
h, s := testhelpers.MockSessionCreateHandlerWithIdentity(t, reg, <-ident)
h(w, r, ps)
h(w, r)
sess <- s
})

View File

@ -13,7 +13,6 @@ import (
"testing"
"time"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -222,17 +221,17 @@ func TestManagerHTTP(t *testing.T) {
var s *session.Session
rp := x.NewRouterPublic()
rp.GET("/session/revoke", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
rp.GET("/session/revoke", func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, reg.SessionManager().PurgeFromRequest(r.Context(), w, r))
w.WriteHeader(http.StatusOK)
})
rp.GET("/session/set", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
rp.GET("/session/set", func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, reg.SessionManager().UpsertAndIssueCookie(r.Context(), w, r, s))
w.WriteHeader(http.StatusOK)
})
rp.GET("/session/get", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
rp.GET("/session/get", func(w http.ResponseWriter, r *http.Request) {
sess, err := reg.SessionManager().FetchFromRequest(r.Context(), r)
if err != nil {
t.Logf("Got error on lookup: %s %T", err, errors.Unwrap(err))
@ -242,7 +241,7 @@ func TestManagerHTTP(t *testing.T) {
reg.Writer().Write(w, r, sess)
})
rp.GET("/session/get-middleware", reg.SessionHandler().IsAuthenticated(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
rp.GET("/session/get-middleware", reg.SessionHandler().IsAuthenticated(func(w http.ResponseWriter, r *http.Request) {
sess, err := reg.SessionManager().FetchFromRequestContext(r.Context(), r)
if err != nil {
t.Logf("Got error on lookup: %s %T", err, errors.Unwrap(err))
@ -311,7 +310,7 @@ func TestManagerHTTP(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySessionName, "")
})
rp.GET("/session/set/invalid", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
rp.GET("/session/set/invalid", func(w http.ResponseWriter, r *http.Request) {
require.Error(t, reg.SessionManager().UpsertAndIssueCookie(r.Context(), w, r, s))
w.WriteHeader(http.StatusInternalServerError)
})

View File

@ -5,7 +5,6 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/julienschmidt/httprouter v1.3.0
github.com/ory/hydra-client-go v1.7.4
github.com/ory/kratos-client-go v0.10.1
github.com/ory/x v0.0.722-0.20250620091013-eeb8bd14b65a

View File

@ -5,10 +5,9 @@ package main
import (
"fmt"
"log"
"net/http"
"github.com/julienschmidt/httprouter"
"github.com/ory/hydra-client-go/client"
"github.com/ory/hydra-client-go/client/admin"
"github.com/ory/hydra-client-go/models"
@ -18,12 +17,6 @@ import (
"github.com/ory/x/urlx"
)
func check(err error) {
if err != nil {
panic(err)
}
}
func checkReq(w http.ResponseWriter, err error) bool {
if err != nil {
http.Error(w, fmt.Sprintf("%+v", err), 500)
@ -33,16 +26,14 @@ func checkReq(w http.ResponseWriter, err error) bool {
}
func main() {
router := httprouter.New()
kratosPublicURL := urlx.ParseOrPanic(osx.GetenvDefault("KRATOS_PUBLIC_URL", "http://localhost:4433"))
adminURL := urlx.ParseOrPanic(osx.GetenvDefault("HYDRA_ADMIN_URL", "http://localhost:4445"))
hc := client.NewHTTPClientWithConfig(nil, &client.TransportConfig{Schemes: []string{adminURL.Scheme}, Host: adminURL.Host, BasePath: adminURL.Path})
router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
http.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`ok`))
})
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
http.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
res, err := hc.Admin.GetLoginRequest(admin.NewGetLoginRequestParams().
WithLoginChallenge(r.URL.Query().Get("login_challenge")))
if !checkReq(w, err) {
@ -73,7 +64,7 @@ func main() {
</html>`, challenge)
})
router.POST("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
http.HandleFunc("POST /login", func(w http.ResponseWriter, r *http.Request) {
if !checkReq(w, r.ParseForm()) {
return
}
@ -98,7 +89,7 @@ func main() {
http.Redirect(w, r, *res.Payload.RedirectTo, http.StatusFound)
})
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
http.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
res, err := hc.Admin.GetConsentRequest(admin.NewGetConsentRequestParams().
WithConsentChallenge(r.URL.Query().Get("consent_challenge")))
if !checkReq(w, err) {
@ -136,7 +127,7 @@ func main() {
</html>`, challenge, checkoxes)
})
router.POST("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
http.HandleFunc("POST /consent", func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if r.Form.Get("action") == "accept" {
kratosConfig := kratos.NewConfiguration()
@ -180,7 +171,6 @@ func main() {
})
addr := ":" + osx.GetenvDefault("PORT", "4746")
server := &http.Server{Addr: addr, Handler: router}
fmt.Printf("Starting web server at %s\n", addr)
check(server.ListenAndServe())
log.Fatal(http.ListenAndServe(addr, nil))
}

View File

@ -7,8 +7,6 @@ import (
"fmt"
"net/http"
"github.com/julienschmidt/httprouter"
client "github.com/ory/hydra-client-go/v2"
"github.com/ory/x/osx"
@ -31,7 +29,7 @@ func checkReq(w http.ResponseWriter, err error) bool {
}
func main() {
router := httprouter.New()
router := http.NewServeMux()
adminURL := urlx.ParseOrPanic(osx.GetenvDefault("HYDRA_ADMIN_URL", "http://localhost:4445"))
cfg := client.NewConfiguration()
@ -40,10 +38,10 @@ func main() {
}
hc := client.NewAPIClient(cfg)
router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`ok`))
})
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
res, _, err := hc.OAuth2Api.GetOAuth2LoginRequest(r.Context()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute()
if !checkReq(w, err) {
return
@ -77,7 +75,7 @@ func main() {
</html>`, challenge)
})
router.POST("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("POST /login", func(w http.ResponseWriter, r *http.Request) {
check(r.ParseForm())
remember := pointerx.Bool(r.Form.Get("remember") == "true")
if r.Form.Get("action") == "accept" {
@ -103,7 +101,7 @@ func main() {
http.Redirect(w, r, res.RedirectTo, http.StatusFound)
})
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
res, _, err := hc.OAuth2Api.GetOAuth2ConsentRequest(r.Context()).ConsentChallenge(r.URL.Query().
Get("consent_challenge")).Execute()
if !checkReq(w, err) {
@ -144,7 +142,7 @@ func main() {
</html>`, challenge, checkoxes)
})
router.POST("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("POST /consent", func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
remember := pointerx.Bool(r.Form.Get("remember") == "true")
if r.Form.Get("action") == "accept" {

View File

@ -1,13 +1,3 @@
module github.com/ory/mock
go 1.24.4
require (
github.com/julienschmidt/httprouter v1.3.0
github.com/ory/graceful v0.1.3
)
require (
github.com/pkg/errors v0.9.1 // indirect
github.com/stretchr/testify v1.7.0 // indirect
)

View File

@ -1,18 +0,0 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/ory/graceful v0.1.3 h1:FaeXcHZh168WzS+bqruqWEw/HgXWLdNv2nJ+fbhxbhc=
github.com/ory/graceful v0.1.3/go.mod h1:4zFz687IAF7oNHHiB586U4iL+/4aV09o/PYLE34t2bA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -11,10 +11,6 @@ import (
"net/http"
"os"
"sync"
"github.com/julienschmidt/httprouter"
"github.com/ory/graceful"
)
var (
@ -24,22 +20,16 @@ var (
func main() {
port := cmp.Or(os.Getenv("PORT"), "4471")
server := graceful.WithDefaults(&http.Server{Addr: fmt.Sprintf(":%s", port)})
register(server)
if err := graceful.Graceful(server.ListenAndServe, server.Shutdown); err != nil {
log.Fatalln(err)
}
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", port), nil))
}
func register(server *http.Server) {
router := httprouter.New()
router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func init() {
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("OK"))
})
router.GET("/documents/:id", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
id := ps.ByName("id")
http.HandleFunc("GET /documents/{id}", func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
documentsLock.RLock()
doc, ok := documents[id]
@ -52,10 +42,10 @@ func register(server *http.Server) {
}
})
router.PUT("/documents/:id", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
http.HandleFunc("PUT /documents/{id}", func(w http.ResponseWriter, r *http.Request) {
documentsLock.Lock()
defer documentsLock.Unlock()
id := ps.ByName("id")
id := r.PathValue("id")
body, err := io.ReadAll(r.Body)
if err != nil {
@ -66,14 +56,12 @@ func register(server *http.Server) {
documents[id] = body
})
router.DELETE("/documents/:id", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
http.HandleFunc("DELETE /documents/{id}", func(w http.ResponseWriter, r *http.Request) {
documentsLock.Lock()
defer documentsLock.Unlock()
id := ps.ByName("id")
id := r.PathValue("id")
delete(documents, id)
w.WriteHeader(http.StatusNoContent)
})
server.Handler = router
}

View File

@ -11,7 +11,7 @@ import (
"testing"
"github.com/gorilla/sessions"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -30,8 +30,8 @@ func TestSession(t *testing.T) {
assert.EqualValues(t, 78652871, cookie.Options.MaxAge, "we ensure the options are always copied correctly.")
}
router := httprouter.New()
router.GET("/set", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /set", func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, SessionPersistValues(w, r, s, sid, map[string]interface{}{
"string-1": "foo",
"string-2": "bar",
@ -55,7 +55,7 @@ func TestSession(t *testing.T) {
t.Run("case=GetString", func(t *testing.T) {
id := "get-string"
router.GET("/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /"+id, func(w http.ResponseWriter, r *http.Request) {
got, err := SessionGetString(r, s, sid, "string-1")
require.NoError(t, err)
assert.EqualValues(t, "foo", got)
@ -81,7 +81,7 @@ func TestSession(t *testing.T) {
t.Run("case=GetStringMultipleCookies", func(t *testing.T) {
id := "get-string-multiple"
router.GET("/set/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /set/"+id, func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, SessionPersistValues(w, r, s, sid, map[string]interface{}{
"multiple-string-1": "foo",
}))
@ -92,7 +92,7 @@ func TestSession(t *testing.T) {
w.WriteHeader(http.StatusNoContent)
})
router.GET("/get/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /get/"+id, func(w http.ResponseWriter, r *http.Request) {
got, err := SessionGetString(r, s, sid, "multiple-string-1")
require.NoError(t, err)
assert.EqualValues(t, "foo", got)
@ -122,7 +122,7 @@ func TestSession(t *testing.T) {
t.Run("case=GetStringOr", func(t *testing.T) {
id := "get-string-or"
router.GET("/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /"+id, func(w http.ResponseWriter, r *http.Request) {
assert.EqualValues(t, "foo", SessionGetStringOr(r, s, sid, "string-1", "baz"))
assert.EqualValues(t, "bar", SessionGetStringOr(r, s, sid, "string-2", "baz"))
assert.EqualValues(t, "", SessionGetStringOr(r, s, sid, "string-3", "baz"))
@ -189,14 +189,14 @@ func TestSession(t *testing.T) {
})
id := "session-unset"
router.GET("/"+id+"/unset", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /"+id+"/unset", func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, SessionUnset(w, r, s, sid))
w.WriteHeader(http.StatusNoContent)
cookie, _ := s.Get(r, sid)
assert.EqualValues(t, -1, cookie.Options.MaxAge, "we ensure the options are always copied correctly.")
})
router.GET("/"+id+"/get", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /"+id+"/get", func(w http.ResponseWriter, r *http.Request) {
require.Empty(t, SessionGetStringOr(r, s, sid, "string-1", ""))
require.Empty(t, SessionGetStringOr(r, s, sid, "string-2", ""))
require.Empty(t, SessionGetStringOr(r, s, sid, "string-3", ""))
@ -231,13 +231,13 @@ func TestSession(t *testing.T) {
})
id := "session-unset-key"
router.GET("/"+id+"/unset", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /"+id+"/unset", func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, SessionUnsetKey(w, r, s, sid, "string-1"))
w.WriteHeader(http.StatusNoContent)
isExpiryCorrect(t, r)
})
router.GET("/"+id+"/expect-unset", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /"+id+"/expect-unset", func(w http.ResponseWriter, r *http.Request) {
require.Empty(t, SessionGetStringOr(r, s, sid, "string-1", ""))
require.Empty(t, SessionGetStringOr(r, s, sid, "string-2", ""))
require.Empty(t, SessionGetStringOr(r, s, sid, "string-3", ""))
@ -246,7 +246,7 @@ func TestSession(t *testing.T) {
w.WriteHeader(http.StatusNoContent)
})
router.GET("/"+id+"/expect-one", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /"+id+"/expect-one", func(w http.ResponseWriter, r *http.Request) {
require.Empty(t, SessionGetStringOr(r, s, sid, "string-1", ""))
assert.EqualValues(t, "bar", SessionGetStringOr(r, s, sid, "string-2", "baz"))
assert.EqualValues(t, "", SessionGetStringOr(r, s, sid, "string-3", "baz"))

View File

@ -10,15 +10,14 @@ import (
"strings"
"testing"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func TestRedirectAdmin(t *testing.T) {
router := httprouter.New()
router.GET("/admin/identities", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /admin/identities", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("identities"))
})
n := negroni.New()

View File

@ -5,8 +5,6 @@ package x
import (
"net/http"
"github.com/julienschmidt/httprouter"
)
// NoCache adds `Cache-Control: private, no-cache, no-store, must-revalidate` to the response header.
@ -14,14 +12,6 @@ func NoCache(w http.ResponseWriter) {
w.Header().Set("Cache-Control", "private, no-cache, no-store, must-revalidate")
}
// NoCacheHandle wraps httprouter.Handle with `Cache-Control: private, no-cache, no-store, must-revalidate` headers.
func NoCacheHandle(handle httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
NoCache(w)
handle(w, r, ps)
}
}
// NoCacheHandlerFunc wraps http.HandlerFunc with `Cache-Control: private, no-cache, no-store, must-revalidate` headers.
func NoCacheHandlerFunc(handle http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {

View File

@ -8,14 +8,12 @@ import (
"path"
"strings"
"github.com/julienschmidt/httprouter"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/x"
)
func RedirectToAdminRoute(reg config.Provider) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func RedirectToAdminRoute(reg config.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
admin := reg.Config().SelfAdminURL(r.Context())
dest := *r.URL
@ -28,8 +26,8 @@ func RedirectToAdminRoute(reg config.Provider) httprouter.Handle {
}
}
func RedirectToPublicRoute(reg config.Provider) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
func RedirectToPublicRoute(reg config.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
public := reg.Config().SelfPublicURL(r.Context())
dest := *r.URL

View File

@ -15,7 +15,6 @@ import (
"github.com/ory/x/configx"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -38,13 +37,13 @@ func TestRedirectToPublicAdminRoute(t *testing.T) {
pub.POST("/privileged", redir.RedirectToAdminRoute(reg))
pub.POST("/admin/privileged", redir.RedirectToAdminRoute(reg))
adm.POST("/privileged", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
adm.POST("/privileged", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_, _ = w.Write(body)
})
adm.POST("/read", redir.RedirectToPublicRoute(reg))
pub.POST("/read", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
pub.POST("/read", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_, _ = w.Write(body)
})

View File

@ -14,7 +14,6 @@ import (
"github.com/ory/kratos/x/redir"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -31,14 +30,14 @@ func TestSecureContentNegotiationRedirection(t *testing.T) {
var jsonActual = json.RawMessage(`{"foo":"bar"}` + "\n")
writer := herodot.NewJSONWriter(nil)
router := httprouter.New()
router.GET("/redir", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router := http.NewServeMux()
router.HandleFunc("GET /redir", func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, redir.SecureContentNegotiationRedirection(w, r, jsonActual, x.RequestURL(r).String(), writer, conf))
})
router.GET("/default-return-to", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
router.HandleFunc("GET /default-return-to", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.GET("/return-to", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
router.HandleFunc("GET /return-to", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})

View File

@ -5,105 +5,161 @@ package x
import (
"net/http"
"net/http/httptest"
"path"
"github.com/julienschmidt/httprouter"
)
type RouterPublic struct {
*httprouter.Router
mux *http.ServeMux
}
func NewRouterPublic() *RouterPublic {
return &RouterPublic{
Router: httprouter.New(),
mux: http.NewServeMux(),
}
}
func (r *RouterPublic) GET(path string, handle httprouter.Handle) {
r.Handle("GET", path, NoCacheHandle(handle))
func (r *RouterPublic) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.mux.ServeHTTP(w, req)
}
func (r *RouterPublic) HEAD(path string, handle httprouter.Handle) {
r.Handle("HEAD", path, NoCacheHandle(handle))
func (r *RouterPublic) GET(path string, handler http.HandlerFunc) {
r.HandlerFunc("GET", path, handler)
}
func (r *RouterPublic) POST(path string, handle httprouter.Handle) {
r.Handle("POST", path, NoCacheHandle(handle))
func (r *RouterPublic) HEAD(path string, handler http.HandlerFunc) {
r.HandlerFunc("HEAD", path, handler)
}
func (r *RouterPublic) PUT(path string, handle httprouter.Handle) {
r.Handle("PUT", path, NoCacheHandle(handle))
func (r *RouterPublic) POST(path string, handler http.HandlerFunc) {
r.HandlerFunc("POST", path, handler)
}
func (r *RouterPublic) PATCH(path string, handle httprouter.Handle) {
r.Handle("PATCH", path, NoCacheHandle(handle))
func (r *RouterPublic) PUT(path string, handler http.HandlerFunc) {
r.HandlerFunc("PUT", path, handler)
}
func (r *RouterPublic) DELETE(path string, handle httprouter.Handle) {
r.Handle("DELETE", path, NoCacheHandle(handle))
func (r *RouterPublic) PATCH(path string, handler http.HandlerFunc) {
r.HandlerFunc("PATCH", path, handler)
}
func (r *RouterPublic) Handle(method, path string, handle httprouter.Handle) {
r.Router.Handle(method, path, NoCacheHandle(handle))
func (r *RouterPublic) DELETE(path string, handler http.HandlerFunc) {
r.HandlerFunc("DELETE", path, handler)
}
func (r *RouterPublic) HandlerFunc(method, path string, handler http.HandlerFunc) {
r.Router.HandlerFunc(method, path, NoCacheHandlerFunc(handler))
func (r *RouterPublic) Handle(method, route string, handle http.HandlerFunc) {
for _, pattern := range []string{
method + " " + path.Join(route),
method + " " + path.Join(route, "{$}"),
} {
r.mux.HandleFunc(pattern, func(w http.ResponseWriter, req *http.Request) {
NoCache(w)
handle(w, req)
})
}
}
func (r *RouterPublic) HandlerFunc(method, route string, handler http.HandlerFunc) {
for _, pattern := range []string{
method + " " + path.Join(route),
method + " " + path.Join(route, "{$}"),
} {
r.mux.HandleFunc(pattern, NoCacheHandlerFunc(handler))
}
}
func (r *RouterPublic) HandleFunc(pattern string, handler http.HandlerFunc) {
for _, pattern := range []string{
path.Join(pattern),
path.Join(pattern, "{$}"),
} {
r.mux.HandleFunc(pattern, NoCacheHandlerFunc(handler))
}
}
func (r *RouterPublic) Handler(method, path string, handler http.Handler) {
r.Router.Handler(method, path, NoCacheHandler(handler))
route := method + " " + path
r.mux.Handle(route, NoCacheHandler(handler))
}
type RouterAdmin struct {
*httprouter.Router
func (r *RouterPublic) HasRoute(method, path string) bool {
_, pattern := r.mux.Handler(httptest.NewRequest(method, path, nil))
return pattern != ""
}
type RouterAdmin struct{ mux *http.ServeMux }
func NewRouterAdmin() *RouterAdmin {
return &RouterAdmin{
Router: httprouter.New(),
mux: http.NewServeMux(),
}
}
func (r *RouterAdmin) GET(publicPath string, handle httprouter.Handle) {
r.Router.GET(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
func (r *RouterAdmin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.mux.ServeHTTP(w, req)
}
func (r *RouterAdmin) HEAD(publicPath string, handle httprouter.Handle) {
r.Router.HEAD(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
func (r *RouterAdmin) GET(publicPath string, handler http.HandlerFunc) {
r.HandlerFunc("GET", publicPath, handler)
}
func (r *RouterAdmin) POST(publicPath string, handle httprouter.Handle) {
r.Router.POST(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
func (r *RouterAdmin) HEAD(publicPath string, handler http.HandlerFunc) {
r.HandlerFunc("HEAD", publicPath, handler)
}
func (r *RouterAdmin) PUT(publicPath string, handle httprouter.Handle) {
r.Router.PUT(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
func (r *RouterAdmin) POST(publicPath string, handler http.HandlerFunc) {
r.HandlerFunc("POST", publicPath, handler)
}
func (r *RouterAdmin) PATCH(publicPath string, handle httprouter.Handle) {
r.Router.PATCH(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
func (r *RouterAdmin) PUT(publicPath string, handler http.HandlerFunc) {
r.HandlerFunc("PUT", publicPath, handler)
}
func (r *RouterAdmin) DELETE(publicPath string, handle httprouter.Handle) {
r.Router.DELETE(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
func (r *RouterAdmin) PATCH(publicPath string, handler http.HandlerFunc) {
r.HandlerFunc("PATCH", publicPath, handler)
}
func (r *RouterAdmin) Handle(method, publicPath string, handle httprouter.Handle) {
r.Router.Handle(method, path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
func (r *RouterAdmin) DELETE(publicPath string, handler http.HandlerFunc) {
r.HandlerFunc("DELETE", publicPath, handler)
}
func (r *RouterAdmin) Handle(method, publicPath string, handle http.HandlerFunc) {
for _, pattern := range []string{
method + " " + path.Join(AdminPrefix, publicPath),
method + " " + path.Join(AdminPrefix, publicPath, "{$}"),
} {
r.mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
NoCache(w)
handle(w, r)
})
}
}
func (r *RouterAdmin) HandlerFunc(method, publicPath string, handler http.HandlerFunc) {
r.Router.HandlerFunc(method, path.Join(AdminPrefix, publicPath), NoCacheHandlerFunc(handler))
for _, pattern := range []string{
method + " " + path.Join(AdminPrefix, publicPath),
method + " " + path.Join(AdminPrefix, publicPath, "{$}"),
} {
r.mux.HandleFunc(pattern, NoCacheHandlerFunc(handler))
}
}
func (r *RouterAdmin) Handler(method, publicPath string, handler http.Handler) {
r.Router.Handler(method, path.Join(AdminPrefix, publicPath), NoCacheHandler(handler))
for _, pattern := range []string{
method + " " + path.Join(AdminPrefix, publicPath),
method + " " + path.Join(AdminPrefix, publicPath, "{$}"),
} {
r.mux.Handle(pattern, NoCacheHandler(handler))
}
}
func (r *RouterAdmin) Lookup(method, publicPath string) {
r.Router.Lookup(method, path.Join(AdminPrefix, publicPath))
func (r *RouterAdmin) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
for _, p := range []string{
path.Join(pattern),
path.Join(pattern, "{$}"),
} {
r.mux.HandleFunc(p, NoCacheHandlerFunc(handler))
}
}
type HandlerRegistrar interface {

View File

@ -8,7 +8,7 @@ import (
"testing"
"github.com/gobuffalo/httptest"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -23,19 +23,19 @@ func TestCacheHandling(t *testing.T) {
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
router.GET("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.DELETE("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("DELETE /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.POST("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("POST /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.PUT("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("PUT /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.PATCH("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("PATCH /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
@ -52,19 +52,19 @@ func TestAdminPrefix(t *testing.T) {
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
router.GET("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("GET /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.DELETE("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("DELETE /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.POST("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("POST /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.PUT("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("PUT /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
router.PATCH("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
router.HandleFunc("PATCH /foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})

View File

@ -7,8 +7,6 @@ import (
_ "embed"
"net/http"
"github.com/julienschmidt/httprouter"
"github.com/ory/kratos/x"
)
@ -45,8 +43,8 @@ type webAuthnJavaScript string
// Responses:
// 200: webAuthnJavaScript
func RegisterWebauthnRoute(r *x.RouterPublic) {
if handle, _, _ := r.Lookup("GET", ScriptURL); handle == nil {
r.GET(ScriptURL, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
if !r.HasRoute("GET", ScriptURL) {
r.GET(ScriptURL, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/javascript; charset=UTF-8")
_, _ = w.Write(jsOnLoad)
})