chore(kratos): simplify internal APIs

GitOrigin-RevId: 1f209ed68a7e72222b2a29b37bc017ffc5c0cdb4
This commit is contained in:
Patrik 2025-08-19 10:21:10 +02:00 committed by ory-bot
parent 6b2b90e53c
commit 56525d4ef9
32 changed files with 179 additions and 481 deletions

View File

@ -4,7 +4,6 @@
package continuity
import (
"context"
"time"
"github.com/gofrs/uuid"
@ -43,9 +42,7 @@ func (c *Container) UTC() *Container {
return c
}
func (c Container) TableName(ctx context.Context) string {
return "continuity_containers"
}
func (_ Container) TableName() string { return "continuity_containers" }
func NewContainer(name string, o managerOptions) *Container {
return &Container{

View File

@ -4,7 +4,6 @@
package courier
import (
"context"
"encoding/json"
"time"
@ -220,14 +219,5 @@ func (m Message) DefaultPageToken() keysetpagination.PageToken {
}
}
func (m Message) TableName(context.Context) string {
return "courier_messages"
}
func (m *Message) GetID() uuid.UUID {
return m.ID
}
func (m *Message) GetNID() uuid.UUID {
return m.NID
}
func (m Message) TableName() string { return "courier_messages" }
func (m *Message) GetID() uuid.UUID { return m.ID }

View File

@ -671,7 +671,7 @@ func (m *RegistryDefault) Init(ctx context.Context, ctxer contextx.Contextualize
m.Logger().WithError(err).Warnf("Unable to open database, retrying.")
return errors.WithStack(err)
}
p, err := sql.NewPersister(ctx, m, c,
p, err := sql.NewPersister(m, c,
sql.WithExtraMigrations(o.extraMigrations...),
sql.WithExtraGoMigrations(o.extraGoMigrations...),
sql.WithDisabledLogging(o.disableMigrationLogging))

View File

@ -357,10 +357,6 @@ func (i Identity) GetID() uuid.UUID {
return i.ID
}
func (i Identity) GetNID() uuid.UUID {
return i.NID
}
func (i Identity) MarshalJSON() ([]byte, error) {
type localIdentity Identity
i.Credentials = nil

View File

@ -4,7 +4,6 @@
package identity
import (
"context"
"fmt"
"time"
@ -54,17 +53,8 @@ func (v RecoveryAddressType) HTMLFormInputType() string {
return ""
}
func (a RecoveryAddress) TableName(ctx context.Context) string {
return "identity_recovery_addresses"
}
func (a RecoveryAddress) ValidateNID() error {
return nil
}
func (a RecoveryAddress) GetID() uuid.UUID {
return a.ID
}
func (a RecoveryAddress) TableName() string { return "identity_recovery_addresses" }
func (a RecoveryAddress) GetID() uuid.UUID { return a.ID }
// Hash returns a unique string representation for the recovery address.
func (a RecoveryAddress) Hash() string {

View File

@ -108,14 +108,6 @@ func (a VerifiableAddress) GetID() uuid.UUID {
return a.ID
}
func (a VerifiableAddress) GetNID() uuid.UUID {
return a.NID
}
func (a VerifiableAddress) ValidateNID() error {
return nil
}
// Hash returns a unique string representation for the recovery address.
func (a VerifiableAddress) Hash() string {
return fmt.Sprintf("%v|%v|%v|%v|%v|%v|%v", a.Value, a.Verified, a.Via, a.Status, a.VerifiedAt, a.IdentityID, a.NID)

View File

@ -56,34 +56,34 @@ type (
}
)
type persisterOptions struct {
type options struct {
extraMigrations []fs.FS
extraGoMigrations popx.Migrations
disableLogging bool
}
type persisterOption func(o *persisterOptions)
type Option = func(o *options)
func WithExtraMigrations(fss ...fs.FS) persisterOption {
return func(o *persisterOptions) {
func WithExtraMigrations(fss ...fs.FS) Option {
return func(o *options) {
o.extraMigrations = fss
}
}
func WithExtraGoMigrations(ms ...popx.Migration) persisterOption {
return func(o *persisterOptions) {
func WithExtraGoMigrations(ms ...popx.Migration) Option {
return func(o *options) {
o.extraGoMigrations = ms
}
}
func WithDisabledLogging(v bool) persisterOption {
return func(o *persisterOptions) {
func WithDisabledLogging(v bool) Option {
return func(o *options) {
o.disableLogging = v
}
}
func NewPersister(ctx context.Context, r persisterDependencies, c *pop.Connection, opts ...persisterOption) (*Persister, error) {
o := &persisterOptions{}
func NewPersister(r persisterDependencies, c *pop.Connection, opts ...Option) (*Persister, error) {
o := &options{}
for _, f := range opts {
f(o)
}

View File

@ -42,7 +42,7 @@ func withCheckIdentityID(id uuid.UUID) codeOption {
func useOneTimeCode[P any, U interface {
*P
oneTimeCodeProvider
}](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode string, flowTableName string, foreignKeyName string, opts ...codeOption,
}](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode, flowTableName, foreignKeyName string, opts ...codeOption,
) (target U, err error) {
maxSubmissions := p.r.Config().SelfServiceCodeMethodMaxSubmissions(ctx)
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.useOneTimeCode", trace.WithAttributes(attribute.Int("max_submissions", maxSubmissions)))

View File

@ -63,7 +63,7 @@ func (p *Persister) DeleteContinuitySession(ctx context.Context, id uuid.UUID) (
if count, err := p.GetConnection(ctx).RawQuery(
//#nosec G201 -- TableName is static
fmt.Sprintf("DELETE FROM %s WHERE id=? AND nid=?",
new(continuity.Container).TableName(ctx)), id, p.NetworkID(ctx)).ExecWithCount(); err != nil {
continuity.Container{}.TableName()), id, p.NetworkID(ctx)).ExecWithCount(); err != nil {
return sqlcon.HandleError(err)
} else if count == 0 {
return errors.WithStack(sqlcon.ErrNoRows)
@ -76,16 +76,13 @@ func (p *Persister) DeleteExpiredContinuitySessions(ctx context.Context, expires
defer otelx.End(span, &err)
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(continuity.Container).TableName(ctx),
new(continuity.Container).TableName(ctx),
limit,
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
continuity.Container{}.TableName(),
),
expiresAt,
p.NetworkID(ctx),
limit,
).Exec()
if err != nil {
return sqlcon.HandleError(err)
}
return nil
return sqlcon.HandleError(err)
}

View File

@ -70,7 +70,7 @@ func TestPersisterHMAC(t *testing.T) {
conf := config.MustNew(t, logrusx.New("", ""), contextx.NewTestConfigProvider(embedx.ConfigSchema, opts...), opts...)
c, err := pop.NewConnection(&pop.ConnectionDetails{URL: "sqlite://foo?mode=memory"})
require.NoError(t, err)
p, err := NewPersister(ctx, &logRegistryOnly{c: conf}, c)
p, err := NewPersister(&logRegistryOnly{c: conf}, c)
require.NoError(t, err)
t.Run("case=behaves deterministically", func(t *testing.T) {

View File

@ -72,16 +72,13 @@ func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time.
defer otelx.End(span, &err)
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(login.Flow).TableName(ctx),
new(login.Flow).TableName(ctx),
limit,
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
login.Flow{}.TableName(),
),
expiresAt,
p.NetworkID(ctx),
limit,
).Exec()
if err != nil {
return sqlcon.HandleError(err)
}
return nil
return sqlcon.HandleError(err)
}

View File

@ -43,7 +43,7 @@ func (p *Persister) UseLoginCode(ctx context.Context, flowID uuid.UUID, identity
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseLoginCode")
defer otelx.End(span, &err)
codeRow, err := useOneTimeCode[code.LoginCode](ctx, p, flowID, userProvidedCode, new(login.Flow).TableName(ctx), "selfservice_login_flow_id", withCheckIdentityID(identityID))
codeRow, err := useOneTimeCode[code.LoginCode](ctx, p, flowID, userProvidedCode, login.Flow{}.TableName(), "selfservice_login_flow_id", withCheckIdentityID(identityID))
if err != nil {
return nil, err
}

View File

@ -123,16 +123,13 @@ func (p *Persister) DeleteExpiredRecoveryFlows(ctx context.Context, expiresAt ti
defer otelx.End(span, &err)
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(recovery.Flow).TableName(ctx),
new(recovery.Flow).TableName(ctx),
limit,
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
recovery.Flow{}.TableName(),
),
expiresAt,
p.NetworkID(ctx),
limit,
).Exec()
if err != nil {
return sqlcon.HandleError(err)
}
return nil
return sqlcon.HandleError(err)
}

View File

@ -58,7 +58,7 @@ func (p *Persister) UseRecoveryCode(ctx context.Context, flowID uuid.UUID, userP
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRecoveryCode")
defer otelx.End(span, &err)
codeRow, err := useOneTimeCode[code.RecoveryCode](ctx, p, flowID, userProvidedCode, new(recovery.Flow).TableName(ctx), "selfservice_recovery_flow_id")
codeRow, err := useOneTimeCode[code.RecoveryCode](ctx, p, flowID, userProvidedCode, recovery.Flow{}.TableName(), "selfservice_recovery_flow_id")
if err != nil {
return nil, err
}

View File

@ -54,16 +54,13 @@ func (p *Persister) DeleteExpiredRegistrationFlows(ctx context.Context, expiresA
defer otelx.End(span, &err)
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(registration.Flow).TableName(ctx),
new(registration.Flow).TableName(ctx),
limit,
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
registration.Flow{}.TableName(),
),
expiresAt,
p.NetworkID(ctx),
limit,
).Exec()
if err != nil {
return sqlcon.HandleError(err)
}
return nil
return sqlcon.HandleError(err)
}

View File

@ -44,7 +44,7 @@ func (p *Persister) UseRegistrationCode(ctx context.Context, flowID uuid.UUID, u
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRegistrationCode")
defer otelx.End(span, &err)
codeRow, err := useOneTimeCode[code.RegistrationCode](ctx, p, flowID, userProvidedCode, new(registration.Flow).TableName(ctx), "selfservice_registration_flow_id")
codeRow, err := useOneTimeCode[code.RegistrationCode](ctx, p, flowID, userProvidedCode, registration.Flow{}.TableName(), "selfservice_registration_flow_id")
if err != nil {
return nil, err
}

View File

@ -304,7 +304,7 @@ func (p *Persister) DeleteSession(ctx context.Context, sid uuid.UUID) (err error
nid := p.NetworkID(ctx)
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", new(session.Session).TableName(ctx)),
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", session.Session{}.TableName()),
sid,
nid,
).ExecWithCount()
@ -324,7 +324,7 @@ func (p *Persister) DeleteSessionsByIdentity(ctx context.Context, identityID uui
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE identity_id = ? AND nid = ?",
new(session.Session).TableName(ctx),
session.Session{}.TableName(),
),
identityID,
p.NetworkID(ctx),
@ -390,7 +390,7 @@ func (p *Persister) DeleteSessionByToken(ctx context.Context, token string) (err
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE token = ? AND nid = ?",
new(session.Session).TableName(ctx),
session.Session{}.TableName(),
),
token,
p.NetworkID(ctx),
@ -411,7 +411,7 @@ func (p *Persister) RevokeSessionByToken(ctx context.Context, token string) (err
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"UPDATE %s SET active = false WHERE token = ? AND nid = ?",
new(session.Session).TableName(ctx),
session.Session{}.TableName(),
),
token,
p.NetworkID(ctx),
@ -433,7 +433,7 @@ func (p *Persister) RevokeSessionById(ctx context.Context, sID uuid.UUID) (err e
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"UPDATE %s SET active = false WHERE id = ? AND nid = ?",
new(session.Session).TableName(ctx),
session.Session{}.TableName(),
),
sID,
p.NetworkID(ctx),
@ -456,7 +456,7 @@ func (p *Persister) RevokeSession(ctx context.Context, iID, sID uuid.UUID) (err
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"UPDATE %s SET active = false WHERE id = ? AND identity_id = ? AND nid = ?",
new(session.Session).TableName(ctx),
session.Session{}.TableName(),
),
sID,
iID,
@ -476,7 +476,7 @@ func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID u
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"UPDATE %s SET active = false WHERE identity_id = ? AND id != ? AND nid = ?",
new(session.Session).TableName(ctx),
session.Session{}.TableName(),
),
iID,
sID,
@ -494,16 +494,13 @@ func (p *Persister) DeleteExpiredSessions(ctx context.Context, expiresAt time.Ti
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(session.Session).TableName(ctx),
new(session.Session).TableName(ctx),
limit,
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
session.Session{}.TableName(),
),
expiresAt,
p.NetworkID(ctx),
limit,
).Exec()
if err != nil {
return sqlcon.HandleError(err)
}
return nil
return sqlcon.HandleError(err)
}

View File

@ -114,13 +114,12 @@ func (p *Persister) DeleteExpiredExchangers(ctx context.Context, at time.Time, l
//#nosec G201 -- TableName is static
err := conn.RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE created_at <= ? and nid = ? ORDER BY created_at ASC LIMIT %d ) AS s )",
conn.Dialect.Quote(new(sessiontokenexchange.Exchanger).TableName()),
conn.Dialect.Quote(new(sessiontokenexchange.Exchanger).TableName()),
limit,
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE created_at <= ? and nid = ? ORDER BY created_at ASC LIMIT ?) AS s)",
sessiontokenexchange.Exchanger{}.TableName(),
),
expiredAfter,
p.NetworkID(ctx),
limit,
).Exec()
return sqlcon.HandleError(err)

View File

@ -64,16 +64,13 @@ func (p *Persister) DeleteExpiredSettingsFlows(ctx context.Context, expiresAt ti
defer otelx.End(span, &err)
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(settings.Flow).TableName(ctx),
new(settings.Flow).TableName(ctx),
limit,
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
settings.Flow{}.TableName(),
),
expiresAt,
p.NetworkID(ctx),
limit,
).Exec()
if err != nil {
return sqlcon.HandleError(err)
}
return nil
return sqlcon.HandleError(err)
}

View File

@ -45,6 +45,7 @@ import (
"github.com/ory/kratos/x"
"github.com/ory/pop/v6"
"github.com/ory/pop/v6/logging"
"github.com/ory/x/popx"
"github.com/ory/x/sqlcon"
"github.com/ory/x/sqlcon/dockertest"
"github.com/ory/x/sqlxx"
@ -315,7 +316,7 @@ func TestPersister_Transaction(t *testing.T) {
ID: x.NewUUID(),
}
err := c.Transaction(func(tx *pop.Connection) error {
ctx := sql.WithTransaction(context.Background(), tx)
ctx := popx.WithTransaction(context.Background(), tx)
require.NoError(t, p.CreateLoginFlow(ctx, lr), "%+v", lr)
require.NoError(t, getErr(p.GetLoginFlow(ctx, lr.ID)), "%+v", lr)
return errors.New(errMessage)
@ -334,9 +335,7 @@ func Benchmark_BatchCreateIdentities(b *testing.B) {
batchSizes := []int{1, 10, 100, 500, 800, 900, 1000, 2000, 3000}
parallelRequests := []int{1, 4, 8, 16}
for name := range conns {
name := name
reg := conns[name]
for name, reg := range conns {
b.Run(fmt.Sprintf("database=%s", name), func(b *testing.B) {
conf := reg.Config()
_, p := testhelpers.NewNetwork(b, ctx, reg.Persister())

View File

@ -11,10 +11,6 @@ import (
"github.com/ory/pop/v6"
)
func WithTransaction(ctx context.Context, tx *pop.Connection) context.Context {
return popx.WithTransaction(ctx, tx)
}
func (p *Persister) Transaction(ctx context.Context, callback func(ctx context.Context, connection *pop.Connection) error) error {
return popx.Transaction(ctx, p.c.WithContext(ctx), callback)
}

View File

@ -121,16 +121,13 @@ func (p *Persister) DeleteExpiredVerificationFlows(ctx context.Context, expiresA
defer otelx.End(span, &err)
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(verification.Flow).TableName(ctx),
new(verification.Flow).TableName(ctx),
limit,
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
verification.Flow{}.TableName(),
),
expiresAt,
p.NetworkID(ctx),
limit,
).Exec()
if err != nil {
return sqlcon.HandleError(err)
}
return nil
return sqlcon.HandleError(err)
}

View File

@ -55,7 +55,7 @@ func (p *Persister) UseVerificationCode(ctx context.Context, flowID uuid.UUID, u
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseVerificationCode")
defer otelx.End(span, &err)
codeRow, err := useOneTimeCode[code.VerificationCode](ctx, p, flowID, userProvidedCode, new(verification.Flow).TableName(ctx), "selfservice_verification_flow_id")
codeRow, err := useOneTimeCode[code.VerificationCode](ctx, p, flowID, userProvidedCode, verification.Flow{}.TableName(), "selfservice_verification_flow_id")
if err != nil {
return nil, err
}

View File

@ -7,7 +7,6 @@ import (
"context"
"fmt"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/trace"
@ -17,12 +16,7 @@ import (
"github.com/ory/x/sqlcon"
)
type Model interface {
GetID() uuid.UUID
GetNID() uuid.UUID
}
func Generic(ctx context.Context, c *pop.Connection, tracer trace.Tracer, v Model, columnNames ...string) (err error) {
func Generic(ctx context.Context, c *pop.Connection, tracer trace.Tracer, v any, columnNames ...string) (err error) {
ctx, span := tracer.Start(ctx, "persistence.sql.update")
defer otelx.End(span, &err)

View File

@ -4,9 +4,9 @@
package login
import (
"cmp"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
@ -14,7 +14,6 @@ import (
"time"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/tidwall/gjson"
@ -28,7 +27,6 @@ import (
"github.com/ory/kratos/x/redir"
"github.com/ory/pop/v6"
"github.com/ory/x/sqlxx"
"github.com/ory/x/stringsx"
"github.com/ory/x/urlx"
)
@ -150,7 +148,7 @@ type Flow struct {
// ReturnToVerification contains the redirect URL for the verification flow.
ReturnToVerification string `json:"-" db:"-"`
isAccountLinkingFlow bool `json:"-" db:"-"`
isAccountLinkingFlow bool `db:"-"`
// IdentitySchema optionally holds the ID of the identity schema that is used
// for this flow. This value can be set by the user when creating the flow and
@ -158,7 +156,8 @@ type Flow struct {
IdentitySchema flow.IdentitySchema `json:"identity_schema,omitempty" faker:"-" db:"identity_schema_id"`
}
var _ flow.Flow = new(Flow)
var _ flow.Flow = (*Flow)(nil)
var _ flow.FlowWithContinueWith = (*Flow)(nil)
func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, flowType flow.Type) (*Flow, error) {
now := time.Now().UTC()
@ -204,7 +203,7 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
CSRFToken: csrf,
Type: flowType,
Refresh: refresh,
RequestedAAL: identity.AuthenticatorAssuranceLevel(strings.ToLower(stringsx.Coalesce(
RequestedAAL: identity.AuthenticatorAssuranceLevel(strings.ToLower(cmp.Or(
r.URL.Query().Get("aal"),
string(identity.AuthenticatorAssuranceLevel1)))),
InternalContext: []byte("{}"),
@ -213,21 +212,25 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
}, nil
}
func (f *Flow) GetType() flow.Type {
return f.Type
}
func (f *Flow) GetType() flow.Type { return f.Type }
func (f *Flow) GetRequestURL() string { return f.RequestURL }
func (f *Flow) GetID() uuid.UUID { return f.ID }
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage { return f.InternalContext }
func (f *Flow) SetInternalContext(bytes sqlxx.JSONRawMessage) { f.InternalContext = bytes }
func (f *Flow) GetUI() *container.Container { return f.UI }
func (f *Flow) GetState() flow.State { return f.State }
func (_ *Flow) GetFlowName() flow.FlowName { return flow.LoginFlow }
func (_ Flow) TableName() string { return "selfservice_login_flows" }
func (f *Flow) ContinueWith() []flow.ContinueWith { return f.ContinueWithItems }
func (f *Flow) SetReturnToVerification(to string) { f.ReturnToVerification = to }
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString { return f.OAuth2LoginChallenge }
func (f *Flow) AppendTo(src *url.URL) *url.URL { return flow.AppendFlowTo(src, f.ID) }
func (f *Flow) SetState(state flow.State) { f.State = state }
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
func (f *Flow) GetRequestURL() string {
return f.RequestURL
}
func (f Flow) TableName(ctx context.Context) string {
return "selfservice_login_flows"
}
func (f Flow) WhereID(ctx context.Context, alias string) string {
return fmt.Sprintf("%s.%s = ? AND %s.%s = ?", alias, "id", alias, "nid")
}
// IsRefresh returns true if the login flow was triggered to re-authenticate the user.
// This is the case if the refresh query parameter is set to true.
func (f *Flow) IsRefresh() bool { return f.Refresh }
func (f *Flow) Valid() error {
if f.ExpiresAt.Before(time.Now()) {
@ -236,38 +239,12 @@ func (f *Flow) Valid() error {
return nil
}
func (f Flow) GetID() uuid.UUID {
return f.ID
}
// IsRefresh returns true if the login flow was triggered to re-authenticate the user.
// This is the case if the refresh query parameter is set to true.
func (f *Flow) IsRefresh() bool {
return f.Refresh
}
func (f *Flow) AppendTo(src *url.URL) *url.URL {
return flow.AppendFlowTo(src, f.ID)
}
func (f Flow) GetNID() uuid.UUID {
return f.NID
}
func (f *Flow) EnsureInternalContext() {
if !gjson.ParseBytes(f.InternalContext).IsObject() {
f.InternalContext = []byte("{}")
}
}
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage {
return f.InternalContext
}
func (f *Flow) SetInternalContext(bytes sqlxx.JSONRawMessage) {
f.InternalContext = bytes
}
func (f Flow) MarshalJSON() ([]byte, error) {
type local Flow
f.SetReturnTo()
@ -294,10 +271,6 @@ func (f *Flow) AfterSave(*pop.Connection) error {
return nil
}
func (f *Flow) GetUI() *container.Container {
return f.UI
}
func (f *Flow) SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (opts []redir.SecureRedirectOption) {
return []redir.SecureRedirectOption{
redir.SecureRedirectReturnTo(f.ReturnTo),
@ -308,41 +281,15 @@ func (f *Flow) SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (o
}
}
func (f *Flow) GetState() flow.State {
return flow.State(f.State)
}
func (f *Flow) GetFlowName() flow.FlowName {
return flow.LoginFlow
}
func (f *Flow) SetState(state flow.State) {
f.State = State(state)
}
func (t *Flow) GetTransientPayload() json.RawMessage {
return t.TransientPayload
}
var _ flow.FlowWithContinueWith = new(Flow)
func (f *Flow) AddContinueWith(c flow.ContinueWith) {
f.ContinueWithItems = append(f.ContinueWithItems, c)
}
func (f *Flow) ContinueWith() []flow.ContinueWith {
return f.ContinueWithItems
}
func (f *Flow) SetReturnToVerification(to string) {
f.ReturnToVerification = to
}
func (f *Flow) ToLoggerField() map[string]interface{} {
func (f *Flow) ToLoggerField() map[string]any {
if f == nil {
return map[string]interface{}{}
return map[string]any{}
}
return map[string]interface{}{
return map[string]any{
"id": f.ID.String(),
"return_to": f.ReturnTo,
"request_url": f.RequestURL,
@ -354,7 +301,3 @@ func (f *Flow) ToLoggerField() map[string]interface{} {
"requested_aal": f.RequestedAAL,
}
}
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString {
return f.OAuth2LoginChallenge
}

View File

@ -4,7 +4,6 @@
package recovery
import (
"context"
"encoding/json"
"net/http"
"net/url"
@ -112,7 +111,7 @@ type Flow struct {
TransientPayload json.RawMessage `json:"transient_payload,omitempty" faker:"-" db:"-"`
}
var _ flow.Flow = new(Flow)
var _ flow.Flow = (*Flow)(nil)
func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategy Strategy, ft flow.Type) (*Flow, error) {
now := time.Now().UTC()
@ -135,7 +134,7 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
state = flow.StateRecoveryAwaitingAddress
}
flow := &Flow{
f := &Flow{
ID: id,
ExpiresAt: now.Add(exp),
IssuedAt: now,
@ -150,13 +149,13 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
}
if strategy != nil {
flow.Active = sqlxx.NullString(strategy.NodeGroup())
if err := strategy.PopulateRecoveryMethod(r, flow); err != nil {
f.Active = sqlxx.NullString(strategy.NodeGroup())
if err := strategy.PopulateRecoveryMethod(r, f); err != nil {
return nil, err
}
}
return flow, nil
return f, nil
}
func FromOldFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategy Strategy, of Flow) (*Flow, error) {
@ -174,25 +173,15 @@ func FromOldFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Re
return nf, nil
}
func (f *Flow) GetType() flow.Type {
return f.Type
}
func (f *Flow) GetRequestURL() string {
return f.RequestURL
}
func (f Flow) TableName(ctx context.Context) string {
return "selfservice_recovery_flows"
}
func (f Flow) GetID() uuid.UUID {
return f.ID
}
func (f Flow) GetNID() uuid.UUID {
return f.NID
}
func (f *Flow) GetType() flow.Type { return f.Type }
func (f *Flow) GetRequestURL() string { return f.RequestURL }
func (_ Flow) TableName() string { return "selfservice_recovery_flows" }
func (f Flow) GetID() uuid.UUID { return f.ID }
func (f *Flow) GetUI() *container.Container { return f.UI }
func (f *Flow) GetState() State { return f.State }
func (_ *Flow) GetFlowName() flow.FlowName { return flow.RecoveryFlow }
func (f *Flow) SetState(state State) { f.State = state }
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
func (f *Flow) Valid() error {
if f.ExpiresAt.Before(time.Now().UTC()) {
@ -236,31 +225,11 @@ func (f *Flow) AfterSave(*pop.Connection) error {
return nil
}
func (f *Flow) GetUI() *container.Container {
return f.UI
}
func (f *Flow) GetState() State {
return f.State
}
func (f *Flow) GetFlowName() flow.FlowName {
return flow.RecoveryFlow
}
func (f *Flow) SetState(state State) {
f.State = state
}
func (t *Flow) GetTransientPayload() json.RawMessage {
return t.TransientPayload
}
func (f *Flow) ToLoggerField() map[string]interface{} {
func (f *Flow) ToLoggerField() map[string]any {
if f == nil {
return map[string]interface{}{}
return map[string]any{}
}
return map[string]interface{}{
return map[string]any{
"id": f.ID.String(),
"return_to": f.ReturnTo,
"request_url": f.RequestURL,

View File

@ -182,17 +182,20 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
}, nil
}
func (f Flow) TableName(context.Context) string {
return "selfservice_registration_flows"
}
func (f Flow) GetID() uuid.UUID {
return f.ID
}
func (f Flow) GetNID() uuid.UUID {
return f.NID
}
func (_ Flow) TableName() string { return "selfservice_registration_flows" }
func (f Flow) GetID() uuid.UUID { return f.ID }
func (f *Flow) AppendTo(src *url.URL) *url.URL { return flow.AppendFlowTo(src, f.ID) }
func (f *Flow) GetType() flow.Type { return f.Type }
func (f *Flow) GetRequestURL() string { return f.RequestURL }
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage { return f.InternalContext }
func (f *Flow) SetInternalContext(bytes sqlxx.JSONRawMessage) { f.InternalContext = bytes }
func (f *Flow) GetUI() *container.Container { return f.UI }
func (f *Flow) GetState() State { return f.State }
func (_ *Flow) GetFlowName() flow.FlowName { return flow.RegistrationFlow }
func (f *Flow) SetState(state State) { f.State = state }
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
func (f *Flow) SetReturnToVerification(to string) { f.ReturnToVerification = to }
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString { return f.OAuth2LoginChallenge }
func (f *Flow) Valid() error {
if f.ExpiresAt.Before(time.Now()) {
@ -201,32 +204,12 @@ func (f *Flow) Valid() error {
return nil
}
func (f *Flow) AppendTo(src *url.URL) *url.URL {
return flow.AppendFlowTo(src, f.ID)
}
func (f *Flow) GetType() flow.Type {
return f.Type
}
func (f *Flow) GetRequestURL() string {
return f.RequestURL
}
func (f *Flow) EnsureInternalContext() {
if !gjson.ParseBytes(f.InternalContext).IsObject() {
f.InternalContext = []byte("{}")
}
}
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage {
return f.InternalContext
}
func (f *Flow) SetInternalContext(bytes sqlxx.JSONRawMessage) {
f.InternalContext = bytes
}
func (f Flow) MarshalJSON() ([]byte, error) {
type local Flow
f.SetReturnTo()
@ -253,17 +236,11 @@ func (f *Flow) AfterSave(*pop.Connection) error {
return nil
}
func (f *Flow) GetUI() *container.Container {
return f.UI
}
func (f *Flow) AddContinueWith(c flow.ContinueWith) {
f.ContinueWithItems = append(f.ContinueWithItems, c)
}
func (f *Flow) ContinueWith() []flow.ContinueWith {
return f.ContinueWithItems
}
func (f *Flow) ContinueWith() []flow.ContinueWith { return f.ContinueWithItems }
func (f *Flow) SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (opts []redir.SecureRedirectOption) {
return []redir.SecureRedirectOption{
@ -275,31 +252,11 @@ func (f *Flow) SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (o
}
}
func (f *Flow) GetState() State {
return f.State
}
func (f *Flow) GetFlowName() flow.FlowName {
return flow.RegistrationFlow
}
func (f *Flow) SetState(state State) {
f.State = state
}
func (f *Flow) GetTransientPayload() json.RawMessage {
return f.TransientPayload
}
func (f *Flow) SetReturnToVerification(to string) {
f.ReturnToVerification = to
}
func (f *Flow) ToLoggerField() map[string]interface{} {
func (f *Flow) ToLoggerField() map[string]any {
if f == nil {
return map[string]interface{}{}
return map[string]any{}
}
return map[string]interface{}{
return map[string]any{
"id": f.ID.String(),
"return_to": f.ReturnTo,
"request_url": f.RequestURL,
@ -309,7 +266,3 @@ func (f *Flow) ToLoggerField() map[string]interface{} {
"state": f.State,
}
}
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString {
return f.OAuth2LoginChallenge
}

View File

@ -4,39 +4,29 @@
package settings
import (
"context"
"encoding/json"
"net/http"
"net/url"
"time"
"github.com/ory/kratos/x/redir"
"github.com/ory/pop/v6"
"github.com/ory/kratos/text"
"github.com/tidwall/gjson"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/ui/container"
"github.com/ory/x/urlx"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/ory/x/sqlxx"
"github.com/tidwall/gjson"
"github.com/ory/herodot"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/session"
"github.com/ory/kratos/text"
"github.com/ory/kratos/ui/container"
"github.com/ory/kratos/x"
"github.com/ory/kratos/x/redir"
"github.com/ory/pop/v6"
"github.com/ory/x/sqlxx"
"github.com/ory/x/urlx"
)
var _ flow.InternalContexter = (*Flow)(nil)
// Flow represents a Settings Flow
//
// This flow is used when an identity wants to update settings
@ -130,15 +120,8 @@ type Flow struct {
TransientPayload json.RawMessage `json:"transient_payload,omitempty" faker:"-" db:"-"`
}
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage {
return f.InternalContext
}
func (f *Flow) SetInternalContext(message sqlxx.JSONRawMessage) {
f.InternalContext = message
}
var _ flow.Flow = new(Flow)
var _ flow.Flow = (*Flow)(nil)
var _ flow.InternalContexter = (*Flow)(nil)
func MustNewFlow(conf *config.Config, exp time.Duration, r *http.Request, i *identity.Identity, ft flow.Type) *Flow {
f, err := NewFlow(conf, exp, r, i, ft)
@ -181,29 +164,19 @@ func NewFlow(conf *config.Config, exp time.Duration, r *http.Request, i *identit
}, nil
}
func (f *Flow) GetType() flow.Type {
return f.Type
}
func (f *Flow) GetRequestURL() string {
return f.RequestURL
}
func (f Flow) TableName(ctx context.Context) string {
return "selfservice_settings_flows"
}
func (f Flow) GetID() uuid.UUID {
return f.ID
}
func (f Flow) GetNID() uuid.UUID {
return f.NID
}
func (f *Flow) AppendTo(src *url.URL) *url.URL {
return flow.AppendFlowTo(src, f.ID)
}
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage { return f.InternalContext }
func (f *Flow) SetInternalContext(message sqlxx.JSONRawMessage) { f.InternalContext = message }
func (f *Flow) GetType() flow.Type { return f.Type }
func (f *Flow) GetRequestURL() string { return f.RequestURL }
func (_ Flow) TableName() string { return "selfservice_settings_flows" }
func (f Flow) GetID() uuid.UUID { return f.ID }
func (f *Flow) AppendTo(src *url.URL) *url.URL { return flow.AppendFlowTo(src, f.ID) }
func (f *Flow) GetUI() *container.Container { return f.UI }
func (f *Flow) ContinueWith() []flow.ContinueWith { return f.ContinueWithItems }
func (f *Flow) GetState() State { return f.State }
func (_ *Flow) GetFlowName() flow.FlowName { return flow.SettingsFlow }
func (f *Flow) SetState(state State) { f.State = state }
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
func (f *Flow) Valid(s *session.Session) error {
if f.ExpiresAt.Before(time.Now().UTC()) {
@ -250,39 +223,15 @@ func (f *Flow) AfterSave(*pop.Connection) error {
return nil
}
func (f *Flow) GetUI() *container.Container {
return f.UI
}
func (f *Flow) AddContinueWith(c flow.ContinueWith) {
f.ContinueWithItems = append(f.ContinueWithItems, c)
}
func (f *Flow) ContinueWith() []flow.ContinueWith {
return f.ContinueWithItems
}
func (f *Flow) GetState() State {
return f.State
}
func (f *Flow) GetFlowName() flow.FlowName {
return flow.SettingsFlow
}
func (f *Flow) SetState(state State) {
f.State = state
}
func (t *Flow) GetTransientPayload() json.RawMessage {
return t.TransientPayload
}
func (f *Flow) ToLoggerField() map[string]interface{} {
func (f *Flow) ToLoggerField() map[string]any {
if f == nil {
return map[string]interface{}{}
return map[string]any{}
}
return map[string]interface{}{
return map[string]any{
"id": f.ID.String(),
"return_to": f.ReturnTo,
"request_url": f.RequestURL,

View File

@ -26,8 +26,6 @@ import (
"github.com/ory/x/urlx"
)
var _ flow.Flow = new(Flow)
// A Verification Flow
//
// Used to verify an out-of-band communication
@ -113,19 +111,7 @@ type OAuth2LoginChallengeParams struct {
AMR session.AuthenticationMethods `db:"authentication_methods" json:"-"`
}
var _ flow.Flow = new(Flow)
func (f *Flow) GetType() flow.Type {
return f.Type
}
func (f *Flow) GetRequestURL() string {
return f.RequestURL
}
func (f Flow) TableName(context.Context) string {
return "selfservice_verification_flows"
}
var _ flow.Flow = (*Flow)(nil)
func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategy Strategy, ft flow.Type) (*Flow, error) {
now := time.Now().UTC()
@ -209,6 +195,17 @@ func NewPostHookFlow(conf *config.Config, exp time.Duration, csrf string, r *htt
return f, nil
}
func (f *Flow) GetType() flow.Type { return f.Type }
func (f *Flow) GetRequestURL() string { return f.RequestURL }
func (_ Flow) TableName() string { return "selfservice_verification_flows" }
func (f Flow) GetID() uuid.UUID { return f.ID }
func (f *Flow) GetState() State { return f.State }
func (_ *Flow) GetFlowName() flow.FlowName { return flow.VerificationFlow }
func (f *Flow) SetState(state State) { f.State = state }
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString { return f.OAuth2LoginChallenge }
func (f *Flow) GetUI() *container.Container { return f.UI }
func (f *Flow) Valid() error {
if f.ExpiresAt.Before(time.Now()) {
return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt))
@ -222,14 +219,6 @@ func (f *Flow) AppendTo(src *url.URL) *url.URL {
return urlx.CopyWithQuery(src, values)
}
func (f Flow) GetID() uuid.UUID {
return f.ID
}
func (f Flow) GetNID() uuid.UUID {
return f.NID
}
func (f *Flow) SetCSRFToken(token string) {
f.CSRFToken = token
f.UI.SetCSRF(token)
@ -257,10 +246,6 @@ func (f *Flow) AfterSave(*pop.Connection) error {
return nil
}
func (f *Flow) GetUI() *container.Container {
return f.UI
}
// ContinueURL generates the URL to show on the continue screen after succesful verification
//
// It follows the following precedence:
@ -290,22 +275,6 @@ func (f *Flow) ContinueURL(ctx context.Context, config *config.Config) *url.URL
return returnTo
}
func (f *Flow) GetState() State {
return f.State
}
func (f *Flow) GetFlowName() flow.FlowName {
return flow.VerificationFlow
}
func (f *Flow) SetState(state State) {
f.State = state
}
func (t *Flow) GetTransientPayload() json.RawMessage {
return t.TransientPayload
}
func (f *Flow) ToLoggerField() map[string]interface{} {
if f == nil {
return map[string]interface{}{}
@ -320,7 +289,3 @@ func (f *Flow) ToLoggerField() map[string]interface{} {
"state": f.State,
}
}
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString {
return f.OAuth2LoginChallenge
}

View File

@ -31,9 +31,7 @@ type Exchanger struct {
UpdatedAt time.Time `db:"updated_at"`
}
func (e *Exchanger) TableName() string {
return "session_token_exchanges"
}
func (_ Exchanger) TableName() string { return "session_token_exchanges" }
type (
Persister interface {

View File

@ -12,18 +12,15 @@ import (
"strings"
"time"
"github.com/ory/kratos/x"
"github.com/ory/x/httpx"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/pointerx"
"github.com/pkg/errors"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/ory/herodot"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/x"
"github.com/ory/x/httpx"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/pointerx"
"github.com/ory/x/randx"
)
@ -67,9 +64,7 @@ type Device struct {
NID uuid.UUID `json:"-" faker:"-" db:"nid"`
}
func (m Device) TableName(ctx context.Context) string {
return "session_devices"
}
func (Device) TableName() string { return "session_devices" }
// A Session
//
@ -166,9 +161,7 @@ func (m Session) DefaultPageToken() keysetpagination.PageToken {
}
}
func (s Session) TableName(ctx context.Context) string {
return "sessions"
}
func (s Session) TableName() string { return "sessions" }
func (s *Session) CompletedLoginForMethod(method AuthenticationMethod) {
method.CompletedAt = time.Now().UTC()

View File

@ -7,12 +7,8 @@ import (
"net/http"
"path"
"strings"
"github.com/urfave/negroni"
)
var _ negroni.Handler
const AdminPrefix = "/admin"
func RedirectAdminMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {