chore: simplify consent store

GitOrigin-RevId: e0b035133b185c3a5c52fbcb634cd3bb9a71c089
This commit is contained in:
Patrik 2025-12-10 18:04:24 +01:00 committed by ory-bot
parent c72853fb3a
commit 028908f0ab
6 changed files with 124 additions and 119 deletions

View File

@ -27,7 +27,7 @@ import (
"github.com/ory/x/uuidx"
)
func mockConsentRequest(remember bool, rememberFor int, skip bool) *flow.Flow {
func MockConsentFlow(remember bool, rememberFor int, skip bool) *flow.Flow {
return &flow.Flow{
ID: uuidx.NewV4().String(),
Client: &client.Client{ID: uuidx.NewV4().String()},
@ -219,8 +219,8 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
{"6", true, 120, true, false},
} {
t.Run("key="+tc.key, func(t *testing.T) {
f := mockConsentRequest(tc.remember, tc.rememberFor, tc.skip)
_ = clientManager.CreateClient(t.Context(), f.Client) // Ignore errors that are caused by duplication
f := MockConsentFlow(tc.remember, tc.rememberFor, tc.skip)
require.NoError(t, clientManager.CreateClient(t.Context(), f.Client))
f.NID = deps.Networker().NetworkID(t.Context())
require.NoError(t, m.CreateConsentSession(t.Context(), f))
@ -234,7 +234,7 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
// unfortunately the interface does not allow us to set the absolute time, so we have to wait
time.Sleep(2 * time.Second)
}
actual, err := m.FindGrantedAndRememberedConsentRequest(t.Context(), f.ClientID, f.Subject)
actual, err := m.FindGrantedAndRememberedConsentRequest(t.Context(), f.Client.ID, f.Subject)
if !tc.expectRemembered {
assert.Nil(t, actual)
assert.ErrorIs(t, err, consent.ErrNoPreviousConsentFound)
@ -260,36 +260,45 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
t.Run("case=revoke consent request", func(t *testing.T) {
type tc struct {
at, rt, subject, client string
revoke func(*testing.T)
f *flow.Flow
at, rt string
revoke func(*testing.T, tc)
}
tcs := make([]tc, 2)
revokeFuncs := []func(*testing.T, tc){
func(t *testing.T, c tc) {
require.NoError(t, m.RevokeSubjectConsentSession(t.Context(), c.f.Subject))
},
func(t *testing.T, c tc) {
require.NoError(t, m.RevokeSubjectClientConsentSession(t.Context(), c.f.Subject, c.f.Client.ID))
},
func(t *testing.T, c tc) {
require.NoError(t, m.RevokeConsentSessionByID(t.Context(), c.f.ConsentRequestID.String()))
},
}
tcs := make([]tc, 2*len(revokeFuncs))
for i := range tcs {
f := mockConsentRequest(true, 0, false)
f := MockConsentFlow(i < len(revokeFuncs), 0, true)
f.NID = deps.Networker().NetworkID(t.Context())
tcs[i] = tc{
subject: f.Subject,
client: f.Client.ID,
at: uuidx.NewV4().String(),
rt: uuidx.NewV4().String(),
f: f,
at: uuidx.NewV4().String(),
rt: uuidx.NewV4().String(),
revoke: revokeFuncs[i%len(revokeFuncs)],
}
require.NoError(t, clientManager.CreateClient(t.Context(), f.Client))
require.NoError(t, m.CreateConsentSession(t.Context(), f))
sess := &oauth2.Session{DefaultSession: openid.NewDefaultSession()}
sess.Subject = f.Subject
sess.ConsentChallenge = f.ConsentRequestID.String()
require.NoError(t, fositeManager.CreateAccessTokenSession(t.Context(), tcs[i].at,
&fosite.Request{Client: f.Client, ID: f.ConsentRequestID.String(), RequestedAt: time.Now(), Session: &oauth2.Session{DefaultSession: openid.NewDefaultSession()}},
))
&fosite.Request{Client: f.Client, ID: f.ConsentRequestID.String(), RequestedAt: time.Now(), Session: sess}),
)
require.NoError(t, fositeManager.CreateRefreshTokenSession(t.Context(), tcs[i].rt, tcs[i].at,
&fosite.Request{Client: f.Client, ID: f.ConsentRequestID.String(), RequestedAt: time.Now(), Session: &oauth2.Session{DefaultSession: openid.NewDefaultSession()}},
&fosite.Request{Client: f.Client, ID: f.ConsentRequestID.String(), RequestedAt: time.Now(), Session: sess},
))
}
tcs[0].revoke = func(t *testing.T) {
require.NoError(t, m.RevokeSubjectConsentSession(t.Context(), tcs[0].subject))
}
tcs[1].revoke = func(t *testing.T) {
require.NoError(t, m.RevokeSubjectClientConsentSession(t.Context(), tcs[1].subject, tcs[1].client))
}
for i, tc := range tcs {
t.Run(fmt.Sprintf("run=%d", i), func(t *testing.T) {
@ -298,7 +307,7 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
_, err = fositeManager.GetRefreshTokenSession(t.Context(), tc.rt, nil)
require.NoError(t, err)
tc.revoke(t)
tc.revoke(t, tc)
r, err := fositeManager.GetAccessTokenSession(t.Context(), tc.at, nil)
assert.ErrorIsf(t, err, fosite.ErrNotFound, "%+v", r)
@ -316,7 +325,7 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
t.Run("case=list consents", func(t *testing.T) {
flows := make([]*flow.Flow, 2)
for i := range flows {
f := mockConsentRequest(true, 0, false)
f := MockConsentFlow(true, 0, false)
f.NID = deps.Networker().NetworkID(t.Context())
f.SessionID = sqlxx.NullString(uuidx.NewV4().String())
flows[i] = f
@ -345,8 +354,8 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
}
t.Run("random subject", func(t *testing.T) {
_, _, err := m.FindSubjectsSessionGrantedConsentRequests(t.Context(), uuidx.NewV4().String(), flows[0].SessionID.String())
assert.ErrorIs(t, err, consent.ErrNoPreviousConsentFound)
res, _, err := m.FindSubjectsSessionGrantedConsentRequests(t.Context(), uuidx.NewV4().String(), flows[0].SessionID.String())
assert.ErrorIsf(t, err, consent.ErrNoPreviousConsentFound, "%+v", res)
})
})
@ -426,6 +435,7 @@ func ConsentManagerTests(t *testing.T, deps Deps, m consent.Manager, loginManage
SessionID: sqlxx.NullString(ls.ID),
ConsentRequestID: sqlxx.NullString(uuid.Must(uuid.NewV4()).String()),
GrantedScope: sqlxx.StringSliceJSONFormat{"scopea", "scopeb"},
ConsentRemember: true,
ConsentRememberFor: pointerx.Ptr(0),
}

View File

@ -4,6 +4,7 @@
package driver
import (
"github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/fosite"
"github.com/ory/hydra/v2/fosite/handler/oauth2"
"github.com/ory/hydra/v2/jwk"
@ -58,3 +59,10 @@ func RegistryWithAuthorizeCodeStorage(s func(r *RegistrySQL) oauth2.AuthorizeCod
return nil
}
}
func RegistryWithConsentManager(cm func(r *RegistrySQL) (consent.Manager, error)) RegistryModifier {
return func(r *RegistrySQL) (err error) {
r.consentManager, err = cm(r)
return err
}
}

View File

@ -92,7 +92,9 @@ type RegistrySQL struct {
migrator *sql.MigrationManager
dbOptsModifier []func(details *pop.ConnectionDetails)
keyManager jwk.Manager
keyManager jwk.Manager
consentManager consent.Manager
initialPing func(ctx context.Context, l *logrusx.Logger, p *sql.BasePersister) error
middlewares []negroni.Handler
}
@ -263,7 +265,12 @@ func (m *RegistrySQL) PingContext(ctx context.Context) error { return m.basePers
func (m *RegistrySQL) BasePersister() *sql.BasePersister { return m.basePersister }
func (m *RegistrySQL) ClientManager() client.Manager { return m.Persister() }
func (m *RegistrySQL) ConsentManager() consent.Manager { return m.Persister() }
func (m *RegistrySQL) ConsentManager() consent.Manager {
if m.consentManager != nil {
return m.consentManager
}
return &sql.ConsentPersister{BasePersister: m.basePersister}
}
func (m *RegistrySQL) ObfuscatedSubjectManager() consent.ObfuscatedSubjectManager {
return m.Persister()
}
@ -380,8 +387,14 @@ func (m *RegistrySQL) HealthHandler() *healthx.Handler {
}
if status.HasPending() {
err := errors.Errorf("migrations have not yet been fully applied: %+v", status)
m.Logger().WithField("status", fmt.Sprintf("%+v", status)).WithError(err).Warn("Instance is not yet ready because migrations have not yet been fully applied.")
var notApplied []string
for _, s := range status {
if s.State != "Applied" {
notApplied = append(notApplied, s.Version)
}
}
err := errors.Errorf("migrations have not yet been fully applied: %+v", notApplied)
m.Logger().WithField("not_applied", fmt.Sprintf("%+v", notApplied)).WithError(err).Warn("Instance is not yet ready because migrations have not yet been fully applied.")
return err
}
return nil

View File

@ -16,7 +16,6 @@ import (
type (
Persister interface {
consent.Manager
consent.ObfuscatedSubjectManager
consent.LoginManager
consent.LogoutManager

View File

@ -21,37 +21,45 @@ import (
"github.com/ory/pop/v6"
"github.com/ory/x/otelx"
keysetpagination "github.com/ory/x/pagination/keysetpagination_v2"
"github.com/ory/x/pointerx"
"github.com/ory/x/popx"
"github.com/ory/x/sqlcon"
"github.com/ory/x/sqlxx"
)
var _ consent.Manager = (*Persister)(nil)
var (
_ consent.Manager = (*ConsentPersister)(nil)
_ consent.LoginManager = (*Persister)(nil)
_ consent.LogoutManager = (*Persister)(nil)
_ consent.ObfuscatedSubjectManager = (*Persister)(nil)
)
func (p *Persister) RevokeSubjectConsentSession(ctx context.Context, user string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession")
type ConsentPersister struct {
*BasePersister
}
func (p *ConsentPersister) RevokeSubjectConsentSession(ctx context.Context, user string) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession")
defer otelx.End(span, &err)
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user))
}
func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user, client string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectClientConsentSession", trace.WithAttributes(attribute.String("client", client)))
func (p *ConsentPersister) RevokeSubjectClientConsentSession(ctx context.Context, user, client string) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectClientConsentSession", trace.WithAttributes(attribute.String("client", client)))
defer otelx.End(span, &err)
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client))
}
func (p *Persister) RevokeConsentSessionByID(ctx context.Context, consentRequestID string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeConsentSessionByID",
func (p *ConsentPersister) RevokeConsentSessionByID(ctx context.Context, consentRequestID string) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeConsentSessionByID",
trace.WithAttributes(attribute.String("consent_challenge_id", consentRequestID)))
defer otelx.End(span, &err)
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id = ?", consentRequestID))
}
func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error {
func (p *ConsentPersister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error {
return func(ctx context.Context, c *pop.Connection) error {
fs := make([]*flow.Flow, 0)
if err := p.QueryWithNetwork(ctx).
@ -168,49 +176,14 @@ func (p *Persister) GetForcedObfuscatedLoginSession(ctx context.Context, client,
return &s, nil
}
type FlowWithConstantColumns struct {
*flow.Flow
State flow.State `db:"state"`
// we need to write these columns because of the check constraint, but we will soon switch to a new table anyway that will not have them at all
LoginRemember bool `db:"login_remember"`
LoginRememberFor int `db:"login_remember_for"`
LoginError string `db:"login_error"`
LoginUsed bool `db:"login_was_used"`
ConsentVerifier string `db:"consent_verifier"`
ConsentCSRF string `db:"consent_csrf"`
ConsentError string `db:"consent_error"`
ConsentUsed bool `db:"consent_was_used"`
// these columns have NOT NULL constraints, but are not required to be stored
LoginVerifier string `db:"login_verifier"`
LoginCSRF string `db:"login_csrf"`
LoginSkip bool `db:"login_skip"`
}
func (p *Persister) CreateConsentSession(ctx context.Context, f *flow.Flow) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateConsentSession")
func (p *ConsentPersister) CreateConsentSession(ctx context.Context, f *flow.Flow) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateConsentSession")
defer otelx.End(span, &err)
if f.NID != p.NetworkID(ctx) {
return errors.WithStack(sqlcon.ErrNoRows)
}
if f.ConsentRememberFor == nil {
// This is really stupid: we treat 0 the same as NULL, which means the flow does not expire.
// However, for some reason it is part of the check constraint and required to be NOT NULL.
f.ConsentRememberFor = pointerx.Ptr(0)
}
fx := &FlowWithConstantColumns{
Flow: f,
State: flow.FlowStateConsentUsed, // if this was another state, we'd not store it in the DB
ConsentUsed: true,
LoginUsed: true,
LoginError: "{}",
ConsentError: "{}",
}
return sqlcon.HandleError(p.Connection(ctx).Create(fx))
return sqlcon.HandleError(p.Connection(ctx).Create(f))
}
func (p *Persister) GetRememberedLoginSession(ctx context.Context, id string) (_ *flow.LoginSession, err error) {
@ -310,8 +283,8 @@ func (p *Persister) mySQLDeleteLoginSession(ctx context.Context, id string) (_ *
return &session, nil
}
func (p *Persister) FindGrantedAndRememberedConsentRequest(ctx context.Context, client, subject string) (_ *flow.Flow, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequest")
func (p *ConsentPersister) FindGrantedAndRememberedConsentRequest(ctx context.Context, client, subject string) (_ *flow.Flow, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequest")
defer otelx.End(span, &err)
f := flow.Flow{}
@ -327,11 +300,10 @@ func (p *Persister) FindGrantedAndRememberedConsentRequest(ctx context.Context,
q := fmt.Sprintf(`
SELECT %s FROM %s
WHERE nid = ?
AND state = ?
AND (state = ? OR state IS NULL)
AND subject = ?
AND client_id = ?
AND consent_skip = FALSE
AND consent_error = '{}'
AND consent_remember = TRUE
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)
ORDER BY requested_at DESC
@ -371,8 +343,8 @@ func applyTableNameWithIndexHint(conn *pop.Connection, table string, index strin
}
}
func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, pageOpts ...keysetpagination.Option) (_ []flow.Flow, _ *keysetpagination.Paginator, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsGrantedConsentRequests")
func (p *ConsentPersister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, pageOpts ...keysetpagination.Option) (_ []flow.Flow, _ *keysetpagination.Paginator, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsGrantedConsentRequests")
defer otelx.End(span, &err)
paginator := keysetpagination.NewPaginator(append(pageOpts,
@ -381,10 +353,9 @@ func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subj
var fs []flow.Flow
err = p.QueryWithNetwork(ctx).
Where("state IN (?, ?)", flow.FlowStateConsentUsed, flow.FlowStateConsentUnused).
Where("(state IN (?, ?) OR state IS NULL)", flow.FlowStateConsentUsed, flow.FlowStateConsentUnused).
Where("subject = ?", subject).
Where("consent_skip = FALSE").
Where("consent_error = '{}'").
Where("(expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)").
Scope(keysetpagination.Paginate[flow.Flow](paginator)).
All(&fs)
@ -399,8 +370,8 @@ func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subj
return fs, nextPage, nil
}
func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Context, subject, sid string, pageOpts ...keysetpagination.Option) (_ []flow.Flow, _ *keysetpagination.Paginator, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsSessionGrantedConsentRequests", trace.WithAttributes(attribute.String("sid", sid)))
func (p *ConsentPersister) FindSubjectsSessionGrantedConsentRequests(ctx context.Context, subject, sid string, pageOpts ...keysetpagination.Option) (_ []flow.Flow, _ *keysetpagination.Paginator, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsSessionGrantedConsentRequests", trace.WithAttributes(attribute.String("sid", sid)))
defer otelx.End(span, &err)
paginator := keysetpagination.NewPaginator(append(pageOpts,
@ -409,11 +380,10 @@ func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Contex
var fs []flow.Flow
err = p.QueryWithNetwork(ctx).
Where("state IN (?, ?)", flow.FlowStateConsentUsed, flow.FlowStateConsentUnused).
Where("(state IN (?, ?) OR state IS NULL)", flow.FlowStateConsentUsed, flow.FlowStateConsentUnused).
Where("subject = ?", subject).
Where("login_session_id = ?", sid).
Where("consent_skip = FALSE").
Where("consent_error = '{}'").
Where("(expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)").
Scope(keysetpagination.Paginate[flow.Flow](paginator)).
All(&fs)
@ -428,22 +398,22 @@ func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Contex
return fs, nextPage, nil
}
func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout")
func (p *ConsentPersister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout")
defer otelx.End(span, &err)
return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
}
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout")
func (p *ConsentPersister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout")
defer otelx.End(span, &err)
return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
}
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients",
func (p *ConsentPersister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients",
trace.WithAttributes(attribute.String("sid", sid)))
defer otelx.End(span, &err)
@ -574,7 +544,7 @@ func (p *Persister) FlushInactiveLoginConsentRequests(ctx context.Context, notAf
SELECT login_challenge
FROM hydra_oauth2_flow
WHERE (
(state != ?)
(state != ? AND state IS NOT NULL)
OR (login_error IS NOT NULL AND login_error <> '{}' AND login_error <> '')
OR (consent_error IS NOT NULL AND consent_error <> '{}' AND consent_error <> '')
)

View File

@ -598,13 +598,13 @@ func (s *PersisterTestSuite) TestFindGrantedAndRememberedConsentRequests() {
}))
f.State = flow.FlowStateConsentUsed
require.NoError(t, r.Persister().CreateConsentSession(s.t1, f))
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, f))
actual, err := r.Persister().FindGrantedAndRememberedConsentRequest(s.t2, cl.ID, f.Subject)
actual, err := r.ConsentManager().FindGrantedAndRememberedConsentRequest(s.t2, cl.ID, f.Subject)
require.ErrorIs(t, err, consent.ErrNoPreviousConsentFound)
assert.Nil(t, actual)
actual, err = r.Persister().FindGrantedAndRememberedConsentRequest(s.t1, cl.ID, f.Subject)
actual, err = r.ConsentManager().FindGrantedAndRememberedConsentRequest(s.t1, cl.ID, f.Subject)
require.NoError(t, err)
assert.EqualValues(t, req.ConsentRequestID, actual.ConsentRequestID)
})
@ -619,12 +619,12 @@ func (s *PersisterTestSuite) TestFindSubjectsGrantedConsentRequests() {
f := newFlow(s.t1NID, cl.ID, "sub", sqlxx.NullString(sessionID))
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, cl))
require.NoError(t, r.Persister().CreateConsentSession(s.t1, f))
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, f))
_, _, err := r.Persister().FindSubjectsGrantedConsentRequests(s.t2, f.Subject)
_, _, err := r.ConsentManager().FindSubjectsGrantedConsentRequests(s.t2, f.Subject)
require.ErrorIs(t, err, consent.ErrNoPreviousConsentFound)
actual, nextPage, err := r.Persister().FindSubjectsGrantedConsentRequests(s.t1, f.Subject)
actual, nextPage, err := r.ConsentManager().FindSubjectsGrantedConsentRequests(s.t1, f.Subject)
require.NoError(t, err)
require.Len(t, actual, 1)
assert.Equal(t, f.ConsentRequestID.String(), actual[0].ConsentRequestID.String())
@ -700,7 +700,12 @@ func (s *PersisterTestSuite) TestFlushInactiveLoginConsentRequests() {
f.RequestedAt = time.Now().Add(-24 * time.Hour)
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, cl))
require.NoError(t, r.Persister().Connection(s.t1).Create(&persistencesql.FlowWithConstantColumns{Flow: f, State: f.State}))
type legacyFlow struct {
*flow.Flow
State flow.State `db:"state"`
}
require.NoError(t, r.Persister().Connection(s.t1).Create(&legacyFlow{Flow: f, State: f.State}))
actual := flow.Flow{}
@ -1104,7 +1109,7 @@ func (s *PersisterTestSuite) TestHandleConsentRequest() {
f.ConsentRequestID = "consent-request-id"
actual, err := r.Persister().FindGrantedAndRememberedConsentRequest(s.t1, c1.ID, f.Subject)
actual, err := r.ConsentManager().FindGrantedAndRememberedConsentRequest(s.t1, c1.ID, f.Subject)
require.ErrorIs(t, err, consent.ErrNoPreviousConsentFound)
assert.Nil(t, actual)
@ -1114,8 +1119,8 @@ func (s *PersisterTestSuite) TestHandleConsentRequest() {
f.State = flow.FlowStateConsentUsed
require.NoError(t, r.Persister().CreateConsentSession(s.t1, f))
actual, err = r.Persister().FindGrantedAndRememberedConsentRequest(s.t1, c1.ID, f.Subject)
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, f))
actual, err = r.ConsentManager().FindGrantedAndRememberedConsentRequest(s.t1, c1.ID, f.Subject)
require.NoError(t, err)
assert.EqualValues(t, f.ConsentRequestID, actual.ConsentRequestID)
})
@ -1184,19 +1189,19 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithBackChannelLogo
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()})
require.NoError(t, r.Persister().CreateConsentSession(s.t1, t1f1))
require.NoError(t, r.Persister().CreateConsentSession(s.t2, t2f1))
require.NoError(t, r.Persister().CreateConsentSession(s.t2, t2f2))
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, t1f1))
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t2, t2f1))
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t2, t2f2))
t1f1.ConsentRequestID = sqlxx.NullString(t1f1.ID)
t2f1.ConsentRequestID = sqlxx.NullString(t2f1.ID)
t2f2.ConsentRequestID = sqlxx.NullString(t2f2.ID)
cs, err := r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String())
cs, err := r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String())
require.NoError(t, err)
require.Equal(t, 1, len(cs))
cs, err = r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t2, "sub", t1f1.SessionID.String())
cs, err = r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(s.t2, "sub", t1f1.SessionID.String())
require.NoError(t, err)
require.Equal(t, 2, len(cs))
})
@ -1223,19 +1228,19 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()})
require.NoError(t, r.Persister().CreateConsentSession(s.t1, t1f1))
require.NoError(t, r.Persister().CreateConsentSession(s.t2, t2f1))
require.NoError(t, r.Persister().CreateConsentSession(s.t2, t2f2))
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, t1f1))
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t2, t2f1))
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t2, t2f2))
t1f1.ConsentRequestID = sqlxx.NullString(t1f1.ID)
t2f1.ConsentRequestID = sqlxx.NullString(t2f1.ID)
t2f2.ConsentRequestID = sqlxx.NullString(t2f2.ID)
cs, err := r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String())
cs, err := r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String())
require.NoError(t, err)
require.Equal(t, 1, len(cs))
cs, err = r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t2, "sub", t1f1.SessionID.String())
cs, err = r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t2, "sub", t1f1.SessionID.String())
require.NoError(t, err)
require.Equal(t, 2, len(cs))
})
@ -1453,13 +1458,13 @@ func (s *PersisterTestSuite) TestRevokeSubjectClientConsentSession() {
f.RequestedAt = time.Now().Add(-24 * time.Hour)
persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, cl))
require.NoError(t, r.Persister().CreateConsentSession(s.t1, f))
require.NoError(t, r.ConsentManager().CreateConsentSession(s.t1, f))
actual := flow.Flow{}
require.NoError(t, r.Persister().RevokeSubjectClientConsentSession(s.t2, "sub", cl.ID), "should not error if nothing was found")
require.NoError(t, r.ConsentManager().RevokeSubjectClientConsentSession(s.t2, "sub", cl.ID), "should not error if nothing was found")
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, f.ID))
require.NoError(t, r.Persister().RevokeSubjectClientConsentSession(s.t1, "sub", cl.ID))
require.NoError(t, r.ConsentManager().RevokeSubjectClientConsentSession(s.t1, "sub", cl.ID))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, f.ID))
})
}