mirror of https://github.com/ory/hydra
feat: faster OIDC front/back-channel logout
This commit is contained in:
parent
a7579b8d1d
commit
aa1ee32d6a
|
|
@ -52,8 +52,7 @@ type (
|
|||
CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error
|
||||
GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error)
|
||||
|
||||
ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
|
||||
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
|
||||
ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error)
|
||||
|
||||
CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) error
|
||||
GetLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error)
|
||||
|
|
|
|||
|
|
@ -19,16 +19,16 @@ import (
|
|||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/ory/hydra/v2/flow"
|
||||
"github.com/ory/hydra/v2/oauth2/flowctx"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"github.com/ory/fosite/token/jwt"
|
||||
"github.com/ory/hydra/v2/client"
|
||||
"github.com/ory/hydra/v2/driver/config"
|
||||
"github.com/ory/hydra/v2/flow"
|
||||
"github.com/ory/hydra/v2/oauth2/flowctx"
|
||||
"github.com/ory/hydra/v2/x"
|
||||
"github.com/ory/x/errorsx"
|
||||
"github.com/ory/x/mapx"
|
||||
|
|
@ -705,13 +705,7 @@ func (s *DefaultStrategy) verifyConsent(ctx context.Context, _ http.ResponseWrit
|
|||
return session, f, nil
|
||||
}
|
||||
|
||||
func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) {
|
||||
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var urls []string
|
||||
func generateFrontChannelLogoutURLs(clients []client.Client, iss, sid string) (urls []string, _ error) {
|
||||
for _, c := range clients {
|
||||
u, err := url.Parse(c.FrontChannelLogoutURI)
|
||||
if err != nil {
|
||||
|
|
@ -719,7 +713,7 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
|
|||
}
|
||||
|
||||
urls = append(urls, urlx.SetQuery(u, url.Values{
|
||||
"iss": {s.c.IssuerURL(ctx).String()},
|
||||
"iss": {iss},
|
||||
"sid": {sid},
|
||||
}).String())
|
||||
}
|
||||
|
|
@ -727,13 +721,17 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
|
|||
return urls, nil
|
||||
}
|
||||
|
||||
func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid string) error {
|
||||
ctx := r.Context()
|
||||
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
|
||||
if err != nil {
|
||||
return err
|
||||
func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, clients []client.Client, sid string) (err error) {
|
||||
if len(clients) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.executeBackChannelLogout",
|
||||
trace.WithAttributes(
|
||||
attribute.Int("clients", len(clients)),
|
||||
attribute.String("sid", sid)))
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
openIDKeyID, err := s.r.OpenIDJWTStrategy().GetPublicKeyID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -771,15 +769,19 @@ func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid
|
|||
tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.GetID(), token: t})
|
||||
}
|
||||
|
||||
span := trace.SpanFromContext(ctx)
|
||||
cl := s.r.HTTPClient(ctx)
|
||||
execute := func(t task) {
|
||||
log := s.r.Logger().WithRequest(r).
|
||||
cl.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
execute := func(ctx context.Context, t task) {
|
||||
log := s.r.Logger().
|
||||
WithField("client_id", t.clientID).
|
||||
WithField("backchannel_logout_url", t.url)
|
||||
|
||||
body := url.Values{"logout_token": {t.token}}.Encode()
|
||||
req, err := retryablehttp.NewRequestWithContext(trace.ContextWithSpan(context.Background(), span), "POST", t.url, []byte(body))
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
req, err := retryablehttp.NewRequestWithContext(ctx, "POST", t.url, []byte(body))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Unable to construct OpenID Connect Back-Channel Logout Request")
|
||||
return
|
||||
|
|
@ -803,7 +805,7 @@ func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid
|
|||
}
|
||||
|
||||
for _, t := range tasks {
|
||||
go execute(t)
|
||||
go execute(context.WithoutCancel(ctx), t)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -999,9 +1001,8 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
|
|||
return nil, errorsx.WithStack(ErrAbortOAuth2Request)
|
||||
}
|
||||
|
||||
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Request, subject string, sid string) error {
|
||||
ctx := r.Context()
|
||||
if err := s.executeBackChannelLogout(r, subject, sid); err != nil {
|
||||
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(ctx context.Context, clients []client.Client, sid string) error {
|
||||
if err := s.executeBackChannelLogout(ctx, clients, sid); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -1028,7 +1029,7 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Reque
|
|||
func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) {
|
||||
verifier := r.URL.Query().Get("logout_verifier")
|
||||
|
||||
lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(r.Context(), verifier)
|
||||
lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(ctx, verifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1069,12 +1070,17 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri
|
|||
|
||||
_, _ = s.revokeAuthenticationCookie(w, r, store) // Cookie removal is optional
|
||||
|
||||
urls, err := s.generateFrontChannelLogoutURLs(r.Context(), lr.Subject, lr.SessionID)
|
||||
frontChannelClients, backChannelClients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, lr.Subject, lr.SessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.performBackChannelLogoutAndDeleteSession(r, lr.Subject, lr.SessionID); err != nil {
|
||||
urls, err := generateFrontChannelLogoutURLs(frontChannelClients, s.c.IssuerURL(ctx).String(), lr.SessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.performBackChannelLogoutAndDeleteSession(ctx, backChannelClients, lr.SessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -1110,7 +1116,12 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.Respo
|
|||
return lsErr
|
||||
}
|
||||
|
||||
if err := s.performBackChannelLogoutAndDeleteSession(r, loginSession.Subject, sid); err != nil {
|
||||
_, clients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, loginSession.Subject, sid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.performBackChannelLogoutAndDeleteSession(ctx, clients, sid); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -974,16 +974,11 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo
|
|||
}
|
||||
}
|
||||
|
||||
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
|
||||
actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, ls.Subject, ls.ID)
|
||||
t.Run(fmt.Sprintf("method=ListClientsWithLogoutURLsForSubjectAndSID/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
|
||||
front, back, err := m.ListClientsWithLogoutURLsForSubjectAndSID(ctx, ls.Subject, ls.ID)
|
||||
require.NoError(t, err)
|
||||
check(t, frontChannels, actual)
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
|
||||
actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(ctx, ls.Subject, ls.ID)
|
||||
require.NoError(t, err)
|
||||
check(t, backChannels, actual)
|
||||
check(t, frontChannels, front)
|
||||
check(t, backChannels, back)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
DROP INDEX IF EXISTS hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx;
|
||||
|
|
@ -0,0 +1 @@
|
|||
CREATE INDEX IF NOT EXISTS hydra_oauth2_flow_nid_sid_subject_idx ON hydra_oauth2_flow (nid, login_session_id, subject) STORING (client_id) WHERE login_session_id IS NOT NULL;
|
||||
|
|
@ -618,49 +618,72 @@ func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests [
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
|
||||
}
|
||||
|
||||
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
|
||||
}
|
||||
|
||||
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients",
|
||||
func (p *Persister) ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListClientsWithLogoutURLsForSubjectAndSID",
|
||||
trace.WithAttributes(attribute.String("sid", sid)))
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
if err := p.Connection(ctx).RawQuery(
|
||||
/* #nosec G201 - channel can either be "front" or "back" */
|
||||
fmt.Sprintf(`
|
||||
SELECT DISTINCT c.* FROM hydra_client as c
|
||||
JOIN hydra_oauth2_flow as f ON (c.id = f.client_id AND c.nid = f.nid)
|
||||
WHERE
|
||||
f.subject=? AND
|
||||
c.%schannel_logout_uri != '' AND
|
||||
c.%schannel_logout_uri IS NOT NULL AND
|
||||
f.login_session_id = ? AND
|
||||
f.nid = ? AND
|
||||
c.nid = ?`,
|
||||
channel,
|
||||
channel,
|
||||
),
|
||||
subject,
|
||||
sid,
|
||||
p.NetworkID(ctx),
|
||||
p.NetworkID(ctx),
|
||||
).All(&cs); err != nil {
|
||||
return nil, sqlcon.HandleError(err)
|
||||
defer func() {
|
||||
span.SetAttributes(
|
||||
attribute.Int("withFrontChannelURL", len(withFrontChannelURL)),
|
||||
attribute.Int("withBackChannelURL", len(withBackChannelURL)))
|
||||
}()
|
||||
|
||||
var (
|
||||
cols = pop.NewModel(new(client.Client), ctx).Columns().Readable()
|
||||
clientTable, flowTable = p.clientFlowTableNamesWithQueryHint(p.Connection(ctx).Dialect.Name())
|
||||
|
||||
q = fmt.Sprintf(`
|
||||
SELECT %s FROM %s c
|
||||
WHERE id IN (
|
||||
SELECT DISTINCT client_id
|
||||
FROM %s f
|
||||
WHERE
|
||||
f.nid = ?
|
||||
AND f.login_session_id = ?
|
||||
AND f.subject = ?
|
||||
)
|
||||
AND c.nid = ?
|
||||
AND (
|
||||
(c.frontchannel_logout_uri IS NOT NULL AND c.frontchannel_logout_uri != '')
|
||||
OR c.backchannel_logout_uri != ''
|
||||
)`,
|
||||
cols.QuotedString(p.Connection(ctx).Dialect),
|
||||
clientTable,
|
||||
flowTable)
|
||||
|
||||
nid = p.NetworkID(ctx)
|
||||
cs []client.Client
|
||||
)
|
||||
|
||||
err = p.Connection(ctx).RawQuery(q, nid, sid, subject, nid).All(&cs)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
return cs, nil
|
||||
for _, c := range cs {
|
||||
if c.FrontChannelLogoutURI != "" {
|
||||
withFrontChannelURL = append(withFrontChannelURL, c)
|
||||
}
|
||||
if c.BackChannelLogoutURI != "" {
|
||||
withBackChannelURL = append(withBackChannelURL, c)
|
||||
}
|
||||
}
|
||||
|
||||
return withFrontChannelURL, withBackChannelURL, nil
|
||||
}
|
||||
|
||||
func (p *Persister) clientFlowTableNamesWithQueryHint(dialect string) (clientTable, flowTable string) {
|
||||
switch dialect {
|
||||
case "cockroach":
|
||||
return "hydra_client@primary", "hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx"
|
||||
// TODO: more
|
||||
default:
|
||||
return "hydra_client", "hydra_oauth2_flow"
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Persister) CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) (err error) {
|
||||
|
|
|
|||
|
|
@ -1584,11 +1584,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithBackChannelLogo
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cs, err := r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String())
|
||||
_, cs, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(cs))
|
||||
|
||||
cs, err = r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t2, "sub", t1f1.SessionID.String())
|
||||
_, cs, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(cs))
|
||||
})
|
||||
|
|
@ -1667,11 +1667,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cs, err := r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String())
|
||||
cs, _, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(cs))
|
||||
|
||||
cs, err = r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t2, "sub", t1f1.SessionID.String())
|
||||
cs, _, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(cs))
|
||||
})
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ func init() {
|
|||
|
||||
func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registry, t2 driver.Registry) {
|
||||
t.Run("package=client/manager="+k, func(t *testing.T) {
|
||||
t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(context.Background()), t1.ClientManager(), t2.ClientManager()))
|
||||
t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(ctx), t1.ClientManager(), t2.ClientManager()))
|
||||
|
||||
t.Run("case=autogenerate-key", client.TestHelperClientAutoGenerateKey(k, t1.ClientManager()))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue