chore: simplify consent store

GitOrigin-RevId: e0b035133b185c3a5c52fbcb634cd3bb9a71c089
This commit is contained in:
Patrik 2025-12-10 18:04:24 +01:00 committed by ory-bot
parent c72853fb3a
commit 028908f0ab
6 changed files with 124 additions and 119 deletions

View File

@ -27,7 +27,7 @@ import (
"github.com/ory/x/uuidx" "github.com/ory/x/uuidx"
) )
func mockConsentRequest(remember bool, rememberFor int, skip bool) *flow.Flow { func MockConsentFlow(remember bool, rememberFor int, skip bool) *flow.Flow {
return &flow.Flow{ return &flow.Flow{
ID: uuidx.NewV4().String(), ID: uuidx.NewV4().String(),
Client: &client.Client{ID: uuidx.NewV4().String()}, Client: &client.Client{ID: uuidx.NewV4().String()},
@ -219,8 +219,8 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
{"6", true, 120, true, false}, {"6", true, 120, true, false},
} { } {
t.Run("key="+tc.key, func(t *testing.T) { t.Run("key="+tc.key, func(t *testing.T) {
f := mockConsentRequest(tc.remember, tc.rememberFor, tc.skip) f := MockConsentFlow(tc.remember, tc.rememberFor, tc.skip)
_ = clientManager.CreateClient(t.Context(), f.Client) // Ignore errors that are caused by duplication require.NoError(t, clientManager.CreateClient(t.Context(), f.Client))
f.NID = deps.Networker().NetworkID(t.Context()) f.NID = deps.Networker().NetworkID(t.Context())
require.NoError(t, m.CreateConsentSession(t.Context(), f)) require.NoError(t, m.CreateConsentSession(t.Context(), f))
@ -234,7 +234,7 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
// unfortunately the interface does not allow us to set the absolute time, so we have to wait // unfortunately the interface does not allow us to set the absolute time, so we have to wait
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
} }
actual, err := m.FindGrantedAndRememberedConsentRequest(t.Context(), f.ClientID, f.Subject) actual, err := m.FindGrantedAndRememberedConsentRequest(t.Context(), f.Client.ID, f.Subject)
if !tc.expectRemembered { if !tc.expectRemembered {
assert.Nil(t, actual) assert.Nil(t, actual)
assert.ErrorIs(t, err, consent.ErrNoPreviousConsentFound) assert.ErrorIs(t, err, consent.ErrNoPreviousConsentFound)
@ -260,36 +260,45 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
t.Run("case=revoke consent request", func(t *testing.T) { t.Run("case=revoke consent request", func(t *testing.T) {
type tc struct { type tc struct {
at, rt, subject, client string f *flow.Flow
revoke func(*testing.T) at, rt string
revoke func(*testing.T, tc)
} }
tcs := make([]tc, 2) revokeFuncs := []func(*testing.T, tc){
func(t *testing.T, c tc) {
require.NoError(t, m.RevokeSubjectConsentSession(t.Context(), c.f.Subject))
},
func(t *testing.T, c tc) {
require.NoError(t, m.RevokeSubjectClientConsentSession(t.Context(), c.f.Subject, c.f.Client.ID))
},
func(t *testing.T, c tc) {
require.NoError(t, m.RevokeConsentSessionByID(t.Context(), c.f.ConsentRequestID.String()))
},
}
tcs := make([]tc, 2*len(revokeFuncs))
for i := range tcs { for i := range tcs {
f := mockConsentRequest(true, 0, false) f := MockConsentFlow(i < len(revokeFuncs), 0, true)
f.NID = deps.Networker().NetworkID(t.Context()) f.NID = deps.Networker().NetworkID(t.Context())
tcs[i] = tc{ tcs[i] = tc{
subject: f.Subject, f: f,
client: f.Client.ID, at: uuidx.NewV4().String(),
at: uuidx.NewV4().String(), rt: uuidx.NewV4().String(),
rt: uuidx.NewV4().String(), revoke: revokeFuncs[i%len(revokeFuncs)],
} }
require.NoError(t, clientManager.CreateClient(t.Context(), f.Client)) require.NoError(t, clientManager.CreateClient(t.Context(), f.Client))
require.NoError(t, m.CreateConsentSession(t.Context(), f)) require.NoError(t, m.CreateConsentSession(t.Context(), f))
sess := &oauth2.Session{DefaultSession: openid.NewDefaultSession()}
sess.Subject = f.Subject
sess.ConsentChallenge = f.ConsentRequestID.String()
require.NoError(t, fositeManager.CreateAccessTokenSession(t.Context(), tcs[i].at, require.NoError(t, fositeManager.CreateAccessTokenSession(t.Context(), tcs[i].at,
&fosite.Request{Client: f.Client, ID: f.ConsentRequestID.String(), RequestedAt: time.Now(), Session: &oauth2.Session{DefaultSession: openid.NewDefaultSession()}}, &fosite.Request{Client: f.Client, ID: f.ConsentRequestID.String(), RequestedAt: time.Now(), Session: sess}),
)) )
require.NoError(t, fositeManager.CreateRefreshTokenSession(t.Context(), tcs[i].rt, tcs[i].at, require.NoError(t, fositeManager.CreateRefreshTokenSession(t.Context(), tcs[i].rt, tcs[i].at,
&fosite.Request{Client: f.Client, ID: f.ConsentRequestID.String(), RequestedAt: time.Now(), Session: &oauth2.Session{DefaultSession: openid.NewDefaultSession()}}, &fosite.Request{Client: f.Client, ID: f.ConsentRequestID.String(), RequestedAt: time.Now(), Session: sess},
)) ))
} }
tcs[0].revoke = func(t *testing.T) {
require.NoError(t, m.RevokeSubjectConsentSession(t.Context(), tcs[0].subject))
}
tcs[1].revoke = func(t *testing.T) {
require.NoError(t, m.RevokeSubjectClientConsentSession(t.Context(), tcs[1].subject, tcs[1].client))
}
for i, tc := range tcs { for i, tc := range tcs {
t.Run(fmt.Sprintf("run=%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("run=%d", i), func(t *testing.T) {
@ -298,7 +307,7 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
_, err = fositeManager.GetRefreshTokenSession(t.Context(), tc.rt, nil) _, err = fositeManager.GetRefreshTokenSession(t.Context(), tc.rt, nil)
require.NoError(t, err) require.NoError(t, err)
tc.revoke(t) tc.revoke(t, tc)
r, err := fositeManager.GetAccessTokenSession(t.Context(), tc.at, nil) r, err := fositeManager.GetAccessTokenSession(t.Context(), tc.at, nil)
assert.ErrorIsf(t, err, fosite.ErrNotFound, "%+v", r) assert.ErrorIsf(t, err, fosite.ErrNotFound, "%+v", r)
@ -316,7 +325,7 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
t.Run("case=list consents", func(t *testing.T) { t.Run("case=list consents", func(t *testing.T) {
flows := make([]*flow.Flow, 2) flows := make([]*flow.Flow, 2)
for i := range flows { for i := range flows {
f := mockConsentRequest(true, 0, false) f := MockConsentFlow(true, 0, false)
f.NID = deps.Networker().NetworkID(t.Context()) f.NID = deps.Networker().NetworkID(t.Context())
f.SessionID = sqlxx.NullString(uuidx.NewV4().String()) f.SessionID = sqlxx.NullString(uuidx.NewV4().String())
flows[i] = f flows[i] = f
@ -345,8 +354,8 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
} }
t.Run("random subject", func(t *testing.T) { t.Run("random subject", func(t *testing.T) {
_, _, err := m.FindSubjectsSessionGrantedConsentRequests(t.Context(), uuidx.NewV4().String(), flows[0].SessionID.String()) res, _, err := m.FindSubjectsSessionGrantedConsentRequests(t.Context(), uuidx.NewV4().String(), flows[0].SessionID.String())
assert.ErrorIs(t, err, consent.ErrNoPreviousConsentFound) assert.ErrorIsf(t, err, consent.ErrNoPreviousConsentFound, "%+v", res)
}) })
}) })
@ -426,6 +435,7 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
SessionID: sqlxx.NullString(ls.ID), SessionID: sqlxx.NullString(ls.ID),
ConsentRequestID: sqlxx.NullString(uuid.Must(uuid.NewV4()).String()), ConsentRequestID: sqlxx.NullString(uuid.Must(uuid.NewV4()).String()),
GrantedScope: sqlxx.StringSliceJSONFormat{"scopea", "scopeb"}, GrantedScope: sqlxx.StringSliceJSONFormat{"scopea", "scopeb"},
ConsentRemember: true,
ConsentRememberFor: pointerx.Ptr(0), ConsentRememberFor: pointerx.Ptr(0),
} }

View File

@ -4,6 +4,7 @@
package driver package driver
import ( import (
"github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/fosite" "github.com/ory/hydra/v2/fosite"
"github.com/ory/hydra/v2/fosite/handler/oauth2" "github.com/ory/hydra/v2/fosite/handler/oauth2"
"github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/jwk"
@ -58,3 +59,10 @@ func RegistryWithAuthorizeCodeStorage(s func(r *RegistrySQL) oauth2.AuthorizeCod
return nil return nil
} }
} }
func RegistryWithConsentManager(cm func(r *RegistrySQL) (consent.Manager, error)) RegistryModifier {
return func(r *RegistrySQL) (err error) {
r.consentManager, err = cm(r)
return err
}
}

View File

@ -92,7 +92,9 @@ type RegistrySQL struct {
migrator *sql.MigrationManager migrator *sql.MigrationManager
dbOptsModifier []func(details *pop.ConnectionDetails) dbOptsModifier []func(details *pop.ConnectionDetails)
keyManager jwk.Manager keyManager jwk.Manager
consentManager consent.Manager
initialPing func(ctx context.Context, l *logrusx.Logger, p *sql.BasePersister) error initialPing func(ctx context.Context, l *logrusx.Logger, p *sql.BasePersister) error
middlewares []negroni.Handler middlewares []negroni.Handler
} }
@ -263,7 +265,12 @@ func (m *RegistrySQL) PingContext(ctx context.Context) error { return m.basePers
func (m *RegistrySQL) BasePersister() *sql.BasePersister { return m.basePersister } func (m *RegistrySQL) BasePersister() *sql.BasePersister { return m.basePersister }
func (m *RegistrySQL) ClientManager() client.Manager { return m.Persister() } func (m *RegistrySQL) ClientManager() client.Manager { return m.Persister() }
func (m *RegistrySQL) ConsentManager() consent.Manager { return m.Persister() } func (m *RegistrySQL) ConsentManager() consent.Manager {
if m.consentManager != nil {
return m.consentManager
}
return &sql.ConsentPersister{BasePersister: m.basePersister}
}
func (m *RegistrySQL) ObfuscatedSubjectManager() consent.ObfuscatedSubjectManager { func (m *RegistrySQL) ObfuscatedSubjectManager() consent.ObfuscatedSubjectManager {
return m.Persister() return m.Persister()
} }
@ -380,8 +387,14 @@ func (m *RegistrySQL) HealthHandler() *healthx.Handler {
} }
if status.HasPending() { if status.HasPending() {
err := errors.Errorf("migrations have not yet been fully applied: %+v", status) var notApplied []string
m.Logger().WithField("status", fmt.Sprintf("%+v", status)).WithError(err).Warn("Instance is not yet ready because migrations have not yet been fully applied.") for _, s := range status {
if s.State != "Applied" {
notApplied = append(notApplied, s.Version)
}
}
err := errors.Errorf("migrations have not yet been fully applied: %+v", notApplied)
m.Logger().WithField("not_applied", fmt.Sprintf("%+v", notApplied)).WithError(err).Warn("Instance is not yet ready because migrations have not yet been fully applied.")
return err return err
} }
return nil return nil

View File

@ -16,7 +16,6 @@ import (
type ( type (
Persister interface { Persister interface {
consent.Manager
consent.ObfuscatedSubjectManager consent.ObfuscatedSubjectManager
consent.LoginManager consent.LoginManager
consent.LogoutManager consent.LogoutManager

View File

@ -21,37 +21,45 @@ import (
"github.com/ory/pop/v6" "github.com/ory/pop/v6"
"github.com/ory/x/otelx" "github.com/ory/x/otelx"
keysetpagination "github.com/ory/x/pagination/keysetpagination_v2" keysetpagination "github.com/ory/x/pagination/keysetpagination_v2"
"github.com/ory/x/pointerx"
"github.com/ory/x/popx" "github.com/ory/x/popx"
"github.com/ory/x/sqlcon" "github.com/ory/x/sqlcon"
"github.com/ory/x/sqlxx" "github.com/ory/x/sqlxx"
) )
var _ consent.Manager = (*Persister)(nil) var (
_ consent.Manager = (*ConsentPersister)(nil)
_ consent.LoginManager = (*Persister)(nil)
_ consent.LogoutManager = (*Persister)(nil)
_ consent.ObfuscatedSubjectManager = (*Persister)(nil)
)
func (p *Persister) RevokeSubjectConsentSession(ctx context.Context, user string) (err error) { type ConsentPersister struct {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession") *BasePersister
}
func (p *ConsentPersister) RevokeSubjectConsentSession(ctx context.Context, user string) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession")
defer otelx.End(span, &err) defer otelx.End(span, &err)
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user)) return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user))
} }
func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user, client string) (err error) { func (p *ConsentPersister) RevokeSubjectClientConsentSession(ctx context.Context, user, client string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectClientConsentSession", trace.WithAttributes(attribute.String("client", client))) ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectClientConsentSession", trace.WithAttributes(attribute.String("client", client)))
defer otelx.End(span, &err) defer otelx.End(span, &err)
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client)) return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client))
} }
func (p *Persister) RevokeConsentSessionByID(ctx context.Context, consentRequestID string) (err error) { func (p *ConsentPersister) RevokeConsentSessionByID(ctx context.Context, consentRequestID string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeConsentSessionByID", ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeConsentSessionByID",
trace.WithAttributes(attribute.String("consent_challenge_id", consentRequestID))) trace.WithAttributes(attribute.String("consent_challenge_id", consentRequestID)))
defer otelx.End(span, &err) defer otelx.End(span, &err)
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id = ?", consentRequestID)) return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id = ?", consentRequestID))
} }
func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error { func (p *ConsentPersister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error {
return func(ctx context.Context, c *pop.Connection) error { return func(ctx context.Context, c *pop.Connection) error {
fs := make([]*flow.Flow, 0) fs := make([]*flow.Flow, 0)
if err := p.QueryWithNetwork(ctx). if err := p.QueryWithNetwork(ctx).
@ -168,49 +176,14 @@ func (p *Persister) GetForcedObfuscatedLoginSession(ctx context.Context, client,
return &s, nil return &s, nil
} }
type FlowWithConstantColumns struct { func (p *ConsentPersister) CreateConsentSession(ctx context.Context, f *flow.Flow) (err error) {
*flow.Flow ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateConsentSession")
State flow.State `db:"state"`
// we need to write these columns because of the check constraint, but we will soon switch to a new table anyway that will not have them at all
LoginRemember bool `db:"login_remember"`
LoginRememberFor int `db:"login_remember_for"`
LoginError string `db:"login_error"`
LoginUsed bool `db:"login_was_used"`
ConsentVerifier string `db:"consent_verifier"`
ConsentCSRF string `db:"consent_csrf"`
ConsentError string `db:"consent_error"`
ConsentUsed bool `db:"consent_was_used"`
// these columns have NOT NULL constraints, but are not required to be stored
LoginVerifier string `db:"login_verifier"`
LoginCSRF string `db:"login_csrf"`
LoginSkip bool `db:"login_skip"`
}
func (p *Persister) CreateConsentSession(ctx context.Context, f *flow.Flow) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateConsentSession")
defer otelx.End(span, &err) defer otelx.End(span, &err)
if f.NID != p.NetworkID(ctx) { if f.NID != p.NetworkID(ctx) {
return errors.WithStack(sqlcon.ErrNoRows) return errors.WithStack(sqlcon.ErrNoRows)
} }
if f.ConsentRememberFor == nil { return sqlcon.HandleError(p.Connection(ctx).Create(f))
// This is really stupid: we treat 0 the same as NULL, which means the flow does not expire.
// However, for some reason it is part of the check constraint and required to be NOT NULL.
f.ConsentRememberFor = pointerx.Ptr(0)
}
fx := &FlowWithConstantColumns{
Flow: f,
State: flow.FlowStateConsentUsed, // if this was another state, we'd not store it in the DB
ConsentUsed: true,
LoginUsed: true,
LoginError: "{}",
ConsentError: "{}",
}
return sqlcon.HandleError(p.Connection(ctx).Create(fx))
} }
func (p *Persister) GetRememberedLoginSession(ctx context.Context, id string) (_ *flow.LoginSession, err error) { func (p *Persister) GetRememberedLoginSession(ctx context.Context, id string) (_ *flow.LoginSession, err error) {
@ -310,8 +283,8 @@ func (p *Persister) mySQLDeleteLoginSession(ctx context.Context, id string) (_ *
return &session, nil return &session, nil
} }
func (p *Persister) FindGrantedAndRememberedConsentRequest(ctx context.Context, client, subject string) (_ *flow.Flow, err error) { func (p *ConsentPersister) FindGrantedAndRememberedConsentRequest(ctx context.Context, client, subject string) (_ *flow.Flow, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequest") ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequest")
defer otelx.End(span, &err) defer otelx.End(span, &err)
f := flow.Flow{} f := flow.Flow{}
@ -327,11 +300,10 @@ func (p *Persister) FindGrantedAndRememberedConsentRequest(ctx context.Context,
q := fmt.Sprintf(` q := fmt.Sprintf(`
SELECT %s FROM %s SELECT %s FROM %s
WHERE nid = ? WHERE nid = ?
AND state = ? AND (state = ? OR state IS NULL)
AND subject = ? AND subject = ?
AND client_id = ? AND client_id = ?
AND consent_skip = FALSE AND consent_skip = FALSE
AND consent_error = '{}'
AND consent_remember = TRUE AND consent_remember = TRUE
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)
ORDER BY requested_at DESC ORDER BY requested_at DESC
@ -371,8 +343,8 @@ func applyTableNameWithIndexHint(conn *pop.Connection, table string, index strin
} }
} }
func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, pageOpts ...keysetpagination.Option) (_ []flow.Flow, _ *keysetpagination.Paginator, err error) { func (p *ConsentPersister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, pageOpts ...keysetpagination.Option) (_ []flow.Flow, _ *keysetpagination.Paginator, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsGrantedConsentRequests") ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsGrantedConsentRequests")
defer otelx.End(span, &err) defer otelx.End(span, &err)
paginator := keysetpagination.NewPaginator(append(pageOpts, paginator := keysetpagination.NewPaginator(append(pageOpts,
@ -381,10 +353,9 @@ func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subj
var fs []flow.Flow var fs []flow.Flow
err = p.QueryWithNetwork(ctx). err = p.QueryWithNetwork(ctx).
Where("state IN (?, ?)", flow.FlowStateConsentUsed, flow.FlowStateConsentUnused). Where("(state IN (?, ?) OR state IS NULL)", flow.FlowStateConsentUsed, flow.FlowStateConsentUnused).
Where("subject = ?", subject). Where("subject = ?", subject).
Where("consent_skip = FALSE"). Where("consent_skip = FALSE").
Where("consent_error = '{}'").
Where("(expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)"). Where("(expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)").
Scope(keysetpagination.Paginate[flow.Flow](paginator)). Scope(keysetpagination.Paginate[flow.Flow](paginator)).
All(&fs) All(&fs)
@ -399,8 +370,8 @@ func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subj
return fs, nextPage, nil return fs, nextPage, nil
} }
func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Context, subject, sid string, pageOpts ...keysetpagination.Option) (_ []flow.Flow, _ *keysetpagination.Paginator, err error) { func (p *ConsentPersister) FindSubjectsSessionGrantedConsentRequests(ctx context.Context, subject, sid string, pageOpts ...keysetpagination.Option) (_ []flow.Flow, _ *keysetpagination.Paginator, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsSessionGrantedConsentRequests", trace.WithAttributes(attribute.String("sid", sid))) ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsSessionGrantedConsentRequests", trace.WithAttributes(attribute.String("sid", sid)))
defer otelx.End(span, &err) defer otelx.End(span, &err)
paginator := keysetpagination.NewPaginator(append(pageOpts, paginator := keysetpagination.NewPaginator(append(pageOpts,
@ -409,11 +380,10 @@ func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Contex
var fs []flow.Flow var fs []flow.Flow
err = p.QueryWithNetwork(ctx). err = p.QueryWithNetwork(ctx).
Where("state IN (?, ?)", flow.FlowStateConsentUsed, flow.FlowStateConsentUnused). Where("(state IN (?, ?) OR state IS NULL)", flow.FlowStateConsentUsed, flow.FlowStateConsentUnused).
Where("subject = ?", subject). Where("subject = ?", subject).
Where("login_session_id = ?", sid). Where("login_session_id = ?", sid).
Where("consent_skip = FALSE"). Where("consent_skip = FALSE").
Where("consent_error = '{}'").
Where("(expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)"). Where("(expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)").
Scope(keysetpagination.Paginate[flow.Flow](paginator)). Scope(keysetpagination.Paginate[flow.Flow](paginator)).
All(&fs) All(&fs)
@ -428,22 +398,22 @@ func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Contex
return fs, nextPage, nil return fs, nextPage, nil
} }
func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) { func (p *ConsentPersister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout") ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout")
defer otelx.End(span, &err) defer otelx.End(span, &err)
return p.listUserAuthenticatedClients(ctx, subject, sid, "front") return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
} }
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) { func (p *ConsentPersister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout") ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout")
defer otelx.End(span, &err) defer otelx.End(span, &err)
return p.listUserAuthenticatedClients(ctx, subject, sid, "back") return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
} }
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) { func (p *ConsentPersister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients", ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients",
trace.WithAttributes(attribute.String("sid", sid))) trace.WithAttributes(attribute.String("sid", sid)))
defer otelx.End(span, &err) defer otelx.End(span, &err)
@ -574,7 +544,7 @@ func (p *Persister) FlushInactiveLoginConsentRequests(ctx context.Context, notAf
SELECT login_challenge SELECT login_challenge
FROM hydra_oauth2_flow FROM hydra_oauth2_flow
WHERE ( WHERE (
(state != ?) (state != ? AND state IS NOT NULL)
OR (login_error IS NOT NULL AND login_error <> '{}' AND login_error <> '') OR (login_error IS NOT NULL AND login_error <> '{}' AND login_error <> '')
OR (consent_error IS NOT NULL AND consent_error <> '{}' AND consent_error <> '') OR (consent_error IS NOT NULL AND consent_error <> '{}' AND consent_error <> '')
) )

View File

@ -598,13 +598,13 @@ func (s *PersisterTestSuite) TestFindGrantedAndRememberedConsentRequests() {
})) }))
f.State = flow.FlowStateConsentUsed f.State = flow.FlowStateConsentUsed
require.NoError(t, r.Persister().CreateConsentSession(s.t1, f)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, f))
actual, err := r.Persister().FindGrantedAndRememberedConsentRequest(s.t2, cl.ID, f.Subject) actual, err := r.ConsentManager().FindGrantedAndRememberedConsentRequest(s.t2, cl.ID, f.Subject)
require.ErrorIs(t, err, consent.ErrNoPreviousConsentFound) require.ErrorIs(t, err, consent.ErrNoPreviousConsentFound)
assert.Nil(t, actual) assert.Nil(t, actual)
actual, err = r.Persister().FindGrantedAndRememberedConsentRequest(s.t1, cl.ID, f.Subject) actual, err = r.ConsentManager().FindGrantedAndRememberedConsentRequest(s.t1, cl.ID, f.Subject)
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, req.ConsentRequestID, actual.ConsentRequestID) assert.EqualValues(t, req.ConsentRequestID, actual.ConsentRequestID)
}) })
@ -619,12 +619,12 @@ func (s *PersisterTestSuite) TestFindSubjectsGrantedConsentRequests() {
f := newFlow(s.t1NID, cl.ID, "sub", sqlxx.NullString(sessionID)) f := newFlow(s.t1NID, cl.ID, "sub", sqlxx.NullString(sessionID))
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, cl)) require.NoError(t, r.Persister().CreateClient(s.t1, cl))
require.NoError(t, r.Persister().CreateConsentSession(s.t1, f)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, f))
_, _, err := r.Persister().FindSubjectsGrantedConsentRequests(s.t2, f.Subject) _, _, err := r.ConsentManager().FindSubjectsGrantedConsentRequests(s.t2, f.Subject)
require.ErrorIs(t, err, consent.ErrNoPreviousConsentFound) require.ErrorIs(t, err, consent.ErrNoPreviousConsentFound)
actual, nextPage, err := r.Persister().FindSubjectsGrantedConsentRequests(s.t1, f.Subject) actual, nextPage, err := r.ConsentManager().FindSubjectsGrantedConsentRequests(s.t1, f.Subject)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, actual, 1) require.Len(t, actual, 1)
assert.Equal(t, f.ConsentRequestID.String(), actual[0].ConsentRequestID.String()) assert.Equal(t, f.ConsentRequestID.String(), actual[0].ConsentRequestID.String())
@ -700,7 +700,12 @@ func (s *PersisterTestSuite) TestFlushInactiveLoginConsentRequests() {
f.RequestedAt = time.Now().Add(-24 * time.Hour) f.RequestedAt = time.Now().Add(-24 * time.Hour)
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, cl)) require.NoError(t, r.Persister().CreateClient(s.t1, cl))
require.NoError(t, r.Persister().Connection(s.t1).Create(&persistencesql.FlowWithConstantColumns{Flow: f, State: f.State}))
type legacyFlow struct {
*flow.Flow
State flow.State `db:"state"`
}
require.NoError(t, r.Persister().Connection(s.t1).Create(&legacyFlow{Flow: f, State: f.State}))
actual := flow.Flow{} actual := flow.Flow{}
@ -1104,7 +1109,7 @@ func (s *PersisterTestSuite) TestHandleConsentRequest() {
f.ConsentRequestID = "consent-request-id" f.ConsentRequestID = "consent-request-id"
actual, err := r.Persister().FindGrantedAndRememberedConsentRequest(s.t1, c1.ID, f.Subject) actual, err := r.ConsentManager().FindGrantedAndRememberedConsentRequest(s.t1, c1.ID, f.Subject)
require.ErrorIs(t, err, consent.ErrNoPreviousConsentFound) require.ErrorIs(t, err, consent.ErrNoPreviousConsentFound)
assert.Nil(t, actual) assert.Nil(t, actual)
@ -1114,8 +1119,8 @@ func (s *PersisterTestSuite) TestHandleConsentRequest() {
f.State = flow.FlowStateConsentUsed f.State = flow.FlowStateConsentUsed
require.NoError(t, r.Persister().CreateConsentSession(s.t1, f)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, f))
actual, err = r.Persister().FindGrantedAndRememberedConsentRequest(s.t1, c1.ID, f.Subject) actual, err = r.ConsentManager().FindGrantedAndRememberedConsentRequest(s.t1, c1.ID, f.Subject)
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, f.ConsentRequestID, actual.ConsentRequestID) assert.EqualValues(t, f.ConsentRequestID, actual.ConsentRequestID)
}) })
@ -1184,19 +1189,19 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithBackChannelLogo
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()}) persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()})
require.NoError(t, r.Persister().CreateConsentSession(s.t1, t1f1)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, t1f1))
require.NoError(t, r.Persister().CreateConsentSession(s.t2, t2f1)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t2, t2f1))
require.NoError(t, r.Persister().CreateConsentSession(s.t2, t2f2)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t2, t2f2))
t1f1.ConsentRequestID = sqlxx.NullString(t1f1.ID) t1f1.ConsentRequestID = sqlxx.NullString(t1f1.ID)
t2f1.ConsentRequestID = sqlxx.NullString(t2f1.ID) t2f1.ConsentRequestID = sqlxx.NullString(t2f1.ID)
t2f2.ConsentRequestID = sqlxx.NullString(t2f2.ID) t2f2.ConsentRequestID = sqlxx.NullString(t2f2.ID)
cs, err := r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String()) cs, err := r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(cs)) require.Equal(t, 1, len(cs))
cs, err = r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t2, "sub", t1f1.SessionID.String()) cs, err = r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(s.t2, "sub", t1f1.SessionID.String())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, len(cs)) require.Equal(t, 2, len(cs))
}) })
@ -1223,19 +1228,19 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()}) persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()})
require.NoError(t, r.Persister().CreateConsentSession(s.t1, t1f1)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, t1f1))
require.NoError(t, r.Persister().CreateConsentSession(s.t2, t2f1)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t2, t2f1))
require.NoError(t, r.Persister().CreateConsentSession(s.t2, t2f2)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t2, t2f2))
t1f1.ConsentRequestID = sqlxx.NullString(t1f1.ID) t1f1.ConsentRequestID = sqlxx.NullString(t1f1.ID)
t2f1.ConsentRequestID = sqlxx.NullString(t2f1.ID) t2f1.ConsentRequestID = sqlxx.NullString(t2f1.ID)
t2f2.ConsentRequestID = sqlxx.NullString(t2f2.ID) t2f2.ConsentRequestID = sqlxx.NullString(t2f2.ID)
cs, err := r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String()) cs, err := r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(cs)) require.Equal(t, 1, len(cs))
cs, err = r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t2, "sub", t1f1.SessionID.String()) cs, err = r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t2, "sub", t1f1.SessionID.String())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, len(cs)) require.Equal(t, 2, len(cs))
}) })
@ -1453,13 +1458,13 @@ func (s *PersisterTestSuite) TestRevokeSubjectClientConsentSession() {
f.RequestedAt = time.Now().Add(-24 * time.Hour) f.RequestedAt = time.Now().Add(-24 * time.Hour)
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, cl)) require.NoError(t, r.Persister().CreateClient(s.t1, cl))
require.NoError(t, r.Persister().CreateConsentSession(s.t1, f)) require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, f))
actual := flow.Flow{} actual := flow.Flow{}
require.NoError(t, r.Persister().RevokeSubjectClientConsentSession(s.t2, "sub", cl.ID), "should not error if nothing was found") require.NoError(t, r.ConsentManager().RevokeSubjectClientConsentSession(s.t2, "sub", cl.ID), "should not error if nothing was found")
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, f.ID)) require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, f.ID))
require.NoError(t, r.Persister().RevokeSubjectClientConsentSession(s.t1, "sub", cl.ID)) require.NoError(t, r.ConsentManager().RevokeSubjectClientConsentSession(s.t1, "sub", cl.ID))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, f.ID)) require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, f.ID))
}) })
} }