mirror of https://github.com/ory/kratos
feat: use stdlib HTTP router in Kratos
GitOrigin-RevId: 799513e99acbf43a05fe3113ffda45d2fff2a9e0
This commit is contained in:
parent
9e94951d9e
commit
acfa6ef2ec
|
|
@ -7,6 +7,7 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/urfave/negroni"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
|
|
@ -14,9 +15,9 @@ import (
|
|||
|
||||
"github.com/ory/graceful"
|
||||
"github.com/ory/kratos/driver"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/x/configx"
|
||||
"github.com/ory/x/otelx"
|
||||
"github.com/ory/x/prometheusx"
|
||||
"github.com/ory/x/reqlog"
|
||||
)
|
||||
|
||||
|
|
@ -58,9 +59,9 @@ func ServeMetrics(ctx context.Context, r driver.Registry, port int) error {
|
|||
l := r.Logger()
|
||||
n := negroni.New()
|
||||
|
||||
router := x.NewRouterAdmin()
|
||||
router := http.NewServeMux()
|
||||
|
||||
r.MetricsHandler().SetRoutes(router.Router)
|
||||
router.Handle(prometheusx.MetricsPrometheusPath, promhttp.Handler())
|
||||
n.Use(reqlog.NewMiddlewareFromLogger(l, "admin#"+cfg.BaseURL.String()))
|
||||
n.Use(r.PrometheusManager())
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ func servePublic(ctx context.Context, r *driver.RegistryDefault, cmd *cobra.Comm
|
|||
csrf.DisablePath(prometheus.MetricsPrometheusPath)
|
||||
|
||||
r.RegisterPublicRoutes(ctx, router)
|
||||
r.PrometheusManager().RegisterRouter(router.Router)
|
||||
|
||||
var handler http.Handler = n
|
||||
if tracer := r.Tracer(ctx); tracer.IsLoaded() {
|
||||
|
|
@ -163,7 +162,6 @@ func serveAdmin(ctx context.Context, r *driver.RegistryDefault, cmd *cobra.Comma
|
|||
|
||||
router := x.NewRouterAdmin()
|
||||
r.RegisterAdminRoutes(ctx, router)
|
||||
r.PrometheusManager().RegisterRouter(router.Router)
|
||||
|
||||
n.UseHandler(http.MaxBytesHandler(router, 5*1024*1024 /* 5 MB */))
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
|
||||
"github.com/ory/x/ioutilx"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -56,22 +55,22 @@ func TestManager(t *testing.T) {
|
|||
|
||||
newServer := func(t *testing.T, p continuity.Manager, tc *persisterTestCase) *httptest.Server {
|
||||
writer := herodot.NewJSONWriter(logrusx.New("", ""))
|
||||
router := httprouter.New()
|
||||
router.PUT("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
if err := p.Pause(r.Context(), w, r, ps.ByName("name"), tc.ro...); err != nil {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("PUT /{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := p.Pause(r.Context(), w, r, r.PathValue("name"), tc.ro...); err != nil {
|
||||
writer.WriteError(w, r, err)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
router.POST("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
if err := p.Pause(r.Context(), w, r, ps.ByName("name"), tc.ro...); err != nil {
|
||||
router.HandleFunc("POST /{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := p.Pause(r.Context(), w, r, r.PathValue("name"), tc.ro...); err != nil {
|
||||
writer.WriteError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
c, err := p.Continue(r.Context(), w, r, ps.ByName("name"), tc.wo...)
|
||||
c, err := p.Continue(r.Context(), w, r, r.PathValue("name"), tc.wo...)
|
||||
if err != nil {
|
||||
writer.WriteError(w, r, err)
|
||||
return
|
||||
|
|
@ -79,8 +78,8 @@ func TestManager(t *testing.T) {
|
|||
writer.Write(w, r, c)
|
||||
})
|
||||
|
||||
router.GET("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
c, err := p.Continue(r.Context(), w, r, ps.ByName("name"), tc.ro...)
|
||||
router.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := p.Continue(r.Context(), w, r, r.PathValue("name"), tc.ro...)
|
||||
if err != nil {
|
||||
writer.WriteError(w, r, err)
|
||||
return
|
||||
|
|
@ -88,8 +87,8 @@ func TestManager(t *testing.T) {
|
|||
writer.Write(w, r, c)
|
||||
})
|
||||
|
||||
router.DELETE("/:name", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
err := p.Abort(r.Context(), w, r, ps.ByName("name"))
|
||||
router.HandleFunc("DELETE /{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := p.Abort(r.Context(), w, r, r.PathValue("name"))
|
||||
if err != nil {
|
||||
writer.WriteError(w, r, err)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ import (
|
|||
"github.com/ory/x/pagination/keysetpagination"
|
||||
"github.com/ory/x/pagination/migrationpagination"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/x"
|
||||
)
|
||||
|
|
@ -25,7 +23,7 @@ import (
|
|||
const (
|
||||
AdminRouteCourier = "/courier"
|
||||
AdminRouteListMessages = AdminRouteCourier + "/messages"
|
||||
AdminRouteGetMessage = AdminRouteCourier + "/messages/:msgID"
|
||||
AdminRouteGetMessage = AdminRouteCourier + "/messages/{msgID}"
|
||||
)
|
||||
|
||||
type (
|
||||
|
|
@ -113,7 +111,7 @@ type ListCourierMessagesParameters struct {
|
|||
// 200: listCourierMessages
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) listCourierMessages(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) listCourierMessages(w http.ResponseWriter, r *http.Request) {
|
||||
filter, paginator, err := parseMessagesFilter(r)
|
||||
if err != nil {
|
||||
h.r.Writer().WriteErrorCode(w, r, http.StatusBadRequest, err)
|
||||
|
|
@ -193,10 +191,10 @@ type getCourierMessage struct {
|
|||
// 200: message
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) getCourierMessage(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
msgID, err := uuid.FromString(ps.ByName("msgID"))
|
||||
func (h *Handler) getCourierMessage(w http.ResponseWriter, r *http.Request) {
|
||||
msgID, err := uuid.FromString(r.PathValue("msgID"))
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebugf("could not parse parameter {id} as UUID, got %s", ps.ByName("id")))
|
||||
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebugf("could not parse parameter {id} as UUID, got %s", r.PathValue("id")))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/kratos/courier/template"
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/internal"
|
||||
|
|
@ -145,11 +143,11 @@ func TestLoadTextTemplate(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("case=http resource", func(t *testing.T) {
|
||||
router := httprouter.New()
|
||||
router.Handle("GET", "/html", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /html", func(writer http.ResponseWriter, request *http.Request) {
|
||||
http.ServeFile(writer, request, "courier/builtin/templates/test_stub/email.body.html.en_US.gotmpl")
|
||||
})
|
||||
router.Handle("GET", "/plaintext", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
|
||||
router.HandleFunc("GET /plaintext", func(writer http.ResponseWriter, request *http.Request) {
|
||||
http.ServeFile(writer, request, "courier/builtin/templates/test_stub/email.body.plaintext.gotmpl")
|
||||
})
|
||||
ts := httptest.NewServer(router)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import (
|
|||
|
||||
"github.com/ory/kratos/courier/template/email"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
|
|
@ -94,9 +93,9 @@ func TestRemoteTemplates(t *testing.T, basePath string, tmplType template.Templa
|
|||
|
||||
t.Run("case=http resource", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
router := httprouter.New()
|
||||
router.Handle("GET", "/:filename", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
|
||||
http.ServeFile(writer, request, path.Join(basePath, params.ByName("filename")))
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /{filename}", func(writer http.ResponseWriter, request *http.Request) {
|
||||
http.ServeFile(writer, request, path.Join(basePath, request.PathValue("filename")))
|
||||
})
|
||||
ts := httptest.NewServer(router)
|
||||
defer ts.Close()
|
||||
|
|
|
|||
|
|
@ -8,16 +8,15 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/knadh/koanf/parsers/json"
|
||||
)
|
||||
|
||||
type router interface {
|
||||
GET(path string, handle httprouter.Handle)
|
||||
HandlerFunc(method, path string, handler http.HandlerFunc)
|
||||
}
|
||||
|
||||
func NewConfigHashHandler(c Provider, router router) {
|
||||
router.GET("/health/config", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandlerFunc("GET", "/health/config", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
if revision := c.Config().GetProvider(r.Context()).String("revision"); len(revision) > 0 {
|
||||
_, _ = fmt.Fprintf(w, "%s", revision)
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ import (
|
|||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/internal"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/x/contextx"
|
||||
)
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ func (c *configProvider) Config() *config.Config {
|
|||
func TestNewConfigHashHandler(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := internal.NewConfigurationWithDefaults(t)
|
||||
router := httprouter.New()
|
||||
router := x.NewRouterPublic()
|
||||
config.NewConfigHashHandler(&configProvider{cfg: cfg}, router)
|
||||
ts := contextx.NewConfigurableTestServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ func (m *RegistryDefault) RegisterPublicRoutes(ctx context.Context, router *x.Ro
|
|||
m.VerificationHandler().RegisterPublicRoutes(router)
|
||||
m.AllVerificationStrategies().RegisterPublicRoutes(router)
|
||||
|
||||
m.HealthHandler(ctx).SetHealthRoutes(router.Router, false)
|
||||
m.HealthHandler(ctx).SetHealthRoutes(router, false)
|
||||
}
|
||||
|
||||
func (m *RegistryDefault) RegisterAdminRoutes(ctx context.Context, router *x.RouterAdmin) {
|
||||
|
|
@ -222,7 +222,7 @@ func (m *RegistryDefault) RegisterAdminRoutes(ctx context.Context, router *x.Rou
|
|||
|
||||
m.HealthHandler(ctx).SetHealthRoutes(router, true)
|
||||
m.HealthHandler(ctx).SetVersionRoutes(router)
|
||||
m.MetricsHandler().SetRoutes(router)
|
||||
m.MetricsHandler().SetMuxRoutes(router)
|
||||
|
||||
config.NewConfigHashHandler(m, router)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ import (
|
|||
|
||||
"github.com/ory/herodot"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/x/decoderx"
|
||||
|
|
@ -44,8 +43,8 @@ import (
|
|||
|
||||
const (
|
||||
RouteCollection = "/identities"
|
||||
RouteItem = RouteCollection + "/:id"
|
||||
RouteCredentialItem = RouteItem + "/credentials/:type"
|
||||
RouteItem = RouteCollection + "/{id}"
|
||||
RouteCredentialItem = RouteItem + "/credentials/{type}"
|
||||
|
||||
BatchPatchIdentitiesLimit = 1000
|
||||
BatchPatchIdentitiesWithPasswordLimit = 200
|
||||
|
|
@ -266,7 +265,7 @@ func parseListIdentitiesParameters(r *http.Request) (params ListIdentityParamete
|
|||
// Responses:
|
||||
// 200: listIdentities
|
||||
// default: errorGeneric
|
||||
func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) list(w http.ResponseWriter, r *http.Request) {
|
||||
params, err := parseListIdentitiesParameters(r)
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
|
|
@ -355,8 +354,8 @@ type getIdentity struct {
|
|||
// 200: identity
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) get(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
i, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), x.ParseUUID(ps.ByName("id")))
|
||||
func (h *Handler) get(w http.ResponseWriter, r *http.Request) {
|
||||
i, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), x.ParseUUID(r.PathValue("id")))
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
return
|
||||
|
|
@ -577,7 +576,7 @@ type AdminCreateIdentityImportCredentialsSAMLProvider struct {
|
|||
// 400: errorGeneric
|
||||
// 409: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) create(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) create(w http.ResponseWriter, r *http.Request) {
|
||||
var cr CreateIdentityBody
|
||||
if err := jsonx.NewStrictDecoder(r.Body).Decode(&cr); err != nil {
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error())))
|
||||
|
|
@ -688,7 +687,7 @@ func (h *Handler) identityFromCreateIdentityBody(ctx context.Context, cr *Create
|
|||
// 400: errorGeneric
|
||||
// 409: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) batchPatchIdentities(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) batchPatchIdentities(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
req BatchPatchIdentitiesBody
|
||||
res batchPatchIdentitiesResponse
|
||||
|
|
@ -845,7 +844,7 @@ type UpdateIdentityBody struct {
|
|||
// 404: errorGeneric
|
||||
// 409: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) update(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) update(w http.ResponseWriter, r *http.Request) {
|
||||
var ur UpdateIdentityBody
|
||||
if err := h.dx.Decode(r, &ur,
|
||||
decoderx.HTTPJSONDecoder()); err != nil {
|
||||
|
|
@ -853,7 +852,7 @@ func (h *Handler) update(w http.ResponseWriter, r *http.Request, ps httprouter.P
|
|||
return
|
||||
}
|
||||
|
||||
id := x.ParseUUID(ps.ByName("id"))
|
||||
id := x.ParseUUID(r.PathValue("id"))
|
||||
identity, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), id)
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
|
|
@ -933,8 +932,8 @@ type deleteIdentity struct {
|
|||
// 204: emptyResponse
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) delete(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
if err := h.r.PrivilegedIdentityPool().DeleteIdentity(r.Context(), x.ParseUUID(ps.ByName("id"))); err != nil {
|
||||
func (h *Handler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
if err := h.r.PrivilegedIdentityPool().DeleteIdentity(r.Context(), x.ParseUUID(r.PathValue("id"))); err != nil {
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
|
@ -983,14 +982,14 @@ type patchIdentity struct {
|
|||
// 404: errorGeneric
|
||||
// 409: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) patch(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) patch(w http.ResponseWriter, r *http.Request) {
|
||||
requestBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
id := x.ParseUUID(ps.ByName("id"))
|
||||
id := x.ParseUUID(r.PathValue("id"))
|
||||
identity, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), id)
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
|
|
@ -1095,17 +1094,17 @@ type _ struct {
|
|||
// 204: emptyResponse
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) deleteIdentityCredentials(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) deleteIdentityCredentials(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
identity, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(ctx, x.ParseUUID(ps.ByName("id")))
|
||||
identity, err := h.r.PrivilegedIdentityPool().GetIdentityConfidential(ctx, x.ParseUUID(r.PathValue("id")))
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
cred, ok := identity.GetCredentials(CredentialsType(ps.ByName("type")))
|
||||
cred, ok := identity.GetCredentials(CredentialsType(r.PathValue("type")))
|
||||
if !ok {
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("You tried to remove a %s but this user have no %s set up.", ps.ByName("type"), ps.ByName("type"))))
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("You tried to remove a %s but this user have no %s set up.", r.PathValue("type"), r.PathValue("type"))))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import (
|
|||
"github.com/ory/x/httpx"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
|
|
@ -38,10 +37,10 @@ func TestSchemaValidatorDisallowsInternalNetworkRequests(t *testing.T) {
|
|||
|
||||
v := NewValidator(reg)
|
||||
n := negroni.New(x.HTTPLoaderContextMiddleware(reg))
|
||||
router := httprouter.New()
|
||||
router.GET("/:id", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
i := &Identity{
|
||||
SchemaID: ps.ByName("id"),
|
||||
SchemaID: r.PathValue("id"),
|
||||
Traits: Traits(`{ "firstName": "first-name", "lastName": "last-name", "age": 1 }`),
|
||||
}
|
||||
_, _ = w.Write([]byte(fmt.Sprintf("%+v", v.Validate(r.Context(), i))))
|
||||
|
|
@ -75,8 +74,8 @@ func TestSchemaValidator(t *testing.T) {
|
|||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
router := httprouter.New()
|
||||
router.GET("/schema/:name", func(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /schema/{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{
|
||||
"$id": "https://example.com/person.schema.json",
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
|
|
@ -86,7 +85,7 @@ func TestSchemaValidator(t *testing.T) {
|
|||
"traits": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"` + ps.ByName("name") + `": {
|
||||
"` + r.PathValue("name") + `": {
|
||||
"type": "string",
|
||||
"description": "The person's first name."
|
||||
},
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import (
|
|||
|
||||
"github.com/go-faker/faker/v4"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -33,8 +32,8 @@ type mockDeps interface {
|
|||
config.Provider
|
||||
}
|
||||
|
||||
func MockSetSession(t *testing.T, reg mockDeps, conf *config.Config) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func MockSetSession(t *testing.T, reg mockDeps, conf *config.Config) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
|
||||
i.NID = uuid.Must(uuid.NewV4())
|
||||
require.NoError(t, i.SetCredentialsWithConfig(
|
||||
|
|
@ -46,12 +45,12 @@ func MockSetSession(t *testing.T, reg mockDeps, conf *config.Config) httprouter.
|
|||
json.RawMessage(`{"hashed_password":"$"}`)))
|
||||
require.NoError(t, reg.IdentityManager().Create(context.Background(), i))
|
||||
|
||||
MockSetSessionWithIdentity(t, reg, conf, i)(w, r, ps)
|
||||
MockSetSessionWithIdentity(t, reg, conf, i)(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, _ *config.Config, i *identity.Identity) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, _ *config.Config, i *identity.Identity) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
activeSession, err := NewActiveSession(r, reg, i, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
|
||||
require.NoError(t, err)
|
||||
if aal := r.URL.Query().Get("set_aal"); len(aal) > 0 {
|
||||
|
|
@ -63,20 +62,24 @@ func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, _ *config.Config, i
|
|||
}
|
||||
}
|
||||
|
||||
func MockMakeAuthenticatedRequest(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request) ([]byte, *http.Response) {
|
||||
type router interface {
|
||||
HandleFunc(pattern string, handler http.HandlerFunc)
|
||||
}
|
||||
|
||||
func MockMakeAuthenticatedRequest(t *testing.T, reg mockDeps, conf *config.Config, router router, req *http.Request) ([]byte, *http.Response) {
|
||||
return MockMakeAuthenticatedRequestWithClient(t, reg, conf, router, req, NewClientWithCookies(t))
|
||||
}
|
||||
|
||||
func MockMakeAuthenticatedRequestWithClient(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request, client *http.Client) ([]byte, *http.Response) {
|
||||
func MockMakeAuthenticatedRequestWithClient(t *testing.T, reg mockDeps, conf *config.Config, router router, req *http.Request, client *http.Client) ([]byte, *http.Response) {
|
||||
return MockMakeAuthenticatedRequestWithClientAndID(t, reg, conf, router, req, client, nil)
|
||||
}
|
||||
|
||||
func MockMakeAuthenticatedRequestWithClientAndID(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request, client *http.Client, id *identity.Identity) ([]byte, *http.Response) {
|
||||
func MockMakeAuthenticatedRequestWithClientAndID(t *testing.T, reg mockDeps, conf *config.Config, router router, req *http.Request, client *http.Client, id *identity.Identity) ([]byte, *http.Response) {
|
||||
set := "/" + uuid.Must(uuid.NewV4()).String() + "/set"
|
||||
if id == nil {
|
||||
router.GET(set, MockSetSession(t, reg, conf))
|
||||
router.HandleFunc("GET "+set, MockSetSession(t, reg, conf))
|
||||
} else {
|
||||
router.GET(set, MockSetSessionWithIdentity(t, reg, conf, id))
|
||||
router.HandleFunc("GET "+set, MockSetSessionWithIdentity(t, reg, conf, id))
|
||||
}
|
||||
|
||||
MockHydrateCookieClient(t, client, "http://"+req.URL.Host+set+"?"+req.URL.Query().Encode())
|
||||
|
|
@ -128,11 +131,11 @@ func MockHydrateCookieClient(t *testing.T, c *http.Client, u string) *http.Cooki
|
|||
return sessionCookie
|
||||
}
|
||||
|
||||
func MockSessionCreateHandlerWithIdentity(t *testing.T, reg mockDeps, i *identity.Identity) (httprouter.Handle, *session.Session) {
|
||||
func MockSessionCreateHandlerWithIdentity(t *testing.T, reg mockDeps, i *identity.Identity) (http.HandlerFunc, *session.Session) {
|
||||
return MockSessionCreateHandlerWithIdentityAndAMR(t, reg, i, []identity.CredentialsType{"password"})
|
||||
}
|
||||
|
||||
func MockSessionCreateHandlerWithIdentityAndAMR(t *testing.T, reg mockDeps, i *identity.Identity, methods []identity.CredentialsType) (httprouter.Handle, *session.Session) {
|
||||
func MockSessionCreateHandlerWithIdentityAndAMR(t *testing.T, reg mockDeps, i *identity.Identity, methods []identity.CredentialsType) (http.HandlerFunc, *session.Session) {
|
||||
var sess session.Session
|
||||
require.NoError(t, faker.FakeData(&sess))
|
||||
// require AuthenticatedAt to be time.Now() as we always compare it to the current time
|
||||
|
|
@ -159,12 +162,12 @@ func MockSessionCreateHandlerWithIdentityAndAMR(t *testing.T, reg mockDeps, i *i
|
|||
require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), &sess))
|
||||
require.Len(t, inserted.Credentials, len(i.Credentials))
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, reg.SessionManager().IssueCookie(context.Background(), w, r, &sess))
|
||||
}, &sess
|
||||
}
|
||||
|
||||
func MockSessionCreateHandler(t *testing.T, reg mockDeps) (httprouter.Handle, *session.Session) {
|
||||
func MockSessionCreateHandler(t *testing.T, reg mockDeps) (http.HandlerFunc, *session.Session) {
|
||||
return MockSessionCreateHandlerWithIdentity(t, reg, &identity.Identity{
|
||||
ID: x.NewUUID(), State: identity.StateActive, Traits: identity.Traits(`{"baz":"bar","foo":true,"bar":2.5}`)})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -111,11 +111,11 @@ func ExpectURL(isAPI bool, api, browser string) string {
|
|||
}
|
||||
|
||||
func NewSettingsUITestServer(t *testing.T, conf *config.Config) *httptest.Server {
|
||||
router := httprouter.New()
|
||||
router.GET("/settings", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /settings", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
ts := httptest.NewServer(router)
|
||||
|
|
@ -129,14 +129,14 @@ func NewSettingsUITestServer(t *testing.T, conf *config.Config) *httptest.Server
|
|||
}
|
||||
|
||||
func NewSettingsUIEchoServer(t *testing.T, reg *driver.RegistryDefault) *httptest.Server {
|
||||
router := httprouter.New()
|
||||
router.GET("/settings", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /settings", func(w http.ResponseWriter, r *http.Request) {
|
||||
res, err := reg.SettingsFlowPersister().GetSettingsFlow(r.Context(), x.ParseUUID(r.URL.Query().Get("flow")))
|
||||
require.NoError(t, err)
|
||||
reg.Writer().Write(w, r, res)
|
||||
})
|
||||
|
||||
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
ts := httptest.NewServer(router)
|
||||
|
|
@ -211,9 +211,9 @@ func AddAndLoginIdentities(t *testing.T, reg *driver.RegistryDefault, public *ht
|
|||
location := "/sessions/set/" + tid
|
||||
|
||||
if router, ok := public.Config.Handler.(*x.RouterPublic); ok {
|
||||
router.Router.GET(location, route)
|
||||
} else if router, ok := public.Config.Handler.(*httprouter.Router); ok {
|
||||
router.GET(location, route)
|
||||
} else if router, ok := public.Config.Handler.(*http.ServeMux); ok {
|
||||
router.Handle("GET "+location, route)
|
||||
} else if router, ok := public.Config.Handler.(*x.RouterAdmin); ok {
|
||||
router.GET(location, route)
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -42,6 +42,15 @@ func (h *Handler) SetRoutes(r router) {
|
|||
r.GET(MetricsPrometheusPath, h.Metrics)
|
||||
}
|
||||
|
||||
type muxrouter interface {
|
||||
GET(path string, handle http.HandlerFunc)
|
||||
}
|
||||
|
||||
// SetMuxRoutes registers this handler's routes on a ServeMux.
|
||||
func (h *Handler) SetMuxRoutes(mux muxrouter) {
|
||||
mux.GET(MetricsPrometheusPath, promhttp.Handler().ServeHTTP)
|
||||
}
|
||||
|
||||
// Metrics outputs prometheus metrics
|
||||
//
|
||||
// swagger:route GET /metrics/prometheus metadata prometheus
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ package prometheusx
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
|
@ -61,7 +62,16 @@ func (pmm *MetricsManager) RegisterRouter(router *httprouter.Router) {
|
|||
pmm.routers.data = append(pmm.routers.data, router)
|
||||
}
|
||||
|
||||
var paramPlaceHolderRE = regexp.MustCompile(`\{[a-zA-Z0-9_-]+\}`)
|
||||
|
||||
func (pmm *MetricsManager) getLabelForPath(r *http.Request) string {
|
||||
// If the request came through a http.ServeMux, it already has a pattern that we
|
||||
// can use as a label. We just need to replace all path parameters with a generic
|
||||
// placeholder and remove the trailing slash pattern.
|
||||
if p := r.Pattern; p != "" {
|
||||
return paramPlaceHolderRE.ReplaceAllString(strings.TrimSuffix(p, "/{$}"), "{param}")
|
||||
}
|
||||
|
||||
// looking for a match in one of registered routers
|
||||
pmm.routers.Lock()
|
||||
defer pmm.routers.Unlock()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "kratos-oss",
|
||||
"dependencies": {
|
||||
"@openapitools/openapi-generator-cli": "2.20.0",
|
||||
"yamljs": "0.3.0"
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/ory/kratos/x/nosurfx"
|
||||
"github.com/ory/kratos/x/redir"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
|
|
@ -55,14 +54,14 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
|
|||
"/"+SchemasPath+"/*",
|
||||
x.AdminPrefix+"/"+SchemasPath+"/*",
|
||||
)
|
||||
public.GET(fmt.Sprintf("/%s/:id", SchemasPath), h.getIdentitySchema)
|
||||
public.GET(fmt.Sprintf("/%s/{id}", SchemasPath), h.getIdentitySchema)
|
||||
public.GET(fmt.Sprintf("/%s", SchemasPath), h.getAll)
|
||||
public.GET(fmt.Sprintf("%s/%s/:id", x.AdminPrefix, SchemasPath), h.getIdentitySchema)
|
||||
public.GET(fmt.Sprintf("%s/%s/{id}", x.AdminPrefix, SchemasPath), h.getIdentitySchema)
|
||||
public.GET(fmt.Sprintf("%s/%s", x.AdminPrefix, SchemasPath), h.getAll)
|
||||
}
|
||||
|
||||
func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
|
||||
admin.GET(fmt.Sprintf("/%s/:id", SchemasPath), redir.RedirectToPublicRoute(h.r))
|
||||
admin.GET(fmt.Sprintf("/%s/{id}", SchemasPath), redir.RedirectToPublicRoute(h.r))
|
||||
admin.GET(fmt.Sprintf("/%s", SchemasPath), redir.RedirectToPublicRoute(h.r))
|
||||
}
|
||||
|
||||
|
|
@ -112,7 +111,7 @@ type getIdentitySchema struct {
|
|||
// 200: identitySchema
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := h.r.Tracer(r.Context()).Tracer().Start(r.Context(), "schema.Handler.getIdentitySchema")
|
||||
defer span.End()
|
||||
|
||||
|
|
@ -122,7 +121,7 @@ func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request, ps h
|
|||
return
|
||||
}
|
||||
|
||||
id := ps.ByName("id")
|
||||
id := r.PathValue("id")
|
||||
s, err := ss.GetByID(id)
|
||||
if err != nil {
|
||||
// Maybe it is a base64 encoded ID?
|
||||
|
|
@ -203,7 +202,7 @@ type identitySchemasResponse struct {
|
|||
// Responses:
|
||||
// 200: identitySchemas
|
||||
// default: errorGeneric
|
||||
func (h *Handler) getAll(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) getAll(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := h.r.Tracer(r.Context()).Tracer().Start(r.Context(), "schema.Handler.getAll")
|
||||
defer span.End()
|
||||
|
||||
|
|
|
|||
|
|
@ -14,16 +14,15 @@ import (
|
|||
"github.com/ory/jsonschema/v3/httploader"
|
||||
"github.com/ory/x/httpx"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ory/x/stringsx"
|
||||
)
|
||||
|
||||
func TestSchemaValidator(t *testing.T) {
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
fs := http.StripPrefix("/schema", http.FileServer(http.Dir("stub/validator")))
|
||||
router.GET("/schema/:name", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("/schema/{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
fs.ServeHTTP(w, r)
|
||||
})
|
||||
ts := httptest.NewServer(router)
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ import (
|
|||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/nosurf"
|
||||
|
|
@ -92,7 +90,7 @@ type getFlowError struct {
|
|||
// 403: errorGeneric
|
||||
// 404: errorGeneric
|
||||
// 500: errorGeneric
|
||||
func (h *Handler) publicFetchError(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) publicFetchError(w http.ResponseWriter, r *http.Request) {
|
||||
if err := h.fetchError(w, r); err != nil {
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import (
|
|||
|
||||
"github.com/ory/x/assertx"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -39,11 +38,11 @@ func TestHandler(t *testing.T) {
|
|||
ns := nosurfx.NewTestCSRFHandler(router, reg)
|
||||
|
||||
h.RegisterPublicRoutes(router)
|
||||
router.GET("/regen", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /regen", func(w http.ResponseWriter, r *http.Request) {
|
||||
ns.RegenerateToken(w, r)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.GET("/set-error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /set-error", func(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := reg.SelfServiceErrorPersister().CreateErrorContainer(context.Background(), nosurf.Token(r), herodot.ErrNotFound.WithReason("foobar"))
|
||||
require.NoError(t, err)
|
||||
_, _ = w.Write([]byte(id.String()))
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import (
|
|||
"github.com/ory/kratos/ui/node"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -45,7 +45,7 @@ func TestHandleError(t *testing.T) {
|
|||
|
||||
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/password.schema.json")
|
||||
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
ts := httptest.NewServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ func TestHandleError(t *testing.T) {
|
|||
var loginFlow *login.Flow
|
||||
var flowError error
|
||||
var ct node.UiNodeGroup
|
||||
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
|
||||
h.WriteFlowError(w, r, loginFlow, ct, flowError)
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
|
@ -378,7 +377,7 @@ type createNativeLoginFlow struct {
|
|||
// 200: loginFlow
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createNativeLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) createNativeLoginFlow(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.createNativeLoginFlow")
|
||||
r = r.WithContext(ctx)
|
||||
|
|
@ -497,7 +496,7 @@ type createBrowserLoginFlow struct {
|
|||
// 303: emptyResponse
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createBrowserLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) createBrowserLoginFlow(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.createBrowserLoginFlow")
|
||||
r = r.WithContext(ctx)
|
||||
|
|
@ -655,7 +654,7 @@ type getLoginFlow struct {
|
|||
// 404: errorGeneric
|
||||
// 410: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) getLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) getLoginFlow(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.getLoginFlow")
|
||||
r = r.WithContext(ctx)
|
||||
|
|
@ -797,7 +796,7 @@ type updateLoginFlowBody struct{}
|
|||
// 410: errorGeneric
|
||||
// 422: errorBrowserLocationChangeRequired
|
||||
// default: errorGeneric
|
||||
func (h *Handler) updateLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) updateLoginFlow(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.updateLoginFlow")
|
||||
ctx = semconv.ContextWithAttributes(ctx, attribute.String(events.AttributeKeySelfServiceStrategyUsed.String(), "login"))
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import (
|
|||
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/x/urlx"
|
||||
|
|
@ -88,7 +87,7 @@ func TestFlowLifecycle(t *testing.T) {
|
|||
}
|
||||
req := testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+route, nil)
|
||||
req.URL.RawQuery = extQuery.Encode()
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
|
||||
if isAPI {
|
||||
assert.Len(t, res.Header.Get("Set-Cookie"), 0)
|
||||
}
|
||||
|
|
@ -354,7 +353,7 @@ func TestFlowLifecycle(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
|
||||
return string(body), res
|
||||
}
|
||||
|
||||
|
|
@ -458,7 +457,7 @@ func TestFlowLifecycle(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, reg.IdentityManager().Update(context.Background(), id, identity.ManagerAllowWriteProtectedTraits))
|
||||
|
||||
h := func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
h := func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := testhelpers.NewActiveSession(r, reg, id, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
|
||||
require.NoError(t, err)
|
||||
sess.AuthenticatorAssuranceLevel = identity.AuthenticatorAssuranceLevel1
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import (
|
|||
"github.com/ory/x/configx"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -50,9 +50,9 @@ func TestLoginExecutor(t *testing.T) {
|
|||
_ = testhelpers.NewLoginUIFlowEchoServer(t, reg)
|
||||
|
||||
newServer := func(t *testing.T, ft flow.Type, useIdentity *identity.Identity, flowCallback ...func(*login.Flow)) *httptest.Server {
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
|
||||
router.GET("/login/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /login/pre", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginFlow, err := login.NewFlow(conf, time.Minute, "", r, ft)
|
||||
require.NoError(t, err)
|
||||
if testhelpers.SelfServiceHookLoginErrorHandler(t, w, r, reg.LoginHookExecutor().PreLoginHook(w, r, loginFlow)) {
|
||||
|
|
@ -60,7 +60,7 @@ func TestLoginExecutor(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
router.GET("/login/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /login/post", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginFlow, err := login.NewFlow(conf, time.Minute, "", r, ft)
|
||||
require.NoError(t, err)
|
||||
loginFlow.Active = strategy
|
||||
|
|
@ -79,7 +79,7 @@ func TestLoginExecutor(t *testing.T) {
|
|||
reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, ""))
|
||||
})
|
||||
|
||||
router.GET("/login/post2fa", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /login/post2fa", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginFlow, err := login.NewFlow(conf, time.Minute, "", r, ft)
|
||||
require.NoError(t, err)
|
||||
loginFlow.Active = strategy
|
||||
|
|
|
|||
|
|
@ -22,8 +22,6 @@ import (
|
|||
"github.com/ory/x/sqlcon"
|
||||
"github.com/ory/x/urlx"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/selfservice/errorx"
|
||||
"github.com/ory/kratos/session"
|
||||
|
|
@ -141,7 +139,7 @@ type createBrowserLogoutFlow struct {
|
|||
// 400: errorGeneric
|
||||
// 401: errorGeneric
|
||||
// 500: errorGeneric
|
||||
func (h *Handler) createBrowserLogoutFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) createBrowserLogoutFlow(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := h.d.SessionManager().FetchFromRequest(r.Context(), r)
|
||||
if err != nil {
|
||||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
|
||||
|
|
@ -232,7 +230,7 @@ type performNativeLogoutBody struct {
|
|||
// 204: emptyResponse
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) performNativeLogout(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) performNativeLogout(w http.ResponseWriter, r *http.Request) {
|
||||
var p performNativeLogoutBody
|
||||
if err := h.dx.Decode(r, &p,
|
||||
decoderx.HTTPJSONDecoder(),
|
||||
|
|
@ -323,7 +321,7 @@ type updateLogoutFlow struct {
|
|||
// 303: emptyResponse
|
||||
// 204: emptyResponse
|
||||
// default: errorGeneric
|
||||
func (h *Handler) updateLogoutFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) updateLogoutFlow(w http.ResponseWriter, r *http.Request) {
|
||||
expected := r.URL.Query().Get("token")
|
||||
if len(expected) == 0 {
|
||||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReason("Please include a token in the URL query.")))
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ import (
|
|||
|
||||
"github.com/ory/kratos/session"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -38,8 +37,10 @@ func TestLogout(t *testing.T) {
|
|||
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
|
||||
public, _, publicRouter, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
|
||||
|
||||
publicRouter.GET("/session/browser/set", testhelpers.MockSetSession(t, reg, conf))
|
||||
publicRouter.GET("/session/browser/get", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
publicRouter.GET("/session/browser/set", func(writer http.ResponseWriter, request *http.Request) {
|
||||
testhelpers.MockSetSession(t, reg, conf)(writer, request)
|
||||
})
|
||||
publicRouter.HandleFunc("GET /session/browser/get", func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := reg.SessionManager().FetchFromRequest(r.Context(), r)
|
||||
if err != nil {
|
||||
reg.Writer().WriteError(w, r, err)
|
||||
|
|
@ -47,7 +48,7 @@ func TestLogout(t *testing.T) {
|
|||
}
|
||||
reg.Writer().Write(w, r, sess)
|
||||
})
|
||||
publicRouter.POST("/csrf/check", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
publicRouter.HandleFunc("POST /csrf/check", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
conf.MustSet(ctx, config.ViperKeySelfServiceLogoutBrowserDefaultReturnTo, public.URL+"/session/browser/get")
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import (
|
|||
"github.com/ory/kratos/ui/node"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -49,7 +49,7 @@ func TestHandleError(t *testing.T) {
|
|||
|
||||
public, _ := testhelpers.NewKratosServer(t, reg)
|
||||
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
ts := httptest.NewServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ func TestHandleError(t *testing.T) {
|
|||
var recoveryFlow *recovery.Flow
|
||||
var flowError error
|
||||
var methodName node.UiNodeGroup
|
||||
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
|
||||
h.WriteFlowError(w, r, recoveryFlow, methodName, flowError)
|
||||
})
|
||||
|
||||
|
|
@ -307,7 +307,7 @@ func TestHandleError_WithContinueWith(t *testing.T) {
|
|||
|
||||
public, _ := testhelpers.NewKratosServer(t, reg)
|
||||
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
ts := httptest.NewServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
|
|
@ -320,7 +320,7 @@ func TestHandleError_WithContinueWith(t *testing.T) {
|
|||
var recoveryFlow *recovery.Flow
|
||||
var flowError error
|
||||
var methodName node.UiNodeGroup
|
||||
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
|
||||
h.WriteFlowError(w, r, recoveryFlow, methodName, flowError)
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
|
||||
"github.com/ory/herodot"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/x/urlx"
|
||||
|
|
@ -73,11 +72,11 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
|
|||
h.d.CSRFHandler().IgnorePath(RouteSubmitFlow)
|
||||
|
||||
redirect := session.RedirectOnAuthenticated(h.d)
|
||||
public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsNotAuthenticated(h.createBrowserRecoveryFlow, func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsNotAuthenticated(h.createBrowserRecoveryFlow, func(w http.ResponseWriter, r *http.Request) {
|
||||
if x.IsJSONRequest(r) {
|
||||
h.d.Writer().WriteError(w, r, errors.WithStack(ErrAlreadyLoggedIn))
|
||||
} else {
|
||||
redirect(w, r, ps)
|
||||
redirect(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
|
|
@ -122,7 +121,7 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
|
|||
// 200: recoveryFlow
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createNativeRecoveryFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) createNativeRecoveryFlow(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.d.Config().SelfServiceFlowRecoveryEnabled(r.Context()) {
|
||||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
|
||||
return
|
||||
|
|
@ -187,7 +186,7 @@ type createBrowserRecoveryFlow struct {
|
|||
// 303: emptyResponse
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createBrowserRecoveryFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) createBrowserRecoveryFlow(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.d.Config().SelfServiceFlowRecoveryEnabled(r.Context()) {
|
||||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
|
||||
return
|
||||
|
|
@ -277,7 +276,7 @@ type getRecoveryFlow struct {
|
|||
// 404: errorGeneric
|
||||
// 410: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) getRecoveryFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) getRecoveryFlow(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.d.Config().SelfServiceFlowRecoveryEnabled(r.Context()) {
|
||||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
|
||||
return
|
||||
|
|
@ -402,7 +401,7 @@ type updateRecoveryFlowBody struct{}
|
|||
// 410: errorGeneric
|
||||
// 422: errorBrowserLocationChangeRequired
|
||||
// default: errorGeneric
|
||||
func (h *Handler) updateRecoveryFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) updateRecoveryFlow(w http.ResponseWriter, r *http.Request) {
|
||||
rid, err := flow.GetFlowID(r)
|
||||
if err != nil {
|
||||
h.d.RecoveryFlowErrorHandler().WriteFlowError(w, r, nil, node.DefaultGroup, err)
|
||||
|
|
|
|||
|
|
@ -50,13 +50,13 @@ func TestHandlerRedirectOnAuthenticated(t *testing.T) {
|
|||
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
|
||||
|
||||
t.Run("does redirect to default on authenticated request", func(t *testing.T) {
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+recovery.RouteInitBrowserFlow, nil))
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+recovery.RouteInitBrowserFlow, nil))
|
||||
assert.Contains(t, res.Request.URL.String(), redirTS.URL, "%+v", res)
|
||||
assert.EqualValues(t, "already authenticated", string(body))
|
||||
})
|
||||
|
||||
t.Run("does redirect to default on authenticated request", func(t *testing.T) {
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+recovery.RouteInitAPIFlow, nil))
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+recovery.RouteInitAPIFlow, nil))
|
||||
assert.Contains(t, res.Request.URL.String(), recovery.RouteInitAPIFlow)
|
||||
assert.EqualValues(t, text.ErrIDAlreadyLoggedIn, gjson.GetBytes(body, "error.id").Str)
|
||||
assertx.EqualAsJSON(t, recovery.ErrAlreadyLoggedIn, json.RawMessage(gjson.GetBytes(body, "error").Raw))
|
||||
|
|
@ -96,7 +96,7 @@ func TestInitFlow(t *testing.T) {
|
|||
if isSPA {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
|
||||
if isAPI {
|
||||
assert.Len(t, res.Header.Get("Set-Cookie"), 0)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import (
|
|||
"github.com/ory/kratos/selfservice/strategy/code"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
|
|
@ -35,8 +35,8 @@ func TestRecoveryExecutor(t *testing.T) {
|
|||
s := code.NewStrategy(reg)
|
||||
|
||||
newServer := func(t *testing.T, i *identity.Identity, ft flow.Type) *httptest.Server {
|
||||
router := httprouter.New()
|
||||
router.GET("/recovery/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /recovery/pre", func(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := recovery.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, s, ft)
|
||||
require.NoError(t, err)
|
||||
if testhelpers.SelfServiceHookErrorHandler(t, w, r, recovery.ErrHookAbortFlow, reg.RecoveryExecutor().PreRecoveryHook(w, r, a)) {
|
||||
|
|
@ -44,7 +44,7 @@ func TestRecoveryExecutor(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
router.GET("/recovery/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /recovery/post", func(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := recovery.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, s, ft)
|
||||
require.NoError(t, err)
|
||||
s, err := testhelpers.NewActiveSession(r,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import (
|
|||
"github.com/ory/kratos/ui/node"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -48,7 +48,7 @@ func TestHandleError(t *testing.T) {
|
|||
|
||||
public, _ := testhelpers.NewKratosServer(t, reg)
|
||||
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
ts := httptest.NewServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ func TestHandleError(t *testing.T) {
|
|||
var registrationFlow *registration.Flow
|
||||
var flowError error
|
||||
var group node.UiNodeGroup
|
||||
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
|
||||
h.WriteFlowError(w, r, registrationFlow, group, flowError)
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import (
|
|||
"github.com/ory/kratos/x/nosurfx"
|
||||
"github.com/ory/kratos/x/redir"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
|
@ -88,13 +87,13 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
|
|||
public.GET(RouteSubmitFlow, h.d.SessionHandler().IsNotAuthenticated(h.updateRegistrationFlow, h.onAuthenticated))
|
||||
}
|
||||
|
||||
func (h *Handler) onAuthenticated(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) onAuthenticated(w http.ResponseWriter, r *http.Request) {
|
||||
handler := session.RedirectOnAuthenticated(h.d)
|
||||
if x.IsJSONRequest(r) {
|
||||
handler = session.RespondWithJSONErrorOnAuthenticated(h.d.Writer(), ErrAlreadyLoggedIn)
|
||||
}
|
||||
|
||||
handler(w, r, ps)
|
||||
handler(w, r)
|
||||
}
|
||||
|
||||
func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
|
||||
|
|
@ -224,7 +223,7 @@ func (h *Handler) FromOldFlow(w http.ResponseWriter, r *http.Request, of Flow) (
|
|||
// 200: registrationFlow
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createNativeRegistrationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) createNativeRegistrationFlow(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.NewRegistrationFlow(w, r, flow.TypeAPI)
|
||||
if err != nil {
|
||||
h.d.Writer().WriteError(w, r, err)
|
||||
|
|
@ -336,7 +335,7 @@ type createBrowserRegistrationFlow struct {
|
|||
// 200: registrationFlow
|
||||
// 303: emptyResponse
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createBrowserRegistrationFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) createBrowserRegistrationFlow(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
a, err := h.NewRegistrationFlow(w, r, flow.TypeBrowser)
|
||||
|
|
@ -501,7 +500,7 @@ type getRegistrationFlow struct {
|
|||
// 404: errorGeneric
|
||||
// 410: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) getRegistrationFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) getRegistrationFlow(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.d.Config().SelfServiceFlowRegistrationEnabled(r.Context()) {
|
||||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(ErrRegistrationDisabled))
|
||||
return
|
||||
|
|
@ -638,7 +637,7 @@ type updateRegistrationFlowBody struct{}
|
|||
// 410: errorGeneric
|
||||
// 422: errorBrowserLocationChangeRequired
|
||||
// default: errorGeneric
|
||||
func (h *Handler) updateRegistrationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) updateRegistrationFlow(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx = semconv.ContextWithAttributes(ctx, attribute.String(events.AttributeKeySelfServiceStrategyUsed.String(), "registration"))
|
||||
r = r.WithContext(ctx)
|
||||
|
|
|
|||
|
|
@ -65,19 +65,19 @@ func TestHandlerRedirectOnAuthenticated(t *testing.T) {
|
|||
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
|
||||
|
||||
t.Run("does redirect to default on authenticated request", func(t *testing.T) {
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow, nil))
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow, nil))
|
||||
assert.Contains(t, res.Request.URL.String(), redirTS.URL)
|
||||
assert.EqualValues(t, "already authenticated", string(body))
|
||||
})
|
||||
|
||||
t.Run("does redirect to default on authenticated request", func(t *testing.T) {
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitAPIFlow, nil))
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitAPIFlow, nil))
|
||||
assert.Contains(t, res.Request.URL.String(), registration.RouteInitAPIFlow)
|
||||
assertx.EqualAsJSON(t, registration.ErrAlreadyLoggedIn, json.RawMessage(gjson.GetBytes(body, "error").Raw))
|
||||
})
|
||||
|
||||
t.Run("does redirect to return_to url on authenticated request", func(t *testing.T) {
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?return_to="+returnToTS.URL, nil))
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?return_to="+returnToTS.URL, nil))
|
||||
assert.Contains(t, res.Request.URL.String(), returnToTS.URL)
|
||||
assert.EqualValues(t, "return_to", string(body))
|
||||
})
|
||||
|
|
@ -92,7 +92,7 @@ func TestHandlerRedirectOnAuthenticated(t *testing.T) {
|
|||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?login_challenge="+hydra.FakeValidLoginChallenge, nil), client)
|
||||
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?login_challenge="+hydra.FakeValidLoginChallenge, nil), client)
|
||||
assert.Contains(t, res.Header.Get("location"), login.RouteInitBrowserFlow)
|
||||
})
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ func TestHandlerRedirectOnAuthenticated(t *testing.T) {
|
|||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, router.Router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?login_challenge="+hydra.FakeValidLoginChallenge, nil), client)
|
||||
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, router, testhelpers.NewTestHTTPRequest(t, "GET", ts.URL+registration.RouteInitBrowserFlow+"?login_challenge="+hydra.FakeValidLoginChallenge, nil), client)
|
||||
assert.Contains(t, res.Header.Get("location"), hydra.FakePostLoginURL)
|
||||
})
|
||||
}
|
||||
|
|
@ -142,7 +142,7 @@ func TestInitFlow(t *testing.T) {
|
|||
if isSPA {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
|
||||
body, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
|
||||
if isAPI {
|
||||
assert.Len(t, res.Header.Get("Set-Cookie"), 0)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -47,10 +47,10 @@ func TestRegistrationExecutor(t *testing.T) {
|
|||
conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, returnToServer.URL)
|
||||
|
||||
newServer := func(t *testing.T, i *identity.Identity, ft flow.Type, flowCallbacks ...func(*registration.Flow)) *httptest.Server {
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
|
||||
handleErr := testhelpers.SelfServiceHookRegistrationErrorHandler
|
||||
router.GET("/registration/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /registration/pre", func(w http.ResponseWriter, r *http.Request) {
|
||||
f, err := registration.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, ft)
|
||||
require.NoError(t, err)
|
||||
if handleErr(t, w, r, reg.RegistrationHookExecutor().PreRegistrationHook(w, r, f)) {
|
||||
|
|
@ -58,7 +58,7 @@ func TestRegistrationExecutor(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
router.GET("/registration/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /registration/post", func(w http.ResponseWriter, r *http.Request) {
|
||||
if i == nil {
|
||||
i = testhelpers.SelfServiceHookFakeIdentity(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import (
|
|||
|
||||
"github.com/go-faker/faker/v4"
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -45,7 +45,7 @@ func TestHandleError(t *testing.T) {
|
|||
|
||||
public, _ := testhelpers.NewKratosServer(t, reg)
|
||||
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
ts := httptest.NewServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
|
|
@ -65,11 +65,11 @@ func TestHandleError(t *testing.T) {
|
|||
id.State = identity.StateActive
|
||||
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &id))
|
||||
|
||||
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
|
||||
h.WriteFlowError(ctx, w, r, flowMethod, settingsFlow, &id, flowError)
|
||||
})
|
||||
|
||||
router.GET("/fake-redirect", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
router.HandleFunc("GET /fake-redirect", func(w http.ResponseWriter, r *http.Request) {
|
||||
reg.LoginHandler().NewLoginFlow(w, r, flow.TypeBrowser)
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import (
|
|||
|
||||
"github.com/ory/x/otelx"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
|
|
@ -96,7 +95,7 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
|
|||
h.d.CSRFHandler().IgnorePath(RouteInitAPIFlow)
|
||||
h.d.CSRFHandler().IgnorePath(RouteSubmitFlow)
|
||||
|
||||
public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsAuthenticated(h.createBrowserSettingsFlow, func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsAuthenticated(h.createBrowserSettingsFlow, func(w http.ResponseWriter, r *http.Request) {
|
||||
if x.IsJSONRequest(r) {
|
||||
h.d.Writer().WriteError(w, r, session.NewErrNoActiveSessionFound())
|
||||
} else {
|
||||
|
|
@ -222,7 +221,7 @@ type createNativeSettingsFlow struct {
|
|||
// 200: settingsFlow
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createNativeSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) createNativeSettingsFlow(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
|
||||
if err != nil {
|
||||
|
|
@ -306,7 +305,7 @@ type createBrowserSettingsFlow struct {
|
|||
// 401: errorGeneric
|
||||
// 403: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createBrowserSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) createBrowserSettingsFlow(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
|
||||
if err != nil {
|
||||
|
|
@ -405,7 +404,7 @@ type getSettingsFlow struct {
|
|||
// 404: errorGeneric
|
||||
// 410: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) getSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) getSettingsFlow(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
rid := x.ParseUUID(r.URL.Query().Get("id"))
|
||||
pr, err := h.d.SettingsFlowPersister().GetSettingsFlow(ctx, rid)
|
||||
|
|
@ -567,7 +566,7 @@ type updateSettingsFlowBody struct{}
|
|||
// 410: errorGeneric
|
||||
// 422: errorBrowserLocationChangeRequired
|
||||
// default: errorGeneric
|
||||
func (h *Handler) updateSettingsFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) updateSettingsFlow(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
err error
|
||||
ctx = r.Context()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
|
|
@ -47,9 +47,9 @@ func TestSettingsExecutor(t *testing.T) {
|
|||
|
||||
newServer := func(t *testing.T, i *identity.Identity, ft flow.Type) *httptest.Server {
|
||||
t.Helper()
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
handleErr := testhelpers.SelfServiceHookSettingsErrorHandler
|
||||
router.GET("/settings/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /settings/pre", func(w http.ResponseWriter, r *http.Request) {
|
||||
if i == nil {
|
||||
i = testhelpers.SelfServiceHookCreateFakeIdentity(t, reg)
|
||||
}
|
||||
|
|
@ -62,7 +62,7 @@ func TestSettingsExecutor(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
router.GET("/settings/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /settings/post", func(w http.ResponseWriter, r *http.Request) {
|
||||
if i == nil {
|
||||
i = testhelpers.SelfServiceHookCreateFakeIdentity(t, reg)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
|
|
@ -102,13 +101,13 @@ func GetFlowID(r *http.Request) (uuid.UUID, error) {
|
|||
func OnUnauthenticated(reg interface {
|
||||
config.Provider
|
||||
x.WriterProvider
|
||||
}) func(http.ResponseWriter, *http.Request, httprouter.Params) {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
}) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := session.RedirectOnUnauthenticated(reg.Config().SelfServiceFlowLoginUI(r.Context()).String())
|
||||
if x.IsJSONRequest(r) {
|
||||
handler = session.RespondWithJSONErrorOnAuthenticated(reg.Writer(), herodot.ErrUnauthorized.WithReasonf("A valid Ory Session Cookie or Ory Session Token is missing."))
|
||||
}
|
||||
|
||||
handler(w, r, ps)
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import (
|
|||
"github.com/ory/kratos/ui/node"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -46,7 +46,7 @@ func TestHandleError(t *testing.T) {
|
|||
|
||||
public, _ := testhelpers.NewKratosServer(t, reg)
|
||||
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
ts := httptest.NewServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ func TestHandleError(t *testing.T) {
|
|||
var verificationFlow *verification.Flow
|
||||
var flowError error
|
||||
var methodName node.UiNodeGroup
|
||||
router.GET("/error", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) {
|
||||
h.WriteFlowError(w, r, verificationFlow, methodName, flowError)
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
|
||||
"github.com/ory/herodot"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/x/urlx"
|
||||
|
|
@ -162,7 +161,7 @@ type createNativeVerificationFlow struct {
|
|||
// 200: verificationFlow
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createNativeVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) createNativeVerificationFlow(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.d.Config().SelfServiceFlowVerificationEnabled(r.Context()) {
|
||||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification is not allowed because it was disabled.")))
|
||||
return
|
||||
|
|
@ -209,7 +208,7 @@ type createBrowserVerificationFlow struct {
|
|||
// 200: verificationFlow
|
||||
// 303: emptyResponse
|
||||
// default: errorGeneric
|
||||
func (h *Handler) createBrowserVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) createBrowserVerificationFlow(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.d.Config().SelfServiceFlowVerificationEnabled(r.Context()) {
|
||||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification is not allowed because it was disabled.")))
|
||||
return
|
||||
|
|
@ -284,7 +283,7 @@ type getVerificationFlow struct {
|
|||
// 403: errorGeneric
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) getVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) getVerificationFlow(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.d.Config().SelfServiceFlowVerificationEnabled(r.Context()) {
|
||||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification is not allowed because it was disabled.")))
|
||||
return
|
||||
|
|
@ -408,7 +407,7 @@ type updateVerificationFlowBody struct{}
|
|||
// 400: verificationFlow
|
||||
// 410: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request) {
|
||||
rid, err := flow.GetFlowID(r)
|
||||
if err != nil {
|
||||
h.d.VerificationFlowErrorHandler().WriteFlowError(w, r, nil, node.DefaultGroup, err)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import (
|
|||
"github.com/ory/kratos/selfservice/flow/verification"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
|
|
@ -32,8 +32,8 @@ func TestVerificationExecutor(t *testing.T) {
|
|||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
|
||||
newServer := func(t *testing.T, i *identity.Identity, ft flow.Type) *httptest.Server {
|
||||
router := httprouter.New()
|
||||
router.GET("/verification/pre", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /verification/pre", func(w http.ResponseWriter, r *http.Request) {
|
||||
strategy, err := reg.GetActiveVerificationStrategy(r.Context())
|
||||
require.NoError(t, err)
|
||||
a, err := verification.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, strategy, ft)
|
||||
|
|
@ -43,7 +43,7 @@ func TestVerificationExecutor(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
router.GET("/verification/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /verification/post", func(w http.ResponseWriter, r *http.Request) {
|
||||
strategy, err := reg.GetActiveVerificationStrategy(r.Context())
|
||||
require.NoError(t, err)
|
||||
a, err := verification.NewFlow(conf, time.Minute, nosurfx.FakeCSRFToken, r, strategy, ft)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -103,8 +102,8 @@ func TestWebHooks(t *testing.T) {
|
|||
Method string
|
||||
}
|
||||
|
||||
webHookEndPoint := func(whr *WebHookRequest) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
webHookEndPoint := func(whr *WebHookRequest) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
|
@ -116,14 +115,14 @@ func TestWebHooks(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
webHookHttpCodeEndPoint := func(code int) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
|
||||
webHookHttpCodeEndPoint := func(code int) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(code)
|
||||
}
|
||||
}
|
||||
|
||||
webHookHttpCodeWithBodyEndPoint := func(t *testing.T, code int, body []byte) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
|
||||
webHookHttpCodeWithBodyEndPoint := func(t *testing.T, code int, body []byte) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(code)
|
||||
_, err := w.Write(body)
|
||||
assert.NoError(t, err, "error while returning response from webHookHttpCodeWithBodyEndPoint")
|
||||
|
|
@ -131,17 +130,17 @@ func TestWebHooks(t *testing.T) {
|
|||
}
|
||||
|
||||
path := "/web_hook"
|
||||
newServer := func(f httprouter.Handle) *httptest.Server {
|
||||
r := httprouter.New()
|
||||
newServer := func(f http.HandlerFunc) *httptest.Server {
|
||||
r := http.NewServeMux()
|
||||
|
||||
r.Handle("CONNECT", path, f)
|
||||
r.DELETE(path, f)
|
||||
r.GET(path, f)
|
||||
r.OPTIONS(path, f)
|
||||
r.PATCH(path, f)
|
||||
r.POST(path, f)
|
||||
r.PUT(path, f)
|
||||
r.Handle("TRACE", path, f)
|
||||
r.HandleFunc("CONNECT "+path, f)
|
||||
r.HandleFunc("DELETE "+path, f)
|
||||
r.HandleFunc("GET "+path, f)
|
||||
r.HandleFunc("OPTIONS "+path, f)
|
||||
r.HandleFunc("PATCH "+path, f)
|
||||
r.HandleFunc("POST "+path, f)
|
||||
r.HandleFunc("PUT "+path, f)
|
||||
r.HandleFunc("TRACE "+path, f)
|
||||
|
||||
ts := httptest.NewServer(r)
|
||||
t.Cleanup(ts.Close)
|
||||
|
|
@ -917,7 +916,7 @@ func TestWebHooks(t *testing.T) {
|
|||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
waitTime := time.Millisecond * 100
|
||||
ts := newServer(func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
|
||||
ts := newServer(func(w http.ResponseWriter, _ *http.Request) {
|
||||
defer wg.Done()
|
||||
time.Sleep(waitTime)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
|
@ -956,7 +955,7 @@ func TestWebHooks(t *testing.T) {
|
|||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(3) // HTTP client does 3 attempts
|
||||
ts := newServer(func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
ts := newServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer wg.Done()
|
||||
w.WriteHeader(500)
|
||||
_, _ = w.Write([]byte(`{"error":"some error"}`))
|
||||
|
|
|
|||
|
|
@ -145,9 +145,9 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F
|
|||
if _, err := s.deps.SessionManager().FetchFromRequest(ctx, r); err == nil {
|
||||
// User is already logged in
|
||||
if x.IsJSONRequest(r) {
|
||||
session.RespondWithJSONErrorOnAuthenticated(s.deps.Writer(), recovery.ErrAlreadyLoggedIn)(w, r, nil)
|
||||
session.RespondWithJSONErrorOnAuthenticated(s.deps.Writer(), recovery.ErrAlreadyLoggedIn)(w, r)
|
||||
} else {
|
||||
session.RedirectOnAuthenticated(s.deps)(w, r, nil)
|
||||
session.RedirectOnAuthenticated(s.deps)(w, r)
|
||||
}
|
||||
return errors.WithStack(flow.ErrCompletedByStrategy)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
|
|
@ -138,7 +137,7 @@ type recoveryCodeForIdentity struct {
|
|||
// 400: errorGeneric
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (s *Strategy) createRecoveryCodeForIdentity(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (s *Strategy) createRecoveryCodeForIdentity(w http.ResponseWriter, r *http.Request) {
|
||||
var p createRecoveryCodeForIdentityBody
|
||||
if err := s.dx.Decode(r, &p, decoderx.HTTPJSONDecoder()); err != nil {
|
||||
s.deps.Writer().WriteError(w, r, err)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ package strategy
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/x"
|
||||
|
|
@ -20,32 +18,16 @@ type disabledChecker interface {
|
|||
x.WriterProvider
|
||||
}
|
||||
|
||||
func disabledWriter(c disabledChecker, enabled bool, wrap httprouter.Handle, w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func disabledWriter(c disabledChecker, enabled bool, wrap http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if !enabled {
|
||||
c.Writer().WriteError(w, r, herodot.ErrNotFound.WithReason(EndpointDisabledMessage))
|
||||
return
|
||||
}
|
||||
wrap(w, r, ps)
|
||||
wrap(w, r)
|
||||
}
|
||||
|
||||
func IsDisabled(c disabledChecker, strategy string, wrap httprouter.Handle) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
disabledWriter(c, c.Config().SelfServiceStrategy(r.Context(), strategy).Enabled, wrap, w, r, ps)
|
||||
}
|
||||
}
|
||||
|
||||
func IsRecoveryDisabled(c disabledChecker, strategy string, wrap httprouter.Handle) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
disabledWriter(c,
|
||||
c.Config().SelfServiceStrategy(r.Context(), strategy).Enabled && c.Config().SelfServiceFlowRecoveryEnabled(r.Context()),
|
||||
wrap, w, r, ps)
|
||||
}
|
||||
}
|
||||
|
||||
func IsVerificationDisabled(c disabledChecker, strategy string, wrap httprouter.Handle) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
disabledWriter(c,
|
||||
c.Config().SelfServiceStrategy(r.Context(), strategy).Enabled && c.Config().SelfServiceFlowVerificationEnabled(r.Context()),
|
||||
wrap, w, r, ps)
|
||||
func IsDisabled(c disabledChecker, strategy string, wrap http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
disabledWriter(c, c.Config().SelfServiceStrategy(r.Context(), strategy).Enabled, wrap, w, r)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ func TestCompleteLogin(t *testing.T) {
|
|||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
|
||||
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
|
||||
assert.Contains(t, res.Request.URL.String(), publicTS.URL+login.RouteSubmitFlow)
|
||||
assert.Equal(t, text.NewErrorValidationLoginNoStrategyFound().Text, gjson.GetBytes(actual, "ui.messages.0.text").String())
|
||||
})
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/ory/kratos/x/redir"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
|
@ -148,7 +147,7 @@ type recoveryLinkForIdentity struct {
|
|||
// 400: errorGeneric
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (s *Strategy) createRecoveryLinkForIdentity(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (s *Strategy) createRecoveryLinkForIdentity(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var p createRecoveryLinkForIdentityBody
|
||||
|
|
@ -275,9 +274,9 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F
|
|||
|
||||
if _, err := s.d.SessionManager().FetchFromRequest(r.Context(), r); err == nil {
|
||||
if x.IsJSONRequest(r) {
|
||||
session.RespondWithJSONErrorOnAuthenticated(s.d.Writer(), recovery.ErrAlreadyLoggedIn)(w, r, nil)
|
||||
session.RespondWithJSONErrorOnAuthenticated(s.d.Writer(), recovery.ErrAlreadyLoggedIn)(w, r)
|
||||
} else {
|
||||
session.RedirectOnAuthenticated(s.d)(w, r, nil)
|
||||
session.RedirectOnAuthenticated(s.d)(w, r)
|
||||
}
|
||||
return errors.WithStack(flow.ErrCompletedByStrategy)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -680,7 +680,7 @@ func TestRecovery(t *testing.T) {
|
|||
}
|
||||
|
||||
check(t, expectSuccess(t, nil, false, false, values), email, testhelpers.NewClientWithCookies(t), func(cl *http.Client, req *http.Request) (*http.Response, error) {
|
||||
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, publicRouter.Router, req, cl)
|
||||
_, res := testhelpers.MockMakeAuthenticatedRequestWithClient(t, reg, conf, publicRouter, req, cl)
|
||||
return res, nil
|
||||
})
|
||||
})
|
||||
|
|
@ -692,7 +692,7 @@ func TestRecovery(t *testing.T) {
|
|||
|
||||
cl := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, reg, id)
|
||||
check(t, expectSuccess(t, nil, false, false, values), email, cl, func(_ *http.Client, req *http.Request) (*http.Response, error) {
|
||||
_, res := testhelpers.MockMakeAuthenticatedRequestWithClientAndID(t, reg, conf, publicRouter.Router, req, cl, id)
|
||||
_, res := testhelpers.MockMakeAuthenticatedRequestWithClientAndID(t, reg, conf, publicRouter, req, cl, id)
|
||||
return res, nil
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -146,12 +146,12 @@ func (p Configuration) Redir(public *url.URL) string {
|
|||
|
||||
if p.OrganizationID != "" {
|
||||
route := RouteOrganizationCallback
|
||||
route = strings.Replace(route, ":provider", p.ID, 1)
|
||||
route = strings.Replace(route, ":organization", p.OrganizationID, 1)
|
||||
route = strings.Replace(route, "{provider}", p.ID, 1)
|
||||
route = strings.Replace(route, "{organization}", p.OrganizationID, 1)
|
||||
return urlx.AppendPaths(public, route).String()
|
||||
}
|
||||
|
||||
return urlx.AppendPaths(public, strings.Replace(RouteCallback, ":provider", p.ID, 1)).String()
|
||||
return urlx.AppendPaths(public, strings.Replace(RouteCallback, "{provider}", p.ID, 1)).String()
|
||||
}
|
||||
|
||||
type ConfigurationCollection struct {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ package oidc
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"maps"
|
||||
|
|
@ -20,7 +21,6 @@ import (
|
|||
"github.com/ory/kratos/x/redir"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
|
|
@ -57,10 +57,10 @@ import (
|
|||
const (
|
||||
RouteBase = "/self-service/methods/oidc"
|
||||
|
||||
RouteAuth = RouteBase + "/auth/:flow"
|
||||
RouteCallback = RouteBase + "/callback/:provider"
|
||||
RouteAuth = RouteBase + "/auth/{flow}"
|
||||
RouteCallback = RouteBase + "/callback/{provider}"
|
||||
RouteCallbackGeneric = RouteBase + "/callback"
|
||||
RouteOrganizationCallback = RouteBase + "/organization/:organization/callback/:provider"
|
||||
RouteOrganizationCallback = RouteBase + "/organization/{organization}/callback/{provider}"
|
||||
)
|
||||
|
||||
var _ identity.ActiveCredentialsCounter = new(Strategy)
|
||||
|
|
@ -198,15 +198,15 @@ func (s *Strategy) CountActiveMultiFactorCredentials(_ context.Context, _ map[id
|
|||
|
||||
func (s *Strategy) setRoutes(r *x.RouterPublic) {
|
||||
wrappedHandleCallback := strategy.IsDisabled(s.d, s.ID().String(), s.HandleCallback)
|
||||
if handle, _, _ := r.Lookup("GET", RouteCallback); handle == nil {
|
||||
if !r.HasRoute("GET", RouteCallback) {
|
||||
r.GET(RouteCallback, wrappedHandleCallback)
|
||||
}
|
||||
if handle, _, _ := r.Lookup("GET", RouteCallbackGeneric); handle == nil {
|
||||
if !r.HasRoute("GET", RouteCallbackGeneric) {
|
||||
r.GET(RouteCallbackGeneric, wrappedHandleCallback)
|
||||
}
|
||||
|
||||
// Apple can use the POST request method when calling the callback
|
||||
if handle, _, _ := r.Lookup("POST", RouteCallback); handle == nil {
|
||||
if !r.HasRoute("POST", RouteCallback) {
|
||||
// Apple is the only (known) provider that sometimes does a form POST to the callback URL.
|
||||
// This is a workaround to handle this case.
|
||||
// But since the URL contains the `id` of the provider, we just allow all OIDC provider callbacks to bypass CSRF.
|
||||
|
|
@ -221,7 +221,7 @@ func (s *Strategy) setRoutes(r *x.RouterPublic) {
|
|||
}
|
||||
|
||||
// Redirect POST request to GET rewriting form fields to query params.
|
||||
func (s *Strategy) redirectToGET(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (s *Strategy) redirectToGET(w http.ResponseWriter, r *http.Request) {
|
||||
publicUrl := s.d.Config().SelfPublicURL(r.Context())
|
||||
dest := *r.URL
|
||||
dest.Host = publicUrl.Host
|
||||
|
|
@ -330,7 +330,7 @@ func (s *Strategy) validateFlow(ctx context.Context, r *http.Request, rid uuid.U
|
|||
return ar, err // this must return the error
|
||||
}
|
||||
|
||||
func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request, ps httprouter.Params) (flow.Flow, *oidcv1.State, *AuthCodeContainer, error) {
|
||||
func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flow.Flow, *oidcv1.State, *AuthCodeContainer, error) {
|
||||
var (
|
||||
codeParam = stringsx.Coalesce(r.URL.Query().Get("code"), r.URL.Query().Get("authCode"))
|
||||
stateParam = r.URL.Query().Get("state")
|
||||
|
|
@ -345,7 +345,7 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request, ps h
|
|||
return nil, nil, nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`Unable to complete OpenID Connect flow because the state parameter is invalid.`))
|
||||
}
|
||||
|
||||
if providerFromURL := ps.ByName("provider"); providerFromURL != "" {
|
||||
if providerFromURL := r.PathValue("provider"); providerFromURL != "" {
|
||||
// We're serving an OIDC callback URL with provider in the URL.
|
||||
if state.ProviderId == "" {
|
||||
// provider in URL, but not in state: compatiblity mode, remove this fallback later
|
||||
|
|
@ -442,18 +442,18 @@ func (s *Strategy) alreadyAuthenticated(ctx context.Context, w http.ResponseWrit
|
|||
return false, nil
|
||||
}
|
||||
|
||||
func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
code = stringsx.Coalesce(r.URL.Query().Get("code"), r.URL.Query().Get("authCode"))
|
||||
code = cmp.Or(r.URL.Query().Get("code"), r.URL.Query().Get("authCode"))
|
||||
err error
|
||||
)
|
||||
|
||||
ctx := context.WithValue(r.Context(), httprouter.ParamsKey, ps)
|
||||
ctx := r.Context()
|
||||
ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "strategy.oidc.HandleCallback")
|
||||
defer otelx.End(span, &err)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
req, state, cntnr, err := s.ValidateCallback(w, r, ps)
|
||||
req, state, cntnr, err := s.ValidateCallback(w, r)
|
||||
if err != nil {
|
||||
if req != nil {
|
||||
s.forwardError(ctx, w, r, req, s.HandleError(ctx, w, r, req, state.ProviderId, nil, err))
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rakutentech/jwk-go/jwk"
|
||||
|
|
@ -135,7 +135,7 @@ func createClient(t *testing.T, remote string, redir []string) (id, secret strin
|
|||
}
|
||||
|
||||
func newHydraIntegration(t *testing.T, remote *string, subject *string, claims *idTokenClaims, scope *[]string, addr string) (*http.Server, string) {
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
|
||||
type p struct {
|
||||
Subject string `json:"subject,omitempty"`
|
||||
|
|
@ -164,7 +164,7 @@ func newHydraIntegration(t *testing.T, remote *string, subject *string, claims *
|
|||
http.Redirect(w, r, response.RedirectTo, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NotEmpty(t, *remote)
|
||||
require.NotEmpty(t, *subject)
|
||||
|
||||
|
|
@ -179,7 +179,7 @@ func newHydraIntegration(t *testing.T, remote *string, subject *string, claims *
|
|||
do(w, r, href, &b)
|
||||
})
|
||||
|
||||
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NotEmpty(t, *remote)
|
||||
require.NotNil(t, *scope)
|
||||
|
||||
|
|
|
|||
|
|
@ -402,9 +402,9 @@ func TestSettingsStrategy(t *testing.T) {
|
|||
|
||||
t.Run("case=should not be able to link a connection already linked by another identity", func(t *testing.T) {
|
||||
// While this theoretically allows for account enumeration - because we see an error indicator if an
|
||||
// oidc connection is being linked that exists already - it would require the attacker to already
|
||||
// OIDC connection is being linked that exists already - it would require the attacker to already
|
||||
// have control over the social profile, in which case account enumeration is the least of our worries.
|
||||
// Instead of using the oidc profile for enumeration, the attacker would use it for account takeover.
|
||||
// Instead of using the OIDC profile for enumeration, the attacker would use it for account takeover.
|
||||
|
||||
// This is the multiuser login id for google
|
||||
subject = "hackerman+multiuser+" + testID
|
||||
|
|
|
|||
|
|
@ -397,7 +397,7 @@ func TestStrategy(t *testing.T) {
|
|||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, routerP.Router, req)
|
||||
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, routerP, req)
|
||||
assert.Contains(t, res.Request.URL.String(), ts.URL+login.RouteSubmitFlow)
|
||||
assert.Equal(t, text.NewErrorValidationLoginNoStrategyFound().Text, gjson.GetBytes(actual, "ui.messages.0.text").String())
|
||||
})
|
||||
|
|
@ -1677,6 +1677,7 @@ func TestStrategy(t *testing.T) {
|
|||
subject = email2
|
||||
t.Run("step=should fail login if existing identity identifier doesn't match", func(t *testing.T) {
|
||||
require.NotNil(t, linkingLoginFlow.ID)
|
||||
require.NotEmpty(t, linkingLoginFlow.ID)
|
||||
res, body := loginWithOIDC(t, client, uuid.Must(uuid.FromString(linkingLoginFlow.ID)), "valid")
|
||||
assertUIError(t, res, body, "Linked credentials do not match.")
|
||||
})
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ func TestCompleteLogin(t *testing.T) {
|
|||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router.Router, req)
|
||||
actual, res := testhelpers.MockMakeAuthenticatedRequest(t, reg, conf, router, req)
|
||||
assert.Contains(t, res.Request.URL.String(), publicTS.URL+login.RouteSubmitFlow)
|
||||
assert.Equal(t, text.NewErrorValidationLoginNoStrategyFound().Text, gjson.GetBytes(actual, "ui.messages.0.text").String())
|
||||
})
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import (
|
|||
|
||||
"github.com/ory/kratos/x/nosurfx"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/urfave/negroni"
|
||||
"golang.org/x/oauth2"
|
||||
|
|
@ -53,7 +52,7 @@ func TestOAuth2Provider(t *testing.T) {
|
|||
errTS := testhelpers.NewErrorTestServer(t, reg)
|
||||
redirTS := testhelpers.NewRedirSessionEchoTS(t, reg)
|
||||
|
||||
router.GET("/login-ts", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /login-ts", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("[loginTS] navigated to the login ui")
|
||||
c := r.Context().Value(TestUIConfig).(*testConfig)
|
||||
*c.callTrace = append(*c.callTrace, LoginUI)
|
||||
|
|
@ -115,7 +114,7 @@ func TestOAuth2Provider(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
router.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("[consentTS] navigated to the consent ui")
|
||||
c := r.Context().Value(TestUIConfig).(*testConfig)
|
||||
*c.callTrace = append(*c.callTrace, Consent)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import (
|
|||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/urfave/negroni"
|
||||
|
||||
hydraclientgo "github.com/ory/hydra-client-go/v2"
|
||||
|
|
@ -50,7 +49,7 @@ func TestOAuth2ProviderRegistration(t *testing.T) {
|
|||
TestOAuthClientState contextKey = "test-oauth-client-state"
|
||||
)
|
||||
|
||||
router.GET("/login-ts", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /login-ts", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("[loginTS] navigated to the login ui")
|
||||
c := r.Context().Value(TestUIConfig).(*testConfig)
|
||||
*c.callTrace = append(*c.callTrace, LoginUI)
|
||||
|
|
@ -76,7 +75,7 @@ func TestOAuth2ProviderRegistration(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
router.GET("/registration-ts", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /registration-ts", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("[registrationTS] navigated to the registration ui")
|
||||
c := r.Context().Value(TestUIConfig).(*testConfig)
|
||||
*c.callTrace = append(*c.callTrace, RegistrationUI)
|
||||
|
|
@ -144,7 +143,7 @@ func TestOAuth2ProviderRegistration(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
router.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("[consentTS] navigated to the consent ui")
|
||||
c := r.Context().Value(TestUIConfig).(*testConfig)
|
||||
*c.callTrace = append(*c.callTrace, Consent)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ import (
|
|||
|
||||
"github.com/ory/kratos/corpx"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
|
@ -512,8 +511,8 @@ func TestStrategyTraits(t *testing.T) {
|
|||
setPrivileged(t)
|
||||
|
||||
var returned bool
|
||||
router := httprouter.New()
|
||||
router.GET("/return-ts", func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /return-ts", func(w http.ResponseWriter, r *http.Request) {
|
||||
returned = true
|
||||
})
|
||||
rts := httptest.NewServer(router)
|
||||
|
|
|
|||
|
|
@ -6,13 +6,11 @@ package session
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
)
|
||||
|
||||
func RespondWithJSONErrorOnAuthenticated(h herodot.Writer, err error) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func RespondWithJSONErrorOnAuthenticated(h herodot.Writer, err error) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
h.WriteError(w, r, err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import (
|
|||
"github.com/ory/x/pointerx"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/x/decoderx"
|
||||
|
|
@ -66,12 +65,12 @@ const (
|
|||
RouteCollection = "/sessions"
|
||||
RouteExchangeCodeForSessionToken = RouteCollection + "/token-exchange" // #nosec G101
|
||||
RouteWhoami = RouteCollection + "/whoami"
|
||||
RouteSession = RouteCollection + "/:id"
|
||||
RouteSession = RouteCollection + "/{id}"
|
||||
)
|
||||
|
||||
const (
|
||||
AdminRouteIdentity = "/identities"
|
||||
AdminRouteIdentitiesSessions = AdminRouteIdentity + "/:id/sessions"
|
||||
AdminRouteIdentitiesSessions = AdminRouteIdentity + "/{id}/sessions"
|
||||
AdminRouteSessionExtendId = RouteSession + "/extend"
|
||||
)
|
||||
|
||||
|
|
@ -212,7 +211,7 @@ type toSession struct {
|
|||
// 401: errorGeneric
|
||||
// 403: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) whoami(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) whoami(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := h.r.Tracer(r.Context()).Tracer().Start(r.Context(), "sessions.Handler.whoami")
|
||||
defer span.End()
|
||||
|
||||
|
|
@ -307,8 +306,8 @@ type deleteIdentitySessions struct {
|
|||
// 401: errorGeneric
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) deleteIdentitySessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
iID, err := uuid.FromString(ps.ByName("id"))
|
||||
func (h *Handler) deleteIdentitySessions(w http.ResponseWriter, r *http.Request) {
|
||||
iID, err := uuid.FromString(r.PathValue("id"))
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID"))
|
||||
return
|
||||
|
|
@ -386,7 +385,7 @@ type listSessionsResponse struct {
|
|||
// 200: listSessions
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) adminListSessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) adminListSessions(w http.ResponseWriter, r *http.Request) {
|
||||
activeRaw := r.URL.Query().Get("active")
|
||||
activeBool, err := strconv.ParseBool(activeRaw)
|
||||
if activeRaw != "" && err != nil {
|
||||
|
|
@ -470,14 +469,14 @@ type getSession struct {
|
|||
// 200: session
|
||||
// 400: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) getSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
if ps.ByName("id") == "whoami" {
|
||||
func (h *Handler) getSession(w http.ResponseWriter, r *http.Request) {
|
||||
if r.PathValue("id") == "whoami" {
|
||||
// for /admin/sessions/whoami redirect to the public route
|
||||
redir.RedirectToPublicRoute(h.r)(w, r, ps)
|
||||
redir.RedirectToPublicRoute(h.r)(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
sID, err := uuid.FromString(ps.ByName("id"))
|
||||
sID, err := uuid.FromString(r.PathValue("id"))
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID"))
|
||||
return
|
||||
|
|
@ -539,8 +538,8 @@ type disableSession struct {
|
|||
// 400: errorGeneric
|
||||
// 401: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) disableSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
sID, err := uuid.FromString(ps.ByName("id"))
|
||||
func (h *Handler) disableSession(w http.ResponseWriter, r *http.Request) {
|
||||
sID, err := uuid.FromString(r.PathValue("id"))
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID"))
|
||||
return
|
||||
|
|
@ -605,8 +604,8 @@ type listIdentitySessionsResponse struct {
|
|||
// 400: errorGeneric
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) listIdentitySessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
iID, err := uuid.FromString(ps.ByName("id"))
|
||||
func (h *Handler) listIdentitySessions(w http.ResponseWriter, r *http.Request) {
|
||||
iID, err := uuid.FromString(r.PathValue("id"))
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID")))
|
||||
return
|
||||
|
|
@ -679,7 +678,7 @@ type disableMyOtherSessions struct {
|
|||
// 400: errorGeneric
|
||||
// 401: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) deleteMySessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) deleteMySessions(w http.ResponseWriter, r *http.Request) {
|
||||
s, err := h.r.SessionManager().FetchFromRequest(r.Context(), r)
|
||||
if err != nil {
|
||||
h.r.Audit().WithRequest(r).WithError(err).Info("No valid session cookie found.")
|
||||
|
|
@ -738,11 +737,11 @@ type disableMySession struct {
|
|||
// 400: errorGeneric
|
||||
// 401: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) deleteMySession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
sid := ps.ByName("id")
|
||||
func (h *Handler) deleteMySession(w http.ResponseWriter, r *http.Request) {
|
||||
sid := r.PathValue("id")
|
||||
if sid == "whoami" {
|
||||
// Special case where we actually want to handle the whoami endpoint.
|
||||
h.whoami(w, r, ps)
|
||||
h.whoami(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -822,7 +821,7 @@ type listMySessionsResponse struct {
|
|||
// 400: errorGeneric
|
||||
// 401: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) listMySessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) listMySessions(w http.ResponseWriter, r *http.Request) {
|
||||
s, err := h.r.SessionManager().FetchFromRequest(r.Context(), r)
|
||||
if err != nil {
|
||||
h.r.Audit().WithRequest(r).WithError(err).Info("No valid session cookie found.")
|
||||
|
|
@ -860,13 +859,13 @@ const (
|
|||
sessionInContextKey sessionInContext = iota
|
||||
)
|
||||
|
||||
func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated httprouter.Handle) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) IsAuthenticated(wrap http.HandlerFunc, onUnauthenticated http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
sess, err := h.r.SessionManager().FetchFromRequest(ctx, r)
|
||||
if err != nil {
|
||||
if onUnauthenticated != nil {
|
||||
onUnauthenticated(w, r, ps)
|
||||
onUnauthenticated(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -874,7 +873,7 @@ func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated http
|
|||
return
|
||||
}
|
||||
|
||||
wrap(w, r.WithContext(context.WithValue(ctx, sessionInContextKey, sess)), ps)
|
||||
wrap(w, r.WithContext(context.WithValue(ctx, sessionInContextKey, sess)))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -917,8 +916,8 @@ type extendSession struct {
|
|||
// 400: errorGeneric
|
||||
// 404: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
id, err := uuid.FromString(ps.ByName("id"))
|
||||
func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := uuid.FromString(r.PathValue("id"))
|
||||
if err != nil {
|
||||
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID")))
|
||||
return
|
||||
|
|
@ -945,11 +944,11 @@ func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request, ps
|
|||
h.r.Writer().Write(w, r, s)
|
||||
}
|
||||
|
||||
func (h *Handler) IsNotAuthenticated(wrap httprouter.Handle, onAuthenticated httprouter.Handle) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func (h *Handler) IsNotAuthenticated(wrap http.HandlerFunc, onAuthenticated http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, err := h.r.SessionManager().FetchFromRequest(r.Context(), r); err != nil {
|
||||
if e := new(ErrNoActiveSessionFound); errors.As(err, &e) {
|
||||
wrap(w, r, ps)
|
||||
wrap(w, r)
|
||||
return
|
||||
}
|
||||
h.r.Writer().WriteError(w, r, err)
|
||||
|
|
@ -957,7 +956,7 @@ func (h *Handler) IsNotAuthenticated(wrap httprouter.Handle, onAuthenticated htt
|
|||
}
|
||||
|
||||
if onAuthenticated != nil {
|
||||
onAuthenticated(w, r, ps)
|
||||
onAuthenticated(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -965,8 +964,8 @@ func (h *Handler) IsNotAuthenticated(wrap httprouter.Handle, onAuthenticated htt
|
|||
}
|
||||
}
|
||||
|
||||
func RedirectOnAuthenticated(d interface{ config.Provider }) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func RedirectOnAuthenticated(d interface{ config.Provider }) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
returnTo, err := redir.SecureRedirectTo(r, d.Config().SelfServiceBrowserDefaultReturnTo(ctx), redir.SecureRedirectAllowSelfServiceURLs(d.Config().SelfPublicURL(ctx)))
|
||||
if err != nil {
|
||||
|
|
@ -978,14 +977,14 @@ func RedirectOnAuthenticated(d interface{ config.Provider }) httprouter.Handle {
|
|||
}
|
||||
}
|
||||
|
||||
func RedirectOnUnauthenticated(to string) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func RedirectOnUnauthenticated(to string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, to, http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
func RespondWitherrorGenericOnAuthenticated(h herodot.Writer, err error) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
func RespondWitherrorGenericOnAuthenticated(h herodot.Writer, err error) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
h.WriteError(w, r, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1048,7 +1047,7 @@ type CodeExchangeResponse struct {
|
|||
// 404: errorGeneric
|
||||
// 410: errorGeneric
|
||||
// default: errorGeneric
|
||||
func (h *Handler) exchangeCode(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func (h *Handler) exchangeCode(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
initCode = r.URL.Query().Get("init_code")
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ import (
|
|||
"github.com/ory/x/pagination/keysetpagination"
|
||||
"github.com/ory/x/sqlcon"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
|
|
@ -50,8 +49,8 @@ func init() {
|
|||
corpx.RegisterFakes()
|
||||
}
|
||||
|
||||
func send(code int) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
|
||||
func send(code int) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(code)
|
||||
}
|
||||
}
|
||||
|
|
@ -530,7 +529,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
description: "expand Identity",
|
||||
expand: "/?expand=Identity",
|
||||
expand: "?expand=Identity",
|
||||
expectedIdentityId: s.Identity.ID.String(),
|
||||
expectedDevices: 0,
|
||||
},
|
||||
|
|
@ -859,9 +858,9 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
|
|||
// we limit the scope of the channels, so you cannot accidentally mess up a test case
|
||||
ident := make(chan *identity.Identity, 1)
|
||||
sess := make(chan *Session, 1)
|
||||
r.GET("/set", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
r.GET("/set", func(w http.ResponseWriter, r *http.Request) {
|
||||
h, s := testhelpers.MockSessionCreateHandlerWithIdentity(t, reg, <-ident)
|
||||
h(w, r, ps)
|
||||
h(w, r)
|
||||
sess <- s
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
|
|
@ -222,17 +221,17 @@ func TestManagerHTTP(t *testing.T) {
|
|||
|
||||
var s *session.Session
|
||||
rp := x.NewRouterPublic()
|
||||
rp.GET("/session/revoke", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
rp.GET("/session/revoke", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, reg.SessionManager().PurgeFromRequest(r.Context(), w, r))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rp.GET("/session/set", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
rp.GET("/session/set", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, reg.SessionManager().UpsertAndIssueCookie(r.Context(), w, r, s))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rp.GET("/session/get", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
rp.GET("/session/get", func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := reg.SessionManager().FetchFromRequest(r.Context(), r)
|
||||
if err != nil {
|
||||
t.Logf("Got error on lookup: %s %T", err, errors.Unwrap(err))
|
||||
|
|
@ -242,7 +241,7 @@ func TestManagerHTTP(t *testing.T) {
|
|||
reg.Writer().Write(w, r, sess)
|
||||
})
|
||||
|
||||
rp.GET("/session/get-middleware", reg.SessionHandler().IsAuthenticated(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
rp.GET("/session/get-middleware", reg.SessionHandler().IsAuthenticated(func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := reg.SessionManager().FetchFromRequestContext(r.Context(), r)
|
||||
if err != nil {
|
||||
t.Logf("Got error on lookup: %s %T", err, errors.Unwrap(err))
|
||||
|
|
@ -311,7 +310,7 @@ func TestManagerHTTP(t *testing.T) {
|
|||
conf.MustSet(ctx, config.ViperKeySessionName, "")
|
||||
})
|
||||
|
||||
rp.GET("/session/set/invalid", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
rp.GET("/session/set/invalid", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Error(t, reg.SessionManager().UpsertAndIssueCookie(r.Context(), w, r, s))
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ go 1.24.1
|
|||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/julienschmidt/httprouter v1.3.0
|
||||
github.com/ory/hydra-client-go v1.7.4
|
||||
github.com/ory/kratos-client-go v0.10.1
|
||||
github.com/ory/x v0.0.722-0.20250620091013-eeb8bd14b65a
|
||||
|
|
|
|||
|
|
@ -5,10 +5,9 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/hydra-client-go/client"
|
||||
"github.com/ory/hydra-client-go/client/admin"
|
||||
"github.com/ory/hydra-client-go/models"
|
||||
|
|
@ -18,12 +17,6 @@ import (
|
|||
"github.com/ory/x/urlx"
|
||||
)
|
||||
|
||||
func check(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func checkReq(w http.ResponseWriter, err error) bool {
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("%+v", err), 500)
|
||||
|
|
@ -33,16 +26,14 @@ func checkReq(w http.ResponseWriter, err error) bool {
|
|||
}
|
||||
|
||||
func main() {
|
||||
router := httprouter.New()
|
||||
|
||||
kratosPublicURL := urlx.ParseOrPanic(osx.GetenvDefault("KRATOS_PUBLIC_URL", "http://localhost:4433"))
|
||||
adminURL := urlx.ParseOrPanic(osx.GetenvDefault("HYDRA_ADMIN_URL", "http://localhost:4445"))
|
||||
hc := client.NewHTTPClientWithConfig(nil, &client.TransportConfig{Schemes: []string{adminURL.Scheme}, Host: adminURL.Host, BasePath: adminURL.Path})
|
||||
|
||||
router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
http.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`ok`))
|
||||
})
|
||||
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
http.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
||||
res, err := hc.Admin.GetLoginRequest(admin.NewGetLoginRequestParams().
|
||||
WithLoginChallenge(r.URL.Query().Get("login_challenge")))
|
||||
if !checkReq(w, err) {
|
||||
|
|
@ -73,7 +64,7 @@ func main() {
|
|||
</html>`, challenge)
|
||||
})
|
||||
|
||||
router.POST("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
http.HandleFunc("POST /login", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !checkReq(w, r.ParseForm()) {
|
||||
return
|
||||
}
|
||||
|
|
@ -98,7 +89,7 @@ func main() {
|
|||
http.Redirect(w, r, *res.Payload.RedirectTo, http.StatusFound)
|
||||
})
|
||||
|
||||
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
http.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
|
||||
res, err := hc.Admin.GetConsentRequest(admin.NewGetConsentRequestParams().
|
||||
WithConsentChallenge(r.URL.Query().Get("consent_challenge")))
|
||||
if !checkReq(w, err) {
|
||||
|
|
@ -136,7 +127,7 @@ func main() {
|
|||
</html>`, challenge, checkoxes)
|
||||
})
|
||||
|
||||
router.POST("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
http.HandleFunc("POST /consent", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
if r.Form.Get("action") == "accept" {
|
||||
kratosConfig := kratos.NewConfiguration()
|
||||
|
|
@ -180,7 +171,6 @@ func main() {
|
|||
})
|
||||
|
||||
addr := ":" + osx.GetenvDefault("PORT", "4746")
|
||||
server := &http.Server{Addr: addr, Handler: router}
|
||||
fmt.Printf("Starting web server at %s\n", addr)
|
||||
check(server.ListenAndServe())
|
||||
log.Fatal(http.ListenAndServe(addr, nil))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
client "github.com/ory/hydra-client-go/v2"
|
||||
|
||||
"github.com/ory/x/osx"
|
||||
|
|
@ -31,7 +29,7 @@ func checkReq(w http.ResponseWriter, err error) bool {
|
|||
}
|
||||
|
||||
func main() {
|
||||
router := httprouter.New()
|
||||
router := http.NewServeMux()
|
||||
|
||||
adminURL := urlx.ParseOrPanic(osx.GetenvDefault("HYDRA_ADMIN_URL", "http://localhost:4445"))
|
||||
cfg := client.NewConfiguration()
|
||||
|
|
@ -40,10 +38,10 @@ func main() {
|
|||
}
|
||||
hc := client.NewAPIClient(cfg)
|
||||
|
||||
router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`ok`))
|
||||
})
|
||||
router.GET("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
|
||||
res, _, err := hc.OAuth2Api.GetOAuth2LoginRequest(r.Context()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute()
|
||||
if !checkReq(w, err) {
|
||||
return
|
||||
|
|
@ -77,7 +75,7 @@ func main() {
|
|||
</html>`, challenge)
|
||||
})
|
||||
|
||||
router.POST("/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("POST /login", func(w http.ResponseWriter, r *http.Request) {
|
||||
check(r.ParseForm())
|
||||
remember := pointerx.Bool(r.Form.Get("remember") == "true")
|
||||
if r.Form.Get("action") == "accept" {
|
||||
|
|
@ -103,7 +101,7 @@ func main() {
|
|||
http.Redirect(w, r, res.RedirectTo, http.StatusFound)
|
||||
})
|
||||
|
||||
router.GET("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /consent", func(w http.ResponseWriter, r *http.Request) {
|
||||
res, _, err := hc.OAuth2Api.GetOAuth2ConsentRequest(r.Context()).ConsentChallenge(r.URL.Query().
|
||||
Get("consent_challenge")).Execute()
|
||||
if !checkReq(w, err) {
|
||||
|
|
@ -144,7 +142,7 @@ func main() {
|
|||
</html>`, challenge, checkoxes)
|
||||
})
|
||||
|
||||
router.POST("/consent", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("POST /consent", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
remember := pointerx.Bool(r.Form.Get("remember") == "true")
|
||||
if r.Form.Get("action") == "accept" {
|
||||
|
|
|
|||
|
|
@ -1,13 +1,3 @@
|
|||
module github.com/ory/mock
|
||||
|
||||
go 1.24.4
|
||||
|
||||
require (
|
||||
github.com/julienschmidt/httprouter v1.3.0
|
||||
github.com/ory/graceful v0.1.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/stretchr/testify v1.7.0 // indirect
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
|
||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||
github.com/ory/graceful v0.1.3 h1:FaeXcHZh168WzS+bqruqWEw/HgXWLdNv2nJ+fbhxbhc=
|
||||
github.com/ory/graceful v0.1.3/go.mod h1:4zFz687IAF7oNHHiB586U4iL+/4aV09o/PYLE34t2bA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
@ -11,10 +11,6 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/graceful"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -24,22 +20,16 @@ var (
|
|||
|
||||
func main() {
|
||||
port := cmp.Or(os.Getenv("PORT"), "4471")
|
||||
server := graceful.WithDefaults(&http.Server{Addr: fmt.Sprintf(":%s", port)})
|
||||
register(server)
|
||||
if err := graceful.Graceful(server.ListenAndServe, server.Shutdown); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", port), nil))
|
||||
}
|
||||
|
||||
func register(server *http.Server) {
|
||||
router := httprouter.New()
|
||||
|
||||
router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func init() {
|
||||
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
router.GET("/documents/:id", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
id := ps.ByName("id")
|
||||
http.HandleFunc("GET /documents/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
|
||||
documentsLock.RLock()
|
||||
doc, ok := documents[id]
|
||||
|
|
@ -52,10 +42,10 @@ func register(server *http.Server) {
|
|||
}
|
||||
})
|
||||
|
||||
router.PUT("/documents/:id", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
http.HandleFunc("PUT /documents/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
documentsLock.Lock()
|
||||
defer documentsLock.Unlock()
|
||||
id := ps.ByName("id")
|
||||
id := r.PathValue("id")
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
|
|
@ -66,14 +56,12 @@ func register(server *http.Server) {
|
|||
documents[id] = body
|
||||
})
|
||||
|
||||
router.DELETE("/documents/:id", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
http.HandleFunc("DELETE /documents/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
documentsLock.Lock()
|
||||
defer documentsLock.Unlock()
|
||||
id := ps.ByName("id")
|
||||
id := r.PathValue("id")
|
||||
|
||||
delete(documents, id)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
server.Handler = router
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -30,8 +30,8 @@ func TestSession(t *testing.T) {
|
|||
assert.EqualValues(t, 78652871, cookie.Options.MaxAge, "we ensure the options are always copied correctly.")
|
||||
}
|
||||
|
||||
router := httprouter.New()
|
||||
router.GET("/set", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /set", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, SessionPersistValues(w, r, s, sid, map[string]interface{}{
|
||||
"string-1": "foo",
|
||||
"string-2": "bar",
|
||||
|
|
@ -55,7 +55,7 @@ func TestSession(t *testing.T) {
|
|||
|
||||
t.Run("case=GetString", func(t *testing.T) {
|
||||
id := "get-string"
|
||||
router.GET("/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /"+id, func(w http.ResponseWriter, r *http.Request) {
|
||||
got, err := SessionGetString(r, s, sid, "string-1")
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, "foo", got)
|
||||
|
|
@ -81,7 +81,7 @@ func TestSession(t *testing.T) {
|
|||
t.Run("case=GetStringMultipleCookies", func(t *testing.T) {
|
||||
id := "get-string-multiple"
|
||||
|
||||
router.GET("/set/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /set/"+id, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, SessionPersistValues(w, r, s, sid, map[string]interface{}{
|
||||
"multiple-string-1": "foo",
|
||||
}))
|
||||
|
|
@ -92,7 +92,7 @@ func TestSession(t *testing.T) {
|
|||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
router.GET("/get/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /get/"+id, func(w http.ResponseWriter, r *http.Request) {
|
||||
got, err := SessionGetString(r, s, sid, "multiple-string-1")
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, "foo", got)
|
||||
|
|
@ -122,7 +122,7 @@ func TestSession(t *testing.T) {
|
|||
|
||||
t.Run("case=GetStringOr", func(t *testing.T) {
|
||||
id := "get-string-or"
|
||||
router.GET("/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /"+id, func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.EqualValues(t, "foo", SessionGetStringOr(r, s, sid, "string-1", "baz"))
|
||||
assert.EqualValues(t, "bar", SessionGetStringOr(r, s, sid, "string-2", "baz"))
|
||||
assert.EqualValues(t, "", SessionGetStringOr(r, s, sid, "string-3", "baz"))
|
||||
|
|
@ -189,14 +189,14 @@ func TestSession(t *testing.T) {
|
|||
})
|
||||
|
||||
id := "session-unset"
|
||||
router.GET("/"+id+"/unset", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /"+id+"/unset", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, SessionUnset(w, r, s, sid))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
cookie, _ := s.Get(r, sid)
|
||||
assert.EqualValues(t, -1, cookie.Options.MaxAge, "we ensure the options are always copied correctly.")
|
||||
})
|
||||
|
||||
router.GET("/"+id+"/get", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /"+id+"/get", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Empty(t, SessionGetStringOr(r, s, sid, "string-1", ""))
|
||||
require.Empty(t, SessionGetStringOr(r, s, sid, "string-2", ""))
|
||||
require.Empty(t, SessionGetStringOr(r, s, sid, "string-3", ""))
|
||||
|
|
@ -231,13 +231,13 @@ func TestSession(t *testing.T) {
|
|||
})
|
||||
|
||||
id := "session-unset-key"
|
||||
router.GET("/"+id+"/unset", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /"+id+"/unset", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, SessionUnsetKey(w, r, s, sid, "string-1"))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
isExpiryCorrect(t, r)
|
||||
})
|
||||
|
||||
router.GET("/"+id+"/expect-unset", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /"+id+"/expect-unset", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Empty(t, SessionGetStringOr(r, s, sid, "string-1", ""))
|
||||
require.Empty(t, SessionGetStringOr(r, s, sid, "string-2", ""))
|
||||
require.Empty(t, SessionGetStringOr(r, s, sid, "string-3", ""))
|
||||
|
|
@ -246,7 +246,7 @@ func TestSession(t *testing.T) {
|
|||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
router.GET("/"+id+"/expect-one", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /"+id+"/expect-one", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Empty(t, SessionGetStringOr(r, s, sid, "string-1", ""))
|
||||
assert.EqualValues(t, "bar", SessionGetStringOr(r, s, sid, "string-2", "baz"))
|
||||
assert.EqualValues(t, "", SessionGetStringOr(r, s, sid, "string-3", "baz"))
|
||||
|
|
|
|||
|
|
@ -10,15 +10,14 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestRedirectAdmin(t *testing.T) {
|
||||
router := httprouter.New()
|
||||
router.GET("/admin/identities", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /admin/identities", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("identities"))
|
||||
})
|
||||
n := negroni.New()
|
||||
|
|
|
|||
10
x/nocache.go
10
x/nocache.go
|
|
@ -5,8 +5,6 @@ package x
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
// NoCache adds `Cache-Control: private, no-cache, no-store, must-revalidate` to the response header.
|
||||
|
|
@ -14,14 +12,6 @@ func NoCache(w http.ResponseWriter) {
|
|||
w.Header().Set("Cache-Control", "private, no-cache, no-store, must-revalidate")
|
||||
}
|
||||
|
||||
// NoCacheHandle wraps httprouter.Handle with `Cache-Control: private, no-cache, no-store, must-revalidate` headers.
|
||||
func NoCacheHandle(handle httprouter.Handle) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
NoCache(w)
|
||||
handle(w, r, ps)
|
||||
}
|
||||
}
|
||||
|
||||
// NoCacheHandlerFunc wraps http.HandlerFunc with `Cache-Control: private, no-cache, no-store, must-revalidate` headers.
|
||||
func NoCacheHandlerFunc(handle http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
|||
|
|
@ -8,14 +8,12 @@ import (
|
|||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/x"
|
||||
)
|
||||
|
||||
func RedirectToAdminRoute(reg config.Provider) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func RedirectToAdminRoute(reg config.Provider) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
admin := reg.Config().SelfAdminURL(r.Context())
|
||||
|
||||
dest := *r.URL
|
||||
|
|
@ -28,8 +26,8 @@ func RedirectToAdminRoute(reg config.Provider) httprouter.Handle {
|
|||
}
|
||||
}
|
||||
|
||||
func RedirectToPublicRoute(reg config.Provider) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
func RedirectToPublicRoute(reg config.Provider) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
public := reg.Config().SelfPublicURL(r.Context())
|
||||
|
||||
dest := *r.URL
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import (
|
|||
|
||||
"github.com/ory/x/configx"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
|
|
@ -38,13 +37,13 @@ func TestRedirectToPublicAdminRoute(t *testing.T) {
|
|||
|
||||
pub.POST("/privileged", redir.RedirectToAdminRoute(reg))
|
||||
pub.POST("/admin/privileged", redir.RedirectToAdminRoute(reg))
|
||||
adm.POST("/privileged", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
adm.POST("/privileged", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_, _ = w.Write(body)
|
||||
})
|
||||
|
||||
adm.POST("/read", redir.RedirectToPublicRoute(reg))
|
||||
pub.POST("/read", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
pub.POST("/read", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_, _ = w.Write(body)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import (
|
|||
|
||||
"github.com/ory/kratos/x/redir"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
|
|
@ -31,14 +30,14 @@ func TestSecureContentNegotiationRedirection(t *testing.T) {
|
|||
var jsonActual = json.RawMessage(`{"foo":"bar"}` + "\n")
|
||||
writer := herodot.NewJSONWriter(nil)
|
||||
|
||||
router := httprouter.New()
|
||||
router.GET("/redir", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("GET /redir", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, redir.SecureContentNegotiationRedirection(w, r, jsonActual, x.RequestURL(r).String(), writer, conf))
|
||||
})
|
||||
router.GET("/default-return-to", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
router.HandleFunc("GET /default-return-to", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.GET("/return-to", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
router.HandleFunc("GET /return-to", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
|
|
|
|||
140
x/router.go
140
x/router.go
|
|
@ -5,105 +5,161 @@ package x
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
type RouterPublic struct {
|
||||
*httprouter.Router
|
||||
mux *http.ServeMux
|
||||
}
|
||||
|
||||
func NewRouterPublic() *RouterPublic {
|
||||
return &RouterPublic{
|
||||
Router: httprouter.New(),
|
||||
mux: http.NewServeMux(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouterPublic) GET(path string, handle httprouter.Handle) {
|
||||
r.Handle("GET", path, NoCacheHandle(handle))
|
||||
func (r *RouterPublic) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
r.mux.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
func (r *RouterPublic) HEAD(path string, handle httprouter.Handle) {
|
||||
r.Handle("HEAD", path, NoCacheHandle(handle))
|
||||
func (r *RouterPublic) GET(path string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("GET", path, handler)
|
||||
}
|
||||
|
||||
func (r *RouterPublic) POST(path string, handle httprouter.Handle) {
|
||||
r.Handle("POST", path, NoCacheHandle(handle))
|
||||
func (r *RouterPublic) HEAD(path string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("HEAD", path, handler)
|
||||
}
|
||||
|
||||
func (r *RouterPublic) PUT(path string, handle httprouter.Handle) {
|
||||
r.Handle("PUT", path, NoCacheHandle(handle))
|
||||
func (r *RouterPublic) POST(path string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("POST", path, handler)
|
||||
}
|
||||
|
||||
func (r *RouterPublic) PATCH(path string, handle httprouter.Handle) {
|
||||
r.Handle("PATCH", path, NoCacheHandle(handle))
|
||||
func (r *RouterPublic) PUT(path string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("PUT", path, handler)
|
||||
}
|
||||
|
||||
func (r *RouterPublic) DELETE(path string, handle httprouter.Handle) {
|
||||
r.Handle("DELETE", path, NoCacheHandle(handle))
|
||||
func (r *RouterPublic) PATCH(path string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("PATCH", path, handler)
|
||||
}
|
||||
|
||||
func (r *RouterPublic) Handle(method, path string, handle httprouter.Handle) {
|
||||
r.Router.Handle(method, path, NoCacheHandle(handle))
|
||||
func (r *RouterPublic) DELETE(path string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("DELETE", path, handler)
|
||||
}
|
||||
|
||||
func (r *RouterPublic) HandlerFunc(method, path string, handler http.HandlerFunc) {
|
||||
r.Router.HandlerFunc(method, path, NoCacheHandlerFunc(handler))
|
||||
func (r *RouterPublic) Handle(method, route string, handle http.HandlerFunc) {
|
||||
for _, pattern := range []string{
|
||||
method + " " + path.Join(route),
|
||||
method + " " + path.Join(route, "{$}"),
|
||||
} {
|
||||
r.mux.HandleFunc(pattern, func(w http.ResponseWriter, req *http.Request) {
|
||||
NoCache(w)
|
||||
handle(w, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouterPublic) HandlerFunc(method, route string, handler http.HandlerFunc) {
|
||||
for _, pattern := range []string{
|
||||
method + " " + path.Join(route),
|
||||
method + " " + path.Join(route, "{$}"),
|
||||
} {
|
||||
r.mux.HandleFunc(pattern, NoCacheHandlerFunc(handler))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouterPublic) HandleFunc(pattern string, handler http.HandlerFunc) {
|
||||
for _, pattern := range []string{
|
||||
path.Join(pattern),
|
||||
path.Join(pattern, "{$}"),
|
||||
} {
|
||||
r.mux.HandleFunc(pattern, NoCacheHandlerFunc(handler))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouterPublic) Handler(method, path string, handler http.Handler) {
|
||||
r.Router.Handler(method, path, NoCacheHandler(handler))
|
||||
route := method + " " + path
|
||||
r.mux.Handle(route, NoCacheHandler(handler))
|
||||
}
|
||||
|
||||
type RouterAdmin struct {
|
||||
*httprouter.Router
|
||||
func (r *RouterPublic) HasRoute(method, path string) bool {
|
||||
_, pattern := r.mux.Handler(httptest.NewRequest(method, path, nil))
|
||||
return pattern != ""
|
||||
}
|
||||
|
||||
type RouterAdmin struct{ mux *http.ServeMux }
|
||||
|
||||
func NewRouterAdmin() *RouterAdmin {
|
||||
return &RouterAdmin{
|
||||
Router: httprouter.New(),
|
||||
mux: http.NewServeMux(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) GET(publicPath string, handle httprouter.Handle) {
|
||||
r.Router.GET(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
|
||||
func (r *RouterAdmin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
r.mux.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) HEAD(publicPath string, handle httprouter.Handle) {
|
||||
r.Router.HEAD(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
|
||||
func (r *RouterAdmin) GET(publicPath string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("GET", publicPath, handler)
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) POST(publicPath string, handle httprouter.Handle) {
|
||||
r.Router.POST(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
|
||||
func (r *RouterAdmin) HEAD(publicPath string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("HEAD", publicPath, handler)
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) PUT(publicPath string, handle httprouter.Handle) {
|
||||
r.Router.PUT(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
|
||||
func (r *RouterAdmin) POST(publicPath string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("POST", publicPath, handler)
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) PATCH(publicPath string, handle httprouter.Handle) {
|
||||
r.Router.PATCH(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
|
||||
func (r *RouterAdmin) PUT(publicPath string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("PUT", publicPath, handler)
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) DELETE(publicPath string, handle httprouter.Handle) {
|
||||
r.Router.DELETE(path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
|
||||
func (r *RouterAdmin) PATCH(publicPath string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("PATCH", publicPath, handler)
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) Handle(method, publicPath string, handle httprouter.Handle) {
|
||||
r.Router.Handle(method, path.Join(AdminPrefix, publicPath), NoCacheHandle(handle))
|
||||
func (r *RouterAdmin) DELETE(publicPath string, handler http.HandlerFunc) {
|
||||
r.HandlerFunc("DELETE", publicPath, handler)
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) Handle(method, publicPath string, handle http.HandlerFunc) {
|
||||
for _, pattern := range []string{
|
||||
method + " " + path.Join(AdminPrefix, publicPath),
|
||||
method + " " + path.Join(AdminPrefix, publicPath, "{$}"),
|
||||
} {
|
||||
r.mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
|
||||
NoCache(w)
|
||||
handle(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) HandlerFunc(method, publicPath string, handler http.HandlerFunc) {
|
||||
r.Router.HandlerFunc(method, path.Join(AdminPrefix, publicPath), NoCacheHandlerFunc(handler))
|
||||
for _, pattern := range []string{
|
||||
method + " " + path.Join(AdminPrefix, publicPath),
|
||||
method + " " + path.Join(AdminPrefix, publicPath, "{$}"),
|
||||
} {
|
||||
r.mux.HandleFunc(pattern, NoCacheHandlerFunc(handler))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) Handler(method, publicPath string, handler http.Handler) {
|
||||
r.Router.Handler(method, path.Join(AdminPrefix, publicPath), NoCacheHandler(handler))
|
||||
for _, pattern := range []string{
|
||||
method + " " + path.Join(AdminPrefix, publicPath),
|
||||
method + " " + path.Join(AdminPrefix, publicPath, "{$}"),
|
||||
} {
|
||||
r.mux.Handle(pattern, NoCacheHandler(handler))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouterAdmin) Lookup(method, publicPath string) {
|
||||
r.Router.Lookup(method, path.Join(AdminPrefix, publicPath))
|
||||
func (r *RouterAdmin) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
for _, p := range []string{
|
||||
path.Join(pattern),
|
||||
path.Join(pattern, "{$}"),
|
||||
} {
|
||||
r.mux.HandleFunc(p, NoCacheHandlerFunc(handler))
|
||||
}
|
||||
}
|
||||
|
||||
type HandlerRegistrar interface {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/gobuffalo/httptest"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -23,19 +23,19 @@ func TestCacheHandling(t *testing.T) {
|
|||
ts := httptest.NewServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
router.GET("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.DELETE("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("DELETE /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.POST("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("POST /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.PUT("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("PUT /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.PATCH("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("PATCH /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
|
|
@ -52,19 +52,19 @@ func TestAdminPrefix(t *testing.T) {
|
|||
ts := httptest.NewServer(router)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
router.GET("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("GET /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.DELETE("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("DELETE /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.POST("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("POST /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.PUT("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("PUT /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
router.PATCH("/foo", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
router.HandleFunc("PATCH /foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ import (
|
|||
_ "embed"
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/ory/kratos/x"
|
||||
)
|
||||
|
||||
|
|
@ -45,8 +43,8 @@ type webAuthnJavaScript string
|
|||
// Responses:
|
||||
// 200: webAuthnJavaScript
|
||||
func RegisterWebauthnRoute(r *x.RouterPublic) {
|
||||
if handle, _, _ := r.Lookup("GET", ScriptURL); handle == nil {
|
||||
r.GET(ScriptURL, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
if !r.HasRoute("GET", ScriptURL) {
|
||||
r.GET(ScriptURL, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/javascript; charset=UTF-8")
|
||||
_, _ = w.Write(jsOnLoad)
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in New Issue