diff --git a/consent/manager.go b/consent/manager.go index fe4b01835..01f5603d2 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -52,8 +52,7 @@ type ( CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error) - ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) - ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) + ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error) CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) error GetLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error) diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 8acf7f7c2..e8bc95d1f 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -19,16 +19,16 @@ import ( "github.com/pborman/uuid" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "github.com/ory/hydra/v2/flow" - "github.com/ory/hydra/v2/oauth2/flowctx" - "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver/config" + "github.com/ory/hydra/v2/flow" + "github.com/ory/hydra/v2/oauth2/flowctx" "github.com/ory/hydra/v2/x" "github.com/ory/x/errorsx" "github.com/ory/x/mapx" @@ -705,13 +705,7 @@ func (s *DefaultStrategy) verifyConsent(ctx context.Context, _ http.ResponseWrit return session, f, nil } -func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) { - clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid) - if err != nil { - return nil, err - } - - var urls []string +func generateFrontChannelLogoutURLs(clients []client.Client, iss, sid string) (urls []string, _ error) { for _, c := range clients { u, err := url.Parse(c.FrontChannelLogoutURI) if err != nil { @@ -719,7 +713,7 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su } urls = append(urls, urlx.SetQuery(u, url.Values{ - "iss": {s.c.IssuerURL(ctx).String()}, + "iss": {iss}, "sid": {sid}, }).String()) } @@ -727,13 +721,17 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su return urls, nil } -func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid string) error { - ctx := r.Context() - clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid) - if err != nil { - return err +func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, clients []client.Client, sid string) (err error) { + if len(clients) == 0 { + return nil } + ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.executeBackChannelLogout", + trace.WithAttributes( + attribute.Int("clients", len(clients)), + attribute.String("sid", sid))) + defer otelx.End(span, &err) + openIDKeyID, err := s.r.OpenIDJWTStrategy().GetPublicKeyID(ctx) if err != nil { return err @@ -771,15 +769,19 @@ func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.GetID(), token: t}) } - span := trace.SpanFromContext(ctx) cl := s.r.HTTPClient(ctx) - execute := func(t task) { - log := s.r.Logger().WithRequest(r). + cl.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + execute := func(ctx context.Context, t task) { + log := s.r.Logger(). WithField("client_id", t.clientID). WithField("backchannel_logout_url", t.url) body := url.Values{"logout_token": {t.token}}.Encode() - req, err := retryablehttp.NewRequestWithContext(trace.ContextWithSpan(context.Background(), span), "POST", t.url, []byte(body)) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + req, err := retryablehttp.NewRequestWithContext(ctx, "POST", t.url, []byte(body)) if err != nil { log.WithError(err).Error("Unable to construct OpenID Connect Back-Channel Logout Request") return @@ -803,7 +805,7 @@ func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid } for _, t := range tasks { - go execute(t) + go execute(context.WithoutCancel(ctx), t) } return nil @@ -999,9 +1001,8 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon return nil, errorsx.WithStack(ErrAbortOAuth2Request) } -func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Request, subject string, sid string) error { - ctx := r.Context() - if err := s.executeBackChannelLogout(r, subject, sid); err != nil { +func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(ctx context.Context, clients []client.Client, sid string) error { + if err := s.executeBackChannelLogout(ctx, clients, sid); err != nil { return err } @@ -1028,7 +1029,7 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Reque func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) { verifier := r.URL.Query().Get("logout_verifier") - lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(r.Context(), verifier) + lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(ctx, verifier) if err != nil { return nil, err } @@ -1069,12 +1070,17 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri _, _ = s.revokeAuthenticationCookie(w, r, store) // Cookie removal is optional - urls, err := s.generateFrontChannelLogoutURLs(r.Context(), lr.Subject, lr.SessionID) + frontChannelClients, backChannelClients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, lr.Subject, lr.SessionID) if err != nil { return nil, err } - if err := s.performBackChannelLogoutAndDeleteSession(r, lr.Subject, lr.SessionID); err != nil { + urls, err := generateFrontChannelLogoutURLs(frontChannelClients, s.c.IssuerURL(ctx).String(), lr.SessionID) + if err != nil { + return nil, err + } + + if err := s.performBackChannelLogoutAndDeleteSession(ctx, backChannelClients, lr.SessionID); err != nil { return nil, err } @@ -1110,7 +1116,12 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.Respo return lsErr } - if err := s.performBackChannelLogoutAndDeleteSession(r, loginSession.Subject, sid); err != nil { + _, clients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, loginSession.Subject, sid) + if err != nil { + return err + } + + if err := s.performBackChannelLogoutAndDeleteSession(ctx, clients, sid); err != nil { return err } diff --git a/consent/test/manager_test_helpers.go b/consent/test/manager_test_helpers.go index 986b4f314..5ca046264 100644 --- a/consent/test/manager_test_helpers.go +++ b/consent/test/manager_test_helpers.go @@ -974,16 +974,11 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo } } - t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) { - actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, ls.Subject, ls.ID) + t.Run(fmt.Sprintf("method=ListClientsWithLogoutURLsForSubjectAndSID/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) { + front, back, err := m.ListClientsWithLogoutURLsForSubjectAndSID(ctx, ls.Subject, ls.ID) require.NoError(t, err) - check(t, frontChannels, actual) - }) - - t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) { - actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(ctx, ls.Subject, ls.ID) - require.NoError(t, err) - check(t, backChannels, actual) + check(t, frontChannels, front) + check(t, backChannels, back) }) } }) diff --git a/persistence/sql/migrations/20250206111700000000_flow_login_session_id_idx.cockroach.autocommit.down.sql b/persistence/sql/migrations/20250206111700000000_flow_login_session_id_idx.cockroach.autocommit.down.sql new file mode 100644 index 000000000..baafe85d6 --- /dev/null +++ b/persistence/sql/migrations/20250206111700000000_flow_login_session_id_idx.cockroach.autocommit.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx; diff --git a/persistence/sql/migrations/20250206111700000000_flow_login_session_id_idx.cockroach.autocommit.up.sql b/persistence/sql/migrations/20250206111700000000_flow_login_session_id_idx.cockroach.autocommit.up.sql new file mode 100644 index 000000000..028b04582 --- /dev/null +++ b/persistence/sql/migrations/20250206111700000000_flow_login_session_id_idx.cockroach.autocommit.up.sql @@ -0,0 +1 @@ +CREATE INDEX IF NOT EXISTS hydra_oauth2_flow_nid_sid_subject_idx ON hydra_oauth2_flow (nid, login_session_id, subject) STORING (client_id) WHERE login_session_id IS NOT NULL; diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index ad3c25b3d..1756b8863 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -618,49 +618,72 @@ func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests [ return result, nil } -func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout") - defer otelx.End(span, &err) - - return p.listUserAuthenticatedClients(ctx, subject, sid, "front") -} - -func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout") - defer otelx.End(span, &err) - - return p.listUserAuthenticatedClients(ctx, subject, sid, "back") -} - -func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients", +func (p *Persister) ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListClientsWithLogoutURLsForSubjectAndSID", trace.WithAttributes(attribute.String("sid", sid))) defer otelx.End(span, &err) - if err := p.Connection(ctx).RawQuery( - /* #nosec G201 - channel can either be "front" or "back" */ - fmt.Sprintf(` -SELECT DISTINCT c.* FROM hydra_client as c -JOIN hydra_oauth2_flow as f ON (c.id = f.client_id AND c.nid = f.nid) -WHERE - f.subject=? AND - c.%schannel_logout_uri != '' AND - c.%schannel_logout_uri IS NOT NULL AND - f.login_session_id = ? AND - f.nid = ? AND - c.nid = ?`, - channel, - channel, - ), - subject, - sid, - p.NetworkID(ctx), - p.NetworkID(ctx), - ).All(&cs); err != nil { - return nil, sqlcon.HandleError(err) + defer func() { + span.SetAttributes( + attribute.Int("withFrontChannelURL", len(withFrontChannelURL)), + attribute.Int("withBackChannelURL", len(withBackChannelURL))) + }() + + var ( + cols = pop.NewModel(new(client.Client), ctx).Columns().Readable() + clientTable, flowTable = p.clientFlowTableNamesWithQueryHint(p.Connection(ctx).Dialect.Name()) + + q = fmt.Sprintf(` + SELECT %s FROM %s c + WHERE id IN ( + SELECT DISTINCT client_id + FROM %s f + WHERE + f.nid = ? + AND f.login_session_id = ? + AND f.subject = ? + ) + AND c.nid = ? + AND ( + (c.frontchannel_logout_uri IS NOT NULL AND c.frontchannel_logout_uri != '') + OR c.backchannel_logout_uri != '' + )`, + cols.QuotedString(p.Connection(ctx).Dialect), + clientTable, + flowTable) + + nid = p.NetworkID(ctx) + cs []client.Client + ) + + err = p.Connection(ctx).RawQuery(q, nid, sid, subject, nid).All(&cs) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil, nil + } + if err != nil { + return nil, nil, sqlcon.HandleError(err) } - return cs, nil + for _, c := range cs { + if c.FrontChannelLogoutURI != "" { + withFrontChannelURL = append(withFrontChannelURL, c) + } + if c.BackChannelLogoutURI != "" { + withBackChannelURL = append(withBackChannelURL, c) + } + } + + return withFrontChannelURL, withBackChannelURL, nil +} + +func (p *Persister) clientFlowTableNamesWithQueryHint(dialect string) (clientTable, flowTable string) { + switch dialect { + case "cockroach": + return "hydra_client@primary", "hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx" + // TODO: more + default: + return "hydra_client", "hydra_oauth2_flow" + } } func (p *Persister) CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) (err error) { diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go index 93bccdcfe..779006266 100644 --- a/persistence/sql/persister_nid_test.go +++ b/persistence/sql/persister_nid_test.go @@ -1584,11 +1584,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithBackChannelLogo }) require.NoError(t, err) - cs, err := r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String()) + _, cs, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String()) require.NoError(t, err) require.Equal(t, 1, len(cs)) - cs, err = r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t2, "sub", t1f1.SessionID.String()) + _, cs, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String()) require.NoError(t, err) require.Equal(t, 2, len(cs)) }) @@ -1667,11 +1667,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog }) require.NoError(t, err) - cs, err := r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String()) + cs, _, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String()) require.NoError(t, err) require.Equal(t, 1, len(cs)) - cs, err = r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t2, "sub", t1f1.SessionID.String()) + cs, _, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String()) require.NoError(t, err) require.Equal(t, 2, len(cs)) }) diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index b4c88ef01..95a02c472 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -38,7 +38,7 @@ func init() { func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registry, t2 driver.Registry) { t.Run("package=client/manager="+k, func(t *testing.T) { - t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(context.Background()), t1.ClientManager(), t2.ClientManager())) + t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(ctx), t1.ClientManager(), t2.ClientManager())) t.Run("case=autogenerate-key", client.TestHelperClientAutoGenerateKey(k, t1.ClientManager()))