// Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 package oauth2_test import ( "context" "fmt" "net/url" "slices" "testing" "time" "github.com/go-jose/go-jose/v3" "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/fosite" "github.com/ory/hydra/v2/fosite/handler/openid" "github.com/ory/hydra/v2/fosite/handler/rfc7523" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/oauth2/trust" "github.com/ory/hydra/v2/persistence/sql" "github.com/ory/hydra/v2/x" "github.com/ory/x/assertx" "github.com/ory/x/sqlcon" "github.com/ory/x/sqlxx" ) var defaultIgnoreKeys = []string{ "id", "session", "requested_scope", "granted_scope", "form", "created_at", "updated_at", "client.created_at", "client.updated_at", "requestedAt", "client.client_secret", } func newDefaultRequest(t testing.TB, id string) *fosite.Request { return &fosite.Request{ ID: id, RequestedAt: time.Now().UTC().Round(time.Second), Client: &client.Client{ ID: "foobar", Contacts: []string{}, RedirectURIs: []string{}, Audience: []string{}, AllowedCORSOrigins: []string{}, ResponseTypes: []string{}, GrantTypes: []string{}, JSONWebKeys: &x.JoseJSONWebKeySet{}, Metadata: sqlxx.JSONRawMessage("{}"), }, RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, RequestedAudience: fosite.Arguments{"ad1", "ad2"}, GrantedAudience: fosite.Arguments{"ad1", "ad2"}, Form: url.Values{"foo": []string{"bar", "baz"}}, Session: oauth2.NewTestSession(t, "bar"), } } // var lifespan = time.Hour var flushRequests = []*fosite.Request{ { ID: "flush-1", RequestedAt: time.Now().Round(time.Second), Client: &client.Client{ID: "foobar"}, RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, Form: url.Values{"foo": []string{"bar", "baz"}}, Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, }, { ID: "flush-2", RequestedAt: time.Now().Round(time.Second).Add(-(lifespan + time.Minute)), Client: &client.Client{ID: "foobar"}, RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, Form: url.Values{"foo": []string{"bar", "baz"}}, Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, }, { ID: "flush-3", RequestedAt: time.Now().Round(time.Second).Add(-(lifespan + time.Hour)), Client: &client.Client{ID: "foobar"}, RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, Form: url.Values{"foo": []string{"bar", "baz"}}, Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, }, } func mockRequestForeignKey(t *testing.T, _ string, x *driver.RegistrySQL) { cl := &client.Client{ID: "foobar"} if _, err := x.ClientManager().GetClient(t.Context(), cl.ID); errors.Is(err, sqlcon.ErrNoRows) { require.NoError(t, x.ClientManager().CreateClient(t.Context(), cl)) } } func testHelperRequestIDMultiples(m *driver.RegistrySQL, _ string) func(t *testing.T) { return func(t *testing.T) { ctx := t.Context() requestID := uuid.Must(uuid.NewV4()).String() mockRequestForeignKey(t, requestID, m) cl := &client.Client{ID: "foobar"} fositeRequest := &fosite.Request{ ID: requestID, Client: cl, RequestedAt: time.Now().UTC().Round(time.Second), Session: oauth2.NewTestSession(t, "bar"), } for range 4 { signature := uuid.Must(uuid.NewV4()).String() accessSignature := uuid.Must(uuid.NewV4()).String() err := m.OAuth2Storage().CreateRefreshTokenSession(ctx, signature, accessSignature, fositeRequest) assert.NoError(t, err) err = m.OAuth2Storage().CreateAccessTokenSession(ctx, signature, fositeRequest) assert.NoError(t, err) err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature, fositeRequest) assert.NoError(t, err) err = m.OAuth2Storage().CreatePKCERequestSession(ctx, signature, fositeRequest) assert.NoError(t, err) err = m.OAuth2Storage().CreateAuthorizeCodeSession(ctx, signature, fositeRequest) assert.NoError(t, err) } } } func testHelperCreateGetDeleteOpenIDConnectSession(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() code := uuid.Must(uuid.NewV4()).String() ctx := t.Context() _, err := m.GetOpenIDConnectSession(ctx, code, &fosite.Request{Session: oauth2.NewTestSession(t, "bar")}) assert.NotNil(t, err) err = m.CreateOpenIDConnectSession(ctx, code, newDefaultRequest(t, "blank")) require.NoError(t, err) res, err := m.GetOpenIDConnectSession(ctx, code, &fosite.Request{Session: oauth2.NewTestSession(t, "bar")}) require.NoError(t, err) AssertObjectKeysEqual(t, newDefaultRequest(t, "blank"), res, "RequestedScope", "GrantedScope", "Form", "Session") err = m.DeleteOpenIDConnectSession(ctx, code) require.NoError(t, err) _, err = m.GetOpenIDConnectSession(ctx, code, &fosite.Request{Session: oauth2.NewTestSession(t, "bar")}) assert.NotNil(t, err) } } func testHelperCreateGetDeleteRefreshTokenSession(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() code := uuid.Must(uuid.NewV4()).String() ctx := t.Context() _, err := m.GetRefreshTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) assert.NotNil(t, err) err = m.CreateRefreshTokenSession(ctx, code, "", newDefaultRequest(t, "blank")) require.NoError(t, err) res, err := m.GetRefreshTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) AssertObjectKeysEqual(t, newDefaultRequest(t, "blank"), res, "RequestedScope", "GrantedScope", "Form", "Session") err = m.DeleteRefreshTokenSession(ctx, code) require.NoError(t, err) _, err = m.GetRefreshTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) assert.NotNil(t, err) } } func testHelperRevokeRefreshToken(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() ctx := t.Context() _, err := m.GetRefreshTokenSession(ctx, "1111", oauth2.NewTestSession(t, "bar")) assert.Error(t, err) reqIdOne := uuid.Must(uuid.NewV4()).String() reqIdTwo := uuid.Must(uuid.NewV4()).String() mockRequestForeignKey(t, reqIdOne, x) mockRequestForeignKey(t, reqIdTwo, x) err = m.CreateRefreshTokenSession(ctx, "1111", "", &fosite.Request{ ID: reqIdOne, Client: &client.Client{ID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), Session: oauth2.NewTestSession(t, "user"), }) require.NoError(t, err) err = m.CreateRefreshTokenSession(ctx, "1122", "", &fosite.Request{ ID: reqIdTwo, Client: &client.Client{ID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), Session: oauth2.NewTestSession(t, "user"), }) require.NoError(t, err) _, err = m.GetRefreshTokenSession(ctx, "1111", oauth2.NewTestSession(t, "bar")) require.NoError(t, err) err = m.RevokeRefreshToken(ctx, reqIdOne) require.NoError(t, err) err = m.RevokeRefreshToken(ctx, reqIdTwo) require.NoError(t, err) req, err := m.GetRefreshTokenSession(ctx, "1111", oauth2.NewTestSession(t, "bar")) assert.Nil(t, req) assert.EqualError(t, err, fosite.ErrNotFound.Error()) req, err = m.GetRefreshTokenSession(ctx, "1122", oauth2.NewTestSession(t, "bar")) assert.Nil(t, req) assert.EqualError(t, err, fosite.ErrNotFound.Error()) } } func testHelperCreateGetDeleteAuthorizeCodes(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() mockRequestForeignKey(t, "blank", x) code := uuid.Must(uuid.NewV4()).String() ctx := t.Context() res, err := m.GetAuthorizeCodeSession(ctx, code, oauth2.NewTestSession(t, "bar")) assert.Error(t, err) assert.Nil(t, res) err = m.CreateAuthorizeCodeSession(ctx, code, newDefaultRequest(t, "blank")) require.NoError(t, err) res, err = m.GetAuthorizeCodeSession(ctx, code, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) AssertObjectKeysEqual(t, newDefaultRequest(t, "blank"), res, "RequestedScope", "GrantedScope", "Form", "Session") err = m.InvalidateAuthorizeCodeSession(ctx, code) require.NoError(t, err) res, err = m.GetAuthorizeCodeSession(ctx, code, oauth2.NewTestSession(t, "bar")) require.Error(t, err) assert.EqualError(t, err, fosite.ErrInvalidatedAuthorizeCode.Error()) assert.NotNil(t, res) } } type testHelperExpiryFieldsResult struct { ExpiresAt time.Time `db:"expires_at"` name string } func (r testHelperExpiryFieldsResult) TableName() string { return "hydra_oauth2_" + r.name } func testHelperExpiryFields(reg *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := reg.OAuth2Storage() t.Parallel() mockRequestForeignKey(t, "blank", reg) ctx := t.Context() s := oauth2.NewTestSession(t, "bar") s.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour).Round(time.Minute)) s.SetExpiresAt(fosite.RefreshToken, time.Now().Add(time.Hour*2).Round(time.Minute)) s.SetExpiresAt(fosite.AuthorizeCode, time.Now().Add(time.Hour*3).Round(time.Minute)) request := fosite.Request{ ID: uuid.Must(uuid.NewV4()).String(), RequestedAt: time.Now().UTC().Round(time.Second), Client: &client.Client{ ID: "foobar", Metadata: sqlxx.JSONRawMessage("{}"), }, RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, RequestedAudience: fosite.Arguments{"ad1", "ad2"}, GrantedAudience: fosite.Arguments{"ad1", "ad2"}, Form: url.Values{"foo": []string{"bar", "baz"}}, Session: s, } t.Run("case=CreateAccessTokenSession", func(t *testing.T) { id := uuid.Must(uuid.NewV4()).String() err := m.CreateAccessTokenSession(ctx, id, &request) require.NoError(t, err) r := testHelperExpiryFieldsResult{name: "access"} require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", x.SignatureHash(id)).First(&r)) assert.EqualValues(t, s.GetExpiresAt(fosite.AccessToken).UTC(), r.ExpiresAt.UTC()) }) t.Run("case=CreateRefreshTokenSession", func(t *testing.T) { id := uuid.Must(uuid.NewV4()).String() err := m.CreateRefreshTokenSession(ctx, id, "", &request) require.NoError(t, err) r := testHelperExpiryFieldsResult{name: "refresh"} require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", id).First(&r)) assert.EqualValues(t, s.GetExpiresAt(fosite.RefreshToken).UTC(), r.ExpiresAt.UTC()) }) t.Run("case=CreateAuthorizeCodeSession", func(t *testing.T) { id := uuid.Must(uuid.NewV4()).String() err := m.CreateAuthorizeCodeSession(ctx, id, &request) require.NoError(t, err) r := testHelperExpiryFieldsResult{name: "code"} require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", id).First(&r)) assert.EqualValues(t, s.GetExpiresAt(fosite.AuthorizeCode).UTC(), r.ExpiresAt.UTC()) }) t.Run("case=CreatePKCERequestSession", func(t *testing.T) { id := uuid.Must(uuid.NewV4()).String() err := m.CreatePKCERequestSession(ctx, id, &request) require.NoError(t, err) r := testHelperExpiryFieldsResult{name: "pkce"} require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", id).First(&r)) assert.EqualValues(t, s.GetExpiresAt(fosite.AuthorizeCode).UTC(), r.ExpiresAt.UTC()) }) t.Run("case=CreateOpenIDConnectSession", func(t *testing.T) { id := uuid.Must(uuid.NewV4()).String() err := m.CreateOpenIDConnectSession(ctx, id, &request) require.NoError(t, err) r := testHelperExpiryFieldsResult{name: "oidc"} require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", id).First(&r)) assert.EqualValues(t, s.GetExpiresAt(fosite.AuthorizeCode).UTC(), r.ExpiresAt.UTC()) }) } } func testHelperNilAccessToken(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() c := &client.Client{ID: uuid.Must(uuid.NewV4()).String()} require.NoError(t, x.ClientManager().CreateClient(context.Background(), c)) err := m.CreateAccessTokenSession(context.Background(), uuid.Must(uuid.NewV4()).String(), &fosite.Request{ ID: "", RequestedAt: time.Now().UTC().Round(time.Second), Client: c, RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, RequestedAudience: fosite.Arguments{"ad1", "ad2"}, GrantedAudience: fosite.Arguments{"ad1", "ad2"}, Form: url.Values{"foo": []string{"bar", "baz"}}, Session: oauth2.NewTestSession(t, "bar"), }) require.NoError(t, err) } } func testHelperCreateGetDeleteAccessTokenSession(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() code := uuid.Must(uuid.NewV4()).String() ctx := t.Context() _, err := m.GetAccessTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) assert.Error(t, err) err = m.CreateAccessTokenSession(ctx, code, newDefaultRequest(t, "blank")) require.NoError(t, err) res, err := m.GetAccessTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) AssertObjectKeysEqual(t, newDefaultRequest(t, "blank"), res, "RequestedScope", "GrantedScope", "Form", "Session") err = m.DeleteAccessTokenSession(ctx, code) require.NoError(t, err) _, err = m.GetAccessTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) assert.Error(t, err) } } func testHelperDeleteAccessTokens(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() ctx := t.Context() code := uuid.Must(uuid.NewV4()).String() err := m.CreateAccessTokenSession(ctx, code, newDefaultRequest(t, "blank")) require.NoError(t, err) _, err = m.GetAccessTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) err = m.DeleteAccessTokens(ctx, newDefaultRequest(t, "blank").Client.GetID()) require.NoError(t, err) req, err := m.GetAccessTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) assert.Nil(t, req) assert.EqualError(t, err, fosite.ErrNotFound.Error()) } } func testHelperRevokeAccessToken(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() ctx := t.Context() code := uuid.Must(uuid.NewV4()).String() err := m.CreateAccessTokenSession(ctx, code, newDefaultRequest(t, "blank")) require.NoError(t, err) _, err = m.GetAccessTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) err = m.RevokeAccessToken(ctx, newDefaultRequest(t, "blank").GetID()) require.NoError(t, err) req, err := m.GetAccessTokenSession(ctx, code, oauth2.NewTestSession(t, "bar")) assert.Nil(t, req) assert.EqualError(t, err, fosite.ErrNotFound.Error()) } } func testHelperRotateRefreshToken(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { ctx := t.Context() createTokens := func(t *testing.T, r *fosite.Request) (refreshTokenSession string, accessTokenSession string) { refreshTokenSession = fmt.Sprintf("refresh_token_%s", uuid.Must(uuid.NewV4()).String()) accessTokenSession = fmt.Sprintf("access_token_%s", uuid.Must(uuid.NewV4()).String()) err := x.OAuth2Storage().CreateAccessTokenSession(ctx, accessTokenSession, r) require.NoError(t, err) err = x.OAuth2Storage().CreateRefreshTokenSession(ctx, refreshTokenSession, accessTokenSession, r) require.NoError(t, err) // Sanity check req, err := x.OAuth2Storage().GetRefreshTokenSession(ctx, refreshTokenSession, nil) require.NoError(t, err) require.EqualValues(t, r.GetID(), req.GetID()) req, err = x.OAuth2Storage().GetAccessTokenSession(ctx, accessTokenSession, nil) require.NoError(t, err) require.EqualValues(t, r.GetID(), req.GetID()) return } t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) { m := x.OAuth2Storage() r := newDefaultRequest(t, uuid.Must(uuid.NewV4()).String()) refreshTokenSession, accessTokenSession := createTokens(t, r) err := m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession) require.NoError(t, err) _, err = m.GetAccessTokenSession(ctx, accessTokenSession, nil) assert.ErrorIs(t, err, fosite.ErrNotFound, "Token is no longer active because it was refreshed") _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) assert.ErrorIs(t, err, fosite.ErrInactiveToken, "Token is no longer active because it was refreshed") }) t.Run("Rotation works when access token is already pruned", func(t *testing.T) { // Test both with and without grace period testCases := []struct { name string configureGrace bool expectTokenActive bool }{ { name: "with grace period", configureGrace: true, expectTokenActive: true, }, { name: "without grace period", configureGrace: false, expectTokenActive: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { if tc.configureGrace { x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") } else { x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) } t.Cleanup(func() { x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) }) m := x.OAuth2Storage() r := newDefaultRequest(t, uuid.Must(uuid.NewV4()).String()) // Create tokens refreshTokenSession := fmt.Sprintf("refresh_token_%s", uuid.Must(uuid.NewV4()).String()) accessTokenSession := fmt.Sprintf("access_token_%s", uuid.Must(uuid.NewV4()).String()) // Create access token err := m.CreateAccessTokenSession(ctx, accessTokenSession, r) require.NoError(t, err) // Create refresh token linked to the access token err = m.CreateRefreshTokenSession(ctx, refreshTokenSession, accessTokenSession, r) require.NoError(t, err) // Verify tokens were created successfully req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) require.NoError(t, err) require.Equal(t, r.GetID(), req.GetID()) req, err = m.GetAccessTokenSession(ctx, accessTokenSession, nil) require.NoError(t, err) require.Equal(t, r.GetID(), req.GetID()) // Delete the access token (simulating it being pruned) err = m.DeleteAccessTokenSession(ctx, accessTokenSession) require.NoError(t, err) // Verify access token is gone _, err = m.GetAccessTokenSession(ctx, accessTokenSession, nil) assert.Error(t, err) // Rotation should still work even though the access token is gone err = m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession) require.NoError(t, err) // Check refresh token state based on grace period configuration req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) if tc.expectTokenActive { assert.NoError(t, err) assert.Equal(t, r.GetID(), req.GetID()) } else { assert.ErrorIs(t, err, fosite.ErrInactiveToken, "Token should be inactive when no grace period is configured") } }) } }) t.Run("refresh token is valid until the grace period has ended", func(t *testing.T) { x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") // By setting this to one hour we ensure that using the refresh token triggers the start of the grace period. x.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1h") t.Cleanup(func() { x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) }) m := x.OAuth2Storage() r := newDefaultRequest(t, uuid.Must(uuid.NewV4()).String()) refreshTokenSession, accessTokenSession1 := createTokens(t, r) accessTokenSession2 := fmt.Sprintf("access_token_%s", uuid.Must(uuid.NewV4()).String()) require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession2, r)) // Create a second access token require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) _, err := m.GetAccessTokenSession(ctx, accessTokenSession1, nil) assert.ErrorIs(t, err, fosite.ErrNotFound) _, err = m.GetAccessTokenSession(ctx, accessTokenSession2, nil) assert.NoError(t, err, "The second access token is still valid.") req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) assert.NoError(t, err) assert.Equal(t, r.GetID(), req.GetID()) // We only wait a second, meaning that the token is theoretically still within TTL, but since the // grace period was issued, the token is still valid. time.Sleep(time.Second * 2) _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) assert.Error(t, err) }) t.Run("the used at time does not change", func(t *testing.T) { x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") // By setting this to one hour we ensure that using the refresh token triggers the start of the grace period. x.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, "1h") t.Cleanup(func() { x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) }) m := x.OAuth2Storage() r := newDefaultRequest(t, uuid.Must(uuid.NewV4()).String()) refreshTokenSession, _ := createTokens(t, r) require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) var expected sql.OAuth2RefreshTable require.NoError(t, x.Persister().Connection(ctx).Where("signature=?", refreshTokenSession).First(&expected)) assert.False(t, expected.FirstUsedAt.Time.IsZero()) assert.True(t, expected.FirstUsedAt.Valid) // Refresh does not change the time time.Sleep(time.Second * 2) require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) var actual sql.OAuth2RefreshTable require.NoError(t, x.Persister().Connection(ctx).Where("signature=?", refreshTokenSession).First(&actual)) assert.Equal(t, expected.FirstUsedAt.Time, actual.FirstUsedAt.Time) }) t.Run("refresh token revokes all access tokens from the request if the access token signature is not found", func(t *testing.T) { x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "1s") t.Cleanup(func() { x.Config().Delete(ctx, config.KeyRefreshTokenRotationGracePeriod) }) m := x.OAuth2Storage() r := newDefaultRequest(t, uuid.Must(uuid.NewV4()).String()) refreshTokenSession := fmt.Sprintf("refresh_token_%s", uuid.Must(uuid.NewV4()).String()) accessTokenSession1 := fmt.Sprintf("access_token_%s", uuid.Must(uuid.NewV4()).String()) accessTokenSession2 := fmt.Sprintf("access_token_%s", uuid.Must(uuid.NewV4()).String()) require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession1, r)) require.NoError(t, m.CreateAccessTokenSession(ctx, accessTokenSession2, r)) require.NoError(t, m.CreateRefreshTokenSession(ctx, refreshTokenSession, "", r), "precondition failed: could not create refresh token session") // ACT require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) require.NoError(t, m.RotateRefreshToken(ctx, r.GetID(), refreshTokenSession)) _, err := m.GetAccessTokenSession(ctx, accessTokenSession1, nil) assert.ErrorIs(t, err, fosite.ErrNotFound) _, err = m.GetAccessTokenSession(ctx, accessTokenSession2, nil) assert.ErrorIs(t, err, fosite.ErrNotFound) req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) assert.NoError(t, err) assert.Equal(t, r.GetID(), req.GetID()) time.Sleep(time.Second * 2) req, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, nil) assert.Error(t, err) }) } } func testHelperCreateGetDeletePKCERequestSession(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() code := uuid.Must(uuid.NewV4()).String() ctx := t.Context() _, err := m.GetPKCERequestSession(ctx, code, oauth2.NewTestSession(t, "bar")) assert.NotNil(t, err) err = m.CreatePKCERequestSession(ctx, code, newDefaultRequest(t, "blank")) require.NoError(t, err) res, err := m.GetPKCERequestSession(ctx, code, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) AssertObjectKeysEqual(t, newDefaultRequest(t, "blank"), res, "RequestedScope", "GrantedScope", "Form", "Session") err = m.DeletePKCERequestSession(ctx, code) require.NoError(t, err) _, err = m.GetPKCERequestSession(ctx, code, oauth2.NewTestSession(t, "bar")) assert.NotNil(t, err) } } func testHelperFlushTokens(x *driver.RegistrySQL, lifespan time.Duration) func(t *testing.T) { m := x.OAuth2Storage() ds := &oauth2.Session{} return func(t *testing.T) { ctx := t.Context() for _, r := range flushRequests { mockRequestForeignKey(t, r.ID, x) require.NoError(t, m.CreateAccessTokenSession(ctx, r.ID, r)) _, err := m.GetAccessTokenSession(ctx, r.ID, ds) require.NoError(t, err) } require.NoError(t, m.FlushInactiveAccessTokens(ctx, time.Now().Add(-time.Hour*24), 100, 10)) _, err := m.GetAccessTokenSession(ctx, "flush-1", ds) require.NoError(t, err) _, err = m.GetAccessTokenSession(ctx, "flush-2", ds) require.NoError(t, err) _, err = m.GetAccessTokenSession(ctx, "flush-3", ds) require.NoError(t, err) require.NoError(t, m.FlushInactiveAccessTokens(ctx, time.Now().Add(-(lifespan+time.Hour/2)), 100, 10)) _, err = m.GetAccessTokenSession(ctx, "flush-1", ds) require.NoError(t, err) _, err = m.GetAccessTokenSession(ctx, "flush-2", ds) require.NoError(t, err) _, err = m.GetAccessTokenSession(ctx, "flush-3", ds) require.Error(t, err) require.NoError(t, m.FlushInactiveAccessTokens(ctx, time.Now(), 100, 10)) _, err = m.GetAccessTokenSession(ctx, "flush-1", ds) require.NoError(t, err) _, err = m.GetAccessTokenSession(ctx, "flush-2", ds) require.Error(t, err) _, err = m.GetAccessTokenSession(ctx, "flush-3", ds) require.Error(t, err) require.NoError(t, m.DeleteAccessTokens(ctx, "foobar")) } } func testHelperFlushTokensWithLimitAndBatchSize(x *driver.RegistrySQL, limit int, batchSize int) func(t *testing.T) { m := x.OAuth2Storage() ds := &oauth2.Session{} return func(t *testing.T) { ctx := t.Context() var requests []*fosite.Request // create five expired requests id := uuid.Must(uuid.NewV4()).String() totalCount := 5 for i := 0; i < totalCount; i++ { r := createTestRequest(fmt.Sprintf("%s-%d", id, i+1)) r.RequestedAt = time.Now().Add(-2 * time.Hour) mockRequestForeignKey(t, r.ID, x) require.NoError(t, m.CreateAccessTokenSession(ctx, r.ID, r)) _, err := m.GetAccessTokenSession(ctx, r.ID, ds) require.NoError(t, err) requests = append(requests, r) } require.NoError(t, m.FlushInactiveAccessTokens(ctx, time.Now(), limit, batchSize)) var notFoundCount, foundCount int for i := range requests { if _, err := m.GetAccessTokenSession(ctx, requests[i].ID, ds); err == nil { foundCount++ } else { require.ErrorIs(t, err, fosite.ErrNotFound) notFoundCount++ } } assert.Equal(t, limit, notFoundCount, "should have deleted %d tokens", limit) assert.Equal(t, totalCount-limit, foundCount, "should have found %d tokens", totalCount-limit) } } func testFositeSqlStoreTransactionCommitAccessToken(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { { doTestCommit(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().RevokeAccessToken) doTestCommit(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().DeleteAccessTokenSession) } } } func testFositeSqlStoreTransactionRollbackAccessToken(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { { doTestRollback(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().RevokeAccessToken) doTestRollback(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().DeleteAccessTokenSession) } } } func testFositeSqlStoreTransactionCommitRefreshToken(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { doTestCommitRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) doTestCommitRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) } } func testFositeSqlStoreTransactionRollbackRefreshToken(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { doTestRollbackRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken) doTestRollbackRefresh(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession) } } func testFositeSqlStoreTransactionCommitAuthorizeCode(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { doTestCommit(m, t, m.OAuth2Storage().CreateAuthorizeCodeSession, m.OAuth2Storage().GetAuthorizeCodeSession, m.OAuth2Storage().InvalidateAuthorizeCodeSession) } } func testFositeSqlStoreTransactionRollbackAuthorizeCode(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { doTestRollback(m, t, m.OAuth2Storage().CreateAuthorizeCodeSession, m.OAuth2Storage().GetAuthorizeCodeSession, m.OAuth2Storage().InvalidateAuthorizeCodeSession) } } func testFositeSqlStoreTransactionCommitPKCERequest(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { doTestCommit(m, t, m.OAuth2Storage().CreatePKCERequestSession, m.OAuth2Storage().GetPKCERequestSession, m.OAuth2Storage().DeletePKCERequestSession) } } func testFositeSqlStoreTransactionRollbackPKCERequest(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { doTestRollback(m, t, m.OAuth2Storage().CreatePKCERequestSession, m.OAuth2Storage().GetPKCERequestSession, m.OAuth2Storage().DeletePKCERequestSession) } } // OpenIdConnect tests can't use the helper functions, due to the signature of GetOpenIdConnectSession being // different from the other getter methods func testFositeSqlStoreTransactionCommitOpenIdConnectSession(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { txnStore, ok := m.OAuth2Storage().(fosite.Transactional) require.True(t, ok) ctx := t.Context() ctx, err := txnStore.BeginTX(ctx) require.NoError(t, err) signature := uuid.Must(uuid.NewV4()).String() testRequest := createTestRequest(signature) err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature, testRequest) require.NoError(t, err) err = txnStore.Commit(ctx) require.NoError(t, err) // Require a new context, since the old one contains the transaction. res, err := m.OAuth2Storage().GetOpenIDConnectSession(context.Background(), signature, testRequest) // session should have been created successfully because Commit did not return an error require.NoError(t, err) assertx.EqualAsJSONExcept(t, newDefaultRequest(t, "blank"), res, defaultIgnoreKeys) // test delete within a transaction ctx, err = txnStore.BeginTX(context.Background()) require.NoError(t, err) err = m.OAuth2Storage().DeleteOpenIDConnectSession(ctx, signature) require.NoError(t, err) err = txnStore.Commit(ctx) require.NoError(t, err) // Require a new context, since the old one contains the transaction. _, err = m.OAuth2Storage().GetOpenIDConnectSession(context.Background(), signature, testRequest) // Since commit worked for delete, we should get an error here. require.Error(t, err) } } func testFositeSqlStoreTransactionRollbackOpenIdConnectSession(m *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { txnStore, ok := m.OAuth2Storage().(fosite.Transactional) require.True(t, ok) ctx := t.Context() ctx, err := txnStore.BeginTX(ctx) require.NoError(t, err) signature := uuid.Must(uuid.NewV4()).String() testRequest := createTestRequest(signature) err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature, testRequest) require.NoError(t, err) err = txnStore.Rollback(ctx) require.NoError(t, err) // Require a new context, since the old one contains the transaction. ctx = context.Background() _, err = m.OAuth2Storage().GetOpenIDConnectSession(ctx, signature, testRequest) // Since we rolled back above, the session should not exist and getting it should result in an error require.Error(t, err) // create a new session, delete it, then rollback the delete. We should be able to then get it. signature2 := uuid.Must(uuid.NewV4()).String() testRequest2 := createTestRequest(signature2) err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature2, testRequest2) require.NoError(t, err) _, err = m.OAuth2Storage().GetOpenIDConnectSession(ctx, signature2, testRequest2) require.NoError(t, err) ctx, err = txnStore.BeginTX(context.Background()) require.NoError(t, err) err = m.OAuth2Storage().DeleteOpenIDConnectSession(ctx, signature2) require.NoError(t, err) err = txnStore.Rollback(ctx) require.NoError(t, err) _, err = m.OAuth2Storage().GetOpenIDConnectSession(context.Background(), signature2, testRequest2) require.NoError(t, err) } } func testFositeStoreSetClientAssertionJWT(m *driver.RegistrySQL) func(*testing.T) { return func(t *testing.T) { t.Run("case=basic setting works", func(t *testing.T) { store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry)) cmp, err := store.GetClientAssertionJWT(context.Background(), jti.JTI) require.NotEqual(t, cmp.NID, uuid.Nil) cmp.NID = uuid.Nil require.NoError(t, err) assert.Equal(t, jti, cmp) }) t.Run("case=errors when the JTI is blacklisted", func(t *testing.T) { store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) assert.ErrorIs(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry), fosite.ErrJTIKnown) }) t.Run("case=deletes expired JTIs", func(t *testing.T) { store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) expiredJTI := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(-time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), expiredJTI)) newJTI := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWT(context.Background(), newJTI.JTI, newJTI.Expiry)) _, err := store.GetClientAssertionJWT(context.Background(), expiredJTI.JTI) assert.True(t, errors.Is(err, sqlcon.ErrNoRows)) cmp, err := store.GetClientAssertionJWT(context.Background(), newJTI.JTI) require.NoError(t, err) require.NotEqual(t, cmp.NID, uuid.Nil) cmp.NID = uuid.Nil assert.Equal(t, newJTI, cmp) }) t.Run("case=inserts same JTI if expired", func(t *testing.T) { store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(-time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) jti.Expiry = jti.Expiry.Add(2 * time.Minute) assert.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry)) cmp, err := store.GetClientAssertionJWT(context.Background(), jti.JTI) assert.NoError(t, err) assert.Equal(t, jti, cmp) }) } } func testFositeStoreClientAssertionJWTValid(m *driver.RegistrySQL) func(*testing.T) { return func(t *testing.T) { t.Run("case=returns valid on unknown JTI", func(t *testing.T) { store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), uuid.Must(uuid.NewV4()).String())) }) t.Run("case=returns invalid on known JTI", func(t *testing.T) { store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) assert.True(t, errors.Is(store.ClientAssertionJWTValid(context.Background(), jti.JTI), fosite.ErrJTIKnown)) }) t.Run("case=returns valid on expired JTI", func(t *testing.T) { store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader) require.True(t, ok) jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(-time.Minute)) require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti)) assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), jti.JTI)) }) } } func testFositeJWTBearerGrantStorage(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { ctx := t.Context() grantManager := x.GrantManager() keyManager := x.KeyManager() grantStorage := x.OAuth2Storage().(rfc7523.RFC7523KeyStorage) t.Run("case=associated key added with grant", func(t *testing.T) { keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() issuer := uuid.Must(uuid.NewV4()).String() subject := "bob+" + uuid.Must(uuid.NewV4()).String() + "@example.com" grant := trust.Grant{ ID: uuid.Must(uuid.NewV4()), Issuer: issuer, Subject: subject, AllowAnySubject: false, Scope: []string{"openid", "offline"}, PublicKey: trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID}, CreatedAt: time.Now().UTC().Round(time.Second), ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } storedKeySet, err := grantStorage.GetPublicKeys(ctx, issuer, subject) require.NoError(t, err) require.Len(t, storedKeySet.Keys, 0) require.NoError(t, grantManager.CreateGrant(ctx, grant, publicKey)) storedKeySet, err = grantStorage.GetPublicKeys(ctx, issuer, subject) require.NoError(t, err) assert.Len(t, storedKeySet.Keys, 1) storedKey, err := grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) assert.Equal(t, publicKey.KeyID, storedKey.KeyID) assert.Equal(t, publicKey.Use, storedKey.Use) assert.Equal(t, publicKey.Key, storedKey.Key) storedScopes, err := grantStorage.GetPublicKeyScopes(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) assert.Equal(t, grant.Scope, storedScopes) storedKeySet, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) assert.Equal(t, publicKey.KeyID, storedKeySet.Keys[0].KeyID) assert.Equal(t, publicKey.Use, storedKeySet.Keys[0].Use) assert.Equal(t, publicKey.Key, storedKeySet.Keys[0].Key) }) t.Run("case=only associated key returns", func(t *testing.T) { keySetToNotReturn, err := jwk.GenerateJWK(jose.ES256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) require.NoError(t, keyManager.AddKeySet(context.Background(), uuid.Must(uuid.NewV4()).String(), keySetToNotReturn), "adding a random key should not fail") issuer := uuid.Must(uuid.NewV4()).String() subject := "maria+" + uuid.Must(uuid.NewV4()).String() + "@example.com" keySet1ToReturn, err := jwk.GenerateJWK(jose.ES256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) require.NoError(t, grantManager.CreateGrant(t.Context(), trust.Grant{ ID: uuid.Must(uuid.NewV4()), Issuer: issuer, Subject: subject, AllowAnySubject: false, Scope: []string{"openid"}, PublicKey: trust.PublicKey{Set: issuer, KeyID: keySet1ToReturn.Keys[0].Public().KeyID}, CreatedAt: time.Now().UTC().Round(time.Second), ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), }, keySet1ToReturn.Keys[0].Public())) keySet2ToReturn, err := jwk.GenerateJWK(jose.ES256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) require.NoError(t, grantManager.CreateGrant(ctx, trust.Grant{ ID: uuid.Must(uuid.NewV4()), Issuer: issuer, Subject: subject, AllowAnySubject: false, Scope: []string{"openid"}, PublicKey: trust.PublicKey{Set: issuer, KeyID: keySet2ToReturn.Keys[0].Public().KeyID}, CreatedAt: time.Now().UTC().Round(time.Second), ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), }, keySet2ToReturn.Keys[0].Public())) storedKeySet, err := grantStorage.GetPublicKeys(context.Background(), issuer, subject) require.NoError(t, err) require.Len(t, storedKeySet.Keys, 2) // Cannot rely on sort order because the created_at timestamps may alias. idx1 := slices.IndexFunc(storedKeySet.Keys, func(k jose.JSONWebKey) bool { return k.KeyID == keySet1ToReturn.Keys[0].Public().KeyID }) require.GreaterOrEqual(t, idx1, 0) idx2 := slices.IndexFunc(storedKeySet.Keys, func(k jose.JSONWebKey) bool { return k.KeyID == keySet2ToReturn.Keys[0].Public().KeyID }) require.GreaterOrEqual(t, idx2, 0) assert.Equal(t, keySet1ToReturn.Keys[0].Public().KeyID, storedKeySet.Keys[idx1].KeyID) assert.Equal(t, keySet1ToReturn.Keys[0].Public().Use, storedKeySet.Keys[idx1].Use) assert.Equal(t, keySet1ToReturn.Keys[0].Public().Key, storedKeySet.Keys[idx1].Key) assert.Equal(t, keySet2ToReturn.Keys[0].Public().KeyID, storedKeySet.Keys[idx2].KeyID) assert.Equal(t, keySet2ToReturn.Keys[0].Public().Use, storedKeySet.Keys[idx2].Use) assert.Equal(t, keySet2ToReturn.Keys[0].Public().Key, storedKeySet.Keys[idx2].Key) storedKeySet, err = grantStorage.GetPublicKeys(context.Background(), issuer, "non-existing-subject") require.NoError(t, err) assert.Len(t, storedKeySet.Keys, 0) _, err = grantStorage.GetPublicKeyScopes(context.Background(), issuer, "non-existing-subject", keySet2ToReturn.Keys[0].Public().KeyID) require.Error(t, err) }) t.Run("case=associated key is deleted, when granted is deleted", func(t *testing.T) { keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() issuer := uuid.Must(uuid.NewV4()).String() subject := "aeneas+" + uuid.Must(uuid.NewV4()).String() + "@example.com" grant := trust.Grant{ ID: uuid.Must(uuid.NewV4()), Issuer: issuer, Subject: subject, AllowAnySubject: false, Scope: []string{"openid", "offline"}, PublicKey: trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID}, CreatedAt: time.Now().UTC().Round(time.Second), ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } require.NoError(t, grantManager.CreateGrant(ctx, grant, publicKey)) _, err = grantStorage.GetPublicKey(ctx, issuer, subject, grant.PublicKey.KeyID) require.NoError(t, err) _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) err = grantManager.DeleteGrant(ctx, grant.ID) require.NoError(t, err) _, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) assert.Error(t, err) _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) assert.Error(t, err) }) t.Run("case=associated grant is deleted, when key is deleted", func(t *testing.T) { keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() issuer := uuid.Must(uuid.NewV4()).String() subject := "vladimir+" + uuid.Must(uuid.NewV4()).String() + "@example.com" grant := trust.Grant{ ID: uuid.Must(uuid.NewV4()), Issuer: issuer, Subject: subject, AllowAnySubject: false, Scope: []string{"openid", "offline"}, PublicKey: trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID}, CreatedAt: time.Now().UTC().Round(time.Second), ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } require.NoError(t, grantManager.CreateGrant(ctx, grant, publicKey)) _, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) err = keyManager.DeleteKey(ctx, issuer, publicKey.KeyID) require.NoError(t, err) _, err = keyManager.GetKey(ctx, issuer, publicKey.KeyID) assert.Error(t, err) _, err = grantManager.GetConcreteGrant(ctx, grant.ID) assert.Error(t, err) }) t.Run("case=only returns the key when subject matches", func(t *testing.T) { keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() issuer := uuid.Must(uuid.NewV4()).String() subject := "jagoba+" + uuid.Must(uuid.NewV4()).String() + "@example.com" grant := trust.Grant{ ID: uuid.Must(uuid.NewV4()), Issuer: issuer, Subject: subject, AllowAnySubject: false, Scope: []string{"openid", "offline"}, PublicKey: trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID}, CreatedAt: time.Now().UTC().Round(time.Second), ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } require.NoError(t, grantManager.CreateGrant(ctx, grant, publicKey)) // All three get methods should only return the public key when using the valid subject _, err = grantStorage.GetPublicKey(ctx, issuer, "any-subject-1", publicKey.KeyID) require.Error(t, err) _, err = grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) _, err = grantStorage.GetPublicKeyScopes(ctx, issuer, "any-subject-2", publicKey.KeyID) require.Error(t, err) _, err = grantStorage.GetPublicKeyScopes(ctx, issuer, subject, publicKey.KeyID) require.NoError(t, err) jwks, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3") require.NoError(t, err) require.NotNil(t, jwks) require.Empty(t, jwks.Keys) jwks, err = grantStorage.GetPublicKeys(ctx, issuer, subject) require.NoError(t, err) require.NotNil(t, jwks) require.NotEmpty(t, jwks.Keys) }) t.Run("case=returns the key when any subject is allowed", func(t *testing.T) { keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() issuer := uuid.Must(uuid.NewV4()).String() grant := trust.Grant{ ID: uuid.Must(uuid.NewV4()), Issuer: issuer, Subject: "", AllowAnySubject: true, Scope: []string{"openid", "offline"}, PublicKey: trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID}, CreatedAt: time.Now().UTC().Round(time.Second), ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0), } require.NoError(t, grantManager.CreateGrant(ctx, grant, publicKey)) // All three get methods should always return the public key _, err = grantStorage.GetPublicKey(ctx, issuer, "any-subject-1", publicKey.KeyID) require.NoError(t, err) _, err = grantStorage.GetPublicKeyScopes(ctx, issuer, "any-subject-2", publicKey.KeyID) require.NoError(t, err) jwks, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3") require.NoError(t, err) require.NotNil(t, jwks) require.NotEmpty(t, jwks.Keys) }) t.Run("case=does not return expired values", func(t *testing.T) { keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) publicKey := keySet.Keys[0].Public() issuer := uuid.Must(uuid.NewV4()).String() grant := trust.Grant{ ID: uuid.Must(uuid.NewV4()), Issuer: issuer, Subject: "", AllowAnySubject: true, Scope: []string{"openid", "offline"}, PublicKey: trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID}, CreatedAt: time.Now().UTC().Round(time.Second), ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(-1, 0, 0), } require.NoError(t, grantManager.CreateGrant(ctx, grant, publicKey)) keys, err := grantStorage.GetPublicKeys(ctx, issuer, "any-subject-3") require.NoError(t, err) assert.Len(t, keys.Keys, 0) }) } } func doTestCommit(m *driver.RegistrySQL, t *testing.T, createFn func(context.Context, string, fosite.Requester) error, getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), revokeFn func(context.Context, string) error, ) { txnStore, ok := m.OAuth2Storage().(fosite.Transactional) require.True(t, ok) ctx := t.Context() ctx, err := txnStore.BeginTX(ctx) require.NoError(t, err) signature := uuid.Must(uuid.NewV4()).String() err = createFn(ctx, signature, createTestRequest(signature)) require.NoError(t, err) err = txnStore.Commit(ctx) require.NoError(t, err) // Require a new context, since the old one contains the transaction. res, err := getFn(context.Background(), signature, oauth2.NewTestSession(t, "bar")) // token should have been created successfully because Commit did not return an error require.NoError(t, err) assertx.EqualAsJSONExcept(t, newDefaultRequest(t, "blank"), res, defaultIgnoreKeys) // AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") // testrevoke within a transaction ctx, err = txnStore.BeginTX(context.Background()) require.NoError(t, err) err = revokeFn(ctx, signature) require.NoError(t, err) err = txnStore.Commit(ctx) require.NoError(t, err) // Require a new context, since the old one contains the transaction. _, err = getFn(context.Background(), signature, oauth2.NewTestSession(t, "bar")) // Since commit worked for revoke, we should get an error here. require.Error(t, err) } func doTestCommitRefresh(m *driver.RegistrySQL, t *testing.T, createFn func(context.Context, string, string, fosite.Requester) error, getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), revokeFn func(context.Context, string) error, ) { txnStore, ok := m.OAuth2Storage().(fosite.Transactional) require.True(t, ok) ctx := t.Context() ctx, err := txnStore.BeginTX(ctx) require.NoError(t, err) signature := uuid.Must(uuid.NewV4()).String() err = createFn(ctx, signature, "", createTestRequest(signature)) require.NoError(t, err) err = txnStore.Commit(ctx) require.NoError(t, err) // Require a new context, since the old one contains the transaction. res, err := getFn(context.Background(), signature, oauth2.NewTestSession(t, "bar")) // token should have been created successfully because Commit did not return an error require.NoError(t, err) assertx.EqualAsJSONExcept(t, newDefaultRequest(t, "blank"), res, defaultIgnoreKeys) // AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") // testrevoke within a transaction ctx, err = txnStore.BeginTX(context.Background()) require.NoError(t, err) err = revokeFn(ctx, signature) require.NoError(t, err) err = txnStore.Commit(ctx) require.NoError(t, err) // Require a new context, since the old one contains the transaction. _, err = getFn(context.Background(), signature, oauth2.NewTestSession(t, "bar")) // Since commit worked for revoke, we should get an error here. require.Error(t, err) } func doTestRollback(m *driver.RegistrySQL, t *testing.T, createFn func(context.Context, string, fosite.Requester) error, getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), revokeFn func(context.Context, string) error, ) { txnStore, ok := m.OAuth2Storage().(fosite.Transactional) require.True(t, ok) ctx := t.Context() ctx, err := txnStore.BeginTX(ctx) require.NoError(t, err) signature := uuid.Must(uuid.NewV4()).String() err = createFn(ctx, signature, createTestRequest(signature)) require.NoError(t, err) err = txnStore.Rollback(ctx) require.NoError(t, err) // Require a new context, since the old one contains the transaction. ctx = context.Background() _, err = getFn(ctx, signature, oauth2.NewTestSession(t, "bar")) // Since we rolled back above, the token should not exist and getting it should result in an error require.Error(t, err) // create a new token, revoke it, then rollback the revoke. We should be able to then get it successfully. signature2 := uuid.Must(uuid.NewV4()).String() err = createFn(ctx, signature2, createTestRequest(signature2)) require.NoError(t, err) _, err = getFn(ctx, signature2, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) ctx, err = txnStore.BeginTX(context.Background()) require.NoError(t, err) err = revokeFn(ctx, signature2) require.NoError(t, err) err = txnStore.Rollback(ctx) require.NoError(t, err) _, err = getFn(context.Background(), signature2, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) } func doTestRollbackRefresh(m *driver.RegistrySQL, t *testing.T, createFn func(context.Context, string, string, fosite.Requester) error, getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), revokeFn func(context.Context, string) error, ) { txnStore, ok := m.OAuth2Storage().(fosite.Transactional) require.True(t, ok) ctx := t.Context() ctx, err := txnStore.BeginTX(ctx) require.NoError(t, err) signature := uuid.Must(uuid.NewV4()).String() err = createFn(ctx, signature, "", createTestRequest(signature)) require.NoError(t, err) err = txnStore.Rollback(ctx) require.NoError(t, err) // Require a new context, since the old one contains the transaction. ctx = context.Background() _, err = getFn(ctx, signature, oauth2.NewTestSession(t, "bar")) // Since we rolled back above, the token should not exist and getting it should result in an error require.Error(t, err) // create a new token, revoke it, then rollback the revoke. We should be able to then get it successfully. signature2 := uuid.Must(uuid.NewV4()).String() err = createFn(ctx, signature2, "", createTestRequest(signature2)) require.NoError(t, err) _, err = getFn(ctx, signature2, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) ctx, err = txnStore.BeginTX(context.Background()) require.NoError(t, err) err = revokeFn(ctx, signature2) require.NoError(t, err) err = txnStore.Rollback(ctx) require.NoError(t, err) _, err = getFn(context.Background(), signature2, oauth2.NewTestSession(t, "bar")) require.NoError(t, err) } func createTestRequest(id string) *fosite.Request { return &fosite.Request{ ID: id, RequestedAt: time.Now().UTC().Round(time.Second), Client: &client.Client{ID: "foobar"}, RequestedScope: fosite.Arguments{"fa", "ba"}, GrantedScope: fosite.Arguments{"fa", "ba"}, RequestedAudience: fosite.Arguments{"ad1", "ad2"}, GrantedAudience: fosite.Arguments{"ad1", "ad2"}, Form: url.Values{"foo": []string{"bar", "baz"}}, Session: &oauth2.Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, } } func testHelperRefreshTokenExpiryUpdate(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { ctx := t.Context() // Create client cl := &client.Client{ID: "refresh-expiry-client"} require.NoError(t, x.ClientManager().CreateClient(ctx, cl)) // Create a request with a long expiry initialRequest := fosite.Request{ ID: uuid.Must(uuid.NewV4()).String(), RequestedAt: time.Now().UTC().Round(time.Second), Client: cl, Session: oauth2.NewTestSession(t, "sub"), } // Set a long expiry time (24 hours) initialExpiry := time.Now().Add(24 * time.Hour) initialRequest.Session.SetExpiresAt(fosite.RefreshToken, initialExpiry) t.Run("regular rotation", func(t *testing.T) { // Create original refresh token regularSignature := uuid.Must(uuid.NewV4()).String() require.NoError(t, x.OAuth2Storage().CreateRefreshTokenSession(ctx, regularSignature, "", &initialRequest)) // Verify initial expiry is set correctly originalToken, err := x.OAuth2Storage().GetRefreshTokenSession(ctx, regularSignature, oauth2.NewTestSession(t, "sub")) require.NoError(t, err) require.Equal(t, initialExpiry.Unix(), originalToken.GetSession().GetExpiresAt(fosite.RefreshToken).Unix()) // Set up a connection to directly query the database var actualExpiresAt time.Time require.NoError(t, x.Persister().Connection(ctx).RawQuery("SELECT expires_at FROM hydra_oauth2_refresh WHERE signature=?", regularSignature).First(&actualExpiresAt)) require.Equal(t, initialExpiry.UTC().Round(time.Second), actualExpiresAt.UTC().Round(time.Second)) // Rotate the token err = x.OAuth2Storage().RotateRefreshToken(ctx, initialRequest.ID, regularSignature) require.NoError(t, err) // Check that the original token's expiry was updated to be closer to now var revokedData struct { ExpiresAt time.Time `db:"expires_at"` Active bool `db:"active"` } require.NoError(t, x.Persister().Connection(ctx).RawQuery("SELECT expires_at, active FROM hydra_oauth2_refresh WHERE signature=?", regularSignature).First(&revokedData)) // Verify the token is now inactive require.False(t, revokedData.Active) // Verify the expiry is updated to be closer to now than the original expiry require.True(t, revokedData.ExpiresAt.Before(initialExpiry), "Expiry should be updated to be sooner than original") require.True(t, revokedData.ExpiresAt.After(time.Now()), "Expiry should still be in the future") require.True(t, time.Until(revokedData.ExpiresAt) < time.Until(initialExpiry), "New expiry should be closer to now than original expiry") t.Logf("Original expiry: %v, Updated expiry: %v, Now: %v", initialExpiry, revokedData.ExpiresAt, time.Now()) }) t.Run("graceful rotation", func(t *testing.T) { // Create refresh token for graceful rotation gracefulSignature := uuid.Must(uuid.NewV4()).String() require.NoError(t, x.OAuth2Storage().CreateRefreshTokenSession(ctx, gracefulSignature, "", &initialRequest)) // Set config to graceful rotation oldPeriod := x.Config().GracefulRefreshTokenRotation(ctx).Period t.Cleanup(func() { x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, oldPeriod) x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGraceReuseCount, 0) }) x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, time.Minute*30) x.Config().MustSet(ctx, config.KeyRefreshTokenRotationGraceReuseCount, 3) // Record time before rotation beforeRotation := time.Now().UTC().Add(-time.Second) // Ensure we have a different timestamp for first_used_at // Rotate the token err := x.OAuth2Storage().RotateRefreshToken(ctx, initialRequest.ID, gracefulSignature) require.NoError(t, err) // Check the token's expiry and status var rotatedData struct { ExpiresAt time.Time `db:"expires_at"` Active bool `db:"active"` FirstUsedAt sqlxx.NullTime `db:"first_used_at"` UsedTimes sqlxx.NullInt64 `db:"used_times"` } require.NoError(t, x.Persister().Connection(ctx).RawQuery("SELECT expires_at, active, first_used_at, used_times FROM hydra_oauth2_refresh WHERE signature=?", gracefulSignature).First(&rotatedData)) // Token is used require.False(t, rotatedData.Active) // Verify first_used_at is set and reasonable assert.True(t, time.Time(rotatedData.FirstUsedAt).After(beforeRotation) || time.Time(rotatedData.FirstUsedAt).Equal(beforeRotation), "%s should be after or equal to %s", time.Time(rotatedData.FirstUsedAt), beforeRotation) now := time.Now().UTC().Add(time.Second) assert.True(t, time.Time(rotatedData.FirstUsedAt).Before(now) || time.Time(rotatedData.FirstUsedAt).Equal(now), "%s should be before or equal to %s", time.Time(rotatedData.FirstUsedAt), now) // Verify used_times was incremented assert.True(t, rotatedData.UsedTimes.Valid) assert.Equal(t, int64(1), rotatedData.UsedTimes.Int) // Verify the expiry is updated and is in the future assert.True(t, rotatedData.ExpiresAt.Before(initialExpiry), "Expiry should be updated to be sooner than original") assert.True(t, rotatedData.ExpiresAt.After(time.Now().UTC()), "Expiry should still be in the future") assert.True(t, time.Until(rotatedData.ExpiresAt) < time.Until(initialExpiry), "New expiry should be closer to now than original expiry") t.Logf("Original expiry: %v, Updated expiry: %v, Now: %v", initialExpiry, rotatedData.ExpiresAt, time.Now()) }) } } func testHelperAuthorizeCodeInvalidation(x *driver.RegistrySQL) func(t *testing.T) { return func(t *testing.T) { ctx := t.Context() // Create client cl := &client.Client{ID: "auth-code-client"} require.NoError(t, x.ClientManager().CreateClient(ctx, cl)) // Create a request with a long expiry initialRequest := fosite.Request{ ID: uuid.Must(uuid.NewV4()).String(), RequestedAt: time.Now().UTC().Round(time.Second), Client: cl, Session: oauth2.NewTestSession(t, "sub"), } // Set a long expiry time (1 hour) initialExpiry := time.Now().Add(1 * time.Hour) initialRequest.Session.SetExpiresAt(fosite.AuthorizeCode, initialExpiry) // Create authorize code session authCodeSignature := uuid.Must(uuid.NewV4()).String() require.NoError(t, x.OAuth2Storage().CreateAuthorizeCodeSession(ctx, authCodeSignature, &initialRequest)) // Verify initial state originalCode, err := x.OAuth2Storage().GetAuthorizeCodeSession(ctx, authCodeSignature, oauth2.NewTestSession(t, "sub")) require.NoError(t, err) require.Equal(t, initialExpiry.Unix(), originalCode.GetSession().GetExpiresAt(fosite.AuthorizeCode).Unix()) // Check database directly var codeData struct { ExpiresAt time.Time `db:"expires_at"` Active bool `db:"active"` } require.NoError(t, x.Persister().Connection(ctx).RawQuery( "SELECT expires_at, active FROM hydra_oauth2_code WHERE signature=?", authCodeSignature).First(&codeData)) require.Equal(t, initialExpiry.UTC().Round(time.Second), codeData.ExpiresAt.UTC().Round(time.Second)) require.True(t, codeData.Active) // Invalidate the code err = x.OAuth2Storage().InvalidateAuthorizeCodeSession(ctx, authCodeSignature) require.NoError(t, err) // Check that the code was invalidated but is still retrievable invalidatedCode, err := x.OAuth2Storage().GetAuthorizeCodeSession(ctx, authCodeSignature, oauth2.NewTestSession(t, "sub")) require.Error(t, err) require.ErrorIs(t, err, fosite.ErrInvalidatedAuthorizeCode) require.NotNil(t, invalidatedCode) // Should still be retrievable // Verify database state after invalidation var invalidatedData struct { ExpiresAt time.Time `db:"expires_at"` Active bool `db:"active"` } require.NoError(t, x.Persister().Connection(ctx).RawQuery( "SELECT expires_at, active FROM hydra_oauth2_code WHERE signature=?", authCodeSignature).First(&invalidatedData)) // Verify the code is now inactive require.False(t, invalidatedData.Active) // Verify the expiry is updated to be closer to now than the original expiry require.True(t, invalidatedData.ExpiresAt.Before(initialExpiry), "Expiry should be updated to be sooner than original") require.True(t, invalidatedData.ExpiresAt.After(time.Now()), "Expiry should still be in the future") require.True(t, time.Until(invalidatedData.ExpiresAt) < time.Until(initialExpiry), "New expiry should be closer to now than original expiry") t.Logf("Original expiry: %v, Updated expiry: %v, Now: %v", initialExpiry, invalidatedData.ExpiresAt, time.Now()) } }