mirror of https://github.com/ory/kratos
chore(kratos): simplify internal APIs
GitOrigin-RevId: 1f209ed68a7e72222b2a29b37bc017ffc5c0cdb4
This commit is contained in:
parent
6b2b90e53c
commit
56525d4ef9
|
|
@ -4,7 +4,6 @@
|
|||
package continuity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
|
|
@ -43,9 +42,7 @@ func (c *Container) UTC() *Container {
|
|||
return c
|
||||
}
|
||||
|
||||
func (c Container) TableName(ctx context.Context) string {
|
||||
return "continuity_containers"
|
||||
}
|
||||
func (_ Container) TableName() string { return "continuity_containers" }
|
||||
|
||||
func NewContainer(name string, o managerOptions) *Container {
|
||||
return &Container{
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
package courier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
|
|
@ -220,14 +219,5 @@ func (m Message) DefaultPageToken() keysetpagination.PageToken {
|
|||
}
|
||||
}
|
||||
|
||||
func (m Message) TableName(context.Context) string {
|
||||
return "courier_messages"
|
||||
}
|
||||
|
||||
func (m *Message) GetID() uuid.UUID {
|
||||
return m.ID
|
||||
}
|
||||
|
||||
func (m *Message) GetNID() uuid.UUID {
|
||||
return m.NID
|
||||
}
|
||||
func (m Message) TableName() string { return "courier_messages" }
|
||||
func (m *Message) GetID() uuid.UUID { return m.ID }
|
||||
|
|
|
|||
|
|
@ -671,7 +671,7 @@ func (m *RegistryDefault) Init(ctx context.Context, ctxer contextx.Contextualize
|
|||
m.Logger().WithError(err).Warnf("Unable to open database, retrying.")
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
p, err := sql.NewPersister(ctx, m, c,
|
||||
p, err := sql.NewPersister(m, c,
|
||||
sql.WithExtraMigrations(o.extraMigrations...),
|
||||
sql.WithExtraGoMigrations(o.extraGoMigrations...),
|
||||
sql.WithDisabledLogging(o.disableMigrationLogging))
|
||||
|
|
|
|||
|
|
@ -357,10 +357,6 @@ func (i Identity) GetID() uuid.UUID {
|
|||
return i.ID
|
||||
}
|
||||
|
||||
func (i Identity) GetNID() uuid.UUID {
|
||||
return i.NID
|
||||
}
|
||||
|
||||
func (i Identity) MarshalJSON() ([]byte, error) {
|
||||
type localIdentity Identity
|
||||
i.Credentials = nil
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
package identity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
|
@ -54,17 +53,8 @@ func (v RecoveryAddressType) HTMLFormInputType() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (a RecoveryAddress) TableName(ctx context.Context) string {
|
||||
return "identity_recovery_addresses"
|
||||
}
|
||||
|
||||
func (a RecoveryAddress) ValidateNID() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a RecoveryAddress) GetID() uuid.UUID {
|
||||
return a.ID
|
||||
}
|
||||
func (a RecoveryAddress) TableName() string { return "identity_recovery_addresses" }
|
||||
func (a RecoveryAddress) GetID() uuid.UUID { return a.ID }
|
||||
|
||||
// Hash returns a unique string representation for the recovery address.
|
||||
func (a RecoveryAddress) Hash() string {
|
||||
|
|
|
|||
|
|
@ -108,14 +108,6 @@ func (a VerifiableAddress) GetID() uuid.UUID {
|
|||
return a.ID
|
||||
}
|
||||
|
||||
func (a VerifiableAddress) GetNID() uuid.UUID {
|
||||
return a.NID
|
||||
}
|
||||
|
||||
func (a VerifiableAddress) ValidateNID() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hash returns a unique string representation for the recovery address.
|
||||
func (a VerifiableAddress) Hash() string {
|
||||
return fmt.Sprintf("%v|%v|%v|%v|%v|%v|%v", a.Value, a.Verified, a.Via, a.Status, a.VerifiedAt, a.IdentityID, a.NID)
|
||||
|
|
|
|||
|
|
@ -56,34 +56,34 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
type persisterOptions struct {
|
||||
type options struct {
|
||||
extraMigrations []fs.FS
|
||||
extraGoMigrations popx.Migrations
|
||||
disableLogging bool
|
||||
}
|
||||
|
||||
type persisterOption func(o *persisterOptions)
|
||||
type Option = func(o *options)
|
||||
|
||||
func WithExtraMigrations(fss ...fs.FS) persisterOption {
|
||||
return func(o *persisterOptions) {
|
||||
func WithExtraMigrations(fss ...fs.FS) Option {
|
||||
return func(o *options) {
|
||||
o.extraMigrations = fss
|
||||
}
|
||||
}
|
||||
|
||||
func WithExtraGoMigrations(ms ...popx.Migration) persisterOption {
|
||||
return func(o *persisterOptions) {
|
||||
func WithExtraGoMigrations(ms ...popx.Migration) Option {
|
||||
return func(o *options) {
|
||||
o.extraGoMigrations = ms
|
||||
}
|
||||
}
|
||||
|
||||
func WithDisabledLogging(v bool) persisterOption {
|
||||
return func(o *persisterOptions) {
|
||||
func WithDisabledLogging(v bool) Option {
|
||||
return func(o *options) {
|
||||
o.disableLogging = v
|
||||
}
|
||||
}
|
||||
|
||||
func NewPersister(ctx context.Context, r persisterDependencies, c *pop.Connection, opts ...persisterOption) (*Persister, error) {
|
||||
o := &persisterOptions{}
|
||||
func NewPersister(r persisterDependencies, c *pop.Connection, opts ...Option) (*Persister, error) {
|
||||
o := &options{}
|
||||
for _, f := range opts {
|
||||
f(o)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func withCheckIdentityID(id uuid.UUID) codeOption {
|
|||
func useOneTimeCode[P any, U interface {
|
||||
*P
|
||||
oneTimeCodeProvider
|
||||
}](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode string, flowTableName string, foreignKeyName string, opts ...codeOption,
|
||||
}](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode, flowTableName, foreignKeyName string, opts ...codeOption,
|
||||
) (target U, err error) {
|
||||
maxSubmissions := p.r.Config().SelfServiceCodeMethodMaxSubmissions(ctx)
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.useOneTimeCode", trace.WithAttributes(attribute.Int("max_submissions", maxSubmissions)))
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ func (p *Persister) DeleteContinuitySession(ctx context.Context, id uuid.UUID) (
|
|||
if count, err := p.GetConnection(ctx).RawQuery(
|
||||
//#nosec G201 -- TableName is static
|
||||
fmt.Sprintf("DELETE FROM %s WHERE id=? AND nid=?",
|
||||
new(continuity.Container).TableName(ctx)), id, p.NetworkID(ctx)).ExecWithCount(); err != nil {
|
||||
continuity.Container{}.TableName()), id, p.NetworkID(ctx)).ExecWithCount(); err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
} else if count == 0 {
|
||||
return errors.WithStack(sqlcon.ErrNoRows)
|
||||
|
|
@ -76,16 +76,13 @@ func (p *Persister) DeleteExpiredContinuitySessions(ctx context.Context, expires
|
|||
defer otelx.End(span, &err)
|
||||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(continuity.Container).TableName(ctx),
|
||||
new(continuity.Container).TableName(ctx),
|
||||
limit,
|
||||
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
|
||||
continuity.Container{}.TableName(),
|
||||
),
|
||||
expiresAt,
|
||||
p.NetworkID(ctx),
|
||||
limit,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
return nil
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ func TestPersisterHMAC(t *testing.T) {
|
|||
conf := config.MustNew(t, logrusx.New("", ""), contextx.NewTestConfigProvider(embedx.ConfigSchema, opts...), opts...)
|
||||
c, err := pop.NewConnection(&pop.ConnectionDetails{URL: "sqlite://foo?mode=memory"})
|
||||
require.NoError(t, err)
|
||||
p, err := NewPersister(ctx, &logRegistryOnly{c: conf}, c)
|
||||
p, err := NewPersister(&logRegistryOnly{c: conf}, c)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("case=behaves deterministically", func(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -72,16 +72,13 @@ func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time.
|
|||
defer otelx.End(span, &err)
|
||||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(login.Flow).TableName(ctx),
|
||||
new(login.Flow).TableName(ctx),
|
||||
limit,
|
||||
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
|
||||
login.Flow{}.TableName(),
|
||||
),
|
||||
expiresAt,
|
||||
p.NetworkID(ctx),
|
||||
limit,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
return nil
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ func (p *Persister) UseLoginCode(ctx context.Context, flowID uuid.UUID, identity
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseLoginCode")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
codeRow, err := useOneTimeCode[code.LoginCode](ctx, p, flowID, userProvidedCode, new(login.Flow).TableName(ctx), "selfservice_login_flow_id", withCheckIdentityID(identityID))
|
||||
codeRow, err := useOneTimeCode[code.LoginCode](ctx, p, flowID, userProvidedCode, login.Flow{}.TableName(), "selfservice_login_flow_id", withCheckIdentityID(identityID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -123,16 +123,13 @@ func (p *Persister) DeleteExpiredRecoveryFlows(ctx context.Context, expiresAt ti
|
|||
defer otelx.End(span, &err)
|
||||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(recovery.Flow).TableName(ctx),
|
||||
new(recovery.Flow).TableName(ctx),
|
||||
limit,
|
||||
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
|
||||
recovery.Flow{}.TableName(),
|
||||
),
|
||||
expiresAt,
|
||||
p.NetworkID(ctx),
|
||||
limit,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
return nil
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ func (p *Persister) UseRecoveryCode(ctx context.Context, flowID uuid.UUID, userP
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRecoveryCode")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
codeRow, err := useOneTimeCode[code.RecoveryCode](ctx, p, flowID, userProvidedCode, new(recovery.Flow).TableName(ctx), "selfservice_recovery_flow_id")
|
||||
codeRow, err := useOneTimeCode[code.RecoveryCode](ctx, p, flowID, userProvidedCode, recovery.Flow{}.TableName(), "selfservice_recovery_flow_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,16 +54,13 @@ func (p *Persister) DeleteExpiredRegistrationFlows(ctx context.Context, expiresA
|
|||
defer otelx.End(span, &err)
|
||||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(registration.Flow).TableName(ctx),
|
||||
new(registration.Flow).TableName(ctx),
|
||||
limit,
|
||||
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
|
||||
registration.Flow{}.TableName(),
|
||||
),
|
||||
expiresAt,
|
||||
p.NetworkID(ctx),
|
||||
limit,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
return nil
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ func (p *Persister) UseRegistrationCode(ctx context.Context, flowID uuid.UUID, u
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRegistrationCode")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
codeRow, err := useOneTimeCode[code.RegistrationCode](ctx, p, flowID, userProvidedCode, new(registration.Flow).TableName(ctx), "selfservice_registration_flow_id")
|
||||
codeRow, err := useOneTimeCode[code.RegistrationCode](ctx, p, flowID, userProvidedCode, registration.Flow{}.TableName(), "selfservice_registration_flow_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@ func (p *Persister) DeleteSession(ctx context.Context, sid uuid.UUID) (err error
|
|||
|
||||
nid := p.NetworkID(ctx)
|
||||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", new(session.Session).TableName(ctx)),
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", session.Session{}.TableName()),
|
||||
sid,
|
||||
nid,
|
||||
).ExecWithCount()
|
||||
|
|
@ -324,7 +324,7 @@ func (p *Persister) DeleteSessionsByIdentity(ctx context.Context, identityID uui
|
|||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE identity_id = ? AND nid = ?",
|
||||
new(session.Session).TableName(ctx),
|
||||
session.Session{}.TableName(),
|
||||
),
|
||||
identityID,
|
||||
p.NetworkID(ctx),
|
||||
|
|
@ -390,7 +390,7 @@ func (p *Persister) DeleteSessionByToken(ctx context.Context, token string) (err
|
|||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE token = ? AND nid = ?",
|
||||
new(session.Session).TableName(ctx),
|
||||
session.Session{}.TableName(),
|
||||
),
|
||||
token,
|
||||
p.NetworkID(ctx),
|
||||
|
|
@ -411,7 +411,7 @@ func (p *Persister) RevokeSessionByToken(ctx context.Context, token string) (err
|
|||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"UPDATE %s SET active = false WHERE token = ? AND nid = ?",
|
||||
new(session.Session).TableName(ctx),
|
||||
session.Session{}.TableName(),
|
||||
),
|
||||
token,
|
||||
p.NetworkID(ctx),
|
||||
|
|
@ -433,7 +433,7 @@ func (p *Persister) RevokeSessionById(ctx context.Context, sID uuid.UUID) (err e
|
|||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"UPDATE %s SET active = false WHERE id = ? AND nid = ?",
|
||||
new(session.Session).TableName(ctx),
|
||||
session.Session{}.TableName(),
|
||||
),
|
||||
sID,
|
||||
p.NetworkID(ctx),
|
||||
|
|
@ -456,7 +456,7 @@ func (p *Persister) RevokeSession(ctx context.Context, iID, sID uuid.UUID) (err
|
|||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"UPDATE %s SET active = false WHERE id = ? AND identity_id = ? AND nid = ?",
|
||||
new(session.Session).TableName(ctx),
|
||||
session.Session{}.TableName(),
|
||||
),
|
||||
sID,
|
||||
iID,
|
||||
|
|
@ -476,7 +476,7 @@ func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID u
|
|||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"UPDATE %s SET active = false WHERE identity_id = ? AND id != ? AND nid = ?",
|
||||
new(session.Session).TableName(ctx),
|
||||
session.Session{}.TableName(),
|
||||
),
|
||||
iID,
|
||||
sID,
|
||||
|
|
@ -494,16 +494,13 @@ func (p *Persister) DeleteExpiredSessions(ctx context.Context, expiresAt time.Ti
|
|||
|
||||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(session.Session).TableName(ctx),
|
||||
new(session.Session).TableName(ctx),
|
||||
limit,
|
||||
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
|
||||
session.Session{}.TableName(),
|
||||
),
|
||||
expiresAt,
|
||||
p.NetworkID(ctx),
|
||||
limit,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
return nil
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -114,13 +114,12 @@ func (p *Persister) DeleteExpiredExchangers(ctx context.Context, at time.Time, l
|
|||
|
||||
//#nosec G201 -- TableName is static
|
||||
err := conn.RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE created_at <= ? and nid = ? ORDER BY created_at ASC LIMIT %d ) AS s )",
|
||||
conn.Dialect.Quote(new(sessiontokenexchange.Exchanger).TableName()),
|
||||
conn.Dialect.Quote(new(sessiontokenexchange.Exchanger).TableName()),
|
||||
limit,
|
||||
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE created_at <= ? and nid = ? ORDER BY created_at ASC LIMIT ?) AS s)",
|
||||
sessiontokenexchange.Exchanger{}.TableName(),
|
||||
),
|
||||
expiredAfter,
|
||||
p.NetworkID(ctx),
|
||||
limit,
|
||||
).Exec()
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
|
|
|
|||
|
|
@ -64,16 +64,13 @@ func (p *Persister) DeleteExpiredSettingsFlows(ctx context.Context, expiresAt ti
|
|||
defer otelx.End(span, &err)
|
||||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(settings.Flow).TableName(ctx),
|
||||
new(settings.Flow).TableName(ctx),
|
||||
limit,
|
||||
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
|
||||
settings.Flow{}.TableName(),
|
||||
),
|
||||
expiresAt,
|
||||
p.NetworkID(ctx),
|
||||
limit,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
return nil
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ import (
|
|||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/pop/v6"
|
||||
"github.com/ory/pop/v6/logging"
|
||||
"github.com/ory/x/popx"
|
||||
"github.com/ory/x/sqlcon"
|
||||
"github.com/ory/x/sqlcon/dockertest"
|
||||
"github.com/ory/x/sqlxx"
|
||||
|
|
@ -315,7 +316,7 @@ func TestPersister_Transaction(t *testing.T) {
|
|||
ID: x.NewUUID(),
|
||||
}
|
||||
err := c.Transaction(func(tx *pop.Connection) error {
|
||||
ctx := sql.WithTransaction(context.Background(), tx)
|
||||
ctx := popx.WithTransaction(context.Background(), tx)
|
||||
require.NoError(t, p.CreateLoginFlow(ctx, lr), "%+v", lr)
|
||||
require.NoError(t, getErr(p.GetLoginFlow(ctx, lr.ID)), "%+v", lr)
|
||||
return errors.New(errMessage)
|
||||
|
|
@ -334,9 +335,7 @@ func Benchmark_BatchCreateIdentities(b *testing.B) {
|
|||
batchSizes := []int{1, 10, 100, 500, 800, 900, 1000, 2000, 3000}
|
||||
parallelRequests := []int{1, 4, 8, 16}
|
||||
|
||||
for name := range conns {
|
||||
name := name
|
||||
reg := conns[name]
|
||||
for name, reg := range conns {
|
||||
b.Run(fmt.Sprintf("database=%s", name), func(b *testing.B) {
|
||||
conf := reg.Config()
|
||||
_, p := testhelpers.NewNetwork(b, ctx, reg.Persister())
|
||||
|
|
|
|||
|
|
@ -11,10 +11,6 @@ import (
|
|||
"github.com/ory/pop/v6"
|
||||
)
|
||||
|
||||
func WithTransaction(ctx context.Context, tx *pop.Connection) context.Context {
|
||||
return popx.WithTransaction(ctx, tx)
|
||||
}
|
||||
|
||||
func (p *Persister) Transaction(ctx context.Context, callback func(ctx context.Context, connection *pop.Connection) error) error {
|
||||
return popx.Transaction(ctx, p.c.WithContext(ctx), callback)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -121,16 +121,13 @@ func (p *Persister) DeleteExpiredVerificationFlows(ctx context.Context, expiresA
|
|||
defer otelx.End(span, &err)
|
||||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(verification.Flow).TableName(ctx),
|
||||
new(verification.Flow).TableName(ctx),
|
||||
limit,
|
||||
"DELETE FROM %[1]s WHERE id in (SELECT id FROM (SELECT id FROM %[1]s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT ?) AS s)",
|
||||
verification.Flow{}.TableName(),
|
||||
),
|
||||
expiresAt,
|
||||
p.NetworkID(ctx),
|
||||
limit,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
return nil
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ func (p *Persister) UseVerificationCode(ctx context.Context, flowID uuid.UUID, u
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseVerificationCode")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
codeRow, err := useOneTimeCode[code.VerificationCode](ctx, p, flowID, userProvidedCode, new(verification.Flow).TableName(ctx), "selfservice_verification_flow_id")
|
||||
codeRow, err := useOneTimeCode[code.VerificationCode](ctx, p, flowID, userProvidedCode, verification.Flow{}.TableName(), "selfservice_verification_flow_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
|
|
@ -17,12 +16,7 @@ import (
|
|||
"github.com/ory/x/sqlcon"
|
||||
)
|
||||
|
||||
type Model interface {
|
||||
GetID() uuid.UUID
|
||||
GetNID() uuid.UUID
|
||||
}
|
||||
|
||||
func Generic(ctx context.Context, c *pop.Connection, tracer trace.Tracer, v Model, columnNames ...string) (err error) {
|
||||
func Generic(ctx context.Context, c *pop.Connection, tracer trace.Tracer, v any, columnNames ...string) (err error) {
|
||||
ctx, span := tracer.Start(ctx, "persistence.sql.update")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@
|
|||
package login
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
|
@ -14,7 +14,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
|
|
@ -28,7 +27,6 @@ import (
|
|||
"github.com/ory/kratos/x/redir"
|
||||
"github.com/ory/pop/v6"
|
||||
"github.com/ory/x/sqlxx"
|
||||
"github.com/ory/x/stringsx"
|
||||
"github.com/ory/x/urlx"
|
||||
)
|
||||
|
||||
|
|
@ -150,7 +148,7 @@ type Flow struct {
|
|||
// ReturnToVerification contains the redirect URL for the verification flow.
|
||||
ReturnToVerification string `json:"-" db:"-"`
|
||||
|
||||
isAccountLinkingFlow bool `json:"-" db:"-"`
|
||||
isAccountLinkingFlow bool `db:"-"`
|
||||
|
||||
// IdentitySchema optionally holds the ID of the identity schema that is used
|
||||
// for this flow. This value can be set by the user when creating the flow and
|
||||
|
|
@ -158,7 +156,8 @@ type Flow struct {
|
|||
IdentitySchema flow.IdentitySchema `json:"identity_schema,omitempty" faker:"-" db:"identity_schema_id"`
|
||||
}
|
||||
|
||||
var _ flow.Flow = new(Flow)
|
||||
var _ flow.Flow = (*Flow)(nil)
|
||||
var _ flow.FlowWithContinueWith = (*Flow)(nil)
|
||||
|
||||
func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, flowType flow.Type) (*Flow, error) {
|
||||
now := time.Now().UTC()
|
||||
|
|
@ -204,7 +203,7 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
|
|||
CSRFToken: csrf,
|
||||
Type: flowType,
|
||||
Refresh: refresh,
|
||||
RequestedAAL: identity.AuthenticatorAssuranceLevel(strings.ToLower(stringsx.Coalesce(
|
||||
RequestedAAL: identity.AuthenticatorAssuranceLevel(strings.ToLower(cmp.Or(
|
||||
r.URL.Query().Get("aal"),
|
||||
string(identity.AuthenticatorAssuranceLevel1)))),
|
||||
InternalContext: []byte("{}"),
|
||||
|
|
@ -213,21 +212,25 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (f *Flow) GetType() flow.Type {
|
||||
return f.Type
|
||||
}
|
||||
func (f *Flow) GetType() flow.Type { return f.Type }
|
||||
func (f *Flow) GetRequestURL() string { return f.RequestURL }
|
||||
func (f *Flow) GetID() uuid.UUID { return f.ID }
|
||||
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage { return f.InternalContext }
|
||||
func (f *Flow) SetInternalContext(bytes sqlxx.JSONRawMessage) { f.InternalContext = bytes }
|
||||
func (f *Flow) GetUI() *container.Container { return f.UI }
|
||||
func (f *Flow) GetState() flow.State { return f.State }
|
||||
func (_ *Flow) GetFlowName() flow.FlowName { return flow.LoginFlow }
|
||||
func (_ Flow) TableName() string { return "selfservice_login_flows" }
|
||||
func (f *Flow) ContinueWith() []flow.ContinueWith { return f.ContinueWithItems }
|
||||
func (f *Flow) SetReturnToVerification(to string) { f.ReturnToVerification = to }
|
||||
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString { return f.OAuth2LoginChallenge }
|
||||
func (f *Flow) AppendTo(src *url.URL) *url.URL { return flow.AppendFlowTo(src, f.ID) }
|
||||
func (f *Flow) SetState(state flow.State) { f.State = state }
|
||||
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
|
||||
|
||||
func (f *Flow) GetRequestURL() string {
|
||||
return f.RequestURL
|
||||
}
|
||||
|
||||
func (f Flow) TableName(ctx context.Context) string {
|
||||
return "selfservice_login_flows"
|
||||
}
|
||||
|
||||
func (f Flow) WhereID(ctx context.Context, alias string) string {
|
||||
return fmt.Sprintf("%s.%s = ? AND %s.%s = ?", alias, "id", alias, "nid")
|
||||
}
|
||||
// IsRefresh returns true if the login flow was triggered to re-authenticate the user.
|
||||
// This is the case if the refresh query parameter is set to true.
|
||||
func (f *Flow) IsRefresh() bool { return f.Refresh }
|
||||
|
||||
func (f *Flow) Valid() error {
|
||||
if f.ExpiresAt.Before(time.Now()) {
|
||||
|
|
@ -236,38 +239,12 @@ func (f *Flow) Valid() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f Flow) GetID() uuid.UUID {
|
||||
return f.ID
|
||||
}
|
||||
|
||||
// IsRefresh returns true if the login flow was triggered to re-authenticate the user.
|
||||
// This is the case if the refresh query parameter is set to true.
|
||||
func (f *Flow) IsRefresh() bool {
|
||||
return f.Refresh
|
||||
}
|
||||
|
||||
func (f *Flow) AppendTo(src *url.URL) *url.URL {
|
||||
return flow.AppendFlowTo(src, f.ID)
|
||||
}
|
||||
|
||||
func (f Flow) GetNID() uuid.UUID {
|
||||
return f.NID
|
||||
}
|
||||
|
||||
func (f *Flow) EnsureInternalContext() {
|
||||
if !gjson.ParseBytes(f.InternalContext).IsObject() {
|
||||
f.InternalContext = []byte("{}")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage {
|
||||
return f.InternalContext
|
||||
}
|
||||
|
||||
func (f *Flow) SetInternalContext(bytes sqlxx.JSONRawMessage) {
|
||||
f.InternalContext = bytes
|
||||
}
|
||||
|
||||
func (f Flow) MarshalJSON() ([]byte, error) {
|
||||
type local Flow
|
||||
f.SetReturnTo()
|
||||
|
|
@ -294,10 +271,6 @@ func (f *Flow) AfterSave(*pop.Connection) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *Flow) GetUI() *container.Container {
|
||||
return f.UI
|
||||
}
|
||||
|
||||
func (f *Flow) SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (opts []redir.SecureRedirectOption) {
|
||||
return []redir.SecureRedirectOption{
|
||||
redir.SecureRedirectReturnTo(f.ReturnTo),
|
||||
|
|
@ -308,41 +281,15 @@ func (f *Flow) SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (o
|
|||
}
|
||||
}
|
||||
|
||||
func (f *Flow) GetState() flow.State {
|
||||
return flow.State(f.State)
|
||||
}
|
||||
|
||||
func (f *Flow) GetFlowName() flow.FlowName {
|
||||
return flow.LoginFlow
|
||||
}
|
||||
|
||||
func (f *Flow) SetState(state flow.State) {
|
||||
f.State = State(state)
|
||||
}
|
||||
|
||||
func (t *Flow) GetTransientPayload() json.RawMessage {
|
||||
return t.TransientPayload
|
||||
}
|
||||
|
||||
var _ flow.FlowWithContinueWith = new(Flow)
|
||||
|
||||
func (f *Flow) AddContinueWith(c flow.ContinueWith) {
|
||||
f.ContinueWithItems = append(f.ContinueWithItems, c)
|
||||
}
|
||||
|
||||
func (f *Flow) ContinueWith() []flow.ContinueWith {
|
||||
return f.ContinueWithItems
|
||||
}
|
||||
|
||||
func (f *Flow) SetReturnToVerification(to string) {
|
||||
f.ReturnToVerification = to
|
||||
}
|
||||
|
||||
func (f *Flow) ToLoggerField() map[string]interface{} {
|
||||
func (f *Flow) ToLoggerField() map[string]any {
|
||||
if f == nil {
|
||||
return map[string]interface{}{}
|
||||
return map[string]any{}
|
||||
}
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"id": f.ID.String(),
|
||||
"return_to": f.ReturnTo,
|
||||
"request_url": f.RequestURL,
|
||||
|
|
@ -354,7 +301,3 @@ func (f *Flow) ToLoggerField() map[string]interface{} {
|
|||
"requested_aal": f.RequestedAAL,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString {
|
||||
return f.OAuth2LoginChallenge
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
|
@ -112,7 +111,7 @@ type Flow struct {
|
|||
TransientPayload json.RawMessage `json:"transient_payload,omitempty" faker:"-" db:"-"`
|
||||
}
|
||||
|
||||
var _ flow.Flow = new(Flow)
|
||||
var _ flow.Flow = (*Flow)(nil)
|
||||
|
||||
func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategy Strategy, ft flow.Type) (*Flow, error) {
|
||||
now := time.Now().UTC()
|
||||
|
|
@ -135,7 +134,7 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
|
|||
state = flow.StateRecoveryAwaitingAddress
|
||||
}
|
||||
|
||||
flow := &Flow{
|
||||
f := &Flow{
|
||||
ID: id,
|
||||
ExpiresAt: now.Add(exp),
|
||||
IssuedAt: now,
|
||||
|
|
@ -150,13 +149,13 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
|
|||
}
|
||||
|
||||
if strategy != nil {
|
||||
flow.Active = sqlxx.NullString(strategy.NodeGroup())
|
||||
if err := strategy.PopulateRecoveryMethod(r, flow); err != nil {
|
||||
f.Active = sqlxx.NullString(strategy.NodeGroup())
|
||||
if err := strategy.PopulateRecoveryMethod(r, f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func FromOldFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategy Strategy, of Flow) (*Flow, error) {
|
||||
|
|
@ -174,25 +173,15 @@ func FromOldFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Re
|
|||
return nf, nil
|
||||
}
|
||||
|
||||
func (f *Flow) GetType() flow.Type {
|
||||
return f.Type
|
||||
}
|
||||
|
||||
func (f *Flow) GetRequestURL() string {
|
||||
return f.RequestURL
|
||||
}
|
||||
|
||||
func (f Flow) TableName(ctx context.Context) string {
|
||||
return "selfservice_recovery_flows"
|
||||
}
|
||||
|
||||
func (f Flow) GetID() uuid.UUID {
|
||||
return f.ID
|
||||
}
|
||||
|
||||
func (f Flow) GetNID() uuid.UUID {
|
||||
return f.NID
|
||||
}
|
||||
func (f *Flow) GetType() flow.Type { return f.Type }
|
||||
func (f *Flow) GetRequestURL() string { return f.RequestURL }
|
||||
func (_ Flow) TableName() string { return "selfservice_recovery_flows" }
|
||||
func (f Flow) GetID() uuid.UUID { return f.ID }
|
||||
func (f *Flow) GetUI() *container.Container { return f.UI }
|
||||
func (f *Flow) GetState() State { return f.State }
|
||||
func (_ *Flow) GetFlowName() flow.FlowName { return flow.RecoveryFlow }
|
||||
func (f *Flow) SetState(state State) { f.State = state }
|
||||
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
|
||||
|
||||
func (f *Flow) Valid() error {
|
||||
if f.ExpiresAt.Before(time.Now().UTC()) {
|
||||
|
|
@ -236,31 +225,11 @@ func (f *Flow) AfterSave(*pop.Connection) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *Flow) GetUI() *container.Container {
|
||||
return f.UI
|
||||
}
|
||||
|
||||
func (f *Flow) GetState() State {
|
||||
return f.State
|
||||
}
|
||||
|
||||
func (f *Flow) GetFlowName() flow.FlowName {
|
||||
return flow.RecoveryFlow
|
||||
}
|
||||
|
||||
func (f *Flow) SetState(state State) {
|
||||
f.State = state
|
||||
}
|
||||
|
||||
func (t *Flow) GetTransientPayload() json.RawMessage {
|
||||
return t.TransientPayload
|
||||
}
|
||||
|
||||
func (f *Flow) ToLoggerField() map[string]interface{} {
|
||||
func (f *Flow) ToLoggerField() map[string]any {
|
||||
if f == nil {
|
||||
return map[string]interface{}{}
|
||||
return map[string]any{}
|
||||
}
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"id": f.ID.String(),
|
||||
"return_to": f.ReturnTo,
|
||||
"request_url": f.RequestURL,
|
||||
|
|
|
|||
|
|
@ -182,17 +182,20 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (f Flow) TableName(context.Context) string {
|
||||
return "selfservice_registration_flows"
|
||||
}
|
||||
|
||||
func (f Flow) GetID() uuid.UUID {
|
||||
return f.ID
|
||||
}
|
||||
|
||||
func (f Flow) GetNID() uuid.UUID {
|
||||
return f.NID
|
||||
}
|
||||
func (_ Flow) TableName() string { return "selfservice_registration_flows" }
|
||||
func (f Flow) GetID() uuid.UUID { return f.ID }
|
||||
func (f *Flow) AppendTo(src *url.URL) *url.URL { return flow.AppendFlowTo(src, f.ID) }
|
||||
func (f *Flow) GetType() flow.Type { return f.Type }
|
||||
func (f *Flow) GetRequestURL() string { return f.RequestURL }
|
||||
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage { return f.InternalContext }
|
||||
func (f *Flow) SetInternalContext(bytes sqlxx.JSONRawMessage) { f.InternalContext = bytes }
|
||||
func (f *Flow) GetUI() *container.Container { return f.UI }
|
||||
func (f *Flow) GetState() State { return f.State }
|
||||
func (_ *Flow) GetFlowName() flow.FlowName { return flow.RegistrationFlow }
|
||||
func (f *Flow) SetState(state State) { f.State = state }
|
||||
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
|
||||
func (f *Flow) SetReturnToVerification(to string) { f.ReturnToVerification = to }
|
||||
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString { return f.OAuth2LoginChallenge }
|
||||
|
||||
func (f *Flow) Valid() error {
|
||||
if f.ExpiresAt.Before(time.Now()) {
|
||||
|
|
@ -201,32 +204,12 @@ func (f *Flow) Valid() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *Flow) AppendTo(src *url.URL) *url.URL {
|
||||
return flow.AppendFlowTo(src, f.ID)
|
||||
}
|
||||
|
||||
func (f *Flow) GetType() flow.Type {
|
||||
return f.Type
|
||||
}
|
||||
|
||||
func (f *Flow) GetRequestURL() string {
|
||||
return f.RequestURL
|
||||
}
|
||||
|
||||
func (f *Flow) EnsureInternalContext() {
|
||||
if !gjson.ParseBytes(f.InternalContext).IsObject() {
|
||||
f.InternalContext = []byte("{}")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage {
|
||||
return f.InternalContext
|
||||
}
|
||||
|
||||
func (f *Flow) SetInternalContext(bytes sqlxx.JSONRawMessage) {
|
||||
f.InternalContext = bytes
|
||||
}
|
||||
|
||||
func (f Flow) MarshalJSON() ([]byte, error) {
|
||||
type local Flow
|
||||
f.SetReturnTo()
|
||||
|
|
@ -253,17 +236,11 @@ func (f *Flow) AfterSave(*pop.Connection) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *Flow) GetUI() *container.Container {
|
||||
return f.UI
|
||||
}
|
||||
|
||||
func (f *Flow) AddContinueWith(c flow.ContinueWith) {
|
||||
f.ContinueWithItems = append(f.ContinueWithItems, c)
|
||||
}
|
||||
|
||||
func (f *Flow) ContinueWith() []flow.ContinueWith {
|
||||
return f.ContinueWithItems
|
||||
}
|
||||
func (f *Flow) ContinueWith() []flow.ContinueWith { return f.ContinueWithItems }
|
||||
|
||||
func (f *Flow) SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (opts []redir.SecureRedirectOption) {
|
||||
return []redir.SecureRedirectOption{
|
||||
|
|
@ -275,31 +252,11 @@ func (f *Flow) SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (o
|
|||
}
|
||||
}
|
||||
|
||||
func (f *Flow) GetState() State {
|
||||
return f.State
|
||||
}
|
||||
|
||||
func (f *Flow) GetFlowName() flow.FlowName {
|
||||
return flow.RegistrationFlow
|
||||
}
|
||||
|
||||
func (f *Flow) SetState(state State) {
|
||||
f.State = state
|
||||
}
|
||||
|
||||
func (f *Flow) GetTransientPayload() json.RawMessage {
|
||||
return f.TransientPayload
|
||||
}
|
||||
|
||||
func (f *Flow) SetReturnToVerification(to string) {
|
||||
f.ReturnToVerification = to
|
||||
}
|
||||
|
||||
func (f *Flow) ToLoggerField() map[string]interface{} {
|
||||
func (f *Flow) ToLoggerField() map[string]any {
|
||||
if f == nil {
|
||||
return map[string]interface{}{}
|
||||
return map[string]any{}
|
||||
}
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"id": f.ID.String(),
|
||||
"return_to": f.ReturnTo,
|
||||
"request_url": f.RequestURL,
|
||||
|
|
@ -309,7 +266,3 @@ func (f *Flow) ToLoggerField() map[string]interface{} {
|
|||
"state": f.State,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString {
|
||||
return f.OAuth2LoginChallenge
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,39 +4,29 @@
|
|||
package settings
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/ory/kratos/x/redir"
|
||||
|
||||
"github.com/ory/pop/v6"
|
||||
|
||||
"github.com/ory/kratos/text"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/ui/container"
|
||||
"github.com/ory/x/urlx"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/x/sqlxx"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/selfservice/flow"
|
||||
"github.com/ory/kratos/session"
|
||||
"github.com/ory/kratos/text"
|
||||
"github.com/ory/kratos/ui/container"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/kratos/x/redir"
|
||||
"github.com/ory/pop/v6"
|
||||
"github.com/ory/x/sqlxx"
|
||||
"github.com/ory/x/urlx"
|
||||
)
|
||||
|
||||
var _ flow.InternalContexter = (*Flow)(nil)
|
||||
|
||||
// Flow represents a Settings Flow
|
||||
//
|
||||
// This flow is used when an identity wants to update settings
|
||||
|
|
@ -130,15 +120,8 @@ type Flow struct {
|
|||
TransientPayload json.RawMessage `json:"transient_payload,omitempty" faker:"-" db:"-"`
|
||||
}
|
||||
|
||||
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage {
|
||||
return f.InternalContext
|
||||
}
|
||||
|
||||
func (f *Flow) SetInternalContext(message sqlxx.JSONRawMessage) {
|
||||
f.InternalContext = message
|
||||
}
|
||||
|
||||
var _ flow.Flow = new(Flow)
|
||||
var _ flow.Flow = (*Flow)(nil)
|
||||
var _ flow.InternalContexter = (*Flow)(nil)
|
||||
|
||||
func MustNewFlow(conf *config.Config, exp time.Duration, r *http.Request, i *identity.Identity, ft flow.Type) *Flow {
|
||||
f, err := NewFlow(conf, exp, r, i, ft)
|
||||
|
|
@ -181,29 +164,19 @@ func NewFlow(conf *config.Config, exp time.Duration, r *http.Request, i *identit
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (f *Flow) GetType() flow.Type {
|
||||
return f.Type
|
||||
}
|
||||
|
||||
func (f *Flow) GetRequestURL() string {
|
||||
return f.RequestURL
|
||||
}
|
||||
|
||||
func (f Flow) TableName(ctx context.Context) string {
|
||||
return "selfservice_settings_flows"
|
||||
}
|
||||
|
||||
func (f Flow) GetID() uuid.UUID {
|
||||
return f.ID
|
||||
}
|
||||
|
||||
func (f Flow) GetNID() uuid.UUID {
|
||||
return f.NID
|
||||
}
|
||||
|
||||
func (f *Flow) AppendTo(src *url.URL) *url.URL {
|
||||
return flow.AppendFlowTo(src, f.ID)
|
||||
}
|
||||
func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage { return f.InternalContext }
|
||||
func (f *Flow) SetInternalContext(message sqlxx.JSONRawMessage) { f.InternalContext = message }
|
||||
func (f *Flow) GetType() flow.Type { return f.Type }
|
||||
func (f *Flow) GetRequestURL() string { return f.RequestURL }
|
||||
func (_ Flow) TableName() string { return "selfservice_settings_flows" }
|
||||
func (f Flow) GetID() uuid.UUID { return f.ID }
|
||||
func (f *Flow) AppendTo(src *url.URL) *url.URL { return flow.AppendFlowTo(src, f.ID) }
|
||||
func (f *Flow) GetUI() *container.Container { return f.UI }
|
||||
func (f *Flow) ContinueWith() []flow.ContinueWith { return f.ContinueWithItems }
|
||||
func (f *Flow) GetState() State { return f.State }
|
||||
func (_ *Flow) GetFlowName() flow.FlowName { return flow.SettingsFlow }
|
||||
func (f *Flow) SetState(state State) { f.State = state }
|
||||
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
|
||||
|
||||
func (f *Flow) Valid(s *session.Session) error {
|
||||
if f.ExpiresAt.Before(time.Now().UTC()) {
|
||||
|
|
@ -250,39 +223,15 @@ func (f *Flow) AfterSave(*pop.Connection) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *Flow) GetUI() *container.Container {
|
||||
return f.UI
|
||||
}
|
||||
|
||||
func (f *Flow) AddContinueWith(c flow.ContinueWith) {
|
||||
f.ContinueWithItems = append(f.ContinueWithItems, c)
|
||||
}
|
||||
|
||||
func (f *Flow) ContinueWith() []flow.ContinueWith {
|
||||
return f.ContinueWithItems
|
||||
}
|
||||
|
||||
func (f *Flow) GetState() State {
|
||||
return f.State
|
||||
}
|
||||
|
||||
func (f *Flow) GetFlowName() flow.FlowName {
|
||||
return flow.SettingsFlow
|
||||
}
|
||||
|
||||
func (f *Flow) SetState(state State) {
|
||||
f.State = state
|
||||
}
|
||||
|
||||
func (t *Flow) GetTransientPayload() json.RawMessage {
|
||||
return t.TransientPayload
|
||||
}
|
||||
|
||||
func (f *Flow) ToLoggerField() map[string]interface{} {
|
||||
func (f *Flow) ToLoggerField() map[string]any {
|
||||
if f == nil {
|
||||
return map[string]interface{}{}
|
||||
return map[string]any{}
|
||||
}
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"id": f.ID.String(),
|
||||
"return_to": f.ReturnTo,
|
||||
"request_url": f.RequestURL,
|
||||
|
|
|
|||
|
|
@ -26,8 +26,6 @@ import (
|
|||
"github.com/ory/x/urlx"
|
||||
)
|
||||
|
||||
var _ flow.Flow = new(Flow)
|
||||
|
||||
// A Verification Flow
|
||||
//
|
||||
// Used to verify an out-of-band communication
|
||||
|
|
@ -113,19 +111,7 @@ type OAuth2LoginChallengeParams struct {
|
|||
AMR session.AuthenticationMethods `db:"authentication_methods" json:"-"`
|
||||
}
|
||||
|
||||
var _ flow.Flow = new(Flow)
|
||||
|
||||
func (f *Flow) GetType() flow.Type {
|
||||
return f.Type
|
||||
}
|
||||
|
||||
func (f *Flow) GetRequestURL() string {
|
||||
return f.RequestURL
|
||||
}
|
||||
|
||||
func (f Flow) TableName(context.Context) string {
|
||||
return "selfservice_verification_flows"
|
||||
}
|
||||
var _ flow.Flow = (*Flow)(nil)
|
||||
|
||||
func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategy Strategy, ft flow.Type) (*Flow, error) {
|
||||
now := time.Now().UTC()
|
||||
|
|
@ -209,6 +195,17 @@ func NewPostHookFlow(conf *config.Config, exp time.Duration, csrf string, r *htt
|
|||
return f, nil
|
||||
}
|
||||
|
||||
func (f *Flow) GetType() flow.Type { return f.Type }
|
||||
func (f *Flow) GetRequestURL() string { return f.RequestURL }
|
||||
func (_ Flow) TableName() string { return "selfservice_verification_flows" }
|
||||
func (f Flow) GetID() uuid.UUID { return f.ID }
|
||||
func (f *Flow) GetState() State { return f.State }
|
||||
func (_ *Flow) GetFlowName() flow.FlowName { return flow.VerificationFlow }
|
||||
func (f *Flow) SetState(state State) { f.State = state }
|
||||
func (f *Flow) GetTransientPayload() json.RawMessage { return f.TransientPayload }
|
||||
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString { return f.OAuth2LoginChallenge }
|
||||
func (f *Flow) GetUI() *container.Container { return f.UI }
|
||||
|
||||
func (f *Flow) Valid() error {
|
||||
if f.ExpiresAt.Before(time.Now()) {
|
||||
return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt))
|
||||
|
|
@ -222,14 +219,6 @@ func (f *Flow) AppendTo(src *url.URL) *url.URL {
|
|||
return urlx.CopyWithQuery(src, values)
|
||||
}
|
||||
|
||||
func (f Flow) GetID() uuid.UUID {
|
||||
return f.ID
|
||||
}
|
||||
|
||||
func (f Flow) GetNID() uuid.UUID {
|
||||
return f.NID
|
||||
}
|
||||
|
||||
func (f *Flow) SetCSRFToken(token string) {
|
||||
f.CSRFToken = token
|
||||
f.UI.SetCSRF(token)
|
||||
|
|
@ -257,10 +246,6 @@ func (f *Flow) AfterSave(*pop.Connection) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *Flow) GetUI() *container.Container {
|
||||
return f.UI
|
||||
}
|
||||
|
||||
// ContinueURL generates the URL to show on the continue screen after succesful verification
|
||||
//
|
||||
// It follows the following precedence:
|
||||
|
|
@ -290,22 +275,6 @@ func (f *Flow) ContinueURL(ctx context.Context, config *config.Config) *url.URL
|
|||
return returnTo
|
||||
}
|
||||
|
||||
func (f *Flow) GetState() State {
|
||||
return f.State
|
||||
}
|
||||
|
||||
func (f *Flow) GetFlowName() flow.FlowName {
|
||||
return flow.VerificationFlow
|
||||
}
|
||||
|
||||
func (f *Flow) SetState(state State) {
|
||||
f.State = state
|
||||
}
|
||||
|
||||
func (t *Flow) GetTransientPayload() json.RawMessage {
|
||||
return t.TransientPayload
|
||||
}
|
||||
|
||||
func (f *Flow) ToLoggerField() map[string]interface{} {
|
||||
if f == nil {
|
||||
return map[string]interface{}{}
|
||||
|
|
@ -320,7 +289,3 @@ func (f *Flow) ToLoggerField() map[string]interface{} {
|
|||
"state": f.State,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Flow) GetOAuth2LoginChallenge() sqlxx.NullString {
|
||||
return f.OAuth2LoginChallenge
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,9 +31,7 @@ type Exchanger struct {
|
|||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
func (e *Exchanger) TableName() string {
|
||||
return "session_token_exchanges"
|
||||
}
|
||||
func (_ Exchanger) TableName() string { return "session_token_exchanges" }
|
||||
|
||||
type (
|
||||
Persister interface {
|
||||
|
|
|
|||
|
|
@ -12,18 +12,15 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ory/kratos/x"
|
||||
|
||||
"github.com/ory/x/httpx"
|
||||
"github.com/ory/x/pagination/keysetpagination"
|
||||
"github.com/ory/x/pointerx"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/x/httpx"
|
||||
"github.com/ory/x/pagination/keysetpagination"
|
||||
"github.com/ory/x/pointerx"
|
||||
"github.com/ory/x/randx"
|
||||
)
|
||||
|
||||
|
|
@ -67,9 +64,7 @@ type Device struct {
|
|||
NID uuid.UUID `json:"-" faker:"-" db:"nid"`
|
||||
}
|
||||
|
||||
func (m Device) TableName(ctx context.Context) string {
|
||||
return "session_devices"
|
||||
}
|
||||
func (Device) TableName() string { return "session_devices" }
|
||||
|
||||
// A Session
|
||||
//
|
||||
|
|
@ -166,9 +161,7 @@ func (m Session) DefaultPageToken() keysetpagination.PageToken {
|
|||
}
|
||||
}
|
||||
|
||||
func (s Session) TableName(ctx context.Context) string {
|
||||
return "sessions"
|
||||
}
|
||||
func (s Session) TableName() string { return "sessions" }
|
||||
|
||||
func (s *Session) CompletedLoginForMethod(method AuthenticationMethod) {
|
||||
method.CompletedAt = time.Now().UTC()
|
||||
|
|
|
|||
|
|
@ -7,12 +7,8 @@ import (
|
|||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
var _ negroni.Handler
|
||||
|
||||
const AdminPrefix = "/admin"
|
||||
|
||||
func RedirectAdminMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue