feat: faster OIDC front/back-channel logout

This commit is contained in:
Arne Luenser 2025-02-05 23:28:53 +01:00
parent a7579b8d1d
commit aa1ee32d6a
No known key found for this signature in database
8 changed files with 112 additions and 82 deletions

View File

@ -52,8 +52,7 @@ type (
CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error
GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error)
ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error)
CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) error
GetLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error)

View File

@ -19,16 +19,16 @@ import (
"github.com/pborman/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/oauth2/flowctx"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/oauth2/flowctx"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/errorsx"
"github.com/ory/x/mapx"
@ -705,13 +705,7 @@ func (s *DefaultStrategy) verifyConsent(ctx context.Context, _ http.ResponseWrit
return session, f, nil
}
func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) {
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
if err != nil {
return nil, err
}
var urls []string
func generateFrontChannelLogoutURLs(clients []client.Client, iss, sid string) (urls []string, _ error) {
for _, c := range clients {
u, err := url.Parse(c.FrontChannelLogoutURI)
if err != nil {
@ -719,7 +713,7 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
}
urls = append(urls, urlx.SetQuery(u, url.Values{
"iss": {s.c.IssuerURL(ctx).String()},
"iss": {iss},
"sid": {sid},
}).String())
}
@ -727,13 +721,17 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
return urls, nil
}
func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid string) error {
ctx := r.Context()
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
if err != nil {
return err
func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, clients []client.Client, sid string) (err error) {
if len(clients) == 0 {
return nil
}
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.executeBackChannelLogout",
trace.WithAttributes(
attribute.Int("clients", len(clients)),
attribute.String("sid", sid)))
defer otelx.End(span, &err)
openIDKeyID, err := s.r.OpenIDJWTStrategy().GetPublicKeyID(ctx)
if err != nil {
return err
@ -771,15 +769,19 @@ func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid
tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.GetID(), token: t})
}
span := trace.SpanFromContext(ctx)
cl := s.r.HTTPClient(ctx)
execute := func(t task) {
log := s.r.Logger().WithRequest(r).
cl.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
execute := func(ctx context.Context, t task) {
log := s.r.Logger().
WithField("client_id", t.clientID).
WithField("backchannel_logout_url", t.url)
body := url.Values{"logout_token": {t.token}}.Encode()
req, err := retryablehttp.NewRequestWithContext(trace.ContextWithSpan(context.Background(), span), "POST", t.url, []byte(body))
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
req, err := retryablehttp.NewRequestWithContext(ctx, "POST", t.url, []byte(body))
if err != nil {
log.WithError(err).Error("Unable to construct OpenID Connect Back-Channel Logout Request")
return
@ -803,7 +805,7 @@ func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid
}
for _, t := range tasks {
go execute(t)
go execute(context.WithoutCancel(ctx), t)
}
return nil
@ -999,9 +1001,8 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
return nil, errorsx.WithStack(ErrAbortOAuth2Request)
}
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Request, subject string, sid string) error {
ctx := r.Context()
if err := s.executeBackChannelLogout(r, subject, sid); err != nil {
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(ctx context.Context, clients []client.Client, sid string) error {
if err := s.executeBackChannelLogout(ctx, clients, sid); err != nil {
return err
}
@ -1028,7 +1029,7 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Reque
func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) {
verifier := r.URL.Query().Get("logout_verifier")
lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(r.Context(), verifier)
lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(ctx, verifier)
if err != nil {
return nil, err
}
@ -1069,12 +1070,17 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri
_, _ = s.revokeAuthenticationCookie(w, r, store) // Cookie removal is optional
urls, err := s.generateFrontChannelLogoutURLs(r.Context(), lr.Subject, lr.SessionID)
frontChannelClients, backChannelClients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, lr.Subject, lr.SessionID)
if err != nil {
return nil, err
}
if err := s.performBackChannelLogoutAndDeleteSession(r, lr.Subject, lr.SessionID); err != nil {
urls, err := generateFrontChannelLogoutURLs(frontChannelClients, s.c.IssuerURL(ctx).String(), lr.SessionID)
if err != nil {
return nil, err
}
if err := s.performBackChannelLogoutAndDeleteSession(ctx, backChannelClients, lr.SessionID); err != nil {
return nil, err
}
@ -1110,7 +1116,12 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.Respo
return lsErr
}
if err := s.performBackChannelLogoutAndDeleteSession(r, loginSession.Subject, sid); err != nil {
_, clients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, loginSession.Subject, sid)
if err != nil {
return err
}
if err := s.performBackChannelLogoutAndDeleteSession(ctx, clients, sid); err != nil {
return err
}

View File

@ -974,16 +974,11 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo
}
}
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, ls.Subject, ls.ID)
t.Run(fmt.Sprintf("method=ListClientsWithLogoutURLsForSubjectAndSID/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
front, back, err := m.ListClientsWithLogoutURLsForSubjectAndSID(ctx, ls.Subject, ls.ID)
require.NoError(t, err)
check(t, frontChannels, actual)
})
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(ctx, ls.Subject, ls.ID)
require.NoError(t, err)
check(t, backChannels, actual)
check(t, frontChannels, front)
check(t, backChannels, back)
})
}
})

