From db17987979a06f85e8da4ea25f441fa34f421a95 Mon Sep 17 00:00:00 2001 From: shaunn Date: Fri, 14 Nov 2025 09:23:51 -0800 Subject: [PATCH] chore: fosite and hydra interface enhancements GitOrigin-RevId: f149949f3fdd7b1264ce78c011d49dee61af52a2 --- client/manager.go | 2 +- client/registry.go | 4 +- consent/handler.go | 4 +- consent/handler_test.go | 26 +- driver/di.go | 22 + driver/registry_sql.go | 157 ++- fosite/access_request_handler_test.go | 20 +- fosite/authorize_request_handler.go | 2 +- fosite/authorize_request_handler_test.go | 42 +- fosite/client_authentication.go | 8 +- fosite/compose/compose.go | 8 +- fosite/compose/compose_oauth2.go | 8 +- fosite/compose/compose_strategy.go | 12 +- fosite/compose/compose_userinfo_vc.go | 4 +- fosite/device_request_handler_test.go | 16 +- fosite/generate-mocks.sh | 103 +- fosite/generate.go | 1 + fosite/handler/oauth2/flow_resource_owner.go | 8 +- .../oauth2/flow_resource_owner_storage.go | 8 +- .../oauth2/flow_resource_owner_test.go | 55 +- fosite/handler/oauth2/storage.go | 6 - fosite/handler/rfc8628/strategy.go | 7 - fosite/handler/verifiable/handler.go | 4 +- fosite/handler/verifiable/handler_test.go | 8 +- fosite/handler/verifiable/nonce.go | 4 + ...wner_password_credentials_grant_storage.go | 30 - ...word_credentials_grant_storage_provider.go | 58 + fosite/internal/storage.go | 15 +- fosite/introspection_request_handler.go | 2 +- .../pushed_authorize_request_handler_test.go | 40 +- fosite/revoke_handler_test.go | 14 +- fosite/storage.go | 2 +- fosite/storage/memory.go | 6 +- fositex/config.go | 4 +- oauth2/fosite_store_helpers_test.go | 14 +- oauth2/oauth2_device_code_test.go | 7 +- oauth2/registry.go | 4 +- persistence/sql/persister.go | 123 +- persistence/sql/persister_authenticate.go | 1 + persistence/sql/persister_client.go | 107 +- persistence/sql/persister_consent.go | 2 +- persistence/sql/persister_device.go | 119 +- persistence/sql/persister_grant_jwk.go | 95 +- persistence/sql/persister_jwk.go | 11 +- persistence/sql/persister_nid_test.go | 18 +- persistence/sql/persister_nonce.go | 5 + persistence/sql/persister_oauth2.go | 1071 ++++++++--------- x/fosite_storer.go | 6 +- 48 files changed, 1242 insertions(+), 1051 deletions(-) create mode 100644 fosite/internal/resource_owner_password_credentials_grant_storage_provider.go diff --git a/client/manager.go b/client/manager.go index 4219450ef..b616797fc 100644 --- a/client/manager.go +++ b/client/manager.go @@ -25,7 +25,7 @@ type Manager interface { } type Storage interface { - GetClient(ctx context.Context, id string) (fosite.Client, error) + fosite.ClientManager CreateClient(ctx context.Context, c *Client) error diff --git a/client/registry.go b/client/registry.go index 60d55f0ab..4691fecc5 100644 --- a/client/registry.go +++ b/client/registry.go @@ -26,6 +26,8 @@ type Registry interface { OpenIDJWTStrategy() jwk.JWTSigner OAuth2HMACStrategy() foauth2.CoreStrategy OAuth2EnigmaStrategy() *enigma.HMACStrategy - RFC8628HMACStrategy() rfc8628.RFC8628CodeStrategy + rfc8628.DeviceRateLimitStrategyProvider + rfc8628.DeviceCodeStrategyProvider + rfc8628.UserCodeStrategyProvider config.Provider } diff --git a/consent/handler.go b/consent/handler.go index de7ea31a2..ee17592fe 100644 --- a/consent/handler.go +++ b/consent/handler.go @@ -1032,7 +1032,7 @@ func (h *Handler) acceptUserCodeRequest(w http.ResponseWriter, r *http.Request) return } - userCodeSignature, err := h.r.RFC8628HMACStrategy().UserCodeSignature(r.Context(), reqBody.UserCode) + userCodeSignature, err := h.r.UserCodeStrategy().UserCodeSignature(r.Context(), reqBody.UserCode) if err != nil { h.r.Writer().WriteError(w, r, fosite.ErrServerError.WithWrap(err).WithHint(`The 'user_code' signature could not be computed.`)) return @@ -1044,7 +1044,7 @@ func (h *Handler) acceptUserCodeRequest(w http.ResponseWriter, r *http.Request) return } - if err := h.r.RFC8628HMACStrategy().ValidateUserCode(ctx, userCodeRequest, reqBody.UserCode); err != nil { + if err := h.r.UserCodeStrategy().ValidateUserCode(ctx, userCodeRequest, reqBody.UserCode); err != nil { h.r.Writer().WriteError(w, r, fosite.ErrInvalidRequest.WithWrap(err).WithHint(`The 'user_code' session could not be found or has expired or is otherwise malformed.`)) return } diff --git a/consent/handler_test.go b/consent/handler_test.go index 4ea44e97a..6a46a39ba 100644 --- a/consent/handler_test.go +++ b/consent/handler_test.go @@ -310,9 +310,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) { deviceRequest.Client = cl deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject")) - _, deviceCodeSig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context()) + _, deviceCodeSig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context()) require.NoError(t, err) - userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context()) + userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context()) require.NoError(t, err) require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodeSig, sig, deviceRequest)) @@ -337,7 +337,7 @@ func TestAcceptCodeDeviceRequest(t *testing.T) { }) t.Run("case=random user_code, not persisted in the database", func(t *testing.T) { - userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context()) + userCode, _, err := reg.UserCodeStrategy().GenerateUserCode(t.Context()) require.NoError(t, err) resp := submitCode(t, &flow.AcceptDeviceUserCodeRequest{UserCode: userCode}, challenge) @@ -360,7 +360,7 @@ func TestAcceptCodeDeviceRequest(t *testing.T) { }) t.Run("case=empty challenge", func(t *testing.T) { - userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context()) + userCode, _, err := reg.UserCodeStrategy().GenerateUserCode(t.Context()) require.NoError(t, err) resp := submitCode(t, &flow.AcceptDeviceUserCodeRequest{UserCode: userCode}, "") require.EqualValues(t, http.StatusBadRequest, resp.StatusCode) @@ -372,7 +372,7 @@ func TestAcceptCodeDeviceRequest(t *testing.T) { }) t.Run("case=invalid challenge", func(t *testing.T) { - userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context()) + userCode, _, err := reg.UserCodeStrategy().GenerateUserCode(t.Context()) require.NoError(t, err) resp := submitCode(t, &hydra.AcceptDeviceUserCodeRequest{UserCode: &userCode}, "invalid-challenge") require.EqualValues(t, http.StatusNotFound, resp.StatusCode) @@ -388,9 +388,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) { deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject")) deviceRequest.Session.SetExpiresAt(fosite.UserCode, time.Now().Add(-time.Hour).UTC()) - _, deviceCodeSig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context()) + _, deviceCodeSig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context()) require.NoError(t, err) - userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context()) + userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context()) require.NoError(t, err) require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodeSig, sig, deviceRequest)) @@ -409,9 +409,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) { deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject")) deviceRequest.UserCodeState = fosite.UserCodeAccepted - _, deviceCodeSig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context()) + _, deviceCodeSig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context()) require.NoError(t, err) - userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context()) + userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context()) require.NoError(t, err) require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodeSig, sig, deviceRequest)) @@ -430,9 +430,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) { deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject")) deviceRequest.UserCodeState = fosite.UserCodeRejected - _, deviceCodesig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context()) + _, deviceCodesig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context()) require.NoError(t, err) - userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context()) + userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context()) require.NoError(t, err) require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodesig, sig, deviceRequest)) @@ -450,9 +450,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) { deviceRequest.Client = cl deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject")) - _, deviceCodeSig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context()) + _, deviceCodeSig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context()) require.NoError(t, err) - userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context()) + userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context()) require.NoError(t, err) require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodeSig, sig, deviceRequest)) diff --git a/driver/di.go b/driver/di.go index 61246d71b..6b3467084 100644 --- a/driver/di.go +++ b/driver/di.go @@ -4,6 +4,7 @@ package driver import ( + "github.com/ory/hydra/v2/fosite" "github.com/ory/hydra/v2/fosite/handler/oauth2" "github.com/ory/hydra/v2/jwk" ) @@ -29,3 +30,24 @@ func RegistryWithKeyManager(km func(r *RegistrySQL) (jwk.Manager, error)) Regist return err } } + +func RegistryWithOAuth2Provider(pr func(r *RegistrySQL) fosite.OAuth2Provider) RegistryModifier { + return func(r *RegistrySQL) error { + r.fop = pr(r) + return nil + } +} + +func RegistryWithAccessTokenStorage(s func(r *RegistrySQL) oauth2.AccessTokenStorage) RegistryModifier { + return func(r *RegistrySQL) error { + r.accessTokenStorage = s(r) + return nil + } +} + +func RegistryWithAuthorizeCodeStorage(s func(r *RegistrySQL) oauth2.AuthorizeCodeStorage) RegistryModifier { + return func(r *RegistrySQL) error { + r.authorizeCodeStorage = s(r) + return nil + } +} diff --git a/driver/registry_sql.go b/driver/registry_sql.go index 7c1987d25..55cf7e0f9 100644 --- a/driver/registry_sql.go +++ b/driver/registry_sql.go @@ -12,14 +12,13 @@ import ( "strings" "time" - "github.com/urfave/negroni" - "github.com/gorilla/sessions" "github.com/hashicorp/go-retryablehttp" _ "github.com/jackc/pgx/v5/stdlib" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/cors" + "github.com/urfave/negroni" "go.uber.org/automaxprocs/maxprocs" "github.com/ory/herodot" @@ -31,7 +30,10 @@ import ( "github.com/ory/hydra/v2/fosite/compose" foauth2 "github.com/ory/hydra/v2/fosite/handler/oauth2" "github.com/ory/hydra/v2/fosite/handler/openid" + "github.com/ory/hydra/v2/fosite/handler/pkce" + "github.com/ory/hydra/v2/fosite/handler/rfc7523" "github.com/ory/hydra/v2/fosite/handler/rfc8628" + "github.com/ory/hydra/v2/fosite/handler/verifiable" "github.com/ory/hydra/v2/fosite/token/hmac" "github.com/ory/hydra/v2/fositex" "github.com/ory/hydra/v2/hsm" @@ -58,35 +60,37 @@ import ( ) type RegistrySQL struct { - l, al *logrusx.Logger - conf *config.DefaultProvider - fh fosite.Hasher - cv *client.Validator - ctxer contextx.Contextualizer - hh *healthx.Handler - kc *aead.AESGCM - flowc *aead.XChaCha20Poly1305 - cos consent.Strategy - writer herodot.Writer - hsm hsm.Context - forv *openid.OpenIDConnectRequestValidator - fop fosite.OAuth2Provider - trc *otelx.Tracer - tracerWrapper func(*otelx.Tracer) *otelx.Tracer - arhs []oauth2.AccessRequestHook - basePersister *sql.BasePersister - oc fosite.Configurator - oidcs jwk.JWTSigner - ats jwk.JWTSigner - hmacs foauth2.CoreStrategy - enigmaHMAC *hmac.HMACStrategy - deviceHmac rfc8628.RFC8628CodeStrategy - fc *fositex.Config - publicCORS *cors.Cors - kratos kratos.Client - fositeFactories []fositex.Factory - migrator *sql.MigrationManager - dbOptsModifier []func(details *pop.ConnectionDetails) + l, al *logrusx.Logger + conf *config.DefaultProvider + fh fosite.Hasher + cv *client.Validator + ctxer contextx.Contextualizer + hh *healthx.Handler + kc *aead.AESGCM + flowc *aead.XChaCha20Poly1305 + cos consent.Strategy + writer herodot.Writer + hsm hsm.Context + forv *openid.OpenIDConnectRequestValidator + fop fosite.OAuth2Provider + trc *otelx.Tracer + tracerWrapper func(*otelx.Tracer) *otelx.Tracer + arhs []oauth2.AccessRequestHook + basePersister *sql.BasePersister + accessTokenStorage foauth2.AccessTokenStorage + authorizeCodeStorage foauth2.AuthorizeCodeStorage + oc fosite.Configurator + oidcs jwk.JWTSigner + ats jwk.JWTSigner + hmacs foauth2.CoreStrategy + enigmaHMAC *hmac.HMACStrategy + deviceHmac *rfc8628.DefaultDeviceStrategy + fc *fositex.Config + publicCORS *cors.Cors + kratos kratos.Client + fositeFactories []fositex.Factory + migrator *sql.MigrationManager + dbOptsModifier []func(details *pop.ConnectionDetails) keyManager jwk.Manager initialPing func(ctx context.Context, l *logrusx.Logger, p *sql.BasePersister) error @@ -98,6 +102,66 @@ var ( _ registry = (*RegistrySQL)(nil) ) +func (m *RegistrySQL) FositeClientManager() fosite.ClientManager { + return m.OAuth2Storage() +} + +// AuthorizeCodeStorage implements foauth2.AuthorizeCodeStorageProvider +func (m *RegistrySQL) AuthorizeCodeStorage() foauth2.AuthorizeCodeStorage { + if m.authorizeCodeStorage != nil { + return m.authorizeCodeStorage + } + return m.OAuth2Storage() +} + +// AccessTokenStorage implements foauth2.AccessTokenStorageProvider +func (m *RegistrySQL) AccessTokenStorage() foauth2.AccessTokenStorage { + if m.accessTokenStorage != nil { + return m.accessTokenStorage + } + return m.OAuth2Storage() +} + +// RefreshTokenStorage implements foauth2.RefreshTokenStorageProvider +func (m *RegistrySQL) RefreshTokenStorage() foauth2.RefreshTokenStorage { + return m.OAuth2Storage() +} + +// TokenRevocationStorage implements foauth2.TokenRevocationStorageProvider +func (m *RegistrySQL) TokenRevocationStorage() foauth2.TokenRevocationStorage { + return m.OAuth2Storage() +} + +// ResourceOwnerPasswordCredentialsGrantStorage implements foauth2.ResourceOwnerPasswordCredentialsGrantStorage +func (m *RegistrySQL) ResourceOwnerPasswordCredentialsGrantStorage() foauth2.ResourceOwnerPasswordCredentialsGrantStorage { + return m.OAuth2Storage() +} + +// OpenIDConnectRequestStorage implements openid.OIDCRequestStorageProvider +func (m *RegistrySQL) OpenIDConnectRequestStorage() openid.OpenIDConnectRequestStorage { + return m.OAuth2Storage() +} + +// PKCERequestStorage implements pkce.PKCERequestStorageProvider +func (m *RegistrySQL) PKCERequestStorage() pkce.PKCERequestStorage { + return m.OAuth2Storage() +} + +// DeviceAuthStorage implements rfc8628.DeviceAuthStorageProvider +func (m *RegistrySQL) DeviceAuthStorage() rfc8628.DeviceAuthStorage { + return m.OAuth2Storage() +} + +// RFC7523KeyStorage implements rfc7523.RFC7523KeyStorageProvider +func (m *RegistrySQL) RFC7523KeyStorage() rfc7523.RFC7523KeyStorage { + return m.OAuth2Storage() +} + +// NonceManager implements verifiable.NonceManager +func (m *RegistrySQL) NonceManager() verifiable.NonceManager { + return m.OAuth2Storage() +} + // defaultInitialPing is the default function that will be called within RegistrySQL.Init to make sure // the database is reachable. It can be injected for test purposes by changing the value // of RegistrySQL.initialPing. @@ -209,7 +273,9 @@ func (m *RegistrySQL) ConsentManager() consent.Manager { func (m *RegistrySQL) ObfuscatedSubjectManager() consent.ObfuscatedSubjectManager { return m.Persister() } -func (m *RegistrySQL) LoginManager() consent.LoginManager { return m.Persister() } + +func (m *RegistrySQL) LoginManager() consent.LoginManager { return m.Persister() } + func (m *RegistrySQL) LogoutManager() consent.LogoutManager { return m.Persister() } func (m *RegistrySQL) OAuth2Storage() x.FositeStorer { @@ -412,7 +478,7 @@ func (m *RegistrySQL) HTTPClient(ctx context.Context, opts ...httpx.ResilientOpt func (m *RegistrySQL) OAuth2Provider() fosite.OAuth2Provider { if m.fop == nil { - m.fop = fosite.NewOAuth2Provider(m.OAuth2Storage(), m.OAuth2ProviderConfig()) + m.fop = fosite.NewOAuth2Provider(m, m.OAuth2ProviderConfig()) } return m.fop } @@ -445,14 +511,29 @@ func (m *RegistrySQL) OAuth2HMACStrategy() foauth2.CoreStrategy { return m.hmacs } -// RFC8628HMACStrategy returns the rfc8628 strategy -func (m *RegistrySQL) RFC8628HMACStrategy() rfc8628.RFC8628CodeStrategy { +// rfc8628HMACStrategy returns the rfc8628 strategy +func (m *RegistrySQL) rfc8628HMACStrategy() *rfc8628.DefaultDeviceStrategy { if m.deviceHmac == nil { m.deviceHmac = compose.NewDeviceStrategy(m.OAuth2Config()) } return m.deviceHmac } +// DeviceRateLimitStrategy implements rfc8628.DeviceRateLimitStrategyProvider +func (m *RegistrySQL) DeviceRateLimitStrategy() rfc8628.DeviceRateLimitStrategy { + return m.rfc8628HMACStrategy() +} + +// DeviceCodeStrategy implements rfc8628.DeviceCodeStrategyProvider +func (m *RegistrySQL) DeviceCodeStrategy() rfc8628.DeviceCodeStrategy { + return m.rfc8628HMACStrategy() +} + +// UserCodeStrategy implements rfc8628.UserCodeStrategyProvider +func (m *RegistrySQL) UserCodeStrategy() rfc8628.UserCodeStrategy { + return m.rfc8628HMACStrategy() +} + func (m *RegistrySQL) OAuth2Config() *fositex.Config { if m.fc == nil { m.fc = fositex.NewConfig(m) @@ -471,7 +552,7 @@ func (m *RegistrySQL) OAuth2ProviderConfig() fosite.Configurator { conf := m.OAuth2Config() hmacAtStrategy := m.OAuth2HMACStrategy() - deviceHmacAtStrategy := m.RFC8628HMACStrategy() + deviceHmacAtStrategy := m.rfc8628HMACStrategy() oidcSigner := m.OpenIDJWTStrategy() atSigner := m.AccessTokenJWTStrategy() jwtAtStrategy := &foauth2.DefaultJWTStrategy{ @@ -480,13 +561,13 @@ func (m *RegistrySQL) OAuth2ProviderConfig() fosite.Configurator { Config: conf, } - conf.LoadDefaultHandlers(&compose.CommonStrategy{ + conf.LoadDefaultHandlers(m, &compose.CommonStrategy{ CoreStrategy: fositex.NewTokenStrategy(m.Config(), hmacAtStrategy, &foauth2.DefaultJWTStrategy{ Signer: jwtAtStrategy, Strategy: hmacAtStrategy, Config: conf, }), - RFC8628CodeStrategy: deviceHmacAtStrategy, + DeviceStrategy: deviceHmacAtStrategy, OIDCTokenStrategy: &openid.DefaultStrategy{ Config: conf, Signer: oidcSigner, diff --git a/fosite/access_request_handler_test.go b/fosite/access_request_handler_test.go index 58208589a..ac6c910da 100644 --- a/fosite/access_request_handler_test.go +++ b/fosite/access_request_handler_test.go @@ -77,7 +77,7 @@ func TestNewAccessRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, handlers: TokenEndpointHandlers{handler}, @@ -103,7 +103,7 @@ func TestNewAccessRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, handlers: TokenEndpointHandlers{handler}, @@ -118,7 +118,7 @@ func TestNewAccessRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") @@ -136,7 +136,7 @@ func TestNewAccessRequest(t *testing.T) { }, expectErr: ErrServerError, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") @@ -154,7 +154,7 @@ func TestNewAccessRequest(t *testing.T) { "grant_type": {"foo"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") @@ -178,7 +178,7 @@ func TestNewAccessRequest(t *testing.T) { "grant_type": {"foo"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = true handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) @@ -269,7 +269,7 @@ func TestNewAccessRequestWithoutClientAuth(t *testing.T) { }, mock: func() { // despite error from storage, we should success, because client auth is not required - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "foo").Return(nil, errors.New("no client")).Times(1) handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, @@ -308,7 +308,7 @@ func TestNewAccessRequestWithoutClientAuth(t *testing.T) { "grant_type": {"foo"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "foo").Return(anotherClient, nil).Times(1) hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) @@ -383,7 +383,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) { "grant_type": {"foo"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") @@ -402,7 +402,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) { "grant_type": {"foo"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") diff --git a/fosite/authorize_request_handler.go b/fosite/authorize_request_handler.go index 0b1176ef6..14f99e067 100644 --- a/fosite/authorize_request_handler.go +++ b/fosite/authorize_request_handler.go @@ -357,7 +357,7 @@ func (f *Fosite) newAuthorizeRequest(ctx context.Context, r *http.Request, isPAR } } - client, err := f.Store.ClientManager().GetClient(ctx, request.GetRequestForm().Get("client_id")) + client, err := f.Store.FositeClientManager().GetClient(ctx, request.GetRequestForm().Get("client_id")) if err != nil { return request, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client does not exist.").WithWrap(err).WithDebug(err.Error())) } diff --git a/fosite/authorize_request_handler_test.go b/fosite/authorize_request_handler_test.go index ff339c77a..952fba00a 100644 --- a/fosite/authorize_request_handler_test.go +++ b/fosite/authorize_request_handler_test.go @@ -47,7 +47,7 @@ func TestNewAuthorizeRequest(t *testing.T) { r: &http.Request{}, expectedError: ErrInvalidClient, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")) }, }, @@ -58,7 +58,7 @@ func TestNewAuthorizeRequest(t *testing.T) { query: url.Values{"redirect_uri": []string{"invalid"}}, expectedError: ErrInvalidClient, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")) }, }, @@ -69,7 +69,7 @@ func TestNewAuthorizeRequest(t *testing.T) { query: url.Values{"redirect_uri": []string{"https://foo.bar/cb"}}, expectedError: ErrInvalidClient, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")) }, }, @@ -82,7 +82,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidRequest, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}}, nil) }, }, @@ -96,7 +96,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidRequest, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}}, nil) }, }, @@ -110,7 +110,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidRequest, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}}, nil) }, }, @@ -125,7 +125,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidState, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}}, nil) }, }, @@ -141,7 +141,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidState, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}}, nil) }, }, @@ -157,7 +157,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "scope": {"foo bar baz"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}}, nil) }, expectedError: ErrInvalidScope, @@ -175,7 +175,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, Audience: []string{"https://cloud.ory.sh/api"}, @@ -196,7 +196,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, @@ -232,7 +232,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, @@ -268,7 +268,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "audience": {"test value", ""}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, @@ -304,7 +304,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"web+application://callback"}, @@ -340,7 +340,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, @@ -376,7 +376,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "response_mode": {"unknown"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}}, nil) }, expectedError: ErrUnsupportedResponseMode, @@ -394,7 +394,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "response_mode": {"form_post"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}}, nil) }, expectedError: ErrUnsupportedResponseMode, @@ -412,7 +412,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "response_mode": {"form_post"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ DefaultClient: &DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, @@ -438,7 +438,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ DefaultClient: &DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, @@ -481,7 +481,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ DefaultClient: &DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, @@ -524,7 +524,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ DefaultClient: &DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, diff --git a/fosite/client_authentication.go b/fosite/client_authentication.go index 25eaaeedb..a98cd659d 100644 --- a/fosite/client_authentication.go +++ b/fosite/client_authentication.go @@ -91,7 +91,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt } } - client, err = f.Store.ClientManager().GetClient(ctx, clientID) + client, err = f.Store.FositeClientManager().GetClient(ctx, clientID) if err != nil { return nil, errorsx.WithStack(ErrInvalidClient.WithWrap(err).WithDebug(err.Error())) } @@ -156,7 +156,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) } else if jti, ok = claims["jti"].(string); !ok || len(jti) == 0 { return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'jti' from 'client_assertion' must be set but is not.")) - } else if f.Store.ClientManager().ClientAssertionJWTValid(ctx, jti) != nil { + } else if f.Store.FositeClientManager().ClientAssertionJWTValid(ctx, jti) != nil { return nil, errorsx.WithStack(ErrJTIKnown.WithHint("Claim 'jti' from 'client_assertion' MUST only be used once.")) } @@ -177,7 +177,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt if err != nil { return nil, errorsx.WithStack(err) } - if err := f.Store.ClientManager().SetClientAssertionJWT(ctx, jti, time.Unix(expiry, 0)); err != nil { + if err := f.Store.FositeClientManager().SetClientAssertionJWT(ctx, jti, time.Unix(expiry, 0)); err != nil { return nil, err } @@ -197,7 +197,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt return nil, err } - client, err := f.Store.ClientManager().GetClient(ctx, clientID) + client, err := f.Store.FositeClientManager().GetClient(ctx, clientID) if err != nil { return nil, errorsx.WithStack(ErrInvalidClient.WithWrap(err).WithDebug(err.Error())) } diff --git a/fosite/compose/compose.go b/fosite/compose/compose.go index 5ec242f0e..2c00edc7a 100644 --- a/fosite/compose/compose.go +++ b/fosite/compose/compose.go @@ -70,10 +70,10 @@ func ComposeAllEnabled(config *fosite.Config, storage fosite.Storage, key interf config, storage, &CommonStrategy{ - CoreStrategy: NewOAuth2HMACStrategy(config), - RFC8628CodeStrategy: NewDeviceStrategy(config), - OIDCTokenStrategy: NewOpenIDConnectStrategy(keyGetter, config), - Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, + CoreStrategy: NewOAuth2HMACStrategy(config), + DeviceStrategy: NewDeviceStrategy(config), + OIDCTokenStrategy: NewOpenIDConnectStrategy(keyGetter, config), + Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter}, }, OAuth2AuthorizeExplicitFactory, OAuth2AuthorizeImplicitFactory, diff --git a/fosite/compose/compose_oauth2.go b/fosite/compose/compose_oauth2.go index a8a40775a..29023d424 100644 --- a/fosite/compose/compose_oauth2.go +++ b/fosite/compose/compose_oauth2.go @@ -77,8 +77,12 @@ func OAuth2ResourceOwnerPasswordCredentialsFactory(config fosite.Configurator, s oauth2.AccessTokenStrategyProvider oauth2.RefreshTokenStrategyProvider }), - Storage: storage.(oauth2.ResourceOwnerPasswordCredentialsGrantStorage), - Config: config, + Storage: storage.(interface { + oauth2.ResourceOwnerPasswordCredentialsGrantStorageProvider + oauth2.AccessTokenStorageProvider + oauth2.RefreshTokenStorageProvider + }), + Config: config, } } diff --git a/fosite/compose/compose_strategy.go b/fosite/compose/compose_strategy.go index ac5b75a54..a867ddaf5 100644 --- a/fosite/compose/compose_strategy.go +++ b/fosite/compose/compose_strategy.go @@ -15,9 +15,9 @@ import ( ) type CommonStrategy struct { - CoreStrategy oauth2.CoreStrategy - RFC8628CodeStrategy rfc8628.RFC8628CodeStrategy - OIDCTokenStrategy openid.OpenIDConnectTokenStrategy + CoreStrategy oauth2.CoreStrategy + DeviceStrategy *rfc8628.DefaultDeviceStrategy + OIDCTokenStrategy openid.OpenIDConnectTokenStrategy jwt.Signer } @@ -41,15 +41,15 @@ func (s *CommonStrategy) OpenIDConnectTokenStrategy() openid.OpenIDConnectTokenS // RFC8628 Device Strategy Providers func (s *CommonStrategy) DeviceRateLimitStrategy() rfc8628.DeviceRateLimitStrategy { - return s.RFC8628CodeStrategy + return s.DeviceStrategy } func (s *CommonStrategy) DeviceCodeStrategy() rfc8628.DeviceCodeStrategy { - return s.RFC8628CodeStrategy + return s.DeviceStrategy } func (s *CommonStrategy) UserCodeStrategy() rfc8628.UserCodeStrategy { - return s.RFC8628CodeStrategy + return s.DeviceStrategy } type HMACSHAStrategyConfigurator interface { diff --git a/fosite/compose/compose_userinfo_vc.go b/fosite/compose/compose_userinfo_vc.go index a333afbc4..dcea03f68 100644 --- a/fosite/compose/compose_userinfo_vc.go +++ b/fosite/compose/compose_userinfo_vc.go @@ -12,7 +12,7 @@ import ( // handler. func OIDCUserinfoVerifiableCredentialFactory(config fosite.Configurator, storage fosite.Storage, strategy any) any { return &verifiable.Handler{ - NonceManager: storage.(verifiable.NonceManager), - Config: config, + NonceManagerProvider: storage.(verifiable.NonceManagerProvider), + Config: config, } } diff --git a/fosite/device_request_handler_test.go b/fosite/device_request_handler_test.go index 4c4147a1c..beb5d9cc9 100644 --- a/fosite/device_request_handler_test.go +++ b/fosite/device_request_handler_test.go @@ -64,7 +64,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { expectedError: ErrInvalidClient, method: "POST", mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(nil, errors.New("")) }, }, { @@ -75,7 +75,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { }, method: "POST", mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil) }, expectedError: ErrInvalidScope, @@ -88,7 +88,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { }, method: "POST", mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil) }, expectedError: ErrInvalidRequest, @@ -100,7 +100,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { }, method: "POST", mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id_2")).Return(authCodeClient, nil) }, expectedError: ErrInvalidGrant, @@ -112,7 +112,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { }, method: "POST", mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil) }, }} { @@ -166,7 +166,7 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) { expectedError: ErrInvalidClient, method: "POST", mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("")) }, @@ -183,7 +183,7 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) { expectedError: ErrInvalidRequest, method: "POST", mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil) }, @@ -199,7 +199,7 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) { }, method: "POST", mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil) }, diff --git a/fosite/generate-mocks.sh b/fosite/generate-mocks.sh index 44778924d..222c8564f 100755 --- a/fosite/generate-mocks.sh +++ b/fosite/generate-mocks.sh @@ -1,55 +1,56 @@ #!/bin/bash -mockgen -package internal -destination internal/access_request.go github.com/ory/fosite AccessRequester -mockgen -package internal -destination internal/access_response.go github.com/ory/fosite AccessResponder -mockgen -package internal -destination internal/access_token_storage.go github.com/ory/fosite/handler/oauth2 AccessTokenStorage -mockgen -package internal -destination internal/access_token_storage_provider.go github.com/ory/fosite/handler/oauth2 AccessTokenStorageProvider -mockgen -package internal -destination internal/access_token_strategy.go github.com/ory/fosite/handler/oauth2 AccessTokenStrategy -mockgen -package internal -destination internal/access_token_strategy_provider.go github.com/ory/fosite/handler/oauth2 AccessTokenStrategyProvider -mockgen -package internal -destination internal/authorize_code_storage.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStorage -mockgen -package internal -destination internal/authorize_code_storage_provider.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStorageProvider -mockgen -package internal -destination internal/authorize_code_strategy.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStrategy -mockgen -package internal -destination internal/authorize_code_strategy_provider.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStrategyProvider -mockgen -package internal -destination internal/authorize_endpoint_handler.go github.com/ory/fosite AuthorizeEndpointHandler -mockgen -package internal -destination internal/authorize_endpoint_handlers_provider.go github.com/ory/fosite AuthorizeEndpointHandlersProvider -mockgen -package internal -destination internal/authorize_request.go github.com/ory/fosite AuthorizeRequester -mockgen -package internal -destination internal/authorize_response.go github.com/ory/fosite AuthorizeResponder -mockgen -package internal -destination internal/client.go github.com/ory/fosite Client -mockgen -package internal -destination internal/client_manager.go github.com/ory/fosite ClientManager -mockgen -package internal -destination internal/oauth2_storage.go github.com/ory/fosite/handler/oauth2 CoreStorage -mockgen -package internal -destination internal/oauth2_strategy.go github.com/ory/fosite/handler/oauth2 CoreStrategy -mockgen -package internal -destination internal/device_auth_storage.go github.com/ory/fosite/handler/rfc8628 DeviceAuthStorage -mockgen -package internal -destination internal/device_auth_storage_provider.go github.com/ory/fosite/handler/rfc8628 DeviceAuthStorageProvider -mockgen -package internal -destination internal/device_code_strategy.go github.com/ory/fosite/handler/rfc8628 DeviceCodeStrategy -mockgen -package internal -destination internal/device_code_strategy_provider.go github.com/ory/fosite/handler/rfc8628 DeviceCodeStrategyProvider -mockgen -package internal -destination internal/device_rate_limit_strategy.go github.com/ory/fosite/handler/rfc8628 DeviceRateLimitStrategy -mockgen -package internal -destination internal/device_rate_limit_strategy_provider.go github.com/ory/fosite/handler/rfc8628 DeviceRateLimitStrategyProvider -mockgen -package internal -destination internal/hash.go github.com/ory/fosite Hasher -mockgen -package internal -destination internal/open_id_connect_token_strategy.go github.com/ory/fosite/handler/openid OpenIDConnectTokenStrategy -mockgen -package internal -destination internal/open_id_connect_token_strategy_provider.go github.com/ory/fosite/handler/openid OpenIDConnectTokenStrategyProvider -mockgen -package internal -destination internal/open_id_connect_request_storage.go github.com/ory/fosite/handler/openid OpenIDConnectRequestStorage -mockgen -package internal -destination internal/open_id_connect_request_storage_provider.go github.com/ory/fosite/handler/openid OpenIDConnectRequestStorageProvider -mockgen -package internal -destination internal/par_storage.go github.com/ory/fosite PARStorage -mockgen -package internal -destination internal/par_storage_provider.go github.com/ory/fosite PARStorageProvider -mockgen -package internal -destination internal/pkce_request_storage.go github.com/ory/fosite/handler/pkce PKCERequestStorage -mockgen -package internal -destination internal/pkce_request_storage_provider.go github.com/ory/fosite/handler/pkce PKCERequestStorageProvider -mockgen -package internal -destination internal/refresh_token_storage.go github.com/ory/fosite/handler/oauth2 RefreshTokenStorage -mockgen -package internal -destination internal/refresh_token_storage_provider.go github.com/ory/fosite/handler/oauth2 RefreshTokenStorageProvider -mockgen -package internal -destination internal/refresh_token_strategy.go github.com/ory/fosite/handler/oauth2 RefreshTokenStrategy -mockgen -package internal -destination internal/refresh_token_strategy_provider.go github.com/ory/fosite/handler/oauth2 RefreshTokenStrategyProvider -mockgen -package internal -destination internal/request.go github.com/ory/fosite Requester -mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage.go github.com/ory/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorage -mockgen -package internal -destination internal/revocation_handler.go github.com/ory/fosite RevocationHandler -mockgen -package internal -destination internal/revocation_handlers_provider.go github.com/ory/fosite RevocationHandlersProvider -mockgen -package internal -destination internal/rfc7523_key_storage.go github.com/ory/fosite/handler/rfc7523 RFC7523KeyStorage -mockgen -package internal -destination internal/rfc7523_key_storage_provider.go github.com/ory/fosite/handler/rfc7523 RFC7523KeyStorageProvider -mockgen -package internal -destination internal/storage.go github.com/ory/fosite Storage -mockgen -package internal -destination internal/token_endpoint_handler.go github.com/ory/fosite TokenEndpointHandler -mockgen -package internal -destination internal/token_introspector.go github.com/ory/fosite TokenIntrospector -mockgen -package internal -destination internal/token_revocation_storage.go github.com/ory/fosite/handler/oauth2 TokenRevocationStorage -mockgen -package internal -destination internal/token_revocation_storage_provider.go github.com/ory/fosite/handler/oauth2 TokenRevocationStorageProvider -mockgen -package internal -destination internal/transactional.go github.com/ory/fosite Transactional -mockgen -package internal -destination internal/user_code_strategy.go github.com/ory/fosite/handler/rfc8628 UserCodeStrategy -mockgen -package internal -destination internal/user_code_strategy_provider.go github.com/ory/fosite/handler/rfc8628 UserCodeStrategyProvider +mockgen -package internal -destination internal/access_request.go github.com/ory/hydra/v2/fosite AccessRequester +mockgen -package internal -destination internal/access_response.go github.com/ory/hydra/v2/fosite AccessResponder +mockgen -package internal -destination internal/access_token_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 AccessTokenStorage +mockgen -package internal -destination internal/access_token_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 AccessTokenStorageProvider +mockgen -package internal -destination internal/access_token_strategy.go github.com/ory/hydra/v2/fosite/handler/oauth2 AccessTokenStrategy +mockgen -package internal -destination internal/access_token_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 AccessTokenStrategyProvider +mockgen -package internal -destination internal/authorize_code_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 AuthorizeCodeStorage +mockgen -package internal -destination internal/authorize_code_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 AuthorizeCodeStorageProvider +mockgen -package internal -destination internal/authorize_code_strategy.go github.com/ory/hydra/v2/fosite/handler/oauth2 AuthorizeCodeStrategy +mockgen -package internal -destination internal/authorize_code_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 AuthorizeCodeStrategyProvider +mockgen -package internal -destination internal/authorize_endpoint_handler.go github.com/ory/hydra/v2/fosite AuthorizeEndpointHandler +mockgen -package internal -destination internal/authorize_endpoint_handlers_provider.go github.com/ory/hydra/v2/fosite AuthorizeEndpointHandlersProvider +mockgen -package internal -destination internal/authorize_request.go github.com/ory/hydra/v2/fosite AuthorizeRequester +mockgen -package internal -destination internal/authorize_response.go github.com/ory/hydra/v2/fosite AuthorizeResponder +mockgen -package internal -destination internal/client.go github.com/ory/hydra/v2/fosite Client +mockgen -package internal -destination internal/client_manager.go github.com/ory/hydra/v2/fosite ClientManager +mockgen -package internal -destination internal/oauth2_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 CoreStorage +mockgen -package internal -destination internal/oauth2_strategy.go github.com/ory/hydra/v2/fosite/handler/oauth2 CoreStrategy +mockgen -package internal -destination internal/device_auth_storage.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceAuthStorage +mockgen -package internal -destination internal/device_auth_storage_provider.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceAuthStorageProvider +mockgen -package internal -destination internal/device_code_strategy.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceCodeStrategy +mockgen -package internal -destination internal/device_code_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceCodeStrategyProvider +mockgen -package internal -destination internal/device_rate_limit_strategy.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceRateLimitStrategy +mockgen -package internal -destination internal/device_rate_limit_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceRateLimitStrategyProvider +mockgen -package internal -destination internal/hash.go github.com/ory/hydra/v2/fosite Hasher +mockgen -package internal -destination internal/open_id_connect_token_strategy.go github.com/ory/hydra/v2/fosite/handler/openid OpenIDConnectTokenStrategy +mockgen -package internal -destination internal/open_id_connect_token_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/openid OpenIDConnectTokenStrategyProvider +mockgen -package internal -destination internal/open_id_connect_request_storage.go github.com/ory/hydra/v2/fosite/handler/openid OpenIDConnectRequestStorage +mockgen -package internal -destination internal/open_id_connect_request_storage_provider.go github.com/ory/hydra/v2/fosite/handler/openid OpenIDConnectRequestStorageProvider +mockgen -package internal -destination internal/par_storage.go github.com/ory/hydra/v2/fosite PARStorage +mockgen -package internal -destination internal/par_storage_provider.go github.com/ory/hydra/v2/fosite PARStorageProvider +mockgen -package internal -destination internal/pkce_request_storage.go github.com/ory/hydra/v2/fosite/handler/pkce PKCERequestStorage +mockgen -package internal -destination internal/pkce_request_storage_provider.go github.com/ory/hydra/v2/fosite/handler/pkce PKCERequestStorageProvider +mockgen -package internal -destination internal/refresh_token_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStorage +mockgen -package internal -destination internal/refresh_token_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStorageProvider +mockgen -package internal -destination internal/refresh_token_strategy.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStrategy +mockgen -package internal -destination internal/refresh_token_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStrategyProvider +mockgen -package internal -destination internal/request.go github.com/ory/hydra/v2/fosite Requester +mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorage +mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorageProvider +mockgen -package internal -destination internal/revocation_handler.go github.com/ory/hydra/v2/fosite RevocationHandler +mockgen -package internal -destination internal/revocation_handlers_provider.go github.com/ory/hydra/v2/fosite RevocationHandlersProvider +mockgen -package internal -destination internal/rfc7523_key_storage.go github.com/ory/hydra/v2/fosite/handler/rfc7523 RFC7523KeyStorage +mockgen -package internal -destination internal/rfc7523_key_storage_provider.go github.com/ory/hydra/v2/fosite/handler/rfc7523 RFC7523KeyStorageProvider +mockgen -package internal -destination internal/storage.go github.com/ory/hydra/v2/fosite Storage +mockgen -package internal -destination internal/token_endpoint_handler.go github.com/ory/hydra/v2/fosite TokenEndpointHandler +mockgen -package internal -destination internal/token_introspector.go github.com/ory/hydra/v2/fosite TokenIntrospector +mockgen -package internal -destination internal/token_revocation_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 TokenRevocationStorage +mockgen -package internal -destination internal/token_revocation_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 TokenRevocationStorageProvider +mockgen -package internal -destination internal/transactional.go github.com/ory/hydra/v2/fosite Transactional +mockgen -package internal -destination internal/user_code_strategy.go github.com/ory/hydra/v2/fosite/handler/rfc8628 UserCodeStrategy +mockgen -package internal -destination internal/user_code_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/rfc8628 UserCodeStrategyProvider goimports -w internal/ diff --git a/fosite/generate.go b/fosite/generate.go index b0d3b0790..1c9af1dad 100644 --- a/fosite/generate.go +++ b/fosite/generate.go @@ -42,6 +42,7 @@ package fosite //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/refresh_token_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStrategyProvider //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/request.go github.com/ory/hydra/v2/fosite Requester //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorage +//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorageProvider //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/revocation_handler.go github.com/ory/hydra/v2/fosite RevocationHandler //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/revocation_handlers_provider.go github.com/ory/hydra/v2/fosite RevocationHandlersProvider //go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/rfc7523_key_storage.go github.com/ory/hydra/v2/fosite/handler/rfc7523 RFC7523KeyStorage diff --git a/fosite/handler/oauth2/flow_resource_owner.go b/fosite/handler/oauth2/flow_resource_owner.go index 1ff973e96..28fdd944f 100644 --- a/fosite/handler/oauth2/flow_resource_owner.go +++ b/fosite/handler/oauth2/flow_resource_owner.go @@ -20,7 +20,11 @@ var _ fosite.TokenEndpointHandler = (*ResourceOwnerPasswordCredentialsGrantHandl // is at the time of this writing going to be omitted in the OAuth 2.1 spec. For more information on why this grant type // is discouraged see: https://www.scottbrady91.com/oauth/why-the-resource-owner-password-credentials-grant-type-is-not-authentication-nor-suitable-for-modern-applications type ResourceOwnerPasswordCredentialsGrantHandler struct { - Storage ResourceOwnerPasswordCredentialsGrantStorage + Storage interface { + ResourceOwnerPasswordCredentialsGrantStorageProvider + AccessTokenStorageProvider + RefreshTokenStorageProvider + } Strategy interface { AccessTokenStrategyProvider RefreshTokenStrategyProvider @@ -64,7 +68,7 @@ func (c *ResourceOwnerPasswordCredentialsGrantHandler) HandleTokenEndpointReques password := request.GetRequestForm().Get("password") if username == "" || password == "" { return errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Username or password are missing from the POST body.")) - } else if sub, err := c.Storage.Authenticate(ctx, username, password); errors.Is(err, fosite.ErrNotFound) { + } else if sub, err := c.Storage.ResourceOwnerPasswordCredentialsGrantStorage().Authenticate(ctx, username, password); errors.Is(err, fosite.ErrNotFound) { return errorsx.WithStack(fosite.ErrInvalidGrant.WithHint("Unable to authenticate the provided username and password credentials.").WithWrap(err).WithDebug(err.Error())) } else if err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) diff --git a/fosite/handler/oauth2/flow_resource_owner_storage.go b/fosite/handler/oauth2/flow_resource_owner_storage.go index e5bac1327..31b1f57a4 100644 --- a/fosite/handler/oauth2/flow_resource_owner_storage.go +++ b/fosite/handler/oauth2/flow_resource_owner_storage.go @@ -7,8 +7,12 @@ import ( "context" ) +// ResourceOwnerPasswordCredentialsGrantStorage provides storage for the resource owner password credentials grant. type ResourceOwnerPasswordCredentialsGrantStorage interface { Authenticate(ctx context.Context, name string, secret string) (subject string, err error) - AccessTokenStorageProvider - RefreshTokenStorageProvider +} + +// ResourceOwnerPasswordCredentialsGrantStorageProvider provides the resource owner password credentials grant storage. +type ResourceOwnerPasswordCredentialsGrantStorageProvider interface { + ResourceOwnerPasswordCredentialsGrantStorage() ResourceOwnerPasswordCredentialsGrantStorage } diff --git a/fosite/handler/oauth2/flow_resource_owner_test.go b/fosite/handler/oauth2/flow_resource_owner_test.go index a9ddc46b4..90649b9ee 100644 --- a/fosite/handler/oauth2/flow_resource_owner_test.go +++ b/fosite/handler/oauth2/flow_resource_owner_test.go @@ -22,7 +22,10 @@ import ( func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) { ctrl := gomock.NewController(t) - store := internal.NewMockResourceOwnerPasswordCredentialsGrantStorage(ctrl) + mockRopcgStorageProvider := internal.NewMockResourceOwnerPasswordCredentialsGrantStorageProvider(ctrl) + mockRopcgStorage := internal.NewMockResourceOwnerPasswordCredentialsGrantStorage(ctrl) + mockAccessTokenStorageProvider := internal.NewMockAccessTokenStorageProvider(ctrl) + mockRefreshTokenStorageProvider := internal.NewMockRefreshTokenStorageProvider(ctrl) t.Cleanup(ctrl.Finish) areq := fosite.NewAccessRequest(new(fosite.DefaultSession)) @@ -72,21 +75,24 @@ func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) { areq.Form.Set("password", "pan") areq.Client = &fosite.DefaultClient{GrantTypes: fosite.Arguments{"password"}, Scopes: []string{"foo-scope"}, Audience: []string{"https://www.ory.sh/api"}} - store.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", fosite.ErrNotFound) + mockRopcgStorageProvider.EXPECT().ResourceOwnerPasswordCredentialsGrantStorage().Return(mockRopcgStorage).Times(1) + mockRopcgStorage.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", fosite.ErrNotFound) }, expectErr: fosite.ErrInvalidGrant, }, { description: "should fail because error on lookup", setup: func(config *fosite.Config) { - store.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", errors.New("")) + mockRopcgStorageProvider.EXPECT().ResourceOwnerPasswordCredentialsGrantStorage().Return(mockRopcgStorage).Times(1) + mockRopcgStorage.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", errors.New("")) }, expectErr: fosite.ErrServerError, }, { description: "should pass", setup: func(config *fosite.Config) { - store.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", nil) + mockRopcgStorageProvider.EXPECT().ResourceOwnerPasswordCredentialsGrantStorage().Return(mockRopcgStorage).Times(1) + mockRopcgStorage.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", nil) }, check: func(areq *fosite.AccessRequest) { // assert.NotEmpty(t, areq.GetSession().GetExpiresAt(fosite.AccessToken)) @@ -102,8 +108,17 @@ func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) { ScopeStrategy: fosite.HierarchicScopeStrategy, AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, } + mockStorage := struct { + *internal.MockResourceOwnerPasswordCredentialsGrantStorageProvider + *internal.MockAccessTokenStorageProvider + *internal.MockRefreshTokenStorageProvider + }{ + MockResourceOwnerPasswordCredentialsGrantStorageProvider: mockRopcgStorageProvider, + MockAccessTokenStorageProvider: mockAccessTokenStorageProvider, + MockRefreshTokenStorageProvider: mockRefreshTokenStorageProvider, + } h := oauth2.ResourceOwnerPasswordCredentialsGrantHandler{ - Storage: store, + Storage: mockStorage, Config: config, } c.setup(config) @@ -123,8 +138,10 @@ func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) { func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) { var ( - mockRopcgStorage *internal.MockResourceOwnerPasswordCredentialsGrantStorage + mockRopcgStorageProvider *internal.MockResourceOwnerPasswordCredentialsGrantStorageProvider + mockAccessTokenStorageProvider *internal.MockAccessTokenStorageProvider mockAccessTokenStorage *internal.MockAccessTokenStorage + mockRefreshTokenStorageProvider *internal.MockRefreshTokenStorageProvider mockRefreshTokenStorage *internal.MockRefreshTokenStorage mockAccessTokenStrategyProvider *internal.MockAccessTokenStrategyProvider mockAccessTokenStrategy *internal.MockAccessTokenStrategy @@ -161,7 +178,7 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) { areq.GrantTypes = fosite.Arguments{"password"} mockAccessTokenStrategyProvider.EXPECT().AccessTokenStrategy().Return(mockAccessTokenStrategy).Times(1) mockAccessTokenStrategy.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return(mockAT, "bar", nil) - mockRopcgStorage.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1) + mockAccessTokenStorageProvider.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1) mockAccessTokenStorage.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) }, expect: func() { @@ -175,11 +192,11 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) { areq.GrantScope("offline") mockRefreshTokenStrategyProvider.EXPECT().RefreshTokenStrategy().Return(mockRefreshTokenStrategy).Times(1) mockRefreshTokenStrategy.EXPECT().GenerateRefreshToken(gomock.Any(), areq).Return(mockRT, "bar", nil) - mockRopcgStorage.EXPECT().RefreshTokenStorage().Return(mockRefreshTokenStorage).Times(1) + mockRefreshTokenStorageProvider.EXPECT().RefreshTokenStorage().Return(mockRefreshTokenStorage).Times(1) mockRefreshTokenStorage.EXPECT().CreateRefreshTokenSession(gomock.Any(), "bar", "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) mockAccessTokenStrategyProvider.EXPECT().AccessTokenStrategy().Return(mockAccessTokenStrategy).Times(1) mockAccessTokenStrategy.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return(mockAT, "bar", nil) - mockRopcgStorage.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1) + mockAccessTokenStorageProvider.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1) mockAccessTokenStorage.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) }, expect: func() { @@ -193,11 +210,11 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) { areq.GrantTypes = fosite.Arguments{"password"} mockAccessTokenStrategyProvider.EXPECT().AccessTokenStrategy().Return(mockAccessTokenStrategy).Times(1) mockAccessTokenStrategy.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return(mockAT, "bar", nil) - mockRopcgStorage.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1) + mockAccessTokenStorageProvider.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1) mockAccessTokenStorage.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) mockRefreshTokenStrategyProvider.EXPECT().RefreshTokenStrategy().Return(mockRefreshTokenStrategy).Times(1) mockRefreshTokenStrategy.EXPECT().GenerateRefreshToken(gomock.Any(), areq).Return(mockRT, "bar", nil) - mockRopcgStorage.EXPECT().RefreshTokenStorage().Return(mockRefreshTokenStorage).Times(1) + mockRefreshTokenStorageProvider.EXPECT().RefreshTokenStorage().Return(mockRefreshTokenStorage).Times(1) mockRefreshTokenStorage.EXPECT().CreateRefreshTokenSession(gomock.Any(), "bar", "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) }, expect: func() { @@ -218,14 +235,26 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) { AccessTokenLifespan: time.Hour, } - mockRopcgStorage = internal.NewMockResourceOwnerPasswordCredentialsGrantStorage(ctrl) + mockRopcgStorageProvider = internal.NewMockResourceOwnerPasswordCredentialsGrantStorageProvider(ctrl) + mockAccessTokenStorageProvider = internal.NewMockAccessTokenStorageProvider(ctrl) mockAccessTokenStorage = internal.NewMockAccessTokenStorage(ctrl) + mockRefreshTokenStorageProvider = internal.NewMockRefreshTokenStorageProvider(ctrl) mockRefreshTokenStorage = internal.NewMockRefreshTokenStorage(ctrl) mockAccessTokenStrategyProvider = internal.NewMockAccessTokenStrategyProvider(ctrl) mockAccessTokenStrategy = internal.NewMockAccessTokenStrategy(ctrl) mockRefreshTokenStrategyProvider = internal.NewMockRefreshTokenStrategyProvider(ctrl) mockRefreshTokenStrategy = internal.NewMockRefreshTokenStrategy(ctrl) + mockStorage := struct { + *internal.MockResourceOwnerPasswordCredentialsGrantStorageProvider + *internal.MockAccessTokenStorageProvider + *internal.MockRefreshTokenStorageProvider + }{ + MockResourceOwnerPasswordCredentialsGrantStorageProvider: mockRopcgStorageProvider, + MockAccessTokenStorageProvider: mockAccessTokenStorageProvider, + MockRefreshTokenStorageProvider: mockRefreshTokenStorageProvider, + } + mockStrategy := struct { *internal.MockAccessTokenStrategyProvider *internal.MockRefreshTokenStrategyProvider @@ -235,7 +264,7 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) { } h = oauth2.ResourceOwnerPasswordCredentialsGrantHandler{ - Storage: mockRopcgStorage, + Storage: mockStorage, Strategy: mockStrategy, Config: config, } diff --git a/fosite/handler/oauth2/storage.go b/fosite/handler/oauth2/storage.go index 2763d77f9..e9e939096 100644 --- a/fosite/handler/oauth2/storage.go +++ b/fosite/handler/oauth2/storage.go @@ -9,12 +9,6 @@ import ( "github.com/ory/hydra/v2/fosite" ) -type CoreStorage interface { - AuthorizeCodeStorageProvider - AccessTokenStorageProvider - RefreshTokenStorageProvider -} - // AuthorizeCodeStorage handles storage requests related to authorization codes. type AuthorizeCodeStorage interface { // CreateAuthorizeCodeSession stores the authorization request for a given authorization code. diff --git a/fosite/handler/rfc8628/strategy.go b/fosite/handler/rfc8628/strategy.go index dcb922f94..b224b452b 100644 --- a/fosite/handler/rfc8628/strategy.go +++ b/fosite/handler/rfc8628/strategy.go @@ -9,13 +9,6 @@ import ( "github.com/ory/hydra/v2/fosite" ) -// RFC8628CodeStrategy is the code strategy needed for the DeviceAuthHandler -type RFC8628CodeStrategy interface { - DeviceRateLimitStrategy - DeviceCodeStrategy - UserCodeStrategy -} - // DeviceRateLimitStrategy handles the rate limiting strategy type DeviceRateLimitStrategy interface { // ShouldRateLimit checks whether the token request should be rate-limited diff --git a/fosite/handler/verifiable/handler.go b/fosite/handler/verifiable/handler.go index f338f00cd..805b3eaa2 100644 --- a/fosite/handler/verifiable/handler.go +++ b/fosite/handler/verifiable/handler.go @@ -21,7 +21,7 @@ type Handler struct { Config interface { fosite.VerifiableCredentialsNonceLifespanProvider } - NonceManager + NonceManagerProvider } var _ fosite.TokenEndpointHandler = (*Handler)(nil) @@ -45,7 +45,7 @@ func (c *Handler) PopulateTokenEndpointResponse( lifespan := c.Config.GetVerifiableCredentialsNonceLifespan(ctx) expiry := time.Now().UTC().Add(lifespan) - nonce, err := c.NewNonce(ctx, response.GetAccessToken(), expiry) + nonce, err := c.NonceManager().NewNonce(ctx, response.GetAccessToken(), expiry) if err != nil { return err } diff --git a/fosite/handler/verifiable/handler_test.go b/fosite/handler/verifiable/handler_test.go index 3a6b8615f..b882e7b18 100644 --- a/fosite/handler/verifiable/handler_test.go +++ b/fosite/handler/verifiable/handler_test.go @@ -15,6 +15,10 @@ import ( "github.com/ory/hydra/v2/fosite/internal" ) +type mockNonceManagerProvider struct{ n NonceManager } + +func (m mockNonceManagerProvider) NonceManager() NonceManager { return m.n } + type mockNonceManager struct{ t *testing.T } func (m *mockNonceManager) NewNonce(ctx context.Context, accessToken string, expiresAt time.Time) (string, error) { @@ -67,7 +71,7 @@ func TestHandler(t *testing.T) { func newHandler(t *testing.T) *Handler { return &Handler{ - Config: new(fosite.Config), - NonceManager: &mockNonceManager{t: t}, + Config: new(fosite.Config), + NonceManagerProvider: mockNonceManagerProvider{n: &mockNonceManager{t: t}}, } } diff --git a/fosite/handler/verifiable/nonce.go b/fosite/handler/verifiable/nonce.go index 754d5645f..4bc548042 100644 --- a/fosite/handler/verifiable/nonce.go +++ b/fosite/handler/verifiable/nonce.go @@ -15,3 +15,7 @@ type NonceManager interface { // IsNonceValid checks if the given nonce is valid for the given access token and not expired. IsNonceValid(ctx context.Context, accessToken string, nonce string) error } + +type NonceManagerProvider interface { + NonceManager() NonceManager +} diff --git a/fosite/internal/resource_owner_password_credentials_grant_storage.go b/fosite/internal/resource_owner_password_credentials_grant_storage.go index e13f145c4..4433de9be 100644 --- a/fosite/internal/resource_owner_password_credentials_grant_storage.go +++ b/fosite/internal/resource_owner_password_credentials_grant_storage.go @@ -17,8 +17,6 @@ import ( reflect "reflect" gomock "go.uber.org/mock/gomock" - - oauth2 "github.com/ory/hydra/v2/fosite/handler/oauth2" ) // MockResourceOwnerPasswordCredentialsGrantStorage is a mock of ResourceOwnerPasswordCredentialsGrantStorage interface. @@ -45,20 +43,6 @@ func (m *MockResourceOwnerPasswordCredentialsGrantStorage) EXPECT() *MockResourc return m.recorder } -// AccessTokenStorage mocks base method. -func (m *MockResourceOwnerPasswordCredentialsGrantStorage) AccessTokenStorage() oauth2.AccessTokenStorage { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AccessTokenStorage") - ret0, _ := ret[0].(oauth2.AccessTokenStorage) - return ret0 -} - -// AccessTokenStorage indicates an expected call of AccessTokenStorage. -func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) AccessTokenStorage() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenStorage", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).AccessTokenStorage)) -} - // Authenticate mocks base method. func (m *MockResourceOwnerPasswordCredentialsGrantStorage) Authenticate(ctx context.Context, name, secret string) (string, error) { m.ctrl.T.Helper() @@ -73,17 +57,3 @@ func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) Authenti mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authenticate", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).Authenticate), ctx, name, secret) } - -// RefreshTokenStorage mocks base method. -func (m *MockResourceOwnerPasswordCredentialsGrantStorage) RefreshTokenStorage() oauth2.RefreshTokenStorage { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RefreshTokenStorage") - ret0, _ := ret[0].(oauth2.RefreshTokenStorage) - return ret0 -} - -// RefreshTokenStorage indicates an expected call of RefreshTokenStorage. -func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) RefreshTokenStorage() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTokenStorage", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).RefreshTokenStorage)) -} diff --git a/fosite/internal/resource_owner_password_credentials_grant_storage_provider.go b/fosite/internal/resource_owner_password_credentials_grant_storage_provider.go new file mode 100644 index 000000000..341091cff --- /dev/null +++ b/fosite/internal/resource_owner_password_credentials_grant_storage_provider.go @@ -0,0 +1,58 @@ +// Copyright © 2025 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ory/hydra/v2/fosite/handler/oauth2 (interfaces: ResourceOwnerPasswordCredentialsGrantStorageProvider) +// +// Generated by this command: +// +// mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorageProvider +// + +// Package internal is a generated GoMock package. +package internal + +import ( + reflect "reflect" + + oauth2 "github.com/ory/hydra/v2/fosite/handler/oauth2" + gomock "go.uber.org/mock/gomock" +) + +// MockResourceOwnerPasswordCredentialsGrantStorageProvider is a mock of ResourceOwnerPasswordCredentialsGrantStorageProvider interface. +type MockResourceOwnerPasswordCredentialsGrantStorageProvider struct { + ctrl *gomock.Controller + recorder *MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder + isgomock struct{} +} + +// MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder is the mock recorder for MockResourceOwnerPasswordCredentialsGrantStorageProvider. +type MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder struct { + mock *MockResourceOwnerPasswordCredentialsGrantStorageProvider +} + +// NewMockResourceOwnerPasswordCredentialsGrantStorageProvider creates a new mock instance. +func NewMockResourceOwnerPasswordCredentialsGrantStorageProvider(ctrl *gomock.Controller) *MockResourceOwnerPasswordCredentialsGrantStorageProvider { + mock := &MockResourceOwnerPasswordCredentialsGrantStorageProvider{ctrl: ctrl} + mock.recorder = &MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceOwnerPasswordCredentialsGrantStorageProvider) EXPECT() *MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder { + return m.recorder +} + +// ResourceOwnerPasswordCredentialsGrantStorage mocks base method. +func (m *MockResourceOwnerPasswordCredentialsGrantStorageProvider) ResourceOwnerPasswordCredentialsGrantStorage() oauth2.ResourceOwnerPasswordCredentialsGrantStorage { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResourceOwnerPasswordCredentialsGrantStorage") + ret0, _ := ret[0].(oauth2.ResourceOwnerPasswordCredentialsGrantStorage) + return ret0 +} + +// ResourceOwnerPasswordCredentialsGrantStorage indicates an expected call of ResourceOwnerPasswordCredentialsGrantStorage. +func (mr *MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder) ResourceOwnerPasswordCredentialsGrantStorage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResourceOwnerPasswordCredentialsGrantStorage", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorageProvider)(nil).ResourceOwnerPasswordCredentialsGrantStorage)) +} diff --git a/fosite/internal/storage.go b/fosite/internal/storage.go index 9e0ce19f9..c10ddc242 100644 --- a/fosite/internal/storage.go +++ b/fosite/internal/storage.go @@ -15,9 +15,8 @@ package internal import ( reflect "reflect" - gomock "go.uber.org/mock/gomock" - fosite "github.com/ory/hydra/v2/fosite" + gomock "go.uber.org/mock/gomock" ) // MockStorage is a mock of Storage interface. @@ -44,16 +43,16 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder { return m.recorder } -// ClientManager mocks base method. -func (m *MockStorage) ClientManager() fosite.ClientManager { +// FositeClientManager mocks base method. +func (m *MockStorage) FositeClientManager() fosite.ClientManager { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClientManager") + ret := m.ctrl.Call(m, "FositeClientManager") ret0, _ := ret[0].(fosite.ClientManager) return ret0 } -// ClientManager indicates an expected call of ClientManager. -func (mr *MockStorageMockRecorder) ClientManager() *gomock.Call { +// FositeClientManager indicates an expected call of FositeClientManager. +func (mr *MockStorageMockRecorder) FositeClientManager() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientManager", reflect.TypeOf((*MockStorage)(nil).ClientManager)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FositeClientManager", reflect.TypeOf((*MockStorage)(nil).FositeClientManager)) } diff --git a/fosite/introspection_request_handler.go b/fosite/introspection_request_handler.go index 6aa3da33f..e0b4c51cb 100644 --- a/fosite/introspection_request_handler.go +++ b/fosite/introspection_request_handler.go @@ -138,7 +138,7 @@ func (f *Fosite) NewIntrospectionRequest(ctx context.Context, r *http.Request, s return &IntrospectionResponse{Active: false}, errorsx.WithStack(ErrRequestUnauthorized.WithHint("Unable to decode OAuth 2.0 Client Secret from HTTP basic authorization header, make sure it is properly encoded.").WithWrap(err).WithDebug(err.Error())) } - client, err := f.Store.ClientManager().GetClient(ctx, clientID) + client, err := f.Store.FositeClientManager().GetClient(ctx, clientID) if err != nil { return &IntrospectionResponse{Active: false}, errorsx.WithStack(ErrRequestUnauthorized.WithHint("Unable to find OAuth 2.0 Client from HTTP basic authorization header.").WithWrap(err).WithDebug(err.Error())) } diff --git a/fosite/pushed_authorize_request_handler_test.go b/fosite/pushed_authorize_request_handler_test.go index bcc13536c..cd0db2dec 100644 --- a/fosite/pushed_authorize_request_handler_test.go +++ b/fosite/pushed_authorize_request_handler_test.go @@ -89,7 +89,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidRequest, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) }, @@ -105,7 +105,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidRequest, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) }, @@ -121,7 +121,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidRequest, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) }, @@ -138,7 +138,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidState, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) }, @@ -156,7 +156,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidState, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) }, @@ -174,7 +174,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "scope": {"foo bar baz"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) }, @@ -194,7 +194,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, Audience: []string{"https://cloud.ory.sh/api"}, @@ -218,7 +218,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, @@ -258,7 +258,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, @@ -298,7 +298,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "audience": {"test value", ""}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, @@ -338,7 +338,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"web+application://callback"}, @@ -378,7 +378,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, @@ -418,7 +418,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "response_mode": {"unknown"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) }, @@ -438,7 +438,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "response_mode": {"form_post"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) }, @@ -458,7 +458,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "response_mode": {"form_post"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ DefaultClient: &DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, @@ -487,7 +487,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ DefaultClient: &DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, @@ -534,7 +534,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ DefaultClient: &DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, @@ -581,7 +581,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ DefaultClient: &DefaultClient{ RedirectURIs: []string{"https://foo.bar/cb"}, @@ -629,7 +629,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "response_mode": {"form_post"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(2) + store.EXPECT().FositeClientManager().Return(clientManager).Times(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) }, @@ -650,7 +650,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) { "response_mode": {"form_post"}, }, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).MaxTimes(2) + store.EXPECT().FositeClientManager().Return(clientManager).MaxTimes(2) clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2) hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("4321"))).Return(fmt.Errorf("invalid hash")) }, diff --git a/fosite/revoke_handler_test.go b/fosite/revoke_handler_test.go index f059e5e1a..7dda31b8f 100644 --- a/fosite/revoke_handler_test.go +++ b/fosite/revoke_handler_test.go @@ -70,7 +70,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, }, @@ -84,7 +84,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false @@ -101,7 +101,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false @@ -121,7 +121,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false @@ -141,7 +141,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = true handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) @@ -159,7 +159,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false @@ -179,7 +179,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().ClientManager().Return(clientManager).Times(1) + store.EXPECT().FositeClientManager().Return(clientManager).Times(1) clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false diff --git a/fosite/storage.go b/fosite/storage.go index 47645eef8..979a5fe08 100644 --- a/fosite/storage.go +++ b/fosite/storage.go @@ -9,7 +9,7 @@ import ( // Storage defines fosite's minimal storage interface. type Storage interface { - ClientManager() ClientManager + FositeClientManager() ClientManager } type PARStorageProvider interface { diff --git a/fosite/storage/memory.go b/fosite/storage/memory.go index 7deacf4b4..fb2b732f0 100644 --- a/fosite/storage/memory.go +++ b/fosite/storage/memory.go @@ -101,7 +101,7 @@ func NewMemoryStore() *MemoryStore { } } -func (s *MemoryStore) ClientManager() fosite.ClientManager { +func (s *MemoryStore) FositeClientManager() fosite.ClientManager { return s } @@ -121,6 +121,10 @@ func (s *MemoryStore) TokenRevocationStorage() oauth2.TokenRevocationStorage { return s } +func (s *MemoryStore) ResourceOwnerPasswordCredentialsGrantStorage() oauth2.ResourceOwnerPasswordCredentialsGrantStorage { + return s +} + func (s *MemoryStore) OpenIDConnectRequestStorage() openid.OpenIDConnectRequestStorage { return s } diff --git a/fositex/config.go b/fositex/config.go index 39e1a3640..0683d5bf4 100644 --- a/fositex/config.go +++ b/fositex/config.go @@ -76,10 +76,10 @@ func NewConfig(deps configDependencies) *Config { } } -func (c *Config) LoadDefaultHandlers(strategy interface{}) { +func (c *Config) LoadDefaultHandlers(storage fosite.Storage, strategy interface{}) { factories := append(defaultFactories, c.deps.ExtraFositeFactories()...) for _, factory := range factories { - res := factory(c, c.deps.Persister(), strategy) + res := factory(c, storage, strategy) if ah, ok := res.(fosite.AuthorizeEndpointHandler); ok { c.authorizeEndpointHandlers.Append(ah) } diff --git a/oauth2/fosite_store_helpers_test.go b/oauth2/fosite_store_helpers_test.go index c950cd7eb..8c6ed0d68 100644 --- a/oauth2/fosite_store_helpers_test.go +++ b/oauth2/fosite_store_helpers_test.go @@ -908,7 +908,7 @@ func testFositeStoreSetClientAssertionJWT(m *driver.RegistrySQL) func(*testing.T require.True(t, ok) jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute)) - require.NoError(t, store.ClientManager().SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry)) + require.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry)) cmp, err := store.GetClientAssertionJWT(context.Background(), jti.JTI) require.NotEqual(t, cmp.NID, uuid.Nil) @@ -923,7 +923,7 @@ func testFositeStoreSetClientAssertionJWT(m *driver.RegistrySQL) func(*testing.T jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) - assert.ErrorIs(t, store.ClientManager().SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry), fosite.ErrJTIKnown) + assert.ErrorIs(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry), fosite.ErrJTIKnown) }) t.Run("case=deletes expired JTIs", func(t *testing.T) { @@ -933,7 +933,7 @@ func testFositeStoreSetClientAssertionJWT(m *driver.RegistrySQL) func(*testing.T require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), expiredJTI)) newJTI := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute)) - require.NoError(t, store.ClientManager().SetClientAssertionJWT(context.Background(), newJTI.JTI, newJTI.Expiry)) + require.NoError(t, store.SetClientAssertionJWT(context.Background(), newJTI.JTI, newJTI.Expiry)) _, err := store.GetClientAssertionJWT(context.Background(), expiredJTI.JTI) assert.True(t, errors.Is(err, sqlcon.ErrNoRows)) @@ -951,7 +951,7 @@ func testFositeStoreSetClientAssertionJWT(m *driver.RegistrySQL) func(*testing.T require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) jti.Expiry = jti.Expiry.Add(2 * time.Minute) - assert.NoError(t, store.ClientManager().SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry)) + assert.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry)) cmp, err := store.GetClientAssertionJWT(context.Background(), jti.JTI) assert.NoError(t, err) assert.Equal(t, jti, cmp) @@ -965,7 +965,7 @@ func testFositeStoreClientAssertionJWTValid(m *driver.RegistrySQL) func(*testing store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) - assert.NoError(t, store.ClientManager().ClientAssertionJWTValid(context.Background(), uuid.Must(uuid.NewV4()).String())) + assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), uuid.Must(uuid.NewV4()).String())) }) t.Run("case=returns invalid on known JTI", func(t *testing.T) { @@ -975,7 +975,7 @@ func testFositeStoreClientAssertionJWTValid(m *driver.RegistrySQL) func(*testing require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) - assert.True(t, errors.Is(store.ClientManager().ClientAssertionJWTValid(context.Background(), jti.JTI), fosite.ErrJTIKnown)) + assert.True(t, errors.Is(store.ClientAssertionJWTValid(context.Background(), jti.JTI), fosite.ErrJTIKnown)) }) t.Run("case=returns valid on expired JTI", func(t *testing.T) { @@ -985,7 +985,7 @@ func testFositeStoreClientAssertionJWTValid(m *driver.RegistrySQL) func(*testing require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) - assert.NoError(t, store.ClientManager().ClientAssertionJWTValid(context.Background(), jti.JTI)) + assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), jti.JTI)) }) } } diff --git a/oauth2/oauth2_device_code_test.go b/oauth2/oauth2_device_code_test.go index 5f8a61a41..a5241be4a 100644 --- a/oauth2/oauth2_device_code_test.go +++ b/oauth2/oauth2_device_code_test.go @@ -205,9 +205,9 @@ func TestDeviceTokenRequest(t *testing.T) { for _, testCase := range testCases { t.Run("case="+testCase.description, func(t *testing.T) { - code, signature, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(context.TODO()) + code, signature, err := reg.DeviceCodeStrategy().GenerateDeviceCode(context.TODO()) require.NoError(t, err) - _, userCodeSignature, err := reg.RFC8628HMACStrategy().GenerateUserCode(context.TODO()) + _, userCodeSignature, err := reg.UserCodeStrategy().GenerateUserCode(context.TODO()) require.NoError(t, err) if testCase.setUp != nil { @@ -428,7 +428,6 @@ func TestDeviceCodeWithDefaultStrategy(t *testing.T) { subject := "aeneas-rekkas" nonce := uuid.New() t.Run("case=perform device flow without ID and refresh tokens", func(t *testing.T) { - c, conf := newDeviceClient(t, reg) conf.Scopes = []string{"hydra"} testhelpers.NewDeviceLoginConsentUI(t, reg.Config(), @@ -453,7 +452,6 @@ func TestDeviceCodeWithDefaultStrategy(t *testing.T) { assert.Empty(t, token.RefreshToken) }) t.Run("case=perform device flow with ID token", func(t *testing.T) { - c, conf := newDeviceClient(t, reg) conf.Scopes = []string{"openid", "hydra"} testhelpers.NewDeviceLoginConsentUI(t, reg.Config(), @@ -479,7 +477,6 @@ func TestDeviceCodeWithDefaultStrategy(t *testing.T) { assert.Empty(t, token.RefreshToken) }) t.Run("case=perform device flow with refresh token", func(t *testing.T) { - c, conf := newDeviceClient(t, reg) conf.Scopes = []string{"hydra", "offline"} testhelpers.NewDeviceLoginConsentUI(t, reg.Config(), diff --git a/oauth2/registry.go b/oauth2/registry.go index a7ac0cd84..50235e776 100644 --- a/oauth2/registry.go +++ b/oauth2/registry.go @@ -35,5 +35,7 @@ type Registry interface { OpenIDConnectRequestValidator() *openid.OpenIDConnectRequestValidator AccessRequestHooks() []AccessRequestHook OAuth2ProviderConfig() fosite.Configurator - RFC8628HMACStrategy() rfc8628.RFC8628CodeStrategy + rfc8628.DeviceRateLimitStrategyProvider + rfc8628.DeviceCodeStrategyProvider + rfc8628.UserCodeStrategyProvider } diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index 543bc8e93..3a95b15b6 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -16,6 +16,7 @@ import ( "github.com/ory/hydra/v2/fosite" "github.com/ory/hydra/v2/internal/kratos" "github.com/ory/hydra/v2/jwk" + "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/persistence" "github.com/ory/hydra/v2/x" "github.com/ory/pop/v6" @@ -27,8 +28,11 @@ import ( ) var ( - _ persistence.Persister = (*Persister)(nil) - _ fosite.Transactional = (*Persister)(nil) + _ persistence.Persister = (*Persister)(nil) + _ fosite.Transactional = (*Persister)(nil) + _ fosite.ClientManager = (*Persister)(nil) + _ oauth2.AssertionJWTReader = (*Persister)(nil) + _ x.FositeStorer = (*Persister)(nil) ) var ErrNoTransactionOpen = errors.New("There is no Transaction in this context.") @@ -67,62 +71,6 @@ type ( } ) -func (p *BasePersister) BeginTX(ctx context.Context) (_ context.Context, err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.BeginTX") - defer otelx.End(span, &err) - - fallback := &pop.Connection{TX: &pop.Tx{}} - if popx.GetConnection(ctx, fallback).TX != fallback.TX { - return context.WithValue(ctx, skipCommitKey, true), nil // no-op - } - - tx, err := p.c.Store.TransactionContextOptions(ctx, &sql.TxOptions{ - Isolation: sql.LevelRepeatableRead, - ReadOnly: false, - }) - c := &pop.Connection{ - TX: tx, - Store: tx, - ID: uuid.Must(uuid.NewV4()).String(), - Dialect: p.c.Dialect, - } - return popx.WithTransaction(ctx, c), err -} - -func (p *BasePersister) Commit(ctx context.Context) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Commit") - defer otelx.End(span, &err) - - if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip { - return nil // we skipped BeginTX, so we also skip Commit - } - - fallback := &pop.Connection{TX: &pop.Tx{}} - tx := popx.GetConnection(ctx, fallback) - if tx.TX == fallback.TX || tx.TX == nil { - return errors.WithStack(ErrNoTransactionOpen) - } - - return errors.WithStack(tx.TX.Commit()) -} - -func (p *BasePersister) Rollback(ctx context.Context) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Rollback") - defer otelx.End(span, &err) - - if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip { - return nil // we skipped BeginTX, so we also skip Rollback - } - - fallback := &pop.Connection{TX: &pop.Tx{}} - tx := popx.GetConnection(ctx, fallback) - if tx.TX == fallback.TX || tx.TX == nil { - return errors.WithStack(ErrNoTransactionOpen) - } - - return errors.WithStack(tx.TX.Rollback()) -} - func NewPersister(base *BasePersister, r Dependencies) *Persister { return &Persister{ BasePersister: base, @@ -192,3 +140,62 @@ func (p *BasePersister) mustSetNetwork(ctx context.Context, v interface{}) { func (p *BasePersister) Transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error { return popx.Transaction(ctx, p.c, f) } + +// BeginTX implements Transactional. +func (p *BasePersister) BeginTX(ctx context.Context) (_ context.Context, err error) { + ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.BeginTX") + defer otelx.End(span, &err) + + fallback := &pop.Connection{TX: &pop.Tx{}} + if popx.GetConnection(ctx, fallback).TX != fallback.TX { + return context.WithValue(ctx, skipCommitKey, true), nil // no-op + } + + tx, err := p.c.Store.TransactionContextOptions(ctx, &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + ReadOnly: false, + }) + c := &pop.Connection{ + TX: tx, + Store: tx, + ID: uuid.Must(uuid.NewV4()).String(), + Dialect: p.c.Dialect, + } + return popx.WithTransaction(ctx, c), err +} + +// Commit implements Transactional. +func (p *BasePersister) Commit(ctx context.Context) (err error) { + ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Commit") + defer otelx.End(span, &err) + + if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip { + return nil // we skipped BeginTX, so we also skip Commit + } + + fallback := &pop.Connection{TX: &pop.Tx{}} + tx := popx.GetConnection(ctx, fallback) + if tx.TX == fallback.TX || tx.TX == nil { + return errors.WithStack(ErrNoTransactionOpen) + } + + return errors.WithStack(tx.TX.Commit()) +} + +// Rollback implements Transactional. +func (p *BasePersister) Rollback(ctx context.Context) (err error) { + ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Rollback") + defer otelx.End(span, &err) + + if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip { + return nil // we skipped BeginTX, so we also skip Rollback + } + + fallback := &pop.Connection{TX: &pop.Tx{}} + tx := popx.GetConnection(ctx, fallback) + if tx.TX == fallback.TX || tx.TX == nil { + return errors.WithStack(ErrNoTransactionOpen) + } + + return errors.WithStack(tx.TX.Rollback()) +} diff --git a/persistence/sql/persister_authenticate.go b/persistence/sql/persister_authenticate.go index 013ccc300..be95dff63 100644 --- a/persistence/sql/persister_authenticate.go +++ b/persistence/sql/persister_authenticate.go @@ -7,6 +7,7 @@ import ( "context" ) +// Authenticate implements ResourceOwnerPasswordCredentialsGrantStorage. func (p *Persister) Authenticate(ctx context.Context, name, secret string) (subject string, err error) { session, err := p.r.Kratos().Authenticate(ctx, name, secret) if err != nil { diff --git a/persistence/sql/persister_client.go b/persistence/sql/persister_client.go index b7e159ab7..24dd672aa 100644 --- a/persistence/sql/persister_client.go +++ b/persistence/sql/persister_client.go @@ -18,23 +18,51 @@ import ( "github.com/ory/x/sqlcon" ) -func (p *Persister) GetConcreteClient(ctx context.Context, id string) (c *client.Client, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteClient", +// AuthenticateClient implements client.Manager. +func (p *Persister) AuthenticateClient(ctx context.Context, id string, secret []byte) (_ *client.Client, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AuthenticateClient", trace.WithAttributes(events.ClientID(id)), ) defer otelx.End(span, &err) - var cl client.Client - if err := p.QueryWithNetwork(ctx).Where("id = ?", id).First(&cl); err != nil { - return nil, sqlcon.HandleError(err) + c, err := p.GetConcreteClient(ctx, id) + if err != nil { + return nil, err } - return &cl, nil + + if err := p.r.ClientHasher().Compare(ctx, c.GetHashedSecret(), secret); err != nil { + return nil, err + } + + return c, nil } -func (p *Persister) GetClient(ctx context.Context, id string) (fosite.Client, error) { - return p.GetConcreteClient(ctx, id) +// CreateClient implements client.Storage. +func (p *Persister) CreateClient(ctx context.Context, c *client.Client) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateClient") + defer otelx.End(span, &err) + + h, err := p.r.ClientHasher().Hash(ctx, []byte(c.Secret)) + if err != nil { + return err + } + + c.Secret = string(h) + if c.ID == "" { + c.ID = uuid.Must(uuid.NewV4()).String() + } + if err := sqlcon.HandleError(p.CreateWithNetwork(ctx, c)); err != nil { + return err + } + + events.Trace(ctx, events.ClientCreated, + events.WithClientID(c.ID), + events.WithClientName(c.Name)) + + return nil } +// UpdateClient implements client.Storage. func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateClient", trace.WithAttributes(events.ClientID(cl.ID)), @@ -79,48 +107,7 @@ func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) (err er }) } -func (p *Persister) AuthenticateClient(ctx context.Context, id string, secret []byte) (_ *client.Client, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AuthenticateClient", - trace.WithAttributes(events.ClientID(id)), - ) - defer otelx.End(span, &err) - - c, err := p.GetConcreteClient(ctx, id) - if err != nil { - return nil, err - } - - if err := p.r.ClientHasher().Compare(ctx, c.GetHashedSecret(), secret); err != nil { - return nil, err - } - - return c, nil -} - -func (p *Persister) CreateClient(ctx context.Context, c *client.Client) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateClient") - defer otelx.End(span, &err) - - h, err := p.r.ClientHasher().Hash(ctx, []byte(c.Secret)) - if err != nil { - return err - } - - c.Secret = string(h) - if c.ID == "" { - c.ID = uuid.Must(uuid.NewV4()).String() - } - if err := sqlcon.HandleError(p.CreateWithNetwork(ctx, c)); err != nil { - return err - } - - events.Trace(ctx, events.ClientCreated, - events.WithClientID(c.ID), - events.WithClientName(c.Name)) - - return nil -} - +// DeleteClient implements client.Storage. func (p *Persister) DeleteClient(ctx context.Context, id string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteClient", trace.WithAttributes(events.ClientID(id)), @@ -143,6 +130,7 @@ func (p *Persister) DeleteClient(ctx context.Context, id string) (err error) { return nil } +// GetClients implements client.Storage. func (p *Persister) GetClients(ctx context.Context, filters client.Filter) (cs []client.Client, _ *keysetpagination.Paginator, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetClients") defer otelx.End(span, &err) @@ -170,3 +158,22 @@ func (p *Persister) GetClients(ctx context.Context, filters client.Filter) (cs [ cs, nextPage := keysetpagination.Result(cs, paginator) return cs, nextPage, nil } + +// GetConcreteClient implements client.Storage. +func (p *Persister) GetConcreteClient(ctx context.Context, id string) (c *client.Client, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteClient", + trace.WithAttributes(events.ClientID(id)), + ) + defer otelx.End(span, &err) + + var cl client.Client + if err := p.QueryWithNetwork(ctx).Where("id = ?", id).First(&cl); err != nil { + return nil, sqlcon.HandleError(err) + } + return &cl, nil +} + +// GetClient implements fosite.ClientManager. +func (p *Persister) GetClient(ctx context.Context, id string) (fosite.Client, error) { + return p.GetConcreteClient(ctx, id) +} diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index 8806a4ec7..de078033a 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -27,7 +27,7 @@ import ( "github.com/ory/x/sqlxx" ) -var _ consent.Manager = &Persister{} +var _ consent.Manager = (*Persister)(nil) func (p *Persister) RevokeSubjectConsentSession(ctx context.Context, user string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession") diff --git a/persistence/sql/persister_device.go b/persistence/sql/persister_device.go index 4311cddc2..fbcfc340c 100644 --- a/persistence/sql/persister_device.go +++ b/persistence/sql/persister_device.go @@ -17,7 +17,6 @@ import ( "github.com/tidwall/gjson" "github.com/ory/hydra/v2/fosite" - "github.com/ory/hydra/v2/fosite/handler/rfc8628" "github.com/ory/hydra/v2/oauth2" "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" @@ -102,10 +101,6 @@ func (r *DeviceRequestSQL) toRequest(ctx context.Context, session fosite.Session }, nil } -func (p *Persister) DeviceAuthStorage() rfc8628.DeviceAuthStorage { - return p -} - func (p *Persister) sqlDeviceSchemaFromRequest(ctx context.Context, deviceCodeSignature, userCodeSignature string, r fosite.DeviceRequester, expiresAt time.Time) (*DeviceRequestSQL, error) { subject := "" if r.GetSession() == nil { @@ -157,7 +152,7 @@ func (p *Persister) sqlDeviceSchemaFromRequest(ctx context.Context, deviceCodeSi }, nil } -// CreateDeviceCodeSession creates a new device code session and stores it in the database +// CreateDeviceCodeSession creates a new device code session and stores it in the database. Implements DeviceAuthStorage. func (p *Persister) CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, requester fosite.DeviceRequester) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateDeviceCodeSession") defer otelx.End(span, &err) @@ -178,33 +173,7 @@ func (p *Persister) CreateDeviceAuthSession(ctx context.Context, deviceCodeSigna return nil } -// UpdateDeviceCodeSessionBySignature updates a device code session by the device_code signature -func (p *Persister) UpdateDeviceCodeSessionBySignature(ctx context.Context, signature string, requester fosite.DeviceRequester) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateDeviceCodeSessionBySignature") - defer otelx.End(span, &err) - - req, err := p.sqlDeviceSchemaFromRequest(ctx, signature, "", requester, requester.GetSession().GetExpiresAt(fosite.DeviceCode).UTC()) - if err != nil { - return err - } - - stmt := fmt.Sprintf( - "UPDATE %s SET granted_scope=?, granted_audience=?, session_data=?, user_code_state=?, subject=?, challenge_id=? WHERE device_code_signature=? AND nid = ?", - sqlTableDeviceAuthCodes, - ) - - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.Connection(ctx).RawQuery(stmt, - req.GrantedScope, req.GrantedAudience, - req.Session, req.UserCodeState, - req.Subject, req.ConsentChallenge, - signature, p.NetworkID(ctx), - ).Exec(), - ) -} - -// GetDeviceCodeSession returns a device code session from the database +// GetDeviceCodeSession returns a device code session from the database. Implements DeviceAuthStorage. func (p *Persister) GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.DeviceRequester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSession") defer otelx.End(span, &err) @@ -227,34 +196,7 @@ func (p *Persister) GetDeviceCodeSession(ctx context.Context, signature string, return r.toRequest(ctx, session, p) } -// GetDeviceCodeSessionByRequestID returns a device code session from the database -func (p *Persister) GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, session fosite.Session) (_ fosite.DeviceRequester, deviceCodeSignature string, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSessionByRequestID") - defer otelx.End(span, &err) - - r := DeviceRequestSQL{} - if err = p.QueryWithNetwork(ctx).Where("request_id = ?", requestID).First(&r); errors.Is(err, sql.ErrNoRows) { - return nil, "", errors.WithStack(fosite.ErrNotFound) - } else if err != nil { - return nil, "", sqlcon.HandleError(err) - } - - if !r.DeviceCodeActive { - fr, err := r.toRequest(ctx, session, p) - if err != nil { - return nil, "", err - } - return fr, r.ID, errors.WithStack(fosite.ErrInactiveToken) - } - - fr, err := r.toRequest(ctx, session, p) - if err != nil { - return nil, "", err - } - return fr, r.ID, nil -} - -// InvalidateDeviceCodeSession invalidates a device code session +// InvalidateDeviceCodeSession invalidates a device code session. Implements DeviceAuthStorage. func (p *Persister) InvalidateDeviceCodeSession(ctx context.Context, signature string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InvalidateDeviceCodeSession") defer otelx.End(span, &err) @@ -266,7 +208,7 @@ func (p *Persister) InvalidateDeviceCodeSession(ctx context.Context, signature s Delete(DeviceRequestSQL{})) } -// GetUserCodeSession returns a user code session from the database +// GetUserCodeSession returns a user code session from the database. Implements FositeStorer. func (p *Persister) GetUserCodeSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.DeviceRequester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUserCodeSession") defer otelx.End(span, &err) @@ -293,3 +235,56 @@ func (p *Persister) GetUserCodeSession(ctx context.Context, signature string, se return fr, err } + +// GetDeviceCodeSessionByRequestID returns a device code session from the database. Implements FositeStorer. +func (p *Persister) GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, session fosite.Session) (_ fosite.DeviceRequester, deviceCodeSignature string, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSessionByRequestID") + defer otelx.End(span, &err) + + r := DeviceRequestSQL{} + if err = p.QueryWithNetwork(ctx).Where("request_id = ?", requestID).First(&r); errors.Is(err, sql.ErrNoRows) { + return nil, "", errors.WithStack(fosite.ErrNotFound) + } else if err != nil { + return nil, "", sqlcon.HandleError(err) + } + + if !r.DeviceCodeActive { + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, "", err + } + return fr, r.ID, errors.WithStack(fosite.ErrInactiveToken) + } + + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, "", err + } + return fr, r.ID, nil +} + +// UpdateDeviceCodeSessionBySignature updates a device code session by the device_code signature. Implements FositeStorer. +func (p *Persister) UpdateDeviceCodeSessionBySignature(ctx context.Context, signature string, requester fosite.DeviceRequester) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateDeviceCodeSessionBySignature") + defer otelx.End(span, &err) + + req, err := p.sqlDeviceSchemaFromRequest(ctx, signature, "", requester, requester.GetSession().GetExpiresAt(fosite.DeviceCode).UTC()) + if err != nil { + return err + } + + stmt := fmt.Sprintf( + "UPDATE %s SET granted_scope=?, granted_audience=?, session_data=?, user_code_state=?, subject=?, challenge_id=? WHERE device_code_signature=? AND nid = ?", + sqlTableDeviceAuthCodes, + ) + + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.Connection(ctx).RawQuery(stmt, + req.GrantedScope, req.GrantedAudience, + req.Session, req.UserCodeState, + req.Subject, req.ConsentChallenge, + signature, p.NetworkID(ctx), + ).Exec(), + ) +} diff --git a/persistence/sql/persister_grant_jwk.go b/persistence/sql/persister_grant_jwk.go index 221bd8a88..001167ad0 100644 --- a/persistence/sql/persister_grant_jwk.go +++ b/persistence/sql/persister_grant_jwk.go @@ -12,7 +12,6 @@ import ( "github.com/gofrs/uuid" "github.com/pkg/errors" - "github.com/ory/hydra/v2/fosite/handler/rfc7523" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/oauth2/trust" "github.com/ory/pop/v6" @@ -22,7 +21,7 @@ import ( "github.com/ory/x/sqlxx" ) -var _ trust.GrantManager = &Persister{} +var _ trust.GrantManager = (*Persister)(nil) type SQLGrant struct { ID uuid.UUID `db:"id"` @@ -41,10 +40,37 @@ func (SQLGrant) TableName() string { return "hydra_oauth2_trusted_jwt_bearer_issuer" } -func (p *Persister) RFC7523KeyStorage() rfc7523.RFC7523KeyStorage { - return p +func (SQLGrant) fromGrant(g trust.Grant) SQLGrant { + return SQLGrant{ + ID: g.ID, + Issuer: g.Issuer, + Subject: g.Subject, + AllowAnySubject: g.AllowAnySubject, + Scope: g.Scope, + KeySet: g.PublicKey.Set, + KeyID: g.PublicKey.KeyID, + CreatedAt: g.CreatedAt, + ExpiresAt: g.ExpiresAt, + } } +func (d SQLGrant) toGrant() trust.Grant { + return trust.Grant{ + ID: d.ID, + Issuer: d.Issuer, + Subject: d.Subject, + AllowAnySubject: d.AllowAnySubject, + Scope: d.Scope, + PublicKey: trust.PublicKey{ + Set: d.KeySet, + KeyID: d.KeyID, + }, + CreatedAt: d.CreatedAt, + ExpiresAt: d.ExpiresAt, + } +} + +// CreateGrant implements GrantManager func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jose.JSONWebKey) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateGrant") defer otelx.End(span, &err) @@ -66,6 +92,7 @@ func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jo }) } +// GetConcreteGrant implements GrantManager func (p *Persister) GetConcreteGrant(ctx context.Context, id uuid.UUID) (_ trust.Grant, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteGrant") defer otelx.End(span, &err) @@ -78,6 +105,7 @@ func (p *Persister) GetConcreteGrant(ctx context.Context, id uuid.UUID) (_ trust return data.toGrant(), nil } +// DeleteGrant implements GrantManager func (p *Persister) DeleteGrant(ctx context.Context, id uuid.UUID) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteGrant") defer otelx.End(span, &err) @@ -96,6 +124,7 @@ func (p *Persister) DeleteGrant(ctx context.Context, id uuid.UUID) (err error) { }) } +// GetGrants implements GrantManager func (p *Persister) GetGrants(ctx context.Context, optionalIssuer string, pageOpts ...keysetpagination.Option) (_ []trust.Grant, _ *keysetpagination.Paginator, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetGrants") defer otelx.End(span, &err) @@ -123,6 +152,19 @@ func (p *Persister) GetGrants(ctx context.Context, optionalIssuer string, pageOp return grants, nextPage, nil } +// FlushInactiveGrants implements GrantManager +func (p *Persister) FlushInactiveGrants(ctx context.Context, notAfter time.Time, _ int, _ int) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveGrants") + defer otelx.End(span, &err) + + deleteUntil := time.Now().UTC() + if deleteUntil.After(notAfter) { + deleteUntil = notAfter + } + return sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("expires_at < ?", deleteUntil).Delete(&SQLGrant{})) +} + +// GetPublicKey implements RFC7523KeyStorage func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject string, keyId string) (_ *jose.JSONWebKey, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKey") defer otelx.End(span, &err) @@ -151,6 +193,7 @@ func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject str return &keySet.Keys[0], nil } +// GetPublicKeys implements RFC7523KeyStorage func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject string) (_ *jose.JSONWebKeySet, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeys") defer otelx.End(span, &err) @@ -212,6 +255,7 @@ func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject st return js.ToJWK(ctx, p.r.KeyCipher()) } +// GetPublicKeyScopes implements RFC7523KeyStorage func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) (_ []string, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeyScopes") defer otelx.End(span, &err) @@ -234,6 +278,7 @@ func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subje return scopes, nil } +// IsJWTUsed implements RFC7523KeyStorage func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (ok bool, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.IsJWTUsed") defer otelx.End(span, &err) @@ -246,50 +291,10 @@ func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (ok bool, err err return false, nil } +// MarkJWTUsedForTime implements RFC7523KeyStorage func (p *Persister) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MarkJWTUsedForTime") defer otelx.End(span, &err) return p.SetClientAssertionJWT(ctx, jti, exp) } - -func (SQLGrant) fromGrant(g trust.Grant) SQLGrant { - return SQLGrant{ - ID: g.ID, - Issuer: g.Issuer, - Subject: g.Subject, - AllowAnySubject: g.AllowAnySubject, - Scope: g.Scope, - KeySet: g.PublicKey.Set, - KeyID: g.PublicKey.KeyID, - CreatedAt: g.CreatedAt, - ExpiresAt: g.ExpiresAt, - } -} - -func (d SQLGrant) toGrant() trust.Grant { - return trust.Grant{ - ID: d.ID, - Issuer: d.Issuer, - Subject: d.Subject, - AllowAnySubject: d.AllowAnySubject, - Scope: d.Scope, - PublicKey: trust.PublicKey{ - Set: d.KeySet, - KeyID: d.KeyID, - }, - CreatedAt: d.CreatedAt, - ExpiresAt: d.ExpiresAt, - } -} - -func (p *Persister) FlushInactiveGrants(ctx context.Context, notAfter time.Time, _ int, _ int) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveGrants") - defer otelx.End(span, &err) - - deleteUntil := time.Now().UTC() - if deleteUntil.After(notAfter) { - deleteUntil = notAfter - } - return sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("expires_at < ?", deleteUntil).Delete(&SQLGrant{})) -} diff --git a/persistence/sql/persister_jwk.go b/persistence/sql/persister_jwk.go index 1d3d0bbf3..078d125d9 100644 --- a/persistence/sql/persister_jwk.go +++ b/persistence/sql/persister_jwk.go @@ -26,6 +26,7 @@ type JWKPersister struct { *BasePersister } +// GenerateAndPersistKeySet implements jwk.Manager. func (p *JWKPersister) GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, use string) (_ *jose.JSONWebKeySet, err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GenerateAndPersistKeySet", trace.WithAttributes( @@ -51,6 +52,7 @@ func (p *JWKPersister) GenerateAndPersistKeySet(ctx context.Context, set, kid, a return keys, nil } +// AddKey implements jwk.Manager. func (p *JWKPersister) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) (err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AddKey", trace.WithAttributes( @@ -76,6 +78,7 @@ func (p *JWKPersister) AddKey(ctx context.Context, set string, key *jose.JSONWeb })) } +// AddKeySet implements jwk.Manager. func (p *JWKPersister) AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) (err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AddKeySet", trace.WithAttributes(attribute.String("set", set))) defer otelx.End(span, &err) @@ -105,7 +108,7 @@ func (p *JWKPersister) AddKeySet(ctx context.Context, set string, keys *jose.JSO }) } -// UpdateKey updates or creates the key. +// UpdateKey updates or creates the key. Implements jwk.Manager. func (p *JWKPersister) UpdateKey(ctx context.Context, set string, key *jose.JSONWebKey) (err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKey", trace.WithAttributes( @@ -124,7 +127,7 @@ func (p *JWKPersister) UpdateKey(ctx context.Context, set string, key *jose.JSON }) } -// UpdateKeySet updates or creates the key set. +// UpdateKeySet updates or creates the key set. Implements jwk.Manager. func (p *JWKPersister) UpdateKeySet(ctx context.Context, set string, keySet *jose.JSONWebKeySet) (err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKeySet", trace.WithAttributes(attribute.String("set", set))) defer otelx.End(span, &err) @@ -140,6 +143,7 @@ func (p *JWKPersister) UpdateKeySet(ctx context.Context, set string, keySet *jos }) } +// GetKey implements jwk.Manager. func (p *JWKPersister) GetKey(ctx context.Context, set, kid string) (_ *jose.JSONWebKeySet, err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetKey", trace.WithAttributes( @@ -170,6 +174,7 @@ func (p *JWKPersister) GetKey(ctx context.Context, set, kid string) (_ *jose.JSO }, nil } +// GetKeySet implements jwk.Manager. func (p *JWKPersister) GetKeySet(ctx context.Context, set string) (keys *jose.JSONWebKeySet, err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetKeySet", trace.WithAttributes(attribute.String("set", set))) defer otelx.End(span, &err) @@ -185,6 +190,7 @@ func (p *JWKPersister) GetKeySet(ctx context.Context, set string) (keys *jose.JS return js.ToJWK(ctx, aead.NewAESGCM(p.d.Config())) } +// DeleteKey implements jwk.Manager. func (p *JWKPersister) DeleteKey(ctx context.Context, set, kid string) (err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteKey", trace.WithAttributes( @@ -196,6 +202,7 @@ func (p *JWKPersister) DeleteKey(ctx context.Context, set, kid string) (err erro return sqlcon.HandleError(err) } +// DeleteKeySet implements jwk.Manager. func (p *JWKPersister) DeleteKeySet(ctx context.Context, set string) (err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteKeySet", trace.WithAttributes(attribute.String("set", set))) defer otelx.End(span, &err) diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go index af2d7b645..09e26e07b 100644 --- a/persistence/sql/persister_nid_test.go +++ b/persistence/sql/persister_nid_test.go @@ -163,10 +163,10 @@ func (s *PersisterTestSuite) TestClientAssertionJWTValid() { for k, r := range s.registries { s.T().Run(k, func(t *testing.T) { jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(24*time.Hour)) - require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry)) + require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry)) - require.NoError(t, r.Persister().ClientManager().ClientAssertionJWTValid(s.t2, jti.JTI)) - require.Error(t, r.Persister().ClientManager().ClientAssertionJWTValid(s.t1, jti.JTI)) + require.NoError(t, r.Persister().ClientAssertionJWTValid(s.t2, jti.JTI)) + require.Error(t, r.Persister().ClientAssertionJWTValid(s.t1, jti.JTI)) }) } } @@ -799,7 +799,7 @@ func (s *PersisterTestSuite) TestGetClientAssertionJWT() { store, ok := r.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) expected := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(24*time.Hour)) - require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, expected.JTI, expected.Expiry)) + require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, expected.JTI, expected.Expiry)) _, err := store.GetClientAssertionJWT(s.t2, expected.JTI) require.Error(t, err) @@ -1151,7 +1151,7 @@ func (s *PersisterTestSuite) TestIsJWTUsed() { for k, r := range s.registries { s.T().Run(k, func(t *testing.T) { jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(24*time.Hour)) - require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry)) + require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry)) actual, err := r.Persister().IsJWTUsed(s.t2, jti.JTI) require.NoError(t, err) @@ -1245,9 +1245,9 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog func (s *PersisterTestSuite) TestMarkJWTUsedForTime() { for k, r := range s.registries { s.T().Run(k, func(t *testing.T) { - require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, "a", time.Now().Add(-24*time.Hour))) - require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t2, "a", time.Now().Add(-24*time.Hour))) - require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t2, "b", time.Now().Add(-24*time.Hour))) + require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, "a", time.Now().Add(-24*time.Hour))) + require.NoError(t, r.Persister().SetClientAssertionJWT(s.t2, "a", time.Now().Add(-24*time.Hour))) + require.NoError(t, r.Persister().SetClientAssertionJWT(s.t2, "b", time.Now().Add(-24*time.Hour))) require.NoError(t, r.Persister().MarkJWTUsedForTime(s.t2, "a", time.Now().Add(48*time.Hour))) @@ -1469,7 +1469,7 @@ func (s *PersisterTestSuite) TestSetClientAssertionJWT() { for k, r := range s.registries { s.T().Run(k, func(t *testing.T) { jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(24*time.Hour)) - require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry)) + require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry)) actual := &oauth2.BlacklistedJTI{} require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, jti.ID)) diff --git a/persistence/sql/persister_nonce.go b/persistence/sql/persister_nonce.go index a765d4d4b..e31a57f5a 100644 --- a/persistence/sql/persister_nonce.go +++ b/persistence/sql/persister_nonce.go @@ -10,13 +10,17 @@ import ( "github.com/pkg/errors" "github.com/ory/hydra/v2/fosite" + "github.com/ory/hydra/v2/fosite/handler/verifiable" "github.com/ory/hydra/v2/x" "github.com/ory/x/otelx" ) +var _ verifiable.NonceManager = (*Persister)(nil) + // Set the aadAccessTokenPrefix to something unique to avoid ciphertext confusion with other usages of the AEAD cipher. var aadAccessTokenPrefix = "vc-nonce-at:" // nolint:gosec +// NewNonce implements NonceManager. func (p *Persister) NewNonce(ctx context.Context, accessToken string, expiresIn time.Time) (res string, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.NewNonce") defer otelx.End(span, &err) @@ -27,6 +31,7 @@ func (p *Persister) NewNonce(ctx context.Context, accessToken string, expiresIn return p.r.FlowCipher().Encrypt(ctx, plaintext, aad) } +// IsNonceValid implements NonceManager. func (p *Persister) IsNonceValid(ctx context.Context, accessToken, nonce string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.IsNonceValid") defer otelx.End(span, &err) diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 1c670c06d..f64289e5a 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -21,9 +21,6 @@ import ( "go.opentelemetry.io/otel/trace" "github.com/ory/hydra/v2/fosite" - foauth2 "github.com/ory/hydra/v2/fosite/handler/oauth2" - "github.com/ory/hydra/v2/fosite/handler/openid" - "github.com/ory/hydra/v2/fosite/handler/pkce" "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/x" "github.com/ory/hydra/v2/x/events" @@ -34,11 +31,6 @@ import ( "github.com/ory/x/stringsx" ) -var ( - _ oauth2.AssertionJWTReader = &Persister{} - _ fosite.Transactional = &Persister{} -) - type ( tableName string OAuth2RequestSQL struct { @@ -76,62 +68,12 @@ const ( sqlTablePKCE tableName = "pkce" ) -func (r OAuth2RequestSQL) TableName() string { - return "hydra_oauth2_" + string(r.Table) -} - func (r OAuth2RefreshTable) TableName() string { return "hydra_oauth2_refresh" } -func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName, expiresAt time.Time) (*OAuth2RequestSQL, error) { - subject := "" - if r.GetSession() == nil { - p.l.Debugf("Got an empty session in sqlSchemaFromRequest") - } else { - subject = r.GetSession().GetSubject() - } - - var challenge sql.NullString - rr, ok := r.GetSession().(*oauth2.Session) - if !ok && r.GetSession() != nil { - return nil, errors.Errorf("Expected request to be of type *Session, but got: %T", r.GetSession()) - } else if ok { - if len(rr.ConsentChallenge) > 0 { - challenge = sql.NullString{Valid: true, String: rr.ConsentChallenge} - } - } - - session, err := json.Marshal(rr) - if err != nil { - return nil, errors.WithStack(err) - } - - if p.r.Config().EncryptSessionData(ctx) { - ciphertext, err := p.r.KeyCipher().Encrypt(ctx, session, nil) - if err != nil { - return nil, err - } - session = []byte(ciphertext) - } - - return &OAuth2RequestSQL{ - Request: r.GetID(), - ConsentChallenge: challenge, - ID: signature, - RequestedAt: r.GetRequestedAt(), - InternalExpiresAt: sqlxx.NullTime(expiresAt), - Client: r.GetClient().GetID(), - Scopes: strings.Join(r.GetRequestedScopes(), "|"), - GrantedScope: strings.Join(r.GetGrantedScopes(), "|"), - GrantedAudience: strings.Join(r.GetGrantedAudience(), "|"), - RequestedAudience: strings.Join(r.GetRequestedAudience(), "|"), - Form: r.GetRequestForm().Encode(), - Session: session, - Subject: subject, - Active: true, - Table: table, - }, nil +func (r OAuth2RequestSQL) TableName() string { + return "hydra_oauth2_" + string(r.Table) } func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (_ *fosite.Request, err error) { @@ -179,483 +121,6 @@ func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session }, nil } -func (p *Persister) ClientManager() fosite.ClientManager { - return p -} - -func (p *Persister) AccessTokenStorage() foauth2.AccessTokenStorage { - return p -} - -func (p *Persister) RefreshTokenStorage() foauth2.RefreshTokenStorage { - return p -} - -func (p *Persister) AuthorizeCodeStorage() foauth2.AuthorizeCodeStorage { - return p -} - -func (p *Persister) TokenRevocationStorage() foauth2.TokenRevocationStorage { - return p -} - -func (p *Persister) OpenIDConnectRequestStorage() openid.OpenIDConnectRequestStorage { - return p -} - -func (p *Persister) PKCERequestStorage() pkce.PKCERequestStorage { - return p -} - -func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid") - defer otelx.End(span, &err) - - j, err := p.GetClientAssertionJWT(ctx, jti) - if errors.Is(err, sqlcon.ErrNoRows) { - // the jti is not known => valid - return nil - } else if err != nil { - return err - } - if j.Expiry.After(time.Now()) { - // the jti is not expired yet => invalid - return errors.WithStack(fosite.ErrJTIKnown) - } - // the jti is expired => valid - return nil -} - -func (p *Persister) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetClientAssertionJWT") - defer otelx.End(span, &err) - - // delete expired; this cleanup spares us the need for a background worker - if err := p.QueryWithNetwork(ctx).Where("expires_at < CURRENT_TIMESTAMP").Delete(&oauth2.BlacklistedJTI{}); err != nil { - return sqlcon.HandleError(err) - } - - if err := p.SetClientAssertionJWTRaw(ctx, oauth2.NewBlacklistedJTI(jti, exp)); errors.Is(err, sqlcon.ErrUniqueViolation) { - // found a jti - return errors.WithStack(fosite.ErrJTIKnown) - } else if err != nil { - return err - } - - // setting worked without a problem - return nil -} - -func (p *Persister) GetClientAssertionJWT(ctx context.Context, j string) (_ *oauth2.BlacklistedJTI, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetClientAssertionJWT") - defer otelx.End(span, &err) - - jti := oauth2.NewBlacklistedJTI(j, time.Time{}) - return jti, sqlcon.HandleError(p.QueryWithNetwork(ctx).Find(jti, jti.ID)) -} - -func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.BlacklistedJTI) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetClientAssertionJWTRaw") - defer otelx.End(span, &err) - - return sqlcon.HandleError(p.CreateWithNetwork(ctx, jti)) -} - -func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName, expiresAt time.Time) error { - req, err := p.sqlSchemaFromRequest(ctx, signature, requester, table, expiresAt) - if err != nil { - return err - } - - if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return fosite.ErrSerializationFailure.WithWrap(err) - } else if err != nil { - return err - } - return nil -} - -func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (fosite.Requester, error) { - r := OAuth2RequestSQL{Table: table} - err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) - if errors.Is(err, sql.ErrNoRows) { - return nil, errors.WithStack(fosite.ErrNotFound) - } - if err != nil { - return nil, sqlcon.HandleError(err) - } - if !r.Active { - fr, err := r.toRequest(ctx, session, p) - if err != nil { - return nil, err - } - if table == sqlTableCode { - return fr, errors.WithStack(fosite.ErrInvalidatedAuthorizeCode) - } - return fr, errors.WithStack(fosite.ErrInactiveToken) - } - - return r.toRequest(ctx, session, p) -} - -func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) error { - err := sqlcon.HandleError( - p.QueryWithNetwork(ctx). - Where("signature = ?", signature). - Delete(OAuth2RequestSQL{Table: table}.TableName())) - if errors.Is(err, sqlcon.ErrNoRows) { - return errors.WithStack(fosite.ErrNotFound) - } - if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return fosite.ErrSerializationFailure.WithWrap(err) - } - return err -} - -func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionByRequestID") - defer otelx.End(span, &err) - - err = p.QueryWithNetwork(ctx). - Where("request_id=?", id). - Delete(OAuth2RequestSQL{Table: table}.TableName()) - if errors.Is(err, sql.ErrNoRows) { - return errors.WithStack(fosite.ErrNotFound) - } - if err := sqlcon.HandleError(err); err != nil { - if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return fosite.ErrSerializationFailure.WithWrap(err) - } - if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock? - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) - } - return err - } - return nil -} - -func (p *Persister) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error { - return otelx.WithSpan(ctx, "persistence.sql.CreateAuthorizeCodeSession", func(ctx context.Context) error { - return p.createSession(ctx, signature, requester, sqlTableCode, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode).UTC()) - }) -} - -func (p *Persister) GetAuthorizeCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAuthorizeCodeSession") - defer otelx.End(span, &err) - - return p.findSessionBySignature(ctx, signature, session, sqlTableCode) -} - -func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InvalidateAuthorizeCodeSession") - defer otelx.End(span, &err) - - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.Connection(ctx). - RawQuery( - fmt.Sprintf( - "UPDATE %s SET active = false, expires_at = ? WHERE signature = ? AND nid = ?", - OAuth2RequestSQL{Table: sqlTableCode}.TableName(), - ), - // We don't expire immediately, but in 30 minutes to avoid prematurely removing - // rows while they may still be needed (e.g. for reuse detection). - newUsedExpiry(), - signature, - p.NetworkID(ctx), - ). - Exec(), - ) -} - -func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateAccessTokenSession", - trace.WithAttributes(events.AccessTokenSignature(signature)), - ) - defer otelx.End(span, &err) - - events.Trace(ctx, events.AccessTokenIssued, - append(toEventOptions(requester), events.WithGrantType(requester.GetRequestForm().Get("grant_type")))..., - ) - - return p.createSession(ctx, x.SignatureHash(signature), requester, sqlTableAccess, requester.GetSession().GetExpiresAt(fosite.AccessToken).UTC()) -} - -func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAccessTokenSession", - trace.WithAttributes(events.AccessTokenSignature(signature)), - ) - defer otelx.End(span, &err) - - r := OAuth2RequestSQL{Table: sqlTableAccess} - err = p.QueryWithNetwork(ctx).Where("signature = ?", x.SignatureHash(signature)).First(&r) - if errors.Is(err, sql.ErrNoRows) { - // Backwards compatibility: we previously did not always hash the - // signature before inserting. In case there are still very old (but - // valid) access tokens in the database, this should get them. - err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) - if errors.Is(err, sql.ErrNoRows) { - return nil, errors.WithStack(fosite.ErrNotFound) - } - } - if err != nil { - return nil, sqlcon.HandleError(err) - } - if !r.Active { - fr, err := r.toRequest(ctx, session, p) - if err != nil { - return nil, err - } - return fr, errors.WithStack(fosite.ErrInactiveToken) - } - - return r.toRequest(ctx, session, p) -} - -func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokenSession", - trace.WithAttributes(events.AccessTokenSignature(signature)), - ) - defer otelx.End(span, &err) - - err = sqlcon.HandleError( - p.QueryWithNetwork(ctx). - Where("signature = ?", x.SignatureHash(signature)). - Delete(OAuth2RequestSQL{Table: sqlTableAccess}.TableName())) - if errors.Is(err, sqlcon.ErrNoRows) { - // Backwards compatibility: we previously did not always hash the - // signature before inserting. In case there are still very old (but - // valid) access tokens in the database, this should get them. - err = sqlcon.HandleError( - p.QueryWithNetwork(ctx). - Where("signature = ?", signature). - Delete(OAuth2RequestSQL{Table: sqlTableAccess}.TableName())) - if errors.Is(err, sqlcon.ErrNoRows) { - return errors.WithStack(fosite.ErrNotFound) - } - } - if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return fosite.ErrSerializationFailure.WithWrap(err) - } - return err -} - -func toEventOptions(requester fosite.Requester) []trace.EventOption { - sub := "" - if requester.GetSession() != nil { - hash := sha256.Sum256([]byte(requester.GetSession().GetSubject())) - sub = hex.EncodeToString(hash[:]) - } - return []trace.EventOption{ - events.WithGrantType(requester.GetRequestForm().Get("grant_type")), - events.WithSubject(sub), - events.WithRequest(requester), - events.WithClientID(requester.GetClient().GetID()), - } -} - -func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, accessTokenSignature string, requester fosite.Requester) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateRefreshTokenSession", - trace.WithAttributes(events.RefreshTokenSignature(signature)), - ) - defer otelx.End(span, &err) - events.Trace(ctx, events.RefreshTokenIssued, toEventOptions(requester)...) - - req, err := p.sqlSchemaFromRequest(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) - if err != nil { - return err - } - - var sig sql.NullString - if len(accessTokenSignature) > 0 { - sig = sql.NullString{ - Valid: true, - String: x.SignatureHash(accessTokenSignature), - } - } - - if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, &OAuth2RefreshTable{ - OAuth2RequestSQL: *req, - AccessTokenSignature: sig, - })); errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return fosite.ErrSerializationFailure.WithWrap(err) - } else if err != nil { - return err - } - - return nil -} - -func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession", - trace.WithAttributes(events.RefreshTokenSignature(signature)), - ) - defer otelx.End(span, &err) - - var row OAuth2RefreshTable - if err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&row); errors.Is(err, sql.ErrNoRows) { - return nil, errors.WithStack(fosite.ErrNotFound) - } else if err != nil { - return nil, sqlcon.HandleError(err) - } - - if row.Active { - // Token is active - return row.toRequest(ctx, session, p) - } - - if graceful := p.r.Config().GracefulRefreshTokenRotation(ctx); graceful.Period > 0 && - row.FirstUsedAt.Valid && - row.FirstUsedAt.Time.Add(graceful.Period).After(time.Now()) && - (graceful.Count == 0 || // no limit - (row.UsedTimes.Int32 < graceful.Count)) { - // We return the request as is, which indicates that the token is active (because we are in the grace period still). - return row.toRequest(ctx, session, p) - } - - fositeRequest, err := row.toRequest(ctx, session, p) - if err != nil { - return nil, err - } - - return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) -} - -func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRefreshTokenSession", - trace.WithAttributes(events.RefreshTokenSignature(signature)), - ) - defer otelx.End(span, &err) - return p.deleteSessionBySignature(ctx, signature, sqlTableRefresh) -} - -func (p *Persister) CreateOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateOpenIDConnectSession") - defer otelx.End(span, &err) - events.Trace(ctx, events.IdentityTokenIssued, toEventOptions(requester)...) - // The expiry of an OIDC session is equal to the expiry of the authorization code. If the code is invalid, so is this OIDC request. - return p.createSession(ctx, signature, requester, sqlTableOpenID, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode).UTC()) -} - -func (p *Persister) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (_ fosite.Requester, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetOpenIDConnectSession") - defer otelx.End(span, &err) - return p.findSessionBySignature(ctx, signature, requester.GetSession(), sqlTableOpenID) -} - -func (p *Persister) DeleteOpenIDConnectSession(ctx context.Context, signature string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteOpenIDConnectSession") - defer otelx.End(span, &err) - return p.deleteSessionBySignature(ctx, signature, sqlTableOpenID) -} - -func (p *Persister) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.Requester, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPKCERequestSession") - defer otelx.End(span, &err) - return p.findSessionBySignature(ctx, signature, session, sqlTablePKCE) -} - -func (p *Persister) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreatePKCERequestSession") - defer otelx.End(span, &err) - // The expiry of a PKCE session is equal to the expiry of the authorization code. If the code is invalid, so is this PKCE request. - return p.createSession(ctx, signature, requester, sqlTablePKCE, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode).UTC()) -} - -func (p *Persister) DeletePKCERequestSession(ctx context.Context, signature string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeletePKCERequestSession") - defer otelx.End(span, &err) - return p.deleteSessionBySignature(ctx, signature, sqlTablePKCE) -} - -func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshToken", - trace.WithAttributes(events.ConsentRequestID(id)), - ) - defer otelx.End(span, &err) - return p.deleteSessionByRequestID(ctx, id, sqlTableRefresh) -} - -func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeAccessToken", - trace.WithAttributes(events.ConsentRequestID(id)), - ) - defer otelx.End(span, &err) - return p.deleteSessionByRequestID(ctx, id, sqlTableAccess) -} - -func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) { - /* #nosec G201 table is static */ - // The value of notAfter should be the minimum between input parameter and token max expire based on its configured age - requestMaxExpire := time.Now().Add(-lifespan) - if requestMaxExpire.Before(notAfter) { - notAfter = requestMaxExpire - } - - totalDeletedCount := 0 - for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; { - d := batchSize - if limit-totalDeletedCount < batchSize { - d = limit - totalDeletedCount - } - // Delete in batches - // The outer SELECT is necessary because our version of MySQL doesn't yet support 'LIMIT & IN/ALL/ANY/SOME subquery - deletedRecords, err = p.Connection(ctx).RawQuery( - fmt.Sprintf(`DELETE FROM %s WHERE signature in ( - SELECT signature FROM (SELECT signature FROM %s hoa WHERE requested_at < ? and nid = ? ORDER BY requested_at LIMIT %d ) as s - )`, OAuth2RequestSQL{Table: table}.TableName(), OAuth2RequestSQL{Table: table}.TableName(), d), - notAfter, - p.NetworkID(ctx), - ).ExecWithCount() - totalDeletedCount += deletedRecords - - if err != nil { - break - } - p.l.Debugf("Flushing tokens...: %d/%d", totalDeletedCount, limit) - } - p.l.Debugf("Flush Refresh Tokens flushed_records: %d", totalDeletedCount) - return sqlcon.HandleError(err) -} - -func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveAccessTokens") - defer otelx.End(span, &err) - return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.r.Config().GetAccessTokenLifespan(ctx)) -} - -func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveRefreshTokens") - defer otelx.End(span, &err) - return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.r.Config().GetRefreshTokenLifespan(ctx)) -} - -func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokens", - trace.WithAttributes(events.ClientID(clientID)), - ) - defer otelx.End(span, &err) - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), - ) -} - -func handleRetryError(err error) error { - if err == nil { - return nil - } - - if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return fosite.ErrSerializationFailure.WithWrap(err) - } - if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) - } - return err -} - // strictRefreshRotation implements the strict refresh token rotation strategy. In strict rotation, we disable all // refresh and access tokens associated with a request ID and subsequently create the only valid, new token pair. func (p *Persister) strictRefreshRotation(ctx context.Context, requestID string) (err error) { @@ -780,6 +245,447 @@ WHERE signature = ? AND nid = ?` return nil } +func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName, expiresAt time.Time) (*OAuth2RequestSQL, error) { + subject := "" + if r.GetSession() == nil { + p.l.Debugf("Got an empty session in sqlSchemaFromRequest") + } else { + subject = r.GetSession().GetSubject() + } + + var challenge sql.NullString + rr, ok := r.GetSession().(*oauth2.Session) + if !ok && r.GetSession() != nil { + return nil, errors.Errorf("Expected request to be of type *Session, but got: %T", r.GetSession()) + } else if ok { + if len(rr.ConsentChallenge) > 0 { + challenge = sql.NullString{Valid: true, String: rr.ConsentChallenge} + } + } + + session, err := json.Marshal(rr) + if err != nil { + return nil, errors.WithStack(err) + } + + if p.r.Config().EncryptSessionData(ctx) { + ciphertext, err := p.r.KeyCipher().Encrypt(ctx, session, nil) + if err != nil { + return nil, err + } + session = []byte(ciphertext) + } + + return &OAuth2RequestSQL{ + Request: r.GetID(), + ConsentChallenge: challenge, + ID: signature, + RequestedAt: r.GetRequestedAt(), + InternalExpiresAt: sqlxx.NullTime(expiresAt), + Client: r.GetClient().GetID(), + Scopes: strings.Join(r.GetRequestedScopes(), "|"), + GrantedScope: strings.Join(r.GetGrantedScopes(), "|"), + GrantedAudience: strings.Join(r.GetGrantedAudience(), "|"), + RequestedAudience: strings.Join(r.GetRequestedAudience(), "|"), + Form: r.GetRequestForm().Encode(), + Session: session, + Subject: subject, + Active: true, + Table: table, + }, nil +} + +func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName, expiresAt time.Time) error { + req, err := p.sqlSchemaFromRequest(ctx, signature, requester, table, expiresAt) + if err != nil { + return err + } + + if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } else if err != nil { + return err + } + return nil +} + +func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (fosite.Requester, error) { + r := OAuth2RequestSQL{Table: table} + err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.WithStack(fosite.ErrNotFound) + } + if err != nil { + return nil, sqlcon.HandleError(err) + } + if !r.Active { + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + if table == sqlTableCode { + return fr, errors.WithStack(fosite.ErrInvalidatedAuthorizeCode) + } + return fr, errors.WithStack(fosite.ErrInactiveToken) + } + + return r.toRequest(ctx, session, p) +} + +func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) error { + err := sqlcon.HandleError( + p.QueryWithNetwork(ctx). + Where("signature = ?", signature). + Delete(OAuth2RequestSQL{Table: table}.TableName())) + if errors.Is(err, sqlcon.ErrNoRows) { + return errors.WithStack(fosite.ErrNotFound) + } + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + return err +} + +func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionByRequestID") + defer otelx.End(span, &err) + + err = p.QueryWithNetwork(ctx). + Where("request_id=?", id). + Delete(OAuth2RequestSQL{Table: table}.TableName()) + if errors.Is(err, sql.ErrNoRows) { + return errors.WithStack(fosite.ErrNotFound) + } + if err := sqlcon.HandleError(err); err != nil { + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock? + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return err + } + return nil +} + +func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) { + /* #nosec G201 table is static */ + // The value of notAfter should be the minimum between input parameter and token max expire based on its configured age + requestMaxExpire := time.Now().Add(-lifespan) + if requestMaxExpire.Before(notAfter) { + notAfter = requestMaxExpire + } + + totalDeletedCount := 0 + for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; { + d := batchSize + if limit-totalDeletedCount < batchSize { + d = limit - totalDeletedCount + } + // Delete in batches + // The outer SELECT is necessary because our version of MySQL doesn't yet support 'LIMIT & IN/ALL/ANY/SOME subquery + deletedRecords, err = p.Connection(ctx).RawQuery( + fmt.Sprintf(`DELETE FROM %s WHERE signature in ( + SELECT signature FROM (SELECT signature FROM %s hoa WHERE requested_at < ? and nid = ? ORDER BY requested_at LIMIT %d ) as s + )`, OAuth2RequestSQL{Table: table}.TableName(), OAuth2RequestSQL{Table: table}.TableName(), d), + notAfter, + p.NetworkID(ctx), + ).ExecWithCount() + totalDeletedCount += deletedRecords + + if err != nil { + break + } + p.l.Debugf("Flushing tokens...: %d/%d", totalDeletedCount, limit) + } + p.l.Debugf("Flush Refresh Tokens flushed_records: %d", totalDeletedCount) + return sqlcon.HandleError(err) +} + +func toEventOptions(requester fosite.Requester) []trace.EventOption { + sub := "" + if requester.GetSession() != nil { + hash := sha256.Sum256([]byte(requester.GetSession().GetSubject())) + sub = hex.EncodeToString(hash[:]) + } + return []trace.EventOption{ + events.WithGrantType(requester.GetRequestForm().Get("grant_type")), + events.WithSubject(sub), + events.WithRequest(requester), + events.WithClientID(requester.GetClient().GetID()), + } +} + +func handleRetryError(err error) error { + if err == nil { + return nil + } + + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return err +} + +func newUsedExpiry() time.Time { + // Reuse detection is racy and would generally happen within seconds. Using 30 minutes here is a paranoid + // setting but ensures that we do not prematurely remove rows while they may still be needed (e.g. for reuse detection). + return time.Now().UTC().Round(time.Millisecond).Add(time.Minute * 30) +} + +// ClientAssertionJWTValid implements fosite.ClientManager +func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid") + defer otelx.End(span, &err) + + j, err := p.GetClientAssertionJWT(ctx, jti) + if errors.Is(err, sqlcon.ErrNoRows) { + // the jti is not known => valid + return nil + } else if err != nil { + return err + } + if j.Expiry.After(time.Now()) { + // the jti is not expired yet => invalid + return errors.WithStack(fosite.ErrJTIKnown) + } + // the jti is expired => valid + return nil +} + +// SetClientAssertionJWT implements fosite.ClientManager +func (p *Persister) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetClientAssertionJWT") + defer otelx.End(span, &err) + + // delete expired; this cleanup spares us the need for a background worker + if err := p.QueryWithNetwork(ctx).Where("expires_at < CURRENT_TIMESTAMP").Delete(&oauth2.BlacklistedJTI{}); err != nil { + return sqlcon.HandleError(err) + } + + if err := p.SetClientAssertionJWTRaw(ctx, oauth2.NewBlacklistedJTI(jti, exp)); errors.Is(err, sqlcon.ErrUniqueViolation) { + // found a jti + return errors.WithStack(fosite.ErrJTIKnown) + } else if err != nil { + return err + } + + // setting worked without a problem + return nil +} + +// GetClientAssertionJWT implements AssertionJWTReader +func (p *Persister) GetClientAssertionJWT(ctx context.Context, j string) (_ *oauth2.BlacklistedJTI, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetClientAssertionJWT") + defer otelx.End(span, &err) + + jti := oauth2.NewBlacklistedJTI(j, time.Time{}) + return jti, sqlcon.HandleError(p.QueryWithNetwork(ctx).Find(jti, jti.ID)) +} + +// SetClientAssertionJWTRaw implements AssertionJWTReader +func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.BlacklistedJTI) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetClientAssertionJWTRaw") + defer otelx.End(span, &err) + + return sqlcon.HandleError(p.CreateWithNetwork(ctx, jti)) +} + +// CreateAuthorizeCodeSession implements AuthorizeCodeStorage +func (p *Persister) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error { + return otelx.WithSpan(ctx, "persistence.sql.CreateAuthorizeCodeSession", func(ctx context.Context) error { + return p.createSession(ctx, signature, requester, sqlTableCode, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode).UTC()) + }) +} + +// GetAuthorizeCodeSession implements AuthorizeCodeStorage +func (p *Persister) GetAuthorizeCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAuthorizeCodeSession") + defer otelx.End(span, &err) + + return p.findSessionBySignature(ctx, signature, session, sqlTableCode) +} + +// InvalidateAuthorizeCodeSession implements AuthorizeCodeStorage +func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InvalidateAuthorizeCodeSession") + defer otelx.End(span, &err) + + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.Connection(ctx). + RawQuery( + fmt.Sprintf( + "UPDATE %s SET active = false, expires_at = ? WHERE signature = ? AND nid = ?", + OAuth2RequestSQL{Table: sqlTableCode}.TableName(), + ), + // We don't expire immediately, but in 30 minutes to avoid prematurely removing + // rows while they may still be needed (e.g. for reuse detection). + newUsedExpiry(), + signature, + p.NetworkID(ctx), + ). + Exec(), + ) +} + +// CreateAccessTokenSession implements AccessTokenStorage +func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateAccessTokenSession", + trace.WithAttributes(events.AccessTokenSignature(signature)), + ) + defer otelx.End(span, &err) + + events.Trace(ctx, events.AccessTokenIssued, + append(toEventOptions(requester), events.WithGrantType(requester.GetRequestForm().Get("grant_type")))..., + ) + + return p.createSession(ctx, x.SignatureHash(signature), requester, sqlTableAccess, requester.GetSession().GetExpiresAt(fosite.AccessToken).UTC()) +} + +// GetAccessTokenSession implements AccessTokenStorage +func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAccessTokenSession", + trace.WithAttributes(events.AccessTokenSignature(signature)), + ) + defer otelx.End(span, &err) + + r := OAuth2RequestSQL{Table: sqlTableAccess} + err = p.QueryWithNetwork(ctx).Where("signature = ?", x.SignatureHash(signature)).First(&r) + if errors.Is(err, sql.ErrNoRows) { + // Backwards compatibility: we previously did not always hash the + // signature before inserting. In case there are still very old (but + // valid) access tokens in the database, this should get them. + err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.WithStack(fosite.ErrNotFound) + } + } + if err != nil { + return nil, sqlcon.HandleError(err) + } + if !r.Active { + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + return fr, errors.WithStack(fosite.ErrInactiveToken) + } + + return r.toRequest(ctx, session, p) +} + +// DeleteAccessTokenSession implements AccessTokenStorage +func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokenSession", + trace.WithAttributes(events.AccessTokenSignature(signature)), + ) + defer otelx.End(span, &err) + + err = sqlcon.HandleError( + p.QueryWithNetwork(ctx). + Where("signature = ?", x.SignatureHash(signature)). + Delete(OAuth2RequestSQL{Table: sqlTableAccess}.TableName())) + if errors.Is(err, sqlcon.ErrNoRows) { + // Backwards compatibility: we previously did not always hash the + // signature before inserting. In case there are still very old (but + // valid) access tokens in the database, this should get them. + err = sqlcon.HandleError( + p.QueryWithNetwork(ctx). + Where("signature = ?", signature). + Delete(OAuth2RequestSQL{Table: sqlTableAccess}.TableName())) + if errors.Is(err, sqlcon.ErrNoRows) { + return errors.WithStack(fosite.ErrNotFound) + } + } + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + return err +} + +// CreateRefreshTokenSession implements RefreshTokenStorage +func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, accessTokenSignature string, requester fosite.Requester) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateRefreshTokenSession", + trace.WithAttributes(events.RefreshTokenSignature(signature)), + ) + defer otelx.End(span, &err) + events.Trace(ctx, events.RefreshTokenIssued, toEventOptions(requester)...) + + req, err := p.sqlSchemaFromRequest(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) + if err != nil { + return err + } + + var sig sql.NullString + if len(accessTokenSignature) > 0 { + sig = sql.NullString{ + Valid: true, + String: x.SignatureHash(accessTokenSignature), + } + } + + if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, &OAuth2RefreshTable{ + OAuth2RequestSQL: *req, + AccessTokenSignature: sig, + })); errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } else if err != nil { + return err + } + + return nil +} + +// GetRefreshTokenSession implements RefreshTokenStorage +func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession", + trace.WithAttributes(events.RefreshTokenSignature(signature)), + ) + defer otelx.End(span, &err) + + var row OAuth2RefreshTable + if err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&row); errors.Is(err, sql.ErrNoRows) { + return nil, errors.WithStack(fosite.ErrNotFound) + } else if err != nil { + return nil, sqlcon.HandleError(err) + } + + if row.Active { + // Token is active + return row.toRequest(ctx, session, p) + } + + if graceful := p.r.Config().GracefulRefreshTokenRotation(ctx); graceful.Period > 0 && + row.FirstUsedAt.Valid && + row.FirstUsedAt.Time.Add(graceful.Period).After(time.Now()) && + (graceful.Count == 0 || // no limit + (row.UsedTimes.Int32 < graceful.Count)) { + // We return the request as is, which indicates that the token is active (because we are in the grace period still). + return row.toRequest(ctx, session, p) + } + + fositeRequest, err := row.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + + return fositeRequest, errors.WithStack(fosite.ErrInactiveToken) +} + +// DeleteRefreshTokenSession implements RefreshTokenStorage +func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRefreshTokenSession", + trace.WithAttributes(events.RefreshTokenSignature(signature)), + ) + defer otelx.End(span, &err) + return p.deleteSessionBySignature(ctx, signature, sqlTableRefresh) +} + +// RotateRefreshToken implements RefreshTokenStorage func (p *Persister) RotateRefreshToken(ctx context.Context, requestID, refreshTokenSignature string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RotateRefreshToken") defer otelx.End(span, &err) @@ -792,8 +698,91 @@ func (p *Persister) RotateRefreshToken(ctx context.Context, requestID, refreshTo return handleRetryError(p.strictRefreshRotation(ctx, requestID)) } -func newUsedExpiry() time.Time { - // Reuse detection is racy and would generally happen within seconds. Using 30 minutes here is a paranoid - // setting but ensures that we do not prematurely remove rows while they may still be needed (e.g. for reuse detection). - return time.Now().UTC().Round(time.Millisecond).Add(time.Minute * 30) +// CreateOpenIDConnectSession implements OpenIDConnectRequestStorage +func (p *Persister) CreateOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateOpenIDConnectSession") + defer otelx.End(span, &err) + events.Trace(ctx, events.IdentityTokenIssued, toEventOptions(requester)...) + // The expiry of an OIDC session is equal to the expiry of the authorization code. If the code is invalid, so is this OIDC request. + return p.createSession(ctx, signature, requester, sqlTableOpenID, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode).UTC()) +} + +// GetOpenIDConnectSession implements OpenIDConnectRequestStorage +func (p *Persister) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (_ fosite.Requester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetOpenIDConnectSession") + defer otelx.End(span, &err) + return p.findSessionBySignature(ctx, signature, requester.GetSession(), sqlTableOpenID) +} + +// DeleteOpenIDConnectSession implements OpenIDConnectRequestStorage +func (p *Persister) DeleteOpenIDConnectSession(ctx context.Context, signature string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteOpenIDConnectSession") + defer otelx.End(span, &err) + return p.deleteSessionBySignature(ctx, signature, sqlTableOpenID) +} + +// GetPKCERequestSession implements PKCERequestStorage +func (p *Persister) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.Requester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPKCERequestSession") + defer otelx.End(span, &err) + return p.findSessionBySignature(ctx, signature, session, sqlTablePKCE) +} + +// CreatePKCERequestSession implements PKCERequestStorage +func (p *Persister) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreatePKCERequestSession") + defer otelx.End(span, &err) + // The expiry of a PKCE session is equal to the expiry of the authorization code. If the code is invalid, so is this PKCE request. + return p.createSession(ctx, signature, requester, sqlTablePKCE, requester.GetSession().GetExpiresAt(fosite.AuthorizeCode).UTC()) +} + +// DeletePKCERequestSession implements PKCERequestStorage +func (p *Persister) DeletePKCERequestSession(ctx context.Context, signature string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeletePKCERequestSession") + defer otelx.End(span, &err) + return p.deleteSessionBySignature(ctx, signature, sqlTablePKCE) +} + +// RevokeRefreshToken implements TokenRevocationStorage +func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshToken", + trace.WithAttributes(events.ConsentRequestID(id)), + ) + defer otelx.End(span, &err) + return p.deleteSessionByRequestID(ctx, id, sqlTableRefresh) +} + +// RevokeAccessToken implements TokenRevocationStorage +func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeAccessToken", + trace.WithAttributes(events.ConsentRequestID(id)), + ) + defer otelx.End(span, &err) + return p.deleteSessionByRequestID(ctx, id, sqlTableAccess) +} + +// FlushInactiveAccessTokens implements FositeStorer +func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveAccessTokens") + defer otelx.End(span, &err) + return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.r.Config().GetAccessTokenLifespan(ctx)) +} + +// FlushInactiveRefreshTokens implements FositeStorer +func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveRefreshTokens") + defer otelx.End(span, &err) + return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.r.Config().GetRefreshTokenLifespan(ctx)) +} + +// DeleteAccessTokens implements FositeStorer +func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokens", + trace.WithAttributes(events.ClientID(clientID)), + ) + defer otelx.End(span, &err) + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), + ) } diff --git a/x/fosite_storer.go b/x/fosite_storer.go index 57aca46bb..2fab0cd45 100644 --- a/x/fosite_storer.go +++ b/x/fosite_storer.go @@ -17,7 +17,7 @@ import ( ) type FositeStorer interface { - fosite.Storage + fosite.ClientManager oauth2.AuthorizeCodeStorage oauth2.AccessTokenStorage oauth2.RefreshTokenStorage @@ -29,6 +29,7 @@ type FositeStorer interface { verifiable.NonceManager oauth2.ResourceOwnerPasswordCredentialsGrantStorage + // Hydra-specific storage utilities // flush the access token requests from the database. // no data will be deleted after the 'notAfter' timeframe. FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error @@ -44,8 +45,9 @@ type FositeStorer interface { // DeleteOpenIDConnectSession deletes an OpenID Connect session. // This is duplicated from Ory Fosite to help against deprecation linting errors. - DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error + // DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error + // Hydra-specific RFC8628 Device Auth capabilities GetUserCodeSession(context.Context, string, fosite.Session) (fosite.DeviceRequester, error) GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, requester fosite.Session) (fosite.DeviceRequester, string, error) UpdateDeviceCodeSessionBySignature(ctx context.Context, requestID string, requester fosite.DeviceRequester) error