View File

@ -0,0 +1 @@
DROP INDEX IF EXISTS hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx;

View File

@ -0,0 +1 @@
CREATE INDEX IF NOT EXISTS hydra_oauth2_flow_nid_sid_subject_idx ON hydra_oauth2_flow (nid, login_session_id, subject) STORING (client_id) WHERE login_session_id IS NOT NULL;

View File

@ -618,49 +618,72 @@ func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests [
return result, nil
}
func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout")
defer otelx.End(span, &err)
return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
}
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout")
defer otelx.End(span, &err)
return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
}
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients",
func (p *Persister) ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListClientsWithLogoutURLsForSubjectAndSID",
trace.WithAttributes(attribute.String("sid", sid)))
defer otelx.End(span, &err)
if err := p.Connection(ctx).RawQuery(
/* #nosec G201 - channel can either be "front" or "back" */
fmt.Sprintf(`
SELECT DISTINCT c.* FROM hydra_client as c
JOIN hydra_oauth2_flow as f ON (c.id = f.client_id AND c.nid = f.nid)
WHERE
f.subject=? AND
c.%schannel_logout_uri != '' AND
c.%schannel_logout_uri IS NOT NULL AND
f.login_session_id = ? AND
f.nid = ? AND
c.nid = ?`,
channel,
channel,
),
subject,
sid,
p.NetworkID(ctx),
p.NetworkID(ctx),
).All(&cs); err != nil {
return nil, sqlcon.HandleError(err)
defer func() {
span.SetAttributes(
attribute.Int("withFrontChannelURL", len(withFrontChannelURL)),
attribute.Int("withBackChannelURL", len(withBackChannelURL)))
}()
var (
cols = pop.NewModel(new(client.Client), ctx).Columns().Readable()
clientTable, flowTable = p.clientFlowTableNamesWithQueryHint(p.Connection(ctx).Dialect.Name())
q = fmt.Sprintf(`
SELECT %s FROM %s c
WHERE id IN (
SELECT DISTINCT client_id
FROM %s f
WHERE
f.nid = ?
AND f.login_session_id = ?
AND f.subject = ?
)
AND c.nid = ?
AND (
(c.frontchannel_logout_uri IS NOT NULL AND c.frontchannel_logout_uri != '')
OR c.backchannel_logout_uri != ''
)`,
cols.QuotedString(p.Connection(ctx).Dialect),
clientTable,
flowTable)
nid = p.NetworkID(ctx)
cs []client.Client
)
err = p.Connection(ctx).RawQuery(q, nid, sid, subject, nid).All(&cs)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, nil
}
if err != nil {
return nil, nil, sqlcon.HandleError(err)
}
return cs, nil
for _, c := range cs {
if c.FrontChannelLogoutURI != "" {
withFrontChannelURL = append(withFrontChannelURL, c)
}
if c.BackChannelLogoutURI != "" {
withBackChannelURL = append(withBackChannelURL, c)
}
}
return withFrontChannelURL, withBackChannelURL, nil
}
func (p *Persister) clientFlowTableNamesWithQueryHint(dialect string) (clientTable, flowTable string) {
switch dialect {
case "cockroach":
return "hydra_client@primary", "hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx"
// TODO: more
default:
return "hydra_client", "hydra_oauth2_flow"
}
}
func (p *Persister) CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) (err error) {

View File

@ -1584,11 +1584,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithBackChannelLogo
})
require.NoError(t, err)
cs, err := r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String())
_, cs, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String())
require.NoError(t, err)
require.Equal(t, 1, len(cs))
cs, err = r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t2, "sub", t1f1.SessionID.String())
_, cs, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String())
require.NoError(t, err)
require.Equal(t, 2, len(cs))
})
@ -1667,11 +1667,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog
})
require.NoError(t, err)
cs, err := r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String())
cs, _, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String())
require.NoError(t, err)
require.Equal(t, 1, len(cs))
cs, err = r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t2, "sub", t1f1.SessionID.String())
cs, _, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String())
require.NoError(t, err)
require.Equal(t, 2, len(cs))
})

View File

@ -38,7 +38,7 @@ func init() {
func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registry, t2 driver.Registry) {
t.Run("package=client/manager="+k, func(t *testing.T) {
t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(context.Background()), t1.ClientManager(), t2.ClientManager()))
t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(ctx), t1.ClientManager(), t2.ClientManager()))
t.Run("case=autogenerate-key", client.TestHelperClientAutoGenerateKey(k, t1.ClientManager()))