chore: fosite and hydra interface enhancements

GitOrigin-RevId: f149949f3fdd7b1264ce78c011d49dee61af52a2
This commit is contained in:
shaunn 2025-11-14 09:23:51 -08:00 committed by ory-bot
parent 10ec9bf3e5
commit db17987979
48 changed files with 1242 additions and 1051 deletions

View File

@ -25,7 +25,7 @@ type Manager interface {
}
type Storage interface {
GetClient(ctx context.Context, id string) (fosite.Client, error)
fosite.ClientManager
CreateClient(ctx context.Context, c *Client) error

View File

@ -26,6 +26,8 @@ type Registry interface {
OpenIDJWTStrategy() jwk.JWTSigner
OAuth2HMACStrategy() foauth2.CoreStrategy
OAuth2EnigmaStrategy() *enigma.HMACStrategy
RFC8628HMACStrategy() rfc8628.RFC8628CodeStrategy
rfc8628.DeviceRateLimitStrategyProvider
rfc8628.DeviceCodeStrategyProvider
rfc8628.UserCodeStrategyProvider
config.Provider
}

View File

@ -1032,7 +1032,7 @@ func (h *Handler) acceptUserCodeRequest(w http.ResponseWriter, r *http.Request)
return
}
userCodeSignature, err := h.r.RFC8628HMACStrategy().UserCodeSignature(r.Context(), reqBody.UserCode)
userCodeSignature, err := h.r.UserCodeStrategy().UserCodeSignature(r.Context(), reqBody.UserCode)
if err != nil {
h.r.Writer().WriteError(w, r, fosite.ErrServerError.WithWrap(err).WithHint(`The 'user_code' signature could not be computed.`))
return
@ -1044,7 +1044,7 @@ func (h *Handler) acceptUserCodeRequest(w http.ResponseWriter, r *http.Request)
return
}
if err := h.r.RFC8628HMACStrategy().ValidateUserCode(ctx, userCodeRequest, reqBody.UserCode); err != nil {
if err := h.r.UserCodeStrategy().ValidateUserCode(ctx, userCodeRequest, reqBody.UserCode); err != nil {
h.r.Writer().WriteError(w, r, fosite.ErrInvalidRequest.WithWrap(err).WithHint(`The 'user_code' session could not be found or has expired or is otherwise malformed.`))
return
}

View File

@ -310,9 +310,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) {
deviceRequest.Client = cl
deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject"))
_, deviceCodeSig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context())
_, deviceCodeSig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context())
require.NoError(t, err)
userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context())
userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context())
require.NoError(t, err)
require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodeSig, sig, deviceRequest))
@ -337,7 +337,7 @@ func TestAcceptCodeDeviceRequest(t *testing.T) {
})
t.Run("case=random user_code, not persisted in the database", func(t *testing.T) {
userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context())
userCode, _, err := reg.UserCodeStrategy().GenerateUserCode(t.Context())
require.NoError(t, err)
resp := submitCode(t, &flow.AcceptDeviceUserCodeRequest{UserCode: userCode}, challenge)
@ -360,7 +360,7 @@ func TestAcceptCodeDeviceRequest(t *testing.T) {
})
t.Run("case=empty challenge", func(t *testing.T) {
userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context())
userCode, _, err := reg.UserCodeStrategy().GenerateUserCode(t.Context())
require.NoError(t, err)
resp := submitCode(t, &flow.AcceptDeviceUserCodeRequest{UserCode: userCode}, "")
require.EqualValues(t, http.StatusBadRequest, resp.StatusCode)
@ -372,7 +372,7 @@ func TestAcceptCodeDeviceRequest(t *testing.T) {
})
t.Run("case=invalid challenge", func(t *testing.T) {
userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context())
userCode, _, err := reg.UserCodeStrategy().GenerateUserCode(t.Context())
require.NoError(t, err)
resp := submitCode(t, &hydra.AcceptDeviceUserCodeRequest{UserCode: &userCode}, "invalid-challenge")
require.EqualValues(t, http.StatusNotFound, resp.StatusCode)
@ -388,9 +388,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) {
deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject"))
deviceRequest.Session.SetExpiresAt(fosite.UserCode, time.Now().Add(-time.Hour).UTC())
_, deviceCodeSig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context())
_, deviceCodeSig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context())
require.NoError(t, err)
userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context())
userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context())
require.NoError(t, err)
require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodeSig, sig, deviceRequest))
@ -409,9 +409,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) {
deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject"))
deviceRequest.UserCodeState = fosite.UserCodeAccepted
_, deviceCodeSig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context())
_, deviceCodeSig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context())
require.NoError(t, err)
userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context())
userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context())
require.NoError(t, err)
require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodeSig, sig, deviceRequest))
@ -430,9 +430,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) {
deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject"))
deviceRequest.UserCodeState = fosite.UserCodeRejected
_, deviceCodesig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context())
_, deviceCodesig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context())
require.NoError(t, err)
userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context())
userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context())
require.NoError(t, err)
require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodesig, sig, deviceRequest))
@ -450,9 +450,9 @@ func TestAcceptCodeDeviceRequest(t *testing.T) {
deviceRequest.Client = cl
deviceRequest.SetSession(oauth2.NewTestSession(t, "test-subject"))
_, deviceCodeSig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(t.Context())
_, deviceCodeSig, err := reg.DeviceCodeStrategy().GenerateDeviceCode(t.Context())
require.NoError(t, err)
userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(t.Context())
userCode, sig, err := reg.UserCodeStrategy().GenerateUserCode(t.Context())
require.NoError(t, err)
require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(t.Context(), deviceCodeSig, sig, deviceRequest))

View File

@ -4,6 +4,7 @@
package driver
import (
"github.com/ory/hydra/v2/fosite"
"github.com/ory/hydra/v2/fosite/handler/oauth2"
"github.com/ory/hydra/v2/jwk"
)
@ -29,3 +30,24 @@ func RegistryWithKeyManager(km func(r *RegistrySQL) (jwk.Manager, error)) Regist
return err
}
}
func RegistryWithOAuth2Provider(pr func(r *RegistrySQL) fosite.OAuth2Provider) RegistryModifier {
return func(r *RegistrySQL) error {
r.fop = pr(r)
return nil
}
}
func RegistryWithAccessTokenStorage(s func(r *RegistrySQL) oauth2.AccessTokenStorage) RegistryModifier {
return func(r *RegistrySQL) error {
r.accessTokenStorage = s(r)
return nil
}
}
func RegistryWithAuthorizeCodeStorage(s func(r *RegistrySQL) oauth2.AuthorizeCodeStorage) RegistryModifier {
return func(r *RegistrySQL) error {
r.authorizeCodeStorage = s(r)
return nil
}
}

View File

@ -12,14 +12,13 @@ import (
"strings"
"time"
"github.com/urfave/negroni"
"github.com/gorilla/sessions"
"github.com/hashicorp/go-retryablehttp"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/cors"
"github.com/urfave/negroni"
"go.uber.org/automaxprocs/maxprocs"
"github.com/ory/herodot"
@ -31,7 +30,10 @@ import (
"github.com/ory/hydra/v2/fosite/compose"
foauth2 "github.com/ory/hydra/v2/fosite/handler/oauth2"
"github.com/ory/hydra/v2/fosite/handler/openid"
"github.com/ory/hydra/v2/fosite/handler/pkce"
"github.com/ory/hydra/v2/fosite/handler/rfc7523"
"github.com/ory/hydra/v2/fosite/handler/rfc8628"
"github.com/ory/hydra/v2/fosite/handler/verifiable"
"github.com/ory/hydra/v2/fosite/token/hmac"
"github.com/ory/hydra/v2/fositex"
"github.com/ory/hydra/v2/hsm"
@ -58,35 +60,37 @@ import (
)
type RegistrySQL struct {
l, al *logrusx.Logger
conf *config.DefaultProvider
fh fosite.Hasher
cv *client.Validator
ctxer contextx.Contextualizer
hh *healthx.Handler
kc *aead.AESGCM
flowc *aead.XChaCha20Poly1305
cos consent.Strategy
writer herodot.Writer
hsm hsm.Context
forv *openid.OpenIDConnectRequestValidator
fop fosite.OAuth2Provider
trc *otelx.Tracer
tracerWrapper func(*otelx.Tracer) *otelx.Tracer
arhs []oauth2.AccessRequestHook
basePersister *sql.BasePersister
oc fosite.Configurator
oidcs jwk.JWTSigner
ats jwk.JWTSigner
hmacs foauth2.CoreStrategy
enigmaHMAC *hmac.HMACStrategy
deviceHmac rfc8628.RFC8628CodeStrategy
fc *fositex.Config
publicCORS *cors.Cors
kratos kratos.Client
fositeFactories []fositex.Factory
migrator *sql.MigrationManager
dbOptsModifier []func(details *pop.ConnectionDetails)
l, al *logrusx.Logger
conf *config.DefaultProvider
fh fosite.Hasher
cv *client.Validator
ctxer contextx.Contextualizer
hh *healthx.Handler
kc *aead.AESGCM
flowc *aead.XChaCha20Poly1305
cos consent.Strategy
writer herodot.Writer
hsm hsm.Context
forv *openid.OpenIDConnectRequestValidator
fop fosite.OAuth2Provider
trc *otelx.Tracer
tracerWrapper func(*otelx.Tracer) *otelx.Tracer
arhs []oauth2.AccessRequestHook
basePersister *sql.BasePersister
accessTokenStorage foauth2.AccessTokenStorage
authorizeCodeStorage foauth2.AuthorizeCodeStorage
oc fosite.Configurator
oidcs jwk.JWTSigner
ats jwk.JWTSigner
hmacs foauth2.CoreStrategy
enigmaHMAC *hmac.HMACStrategy
deviceHmac *rfc8628.DefaultDeviceStrategy
fc *fositex.Config
publicCORS *cors.Cors
kratos kratos.Client
fositeFactories []fositex.Factory
migrator *sql.MigrationManager
dbOptsModifier []func(details *pop.ConnectionDetails)
keyManager jwk.Manager
initialPing func(ctx context.Context, l *logrusx.Logger, p *sql.BasePersister) error
@ -98,6 +102,66 @@ var (
_ registry = (*RegistrySQL)(nil)
)
func (m *RegistrySQL) FositeClientManager() fosite.ClientManager {
return m.OAuth2Storage()
}
// AuthorizeCodeStorage implements foauth2.AuthorizeCodeStorageProvider
func (m *RegistrySQL) AuthorizeCodeStorage() foauth2.AuthorizeCodeStorage {
if m.authorizeCodeStorage != nil {
return m.authorizeCodeStorage
}
return m.OAuth2Storage()
}
// AccessTokenStorage implements foauth2.AccessTokenStorageProvider
func (m *RegistrySQL) AccessTokenStorage() foauth2.AccessTokenStorage {
if m.accessTokenStorage != nil {
return m.accessTokenStorage
}
return m.OAuth2Storage()
}
// RefreshTokenStorage implements foauth2.RefreshTokenStorageProvider
func (m *RegistrySQL) RefreshTokenStorage() foauth2.RefreshTokenStorage {
return m.OAuth2Storage()
}
// TokenRevocationStorage implements foauth2.TokenRevocationStorageProvider
func (m *RegistrySQL) TokenRevocationStorage() foauth2.TokenRevocationStorage {
return m.OAuth2Storage()
}
// ResourceOwnerPasswordCredentialsGrantStorage implements foauth2.ResourceOwnerPasswordCredentialsGrantStorage
func (m *RegistrySQL) ResourceOwnerPasswordCredentialsGrantStorage() foauth2.ResourceOwnerPasswordCredentialsGrantStorage {
return m.OAuth2Storage()
}
// OpenIDConnectRequestStorage implements openid.OIDCRequestStorageProvider
func (m *RegistrySQL) OpenIDConnectRequestStorage() openid.OpenIDConnectRequestStorage {
return m.OAuth2Storage()
}
// PKCERequestStorage implements pkce.PKCERequestStorageProvider
func (m *RegistrySQL) PKCERequestStorage() pkce.PKCERequestStorage {
return m.OAuth2Storage()
}
// DeviceAuthStorage implements rfc8628.DeviceAuthStorageProvider
func (m *RegistrySQL) DeviceAuthStorage() rfc8628.DeviceAuthStorage {
return m.OAuth2Storage()
}
// RFC7523KeyStorage implements rfc7523.RFC7523KeyStorageProvider
func (m *RegistrySQL) RFC7523KeyStorage() rfc7523.RFC7523KeyStorage {
return m.OAuth2Storage()
}
// NonceManager implements verifiable.NonceManager
func (m *RegistrySQL) NonceManager() verifiable.NonceManager {
return m.OAuth2Storage()
}
// defaultInitialPing is the default function that will be called within RegistrySQL.Init to make sure
// the database is reachable. It can be injected for test purposes by changing the value
// of RegistrySQL.initialPing.
@ -209,7 +273,9 @@ func (m *RegistrySQL) ConsentManager() consent.Manager {
func (m *RegistrySQL) ObfuscatedSubjectManager() consent.ObfuscatedSubjectManager {
return m.Persister()
}
func (m *RegistrySQL) LoginManager() consent.LoginManager { return m.Persister() }
func (m *RegistrySQL) LoginManager() consent.LoginManager { return m.Persister() }
func (m *RegistrySQL) LogoutManager() consent.LogoutManager { return m.Persister() }
func (m *RegistrySQL) OAuth2Storage() x.FositeStorer {
@ -412,7 +478,7 @@ func (m *RegistrySQL) HTTPClient(ctx context.Context, opts ...httpx.ResilientOpt
func (m *RegistrySQL) OAuth2Provider() fosite.OAuth2Provider {
if m.fop == nil {
m.fop = fosite.NewOAuth2Provider(m.OAuth2Storage(), m.OAuth2ProviderConfig())
m.fop = fosite.NewOAuth2Provider(m, m.OAuth2ProviderConfig())
}
return m.fop
}
@ -445,14 +511,29 @@ func (m *RegistrySQL) OAuth2HMACStrategy() foauth2.CoreStrategy {
return m.hmacs
}
// RFC8628HMACStrategy returns the rfc8628 strategy
func (m *RegistrySQL) RFC8628HMACStrategy() rfc8628.RFC8628CodeStrategy {
// rfc8628HMACStrategy returns the rfc8628 strategy
func (m *RegistrySQL) rfc8628HMACStrategy() *rfc8628.DefaultDeviceStrategy {
if m.deviceHmac == nil {
m.deviceHmac = compose.NewDeviceStrategy(m.OAuth2Config())
}
return m.deviceHmac
}
// DeviceRateLimitStrategy implements rfc8628.DeviceRateLimitStrategyProvider
func (m *RegistrySQL) DeviceRateLimitStrategy() rfc8628.DeviceRateLimitStrategy {
return m.rfc8628HMACStrategy()
}
// DeviceCodeStrategy implements rfc8628.DeviceCodeStrategyProvider
func (m *RegistrySQL) DeviceCodeStrategy() rfc8628.DeviceCodeStrategy {
return m.rfc8628HMACStrategy()
}
// UserCodeStrategy implements rfc8628.UserCodeStrategyProvider
func (m *RegistrySQL) UserCodeStrategy() rfc8628.UserCodeStrategy {
return m.rfc8628HMACStrategy()
}
func (m *RegistrySQL) OAuth2Config() *fositex.Config {
if m.fc == nil {
m.fc = fositex.NewConfig(m)
@ -471,7 +552,7 @@ func (m *RegistrySQL) OAuth2ProviderConfig() fosite.Configurator {
conf := m.OAuth2Config()
hmacAtStrategy := m.OAuth2HMACStrategy()
deviceHmacAtStrategy := m.RFC8628HMACStrategy()
deviceHmacAtStrategy := m.rfc8628HMACStrategy()
oidcSigner := m.OpenIDJWTStrategy()
atSigner := m.AccessTokenJWTStrategy()
jwtAtStrategy := &foauth2.DefaultJWTStrategy{
@ -480,13 +561,13 @@ func (m *RegistrySQL) OAuth2ProviderConfig() fosite.Configurator {
Config: conf,
}
conf.LoadDefaultHandlers(&compose.CommonStrategy{
conf.LoadDefaultHandlers(m, &compose.CommonStrategy{
CoreStrategy: fositex.NewTokenStrategy(m.Config(), hmacAtStrategy, &foauth2.DefaultJWTStrategy{
Signer: jwtAtStrategy,
Strategy: hmacAtStrategy,
Config: conf,
}),
RFC8628CodeStrategy: deviceHmacAtStrategy,
DeviceStrategy: deviceHmacAtStrategy,
OIDCTokenStrategy: &openid.DefaultStrategy{
Config: conf,
Signer: oidcSigner,

View File

@ -77,7 +77,7 @@ func TestNewAccessRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
},
handlers: TokenEndpointHandlers{handler},
@ -103,7 +103,7 @@ func TestNewAccessRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
},
handlers: TokenEndpointHandlers{handler},
@ -118,7 +118,7 @@ func TestNewAccessRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
@ -136,7 +136,7 @@ func TestNewAccessRequest(t *testing.T) {
},
expectErr: ErrServerError,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
@ -154,7 +154,7 @@ func TestNewAccessRequest(t *testing.T) {
"grant_type": {"foo"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
@ -178,7 +178,7 @@ func TestNewAccessRequest(t *testing.T) {
"grant_type": {"foo"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = true
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
@ -269,7 +269,7 @@ func TestNewAccessRequestWithoutClientAuth(t *testing.T) {
},
mock: func() {
// despite error from storage, we should success, because client auth is not required
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "foo").Return(nil, errors.New("no client")).Times(1)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
},
@ -308,7 +308,7 @@ func TestNewAccessRequestWithoutClientAuth(t *testing.T) {
"grant_type": {"foo"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "foo").Return(anotherClient, nil).Times(1)
hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
@ -383,7 +383,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
"grant_type": {"foo"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
@ -402,7 +402,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
"grant_type": {"foo"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")

View File

@ -357,7 +357,7 @@ func (f *Fosite) newAuthorizeRequest(ctx context.Context, r *http.Request, isPAR
}
}
client, err := f.Store.ClientManager().GetClient(ctx, request.GetRequestForm().Get("client_id"))
client, err := f.Store.FositeClientManager().GetClient(ctx, request.GetRequestForm().Get("client_id"))
if err != nil {
return request, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client does not exist.").WithWrap(err).WithDebug(err.Error()))
}

View File

@ -47,7 +47,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
r: &http.Request{},
expectedError: ErrInvalidClient,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo"))
},
},
@ -58,7 +58,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
query: url.Values{"redirect_uri": []string{"invalid"}},
expectedError: ErrInvalidClient,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo"))
},
},
@ -69,7 +69,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
query: url.Values{"redirect_uri": []string{"https://foo.bar/cb"}},
expectedError: ErrInvalidClient,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo"))
},
},
@ -82,7 +82,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidRequest,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}}, nil)
},
},
@ -96,7 +96,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidRequest,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}}, nil)
},
},
@ -110,7 +110,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidRequest,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}}, nil)
},
},
@ -125,7 +125,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidState,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}}, nil)
},
},
@ -141,7 +141,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidState,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}}, nil)
},
},
@ -157,7 +157,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"scope": {"foo bar baz"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}}, nil)
},
expectedError: ErrInvalidScope,
@ -175,7 +175,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"},
Audience: []string{"https://cloud.ory.sh/api"},
@ -196,7 +196,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"https://foo.bar/cb"},
@ -232,7 +232,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api", "https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"https://foo.bar/cb"},
@ -268,7 +268,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"audience": {"test value", ""},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"https://foo.bar/cb"},
@ -304,7 +304,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"web+application://callback"},
@ -340,7 +340,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"https://foo.bar/cb"},
@ -376,7 +376,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"response_mode": {"unknown"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}}, nil)
},
expectedError: ErrUnsupportedResponseMode,
@ -394,7 +394,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"response_mode": {"form_post"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}}, nil)
},
expectedError: ErrUnsupportedResponseMode,
@ -412,7 +412,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"response_mode": {"form_post"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{
DefaultClient: &DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"},
@ -438,7 +438,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{
DefaultClient: &DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"},
@ -481,7 +481,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{
DefaultClient: &DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"},
@ -524,7 +524,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{
DefaultClient: &DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"},

View File

@ -91,7 +91,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
}
}
client, err = f.Store.ClientManager().GetClient(ctx, clientID)
client, err = f.Store.FositeClientManager().GetClient(ctx, clientID)
if err != nil {
return nil, errorsx.WithStack(ErrInvalidClient.WithWrap(err).WithDebug(err.Error()))
}
@ -156,7 +156,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client."))
} else if jti, ok = claims["jti"].(string); !ok || len(jti) == 0 {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'jti' from 'client_assertion' must be set but is not."))
} else if f.Store.ClientManager().ClientAssertionJWTValid(ctx, jti) != nil {
} else if f.Store.FositeClientManager().ClientAssertionJWTValid(ctx, jti) != nil {
return nil, errorsx.WithStack(ErrJTIKnown.WithHint("Claim 'jti' from 'client_assertion' MUST only be used once."))
}
@ -177,7 +177,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
if err != nil {
return nil, errorsx.WithStack(err)
}
if err := f.Store.ClientManager().SetClientAssertionJWT(ctx, jti, time.Unix(expiry, 0)); err != nil {
if err := f.Store.FositeClientManager().SetClientAssertionJWT(ctx, jti, time.Unix(expiry, 0)); err != nil {
return nil, err
}
@ -197,7 +197,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, err
}
client, err := f.Store.ClientManager().GetClient(ctx, clientID)
client, err := f.Store.FositeClientManager().GetClient(ctx, clientID)
if err != nil {
return nil, errorsx.WithStack(ErrInvalidClient.WithWrap(err).WithDebug(err.Error()))
}

View File

@ -70,10 +70,10 @@ func ComposeAllEnabled(config *fosite.Config, storage fosite.Storage, key interf
config,
storage,
&CommonStrategy{
CoreStrategy: NewOAuth2HMACStrategy(config),
RFC8628CodeStrategy: NewDeviceStrategy(config),
OIDCTokenStrategy: NewOpenIDConnectStrategy(keyGetter, config),
Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter},
CoreStrategy: NewOAuth2HMACStrategy(config),
DeviceStrategy: NewDeviceStrategy(config),
OIDCTokenStrategy: NewOpenIDConnectStrategy(keyGetter, config),
Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter},
},
OAuth2AuthorizeExplicitFactory,
OAuth2AuthorizeImplicitFactory,

View File

@ -77,8 +77,12 @@ func OAuth2ResourceOwnerPasswordCredentialsFactory(config fosite.Configurator, s
oauth2.AccessTokenStrategyProvider
oauth2.RefreshTokenStrategyProvider
}),
Storage: storage.(oauth2.ResourceOwnerPasswordCredentialsGrantStorage),
Config: config,
Storage: storage.(interface {
oauth2.ResourceOwnerPasswordCredentialsGrantStorageProvider
oauth2.AccessTokenStorageProvider
oauth2.RefreshTokenStorageProvider
}),
Config: config,
}
}

View File

@ -15,9 +15,9 @@ import (
)
type CommonStrategy struct {
CoreStrategy oauth2.CoreStrategy
RFC8628CodeStrategy rfc8628.RFC8628CodeStrategy
OIDCTokenStrategy openid.OpenIDConnectTokenStrategy
CoreStrategy oauth2.CoreStrategy
DeviceStrategy *rfc8628.DefaultDeviceStrategy
OIDCTokenStrategy openid.OpenIDConnectTokenStrategy
jwt.Signer
}
@ -41,15 +41,15 @@ func (s *CommonStrategy) OpenIDConnectTokenStrategy() openid.OpenIDConnectTokenS
// RFC8628 Device Strategy Providers
func (s *CommonStrategy) DeviceRateLimitStrategy() rfc8628.DeviceRateLimitStrategy {
return s.RFC8628CodeStrategy
return s.DeviceStrategy
}
func (s *CommonStrategy) DeviceCodeStrategy() rfc8628.DeviceCodeStrategy {
return s.RFC8628CodeStrategy
return s.DeviceStrategy
}
func (s *CommonStrategy) UserCodeStrategy() rfc8628.UserCodeStrategy {
return s.RFC8628CodeStrategy
return s.DeviceStrategy
}
type HMACSHAStrategyConfigurator interface {

View File

@ -12,7 +12,7 @@ import (
// handler.
func OIDCUserinfoVerifiableCredentialFactory(config fosite.Configurator, storage fosite.Storage, strategy any) any {
return &verifiable.Handler{
NonceManager: storage.(verifiable.NonceManager),
Config: config,
NonceManagerProvider: storage.(verifiable.NonceManagerProvider),
Config: config,
}
}

View File

@ -64,7 +64,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
expectedError: ErrInvalidClient,
method: "POST",
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(nil, errors.New(""))
},
}, {
@ -75,7 +75,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
},
method: "POST",
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil)
},
expectedError: ErrInvalidScope,
@ -88,7 +88,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
},
method: "POST",
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil)
},
expectedError: ErrInvalidRequest,
@ -100,7 +100,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
},
method: "POST",
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id_2")).Return(authCodeClient, nil)
},
expectedError: ErrInvalidGrant,
@ -112,7 +112,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
},
method: "POST",
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil)
},
}} {
@ -166,7 +166,7 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
expectedError: ErrInvalidClient,
method: "POST",
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New(""))
},
@ -183,7 +183,7 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
expectedError: ErrInvalidRequest,
method: "POST",
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil)
},
@ -199,7 +199,7 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
},
method: "POST",
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil)
},

View File

@ -1,55 +1,56 @@
#!/bin/bash
mockgen -package internal -destination internal/access_request.go github.com/ory/fosite AccessRequester
mockgen -package internal -destination internal/access_response.go github.com/ory/fosite AccessResponder
mockgen -package internal -destination internal/access_token_storage.go github.com/ory/fosite/handler/oauth2 AccessTokenStorage
mockgen -package internal -destination internal/access_token_storage_provider.go github.com/ory/fosite/handler/oauth2 AccessTokenStorageProvider
mockgen -package internal -destination internal/access_token_strategy.go github.com/ory/fosite/handler/oauth2 AccessTokenStrategy
mockgen -package internal -destination internal/access_token_strategy_provider.go github.com/ory/fosite/handler/oauth2 AccessTokenStrategyProvider
mockgen -package internal -destination internal/authorize_code_storage.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStorage
mockgen -package internal -destination internal/authorize_code_storage_provider.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStorageProvider
mockgen -package internal -destination internal/authorize_code_strategy.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStrategy
mockgen -package internal -destination internal/authorize_code_strategy_provider.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStrategyProvider
mockgen -package internal -destination internal/authorize_endpoint_handler.go github.com/ory/fosite AuthorizeEndpointHandler
mockgen -package internal -destination internal/authorize_endpoint_handlers_provider.go github.com/ory/fosite AuthorizeEndpointHandlersProvider
mockgen -package internal -destination internal/authorize_request.go github.com/ory/fosite AuthorizeRequester
mockgen -package internal -destination internal/authorize_response.go github.com/ory/fosite AuthorizeResponder
mockgen -package internal -destination internal/client.go github.com/ory/fosite Client
mockgen -package internal -destination internal/client_manager.go github.com/ory/fosite ClientManager
mockgen -package internal -destination internal/oauth2_storage.go github.com/ory/fosite/handler/oauth2 CoreStorage
mockgen -package internal -destination internal/oauth2_strategy.go github.com/ory/fosite/handler/oauth2 CoreStrategy
mockgen -package internal -destination internal/device_auth_storage.go github.com/ory/fosite/handler/rfc8628 DeviceAuthStorage
mockgen -package internal -destination internal/device_auth_storage_provider.go github.com/ory/fosite/handler/rfc8628 DeviceAuthStorageProvider
mockgen -package internal -destination internal/device_code_strategy.go github.com/ory/fosite/handler/rfc8628 DeviceCodeStrategy
mockgen -package internal -destination internal/device_code_strategy_provider.go github.com/ory/fosite/handler/rfc8628 DeviceCodeStrategyProvider
mockgen -package internal -destination internal/device_rate_limit_strategy.go github.com/ory/fosite/handler/rfc8628 DeviceRateLimitStrategy
mockgen -package internal -destination internal/device_rate_limit_strategy_provider.go github.com/ory/fosite/handler/rfc8628 DeviceRateLimitStrategyProvider
mockgen -package internal -destination internal/hash.go github.com/ory/fosite Hasher
mockgen -package internal -destination internal/open_id_connect_token_strategy.go github.com/ory/fosite/handler/openid OpenIDConnectTokenStrategy
mockgen -package internal -destination internal/open_id_connect_token_strategy_provider.go github.com/ory/fosite/handler/openid OpenIDConnectTokenStrategyProvider
mockgen -package internal -destination internal/open_id_connect_request_storage.go github.com/ory/fosite/handler/openid OpenIDConnectRequestStorage
mockgen -package internal -destination internal/open_id_connect_request_storage_provider.go github.com/ory/fosite/handler/openid OpenIDConnectRequestStorageProvider
mockgen -package internal -destination internal/par_storage.go github.com/ory/fosite PARStorage
mockgen -package internal -destination internal/par_storage_provider.go github.com/ory/fosite PARStorageProvider
mockgen -package internal -destination internal/pkce_request_storage.go github.com/ory/fosite/handler/pkce PKCERequestStorage
mockgen -package internal -destination internal/pkce_request_storage_provider.go github.com/ory/fosite/handler/pkce PKCERequestStorageProvider
mockgen -package internal -destination internal/refresh_token_storage.go github.com/ory/fosite/handler/oauth2 RefreshTokenStorage
mockgen -package internal -destination internal/refresh_token_storage_provider.go github.com/ory/fosite/handler/oauth2 RefreshTokenStorageProvider
mockgen -package internal -destination internal/refresh_token_strategy.go github.com/ory/fosite/handler/oauth2 RefreshTokenStrategy
mockgen -package internal -destination internal/refresh_token_strategy_provider.go github.com/ory/fosite/handler/oauth2 RefreshTokenStrategyProvider
mockgen -package internal -destination internal/request.go github.com/ory/fosite Requester
mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage.go github.com/ory/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorage
mockgen -package internal -destination internal/revocation_handler.go github.com/ory/fosite RevocationHandler
mockgen -package internal -destination internal/revocation_handlers_provider.go github.com/ory/fosite RevocationHandlersProvider
mockgen -package internal -destination internal/rfc7523_key_storage.go github.com/ory/fosite/handler/rfc7523 RFC7523KeyStorage
mockgen -package internal -destination internal/rfc7523_key_storage_provider.go github.com/ory/fosite/handler/rfc7523 RFC7523KeyStorageProvider
mockgen -package internal -destination internal/storage.go github.com/ory/fosite Storage
mockgen -package internal -destination internal/token_endpoint_handler.go github.com/ory/fosite TokenEndpointHandler
mockgen -package internal -destination internal/token_introspector.go github.com/ory/fosite TokenIntrospector
mockgen -package internal -destination internal/token_revocation_storage.go github.com/ory/fosite/handler/oauth2 TokenRevocationStorage
mockgen -package internal -destination internal/token_revocation_storage_provider.go github.com/ory/fosite/handler/oauth2 TokenRevocationStorageProvider
mockgen -package internal -destination internal/transactional.go github.com/ory/fosite Transactional
mockgen -package internal -destination internal/user_code_strategy.go github.com/ory/fosite/handler/rfc8628 UserCodeStrategy
mockgen -package internal -destination internal/user_code_strategy_provider.go github.com/ory/fosite/handler/rfc8628 UserCodeStrategyProvider
mockgen -package internal -destination internal/access_request.go github.com/ory/hydra/v2/fosite AccessRequester
mockgen -package internal -destination internal/access_response.go github.com/ory/hydra/v2/fosite AccessResponder
mockgen -package internal -destination internal/access_token_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 AccessTokenStorage
mockgen -package internal -destination internal/access_token_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 AccessTokenStorageProvider
mockgen -package internal -destination internal/access_token_strategy.go github.com/ory/hydra/v2/fosite/handler/oauth2 AccessTokenStrategy
mockgen -package internal -destination internal/access_token_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 AccessTokenStrategyProvider
mockgen -package internal -destination internal/authorize_code_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 AuthorizeCodeStorage
mockgen -package internal -destination internal/authorize_code_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 AuthorizeCodeStorageProvider
mockgen -package internal -destination internal/authorize_code_strategy.go github.com/ory/hydra/v2/fosite/handler/oauth2 AuthorizeCodeStrategy
mockgen -package internal -destination internal/authorize_code_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 AuthorizeCodeStrategyProvider
mockgen -package internal -destination internal/authorize_endpoint_handler.go github.com/ory/hydra/v2/fosite AuthorizeEndpointHandler
mockgen -package internal -destination internal/authorize_endpoint_handlers_provider.go github.com/ory/hydra/v2/fosite AuthorizeEndpointHandlersProvider
mockgen -package internal -destination internal/authorize_request.go github.com/ory/hydra/v2/fosite AuthorizeRequester
mockgen -package internal -destination internal/authorize_response.go github.com/ory/hydra/v2/fosite AuthorizeResponder
mockgen -package internal -destination internal/client.go github.com/ory/hydra/v2/fosite Client
mockgen -package internal -destination internal/client_manager.go github.com/ory/hydra/v2/fosite ClientManager
mockgen -package internal -destination internal/oauth2_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 CoreStorage
mockgen -package internal -destination internal/oauth2_strategy.go github.com/ory/hydra/v2/fosite/handler/oauth2 CoreStrategy
mockgen -package internal -destination internal/device_auth_storage.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceAuthStorage
mockgen -package internal -destination internal/device_auth_storage_provider.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceAuthStorageProvider
mockgen -package internal -destination internal/device_code_strategy.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceCodeStrategy
mockgen -package internal -destination internal/device_code_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceCodeStrategyProvider
mockgen -package internal -destination internal/device_rate_limit_strategy.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceRateLimitStrategy
mockgen -package internal -destination internal/device_rate_limit_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/rfc8628 DeviceRateLimitStrategyProvider
mockgen -package internal -destination internal/hash.go github.com/ory/hydra/v2/fosite Hasher
mockgen -package internal -destination internal/open_id_connect_token_strategy.go github.com/ory/hydra/v2/fosite/handler/openid OpenIDConnectTokenStrategy
mockgen -package internal -destination internal/open_id_connect_token_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/openid OpenIDConnectTokenStrategyProvider
mockgen -package internal -destination internal/open_id_connect_request_storage.go github.com/ory/hydra/v2/fosite/handler/openid OpenIDConnectRequestStorage
mockgen -package internal -destination internal/open_id_connect_request_storage_provider.go github.com/ory/hydra/v2/fosite/handler/openid OpenIDConnectRequestStorageProvider
mockgen -package internal -destination internal/par_storage.go github.com/ory/hydra/v2/fosite PARStorage
mockgen -package internal -destination internal/par_storage_provider.go github.com/ory/hydra/v2/fosite PARStorageProvider
mockgen -package internal -destination internal/pkce_request_storage.go github.com/ory/hydra/v2/fosite/handler/pkce PKCERequestStorage
mockgen -package internal -destination internal/pkce_request_storage_provider.go github.com/ory/hydra/v2/fosite/handler/pkce PKCERequestStorageProvider
mockgen -package internal -destination internal/refresh_token_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStorage
mockgen -package internal -destination internal/refresh_token_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStorageProvider
mockgen -package internal -destination internal/refresh_token_strategy.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStrategy
mockgen -package internal -destination internal/refresh_token_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStrategyProvider
mockgen -package internal -destination internal/request.go github.com/ory/hydra/v2/fosite Requester
mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorage
mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorageProvider
mockgen -package internal -destination internal/revocation_handler.go github.com/ory/hydra/v2/fosite RevocationHandler
mockgen -package internal -destination internal/revocation_handlers_provider.go github.com/ory/hydra/v2/fosite RevocationHandlersProvider
mockgen -package internal -destination internal/rfc7523_key_storage.go github.com/ory/hydra/v2/fosite/handler/rfc7523 RFC7523KeyStorage
mockgen -package internal -destination internal/rfc7523_key_storage_provider.go github.com/ory/hydra/v2/fosite/handler/rfc7523 RFC7523KeyStorageProvider
mockgen -package internal -destination internal/storage.go github.com/ory/hydra/v2/fosite Storage
mockgen -package internal -destination internal/token_endpoint_handler.go github.com/ory/hydra/v2/fosite TokenEndpointHandler
mockgen -package internal -destination internal/token_introspector.go github.com/ory/hydra/v2/fosite TokenIntrospector
mockgen -package internal -destination internal/token_revocation_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 TokenRevocationStorage
mockgen -package internal -destination internal/token_revocation_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 TokenRevocationStorageProvider
mockgen -package internal -destination internal/transactional.go github.com/ory/hydra/v2/fosite Transactional
mockgen -package internal -destination internal/user_code_strategy.go github.com/ory/hydra/v2/fosite/handler/rfc8628 UserCodeStrategy
mockgen -package internal -destination internal/user_code_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/rfc8628 UserCodeStrategyProvider
goimports -w internal/

View File

@ -42,6 +42,7 @@ package fosite
//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/refresh_token_strategy_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 RefreshTokenStrategyProvider
//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/request.go github.com/ory/hydra/v2/fosite Requester
//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorage
//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorageProvider
//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/revocation_handler.go github.com/ory/hydra/v2/fosite RevocationHandler
//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/revocation_handlers_provider.go github.com/ory/hydra/v2/fosite RevocationHandlersProvider
//go:generate go run go.uber.org/mock/mockgen -package internal -destination internal/rfc7523_key_storage.go github.com/ory/hydra/v2/fosite/handler/rfc7523 RFC7523KeyStorage

View File

@ -20,7 +20,11 @@ var _ fosite.TokenEndpointHandler = (*ResourceOwnerPasswordCredentialsGrantHandl
// is at the time of this writing going to be omitted in the OAuth 2.1 spec. For more information on why this grant type
// is discouraged see: https://www.scottbrady91.com/oauth/why-the-resource-owner-password-credentials-grant-type-is-not-authentication-nor-suitable-for-modern-applications
type ResourceOwnerPasswordCredentialsGrantHandler struct {
Storage ResourceOwnerPasswordCredentialsGrantStorage
Storage interface {
ResourceOwnerPasswordCredentialsGrantStorageProvider
AccessTokenStorageProvider
RefreshTokenStorageProvider
}
Strategy interface {
AccessTokenStrategyProvider
RefreshTokenStrategyProvider
@ -64,7 +68,7 @@ func (c *ResourceOwnerPasswordCredentialsGrantHandler) HandleTokenEndpointReques
password := request.GetRequestForm().Get("password")
if username == "" || password == "" {
return errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Username or password are missing from the POST body."))
} else if sub, err := c.Storage.Authenticate(ctx, username, password); errors.Is(err, fosite.ErrNotFound) {
} else if sub, err := c.Storage.ResourceOwnerPasswordCredentialsGrantStorage().Authenticate(ctx, username, password); errors.Is(err, fosite.ErrNotFound) {
return errorsx.WithStack(fosite.ErrInvalidGrant.WithHint("Unable to authenticate the provided username and password credentials.").WithWrap(err).WithDebug(err.Error()))
} else if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))

View File

@ -7,8 +7,12 @@ import (
"context"
)
// ResourceOwnerPasswordCredentialsGrantStorage provides storage for the resource owner password credentials grant.
type ResourceOwnerPasswordCredentialsGrantStorage interface {
Authenticate(ctx context.Context, name string, secret string) (subject string, err error)
AccessTokenStorageProvider
RefreshTokenStorageProvider
}
// ResourceOwnerPasswordCredentialsGrantStorageProvider provides the resource owner password credentials grant storage.
type ResourceOwnerPasswordCredentialsGrantStorageProvider interface {
ResourceOwnerPasswordCredentialsGrantStorage() ResourceOwnerPasswordCredentialsGrantStorage
}

View File

@ -22,7 +22,10 @@ import (
func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) {
ctrl := gomock.NewController(t)
store := internal.NewMockResourceOwnerPasswordCredentialsGrantStorage(ctrl)
mockRopcgStorageProvider := internal.NewMockResourceOwnerPasswordCredentialsGrantStorageProvider(ctrl)
mockRopcgStorage := internal.NewMockResourceOwnerPasswordCredentialsGrantStorage(ctrl)
mockAccessTokenStorageProvider := internal.NewMockAccessTokenStorageProvider(ctrl)
mockRefreshTokenStorageProvider := internal.NewMockRefreshTokenStorageProvider(ctrl)
t.Cleanup(ctrl.Finish)
areq := fosite.NewAccessRequest(new(fosite.DefaultSession))
@ -72,21 +75,24 @@ func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) {
areq.Form.Set("password", "pan")
areq.Client = &fosite.DefaultClient{GrantTypes: fosite.Arguments{"password"}, Scopes: []string{"foo-scope"}, Audience: []string{"https://www.ory.sh/api"}}
store.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", fosite.ErrNotFound)
mockRopcgStorageProvider.EXPECT().ResourceOwnerPasswordCredentialsGrantStorage().Return(mockRopcgStorage).Times(1)
mockRopcgStorage.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", fosite.ErrNotFound)
},
expectErr: fosite.ErrInvalidGrant,
},
{
description: "should fail because error on lookup",
setup: func(config *fosite.Config) {
store.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", errors.New(""))
mockRopcgStorageProvider.EXPECT().ResourceOwnerPasswordCredentialsGrantStorage().Return(mockRopcgStorage).Times(1)
mockRopcgStorage.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", errors.New(""))
},
expectErr: fosite.ErrServerError,
},
{
description: "should pass",
setup: func(config *fosite.Config) {
store.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", nil)
mockRopcgStorageProvider.EXPECT().ResourceOwnerPasswordCredentialsGrantStorage().Return(mockRopcgStorage).Times(1)
mockRopcgStorage.EXPECT().Authenticate(gomock.Any(), "peter", "pan").Return("", nil)
},
check: func(areq *fosite.AccessRequest) {
// assert.NotEmpty(t, areq.GetSession().GetExpiresAt(fosite.AccessToken))
@ -102,8 +108,17 @@ func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) {
ScopeStrategy: fosite.HierarchicScopeStrategy,
AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy,
}
mockStorage := struct {
*internal.MockResourceOwnerPasswordCredentialsGrantStorageProvider
*internal.MockAccessTokenStorageProvider
*internal.MockRefreshTokenStorageProvider
}{
MockResourceOwnerPasswordCredentialsGrantStorageProvider: mockRopcgStorageProvider,
MockAccessTokenStorageProvider: mockAccessTokenStorageProvider,
MockRefreshTokenStorageProvider: mockRefreshTokenStorageProvider,
}
h := oauth2.ResourceOwnerPasswordCredentialsGrantHandler{
Storage: store,
Storage: mockStorage,
Config: config,
}
c.setup(config)
@ -123,8 +138,10 @@ func TestResourceOwnerFlow_HandleTokenEndpointRequest(t *testing.T) {
func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) {
var (
mockRopcgStorage *internal.MockResourceOwnerPasswordCredentialsGrantStorage
mockRopcgStorageProvider *internal.MockResourceOwnerPasswordCredentialsGrantStorageProvider
mockAccessTokenStorageProvider *internal.MockAccessTokenStorageProvider
mockAccessTokenStorage *internal.MockAccessTokenStorage
mockRefreshTokenStorageProvider *internal.MockRefreshTokenStorageProvider
mockRefreshTokenStorage *internal.MockRefreshTokenStorage
mockAccessTokenStrategyProvider *internal.MockAccessTokenStrategyProvider
mockAccessTokenStrategy *internal.MockAccessTokenStrategy
@ -161,7 +178,7 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) {
areq.GrantTypes = fosite.Arguments{"password"}
mockAccessTokenStrategyProvider.EXPECT().AccessTokenStrategy().Return(mockAccessTokenStrategy).Times(1)
mockAccessTokenStrategy.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return(mockAT, "bar", nil)
mockRopcgStorage.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1)
mockAccessTokenStorageProvider.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1)
mockAccessTokenStorage.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil)
},
expect: func() {
@ -175,11 +192,11 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) {
areq.GrantScope("offline")
mockRefreshTokenStrategyProvider.EXPECT().RefreshTokenStrategy().Return(mockRefreshTokenStrategy).Times(1)
mockRefreshTokenStrategy.EXPECT().GenerateRefreshToken(gomock.Any(), areq).Return(mockRT, "bar", nil)
mockRopcgStorage.EXPECT().RefreshTokenStorage().Return(mockRefreshTokenStorage).Times(1)
mockRefreshTokenStorageProvider.EXPECT().RefreshTokenStorage().Return(mockRefreshTokenStorage).Times(1)
mockRefreshTokenStorage.EXPECT().CreateRefreshTokenSession(gomock.Any(), "bar", "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil)
mockAccessTokenStrategyProvider.EXPECT().AccessTokenStrategy().Return(mockAccessTokenStrategy).Times(1)
mockAccessTokenStrategy.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return(mockAT, "bar", nil)
mockRopcgStorage.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1)
mockAccessTokenStorageProvider.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1)
mockAccessTokenStorage.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil)
},
expect: func() {
@ -193,11 +210,11 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) {
areq.GrantTypes = fosite.Arguments{"password"}
mockAccessTokenStrategyProvider.EXPECT().AccessTokenStrategy().Return(mockAccessTokenStrategy).Times(1)
mockAccessTokenStrategy.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return(mockAT, "bar", nil)
mockRopcgStorage.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1)
mockAccessTokenStorageProvider.EXPECT().AccessTokenStorage().Return(mockAccessTokenStorage).Times(1)
mockAccessTokenStorage.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil)
mockRefreshTokenStrategyProvider.EXPECT().RefreshTokenStrategy().Return(mockRefreshTokenStrategy).Times(1)
mockRefreshTokenStrategy.EXPECT().GenerateRefreshToken(gomock.Any(), areq).Return(mockRT, "bar", nil)
mockRopcgStorage.EXPECT().RefreshTokenStorage().Return(mockRefreshTokenStorage).Times(1)
mockRefreshTokenStorageProvider.EXPECT().RefreshTokenStorage().Return(mockRefreshTokenStorage).Times(1)
mockRefreshTokenStorage.EXPECT().CreateRefreshTokenSession(gomock.Any(), "bar", "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil)
},
expect: func() {
@ -218,14 +235,26 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) {
AccessTokenLifespan: time.Hour,
}
mockRopcgStorage = internal.NewMockResourceOwnerPasswordCredentialsGrantStorage(ctrl)
mockRopcgStorageProvider = internal.NewMockResourceOwnerPasswordCredentialsGrantStorageProvider(ctrl)
mockAccessTokenStorageProvider = internal.NewMockAccessTokenStorageProvider(ctrl)
mockAccessTokenStorage = internal.NewMockAccessTokenStorage(ctrl)
mockRefreshTokenStorageProvider = internal.NewMockRefreshTokenStorageProvider(ctrl)
mockRefreshTokenStorage = internal.NewMockRefreshTokenStorage(ctrl)
mockAccessTokenStrategyProvider = internal.NewMockAccessTokenStrategyProvider(ctrl)
mockAccessTokenStrategy = internal.NewMockAccessTokenStrategy(ctrl)
mockRefreshTokenStrategyProvider = internal.NewMockRefreshTokenStrategyProvider(ctrl)
mockRefreshTokenStrategy = internal.NewMockRefreshTokenStrategy(ctrl)
mockStorage := struct {
*internal.MockResourceOwnerPasswordCredentialsGrantStorageProvider
*internal.MockAccessTokenStorageProvider
*internal.MockRefreshTokenStorageProvider
}{
MockResourceOwnerPasswordCredentialsGrantStorageProvider: mockRopcgStorageProvider,
MockAccessTokenStorageProvider: mockAccessTokenStorageProvider,
MockRefreshTokenStorageProvider: mockRefreshTokenStorageProvider,
}
mockStrategy := struct {
*internal.MockAccessTokenStrategyProvider
*internal.MockRefreshTokenStrategyProvider
@ -235,7 +264,7 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) {
}
h = oauth2.ResourceOwnerPasswordCredentialsGrantHandler{
Storage: mockRopcgStorage,
Storage: mockStorage,
Strategy: mockStrategy,
Config: config,
}

View File

@ -9,12 +9,6 @@ import (
"github.com/ory/hydra/v2/fosite"
)
type CoreStorage interface {
AuthorizeCodeStorageProvider
AccessTokenStorageProvider
RefreshTokenStorageProvider
}
// AuthorizeCodeStorage handles storage requests related to authorization codes.
type AuthorizeCodeStorage interface {
// CreateAuthorizeCodeSession stores the authorization request for a given authorization code.

View File

@ -9,13 +9,6 @@ import (
"github.com/ory/hydra/v2/fosite"
)
// RFC8628CodeStrategy is the code strategy needed for the DeviceAuthHandler
type RFC8628CodeStrategy interface {
DeviceRateLimitStrategy
DeviceCodeStrategy
UserCodeStrategy
}
// DeviceRateLimitStrategy handles the rate limiting strategy
type DeviceRateLimitStrategy interface {
// ShouldRateLimit checks whether the token request should be rate-limited

View File

@ -21,7 +21,7 @@ type Handler struct {
Config interface {
fosite.VerifiableCredentialsNonceLifespanProvider
}
NonceManager
NonceManagerProvider
}
var _ fosite.TokenEndpointHandler = (*Handler)(nil)
@ -45,7 +45,7 @@ func (c *Handler) PopulateTokenEndpointResponse(
lifespan := c.Config.GetVerifiableCredentialsNonceLifespan(ctx)
expiry := time.Now().UTC().Add(lifespan)
nonce, err := c.NewNonce(ctx, response.GetAccessToken(), expiry)
nonce, err := c.NonceManager().NewNonce(ctx, response.GetAccessToken(), expiry)
if err != nil {
return err
}

View File

@ -15,6 +15,10 @@ import (
"github.com/ory/hydra/v2/fosite/internal"
)
type mockNonceManagerProvider struct{ n NonceManager }
func (m mockNonceManagerProvider) NonceManager() NonceManager { return m.n }
type mockNonceManager struct{ t *testing.T }
func (m *mockNonceManager) NewNonce(ctx context.Context, accessToken string, expiresAt time.Time) (string, error) {
@ -67,7 +71,7 @@ func TestHandler(t *testing.T) {
func newHandler(t *testing.T) *Handler {
return &Handler{
Config: new(fosite.Config),
NonceManager: &mockNonceManager{t: t},
Config: new(fosite.Config),
NonceManagerProvider: mockNonceManagerProvider{n: &mockNonceManager{t: t}},
}
}

View File

@ -15,3 +15,7 @@ type NonceManager interface {
// IsNonceValid checks if the given nonce is valid for the given access token and not expired.
IsNonceValid(ctx context.Context, accessToken string, nonce string) error
}
type NonceManagerProvider interface {
NonceManager() NonceManager
}

View File

@ -17,8 +17,6 @@ import (
reflect "reflect"
gomock "go.uber.org/mock/gomock"
oauth2 "github.com/ory/hydra/v2/fosite/handler/oauth2"
)
// MockResourceOwnerPasswordCredentialsGrantStorage is a mock of ResourceOwnerPasswordCredentialsGrantStorage interface.
@ -45,20 +43,6 @@ func (m *MockResourceOwnerPasswordCredentialsGrantStorage) EXPECT() *MockResourc
return m.recorder
}
// AccessTokenStorage mocks base method.
func (m *MockResourceOwnerPasswordCredentialsGrantStorage) AccessTokenStorage() oauth2.AccessTokenStorage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AccessTokenStorage")
ret0, _ := ret[0].(oauth2.AccessTokenStorage)
return ret0
}
// AccessTokenStorage indicates an expected call of AccessTokenStorage.
func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) AccessTokenStorage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenStorage", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).AccessTokenStorage))
}
// Authenticate mocks base method.
func (m *MockResourceOwnerPasswordCredentialsGrantStorage) Authenticate(ctx context.Context, name, secret string) (string, error) {
m.ctrl.T.Helper()
@ -73,17 +57,3 @@ func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) Authenti
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authenticate", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).Authenticate), ctx, name, secret)
}
// RefreshTokenStorage mocks base method.
func (m *MockResourceOwnerPasswordCredentialsGrantStorage) RefreshTokenStorage() oauth2.RefreshTokenStorage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RefreshTokenStorage")
ret0, _ := ret[0].(oauth2.RefreshTokenStorage)
return ret0
}
// RefreshTokenStorage indicates an expected call of RefreshTokenStorage.
func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) RefreshTokenStorage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTokenStorage", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).RefreshTokenStorage))
}

View File

@ -0,0 +1,58 @@
// Copyright © 2025 Ory Corp
// SPDX-License-Identifier: Apache-2.0
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ory/hydra/v2/fosite/handler/oauth2 (interfaces: ResourceOwnerPasswordCredentialsGrantStorageProvider)
//
// Generated by this command:
//
// mockgen -package internal -destination internal/resource_owner_password_credentials_grant_storage_provider.go github.com/ory/hydra/v2/fosite/handler/oauth2 ResourceOwnerPasswordCredentialsGrantStorageProvider
//
// Package internal is a generated GoMock package.
package internal
import (
reflect "reflect"
oauth2 "github.com/ory/hydra/v2/fosite/handler/oauth2"
gomock "go.uber.org/mock/gomock"
)
// MockResourceOwnerPasswordCredentialsGrantStorageProvider is a mock of ResourceOwnerPasswordCredentialsGrantStorageProvider interface.
type MockResourceOwnerPasswordCredentialsGrantStorageProvider struct {
ctrl *gomock.Controller
recorder *MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder
isgomock struct{}
}
// MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder is the mock recorder for MockResourceOwnerPasswordCredentialsGrantStorageProvider.
type MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder struct {
mock *MockResourceOwnerPasswordCredentialsGrantStorageProvider
}
// NewMockResourceOwnerPasswordCredentialsGrantStorageProvider creates a new mock instance.
func NewMockResourceOwnerPasswordCredentialsGrantStorageProvider(ctrl *gomock.Controller) *MockResourceOwnerPasswordCredentialsGrantStorageProvider {
mock := &MockResourceOwnerPasswordCredentialsGrantStorageProvider{ctrl: ctrl}
mock.recorder = &MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockResourceOwnerPasswordCredentialsGrantStorageProvider) EXPECT() *MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder {
return m.recorder
}
// ResourceOwnerPasswordCredentialsGrantStorage mocks base method.
func (m *MockResourceOwnerPasswordCredentialsGrantStorageProvider) ResourceOwnerPasswordCredentialsGrantStorage() oauth2.ResourceOwnerPasswordCredentialsGrantStorage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResourceOwnerPasswordCredentialsGrantStorage")
ret0, _ := ret[0].(oauth2.ResourceOwnerPasswordCredentialsGrantStorage)
return ret0
}
// ResourceOwnerPasswordCredentialsGrantStorage indicates an expected call of ResourceOwnerPasswordCredentialsGrantStorage.
func (mr *MockResourceOwnerPasswordCredentialsGrantStorageProviderMockRecorder) ResourceOwnerPasswordCredentialsGrantStorage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResourceOwnerPasswordCredentialsGrantStorage", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorageProvider)(nil).ResourceOwnerPasswordCredentialsGrantStorage))
}

View File

@ -15,9 +15,8 @@ package internal
import (
reflect "reflect"
gomock "go.uber.org/mock/gomock"
fosite "github.com/ory/hydra/v2/fosite"
gomock "go.uber.org/mock/gomock"
)
// MockStorage is a mock of Storage interface.
@ -44,16 +43,16 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
return m.recorder
}
// ClientManager mocks base method.
func (m *MockStorage) ClientManager() fosite.ClientManager {
// FositeClientManager mocks base method.
func (m *MockStorage) FositeClientManager() fosite.ClientManager {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClientManager")
ret := m.ctrl.Call(m, "FositeClientManager")
ret0, _ := ret[0].(fosite.ClientManager)
return ret0
}
// ClientManager indicates an expected call of ClientManager.
func (mr *MockStorageMockRecorder) ClientManager() *gomock.Call {
// FositeClientManager indicates an expected call of FositeClientManager.
func (mr *MockStorageMockRecorder) FositeClientManager() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientManager", reflect.TypeOf((*MockStorage)(nil).ClientManager))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FositeClientManager", reflect.TypeOf((*MockStorage)(nil).FositeClientManager))
}

View File

@ -138,7 +138,7 @@ func (f *Fosite) NewIntrospectionRequest(ctx context.Context, r *http.Request, s
return &IntrospectionResponse{Active: false}, errorsx.WithStack(ErrRequestUnauthorized.WithHint("Unable to decode OAuth 2.0 Client Secret from HTTP basic authorization header, make sure it is properly encoded.").WithWrap(err).WithDebug(err.Error()))
}
client, err := f.Store.ClientManager().GetClient(ctx, clientID)
client, err := f.Store.FositeClientManager().GetClient(ctx, clientID)
if err != nil {
return &IntrospectionResponse{Active: false}, errorsx.WithStack(ErrRequestUnauthorized.WithHint("Unable to find OAuth 2.0 Client from HTTP basic authorization header.").WithWrap(err).WithDebug(err.Error()))
}

View File

@ -89,7 +89,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidRequest,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil)
},
@ -105,7 +105,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidRequest,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil)
},
@ -121,7 +121,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidRequest,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil)
},
@ -138,7 +138,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidState,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil)
},
@ -156,7 +156,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidState,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil)
},
@ -174,7 +174,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"scope": {"foo bar baz"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil)
},
@ -194,7 +194,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"},
Audience: []string{"https://cloud.ory.sh/api"},
@ -218,7 +218,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"https://foo.bar/cb"},
@ -258,7 +258,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api", "https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"https://foo.bar/cb"},
@ -298,7 +298,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"audience": {"test value", ""},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"https://foo.bar/cb"},
@ -338,7 +338,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"web+application://callback"},
@ -378,7 +378,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{
ResponseTypes: []string{"code token"},
RedirectURIs: []string{"https://foo.bar/cb"},
@ -418,7 +418,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"response_mode": {"unknown"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil)
},
@ -438,7 +438,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"response_mode": {"form_post"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil)
},
@ -458,7 +458,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"response_mode": {"form_post"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{
DefaultClient: &DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"},
@ -487,7 +487,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{
DefaultClient: &DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"},
@ -534,7 +534,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{
DefaultClient: &DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"},
@ -581,7 +581,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{
DefaultClient: &DefaultClient{
RedirectURIs: []string{"https://foo.bar/cb"},
@ -629,7 +629,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"response_mode": {"form_post"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(2)
store.EXPECT().FositeClientManager().Return(clientManager).Times(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil)
},
@ -650,7 +650,7 @@ func TestNewPushedAuthorizeRequest(t *testing.T) {
"response_mode": {"form_post"},
},
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).MaxTimes(2)
store.EXPECT().FositeClientManager().Return(clientManager).MaxTimes(2)
clientManager.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("1234")), gomock.Eq([]byte("4321"))).Return(fmt.Errorf("invalid hash"))
},

View File

@ -70,7 +70,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
},
},
@ -84,7 +84,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
@ -101,7 +101,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
@ -121,7 +121,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
@ -141,7 +141,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = true
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
@ -159,7 +159,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
@ -179,7 +179,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().ClientManager().Return(clientManager).Times(1)
store.EXPECT().FositeClientManager().Return(clientManager).Times(1)
clientManager.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false

View File

@ -9,7 +9,7 @@ import (
// Storage defines fosite's minimal storage interface.
type Storage interface {
ClientManager() ClientManager
FositeClientManager() ClientManager
}
type PARStorageProvider interface {

View File

@ -101,7 +101,7 @@ func NewMemoryStore() *MemoryStore {
}
}
func (s *MemoryStore) ClientManager() fosite.ClientManager {
func (s *MemoryStore) FositeClientManager() fosite.ClientManager {
return s
}
@ -121,6 +121,10 @@ func (s *MemoryStore) TokenRevocationStorage() oauth2.TokenRevocationStorage {
return s
}
func (s *MemoryStore) ResourceOwnerPasswordCredentialsGrantStorage() oauth2.ResourceOwnerPasswordCredentialsGrantStorage {
return s
}
func (s *MemoryStore) OpenIDConnectRequestStorage() openid.OpenIDConnectRequestStorage {
return s
}

View File

@ -76,10 +76,10 @@ func NewConfig(deps configDependencies) *Config {
}
}
func (c *Config) LoadDefaultHandlers(strategy interface{}) {
func (c *Config) LoadDefaultHandlers(storage fosite.Storage, strategy interface{}) {
factories := append(defaultFactories, c.deps.ExtraFositeFactories()...)
for _, factory := range factories {
res := factory(c, c.deps.Persister(), strategy)
res := factory(c, storage, strategy)
if ah, ok := res.(fosite.AuthorizeEndpointHandler); ok {
c.authorizeEndpointHandlers.Append(ah)
}

View File

@ -908,7 +908,7 @@ func testFositeStoreSetClientAssertionJWT(m *driver.RegistrySQL) func(*testing.T
require.True(t, ok)
jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute))
require.NoError(t, store.ClientManager().SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry))
require.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry))
cmp, err := store.GetClientAssertionJWT(context.Background(), jti.JTI)
require.NotEqual(t, cmp.NID, uuid.Nil)
@ -923,7 +923,7 @@ func testFositeStoreSetClientAssertionJWT(m *driver.RegistrySQL) func(*testing.T
jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute))
require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti))
assert.ErrorIs(t, store.ClientManager().SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry), fosite.ErrJTIKnown)
assert.ErrorIs(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry), fosite.ErrJTIKnown)
})
t.Run("case=deletes expired JTIs", func(t *testing.T) {
@ -933,7 +933,7 @@ func testFositeStoreSetClientAssertionJWT(m *driver.RegistrySQL) func(*testing.T
require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), expiredJTI))
newJTI := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(time.Minute))
require.NoError(t, store.ClientManager().SetClientAssertionJWT(context.Background(), newJTI.JTI, newJTI.Expiry))
require.NoError(t, store.SetClientAssertionJWT(context.Background(), newJTI.JTI, newJTI.Expiry))
_, err := store.GetClientAssertionJWT(context.Background(), expiredJTI.JTI)
assert.True(t, errors.Is(err, sqlcon.ErrNoRows))
@ -951,7 +951,7 @@ func testFositeStoreSetClientAssertionJWT(m *driver.RegistrySQL) func(*testing.T
require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti))
jti.Expiry = jti.Expiry.Add(2 * time.Minute)
assert.NoError(t, store.ClientManager().SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry))
assert.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry))
cmp, err := store.GetClientAssertionJWT(context.Background(), jti.JTI)
assert.NoError(t, err)
assert.Equal(t, jti, cmp)
@ -965,7 +965,7 @@ func testFositeStoreClientAssertionJWTValid(m *driver.RegistrySQL) func(*testing
store, ok := m.OAuth2Storage().(oauth2.AssertionJWTReader)
require.True(t, ok)
assert.NoError(t, store.ClientManager().ClientAssertionJWTValid(context.Background(), uuid.Must(uuid.NewV4()).String()))
assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), uuid.Must(uuid.NewV4()).String()))
})
t.Run("case=returns invalid on known JTI", func(t *testing.T) {
@ -975,7 +975,7 @@ func testFositeStoreClientAssertionJWTValid(m *driver.RegistrySQL) func(*testing
require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti))
assert.True(t, errors.Is(store.ClientManager().ClientAssertionJWTValid(context.Background(), jti.JTI), fosite.ErrJTIKnown))
assert.True(t, errors.Is(store.ClientAssertionJWTValid(context.Background(), jti.JTI), fosite.ErrJTIKnown))
})
t.Run("case=returns valid on expired JTI", func(t *testing.T) {
@ -985,7 +985,7 @@ func testFositeStoreClientAssertionJWTValid(m *driver.RegistrySQL) func(*testing
require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti))
assert.NoError(t, store.ClientManager().ClientAssertionJWTValid(context.Background(), jti.JTI))
assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), jti.JTI))
})
}
}

View File

@ -205,9 +205,9 @@ func TestDeviceTokenRequest(t *testing.T) {
for _, testCase := range testCases {
t.Run("case="+testCase.description, func(t *testing.T) {
code, signature, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(context.TODO())
code, signature, err := reg.DeviceCodeStrategy().GenerateDeviceCode(context.TODO())
require.NoError(t, err)
_, userCodeSignature, err := reg.RFC8628HMACStrategy().GenerateUserCode(context.TODO())
_, userCodeSignature, err := reg.UserCodeStrategy().GenerateUserCode(context.TODO())
require.NoError(t, err)
if testCase.setUp != nil {
@ -428,7 +428,6 @@ func TestDeviceCodeWithDefaultStrategy(t *testing.T) {
subject := "aeneas-rekkas"
nonce := uuid.New()
t.Run("case=perform device flow without ID and refresh tokens", func(t *testing.T) {
c, conf := newDeviceClient(t, reg)
conf.Scopes = []string{"hydra"}
testhelpers.NewDeviceLoginConsentUI(t, reg.Config(),
@ -453,7 +452,6 @@ func TestDeviceCodeWithDefaultStrategy(t *testing.T) {
assert.Empty(t, token.RefreshToken)
})
t.Run("case=perform device flow with ID token", func(t *testing.T) {
c, conf := newDeviceClient(t, reg)
conf.Scopes = []string{"openid", "hydra"}
testhelpers.NewDeviceLoginConsentUI(t, reg.Config(),
@ -479,7 +477,6 @@ func TestDeviceCodeWithDefaultStrategy(t *testing.T) {
assert.Empty(t, token.RefreshToken)
})
t.Run("case=perform device flow with refresh token", func(t *testing.T) {
c, conf := newDeviceClient(t, reg)
conf.Scopes = []string{"hydra", "offline"}
testhelpers.NewDeviceLoginConsentUI(t, reg.Config(),

View File

@ -35,5 +35,7 @@ type Registry interface {
OpenIDConnectRequestValidator() *openid.OpenIDConnectRequestValidator
AccessRequestHooks() []AccessRequestHook
OAuth2ProviderConfig() fosite.Configurator
RFC8628HMACStrategy() rfc8628.RFC8628CodeStrategy
rfc8628.DeviceRateLimitStrategyProvider
rfc8628.DeviceCodeStrategyProvider
rfc8628.UserCodeStrategyProvider
}

View File

@ -16,6 +16,7 @@ import (
"github.com/ory/hydra/v2/fosite"
"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/persistence"
"github.com/ory/hydra/v2/x"
"github.com/ory/pop/v6"
@ -27,8 +28,11 @@ import (
)
var (
_ persistence.Persister = (*Persister)(nil)
_ fosite.Transactional = (*Persister)(nil)
_ persistence.Persister = (*Persister)(nil)
_ fosite.Transactional = (*Persister)(nil)
_ fosite.ClientManager = (*Persister)(nil)
_ oauth2.AssertionJWTReader = (*Persister)(nil)
_ x.FositeStorer = (*Persister)(nil)
)
var ErrNoTransactionOpen = errors.New("There is no Transaction in this context.")
@ -67,62 +71,6 @@ type (
}
)
func (p *BasePersister) BeginTX(ctx context.Context) (_ context.Context, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.BeginTX")
defer otelx.End(span, &err)
fallback := &pop.Connection{TX: &pop.Tx{}}
if popx.GetConnection(ctx, fallback).TX != fallback.TX {
return context.WithValue(ctx, skipCommitKey, true), nil // no-op
}
tx, err := p.c.Store.TransactionContextOptions(ctx, &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: false,
})
c := &pop.Connection{
TX: tx,
Store: tx,
ID: uuid.Must(uuid.NewV4()).String(),
Dialect: p.c.Dialect,
}
return popx.WithTransaction(ctx, c), err
}
func (p *BasePersister) Commit(ctx context.Context) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Commit")
defer otelx.End(span, &err)
if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip {
return nil // we skipped BeginTX, so we also skip Commit
}
fallback := &pop.Connection{TX: &pop.Tx{}}
tx := popx.GetConnection(ctx, fallback)
if tx.TX == fallback.TX || tx.TX == nil {
return errors.WithStack(ErrNoTransactionOpen)
}
return errors.WithStack(tx.TX.Commit())
}
func (p *BasePersister) Rollback(ctx context.Context) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Rollback")
defer otelx.End(span, &err)
if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip {
return nil // we skipped BeginTX, so we also skip Rollback
}
fallback := &pop.Connection{TX: &pop.Tx{}}
tx := popx.GetConnection(ctx, fallback)
if tx.TX == fallback.TX || tx.TX == nil {
return errors.WithStack(ErrNoTransactionOpen)
}
return errors.WithStack(tx.TX.Rollback())
}
func NewPersister(base *BasePersister, r Dependencies) *Persister {
return &Persister{
BasePersister: base,
@ -192,3 +140,62 @@ func (p *BasePersister) mustSetNetwork(ctx context.Context, v interface{}) {
func (p *BasePersister) Transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error {
return popx.Transaction(ctx, p.c, f)
}
// BeginTX implements Transactional.
func (p *BasePersister) BeginTX(ctx context.Context) (_ context.Context, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.BeginTX")
defer otelx.End(span, &err)
fallback := &pop.Connection{TX: &pop.Tx{}}
if popx.GetConnection(ctx, fallback).TX != fallback.TX {
return context.WithValue(ctx, skipCommitKey, true), nil // no-op
}
tx, err := p.c.Store.TransactionContextOptions(ctx, &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: false,
})
c := &pop.Connection{
TX: tx,
Store: tx,
ID: uuid.Must(uuid.NewV4()).String(),
Dialect: p.c.Dialect,
}
return popx.WithTransaction(ctx, c), err
}
// Commit implements Transactional.
func (p *BasePersister) Commit(ctx context.Context) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Commit")
defer otelx.End(span, &err)
if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip {
return nil // we skipped BeginTX, so we also skip Commit
}
fallback := &pop.Connection{TX: &pop.Tx{}}
tx := popx.GetConnection(ctx, fallback)
if tx.TX == fallback.TX || tx.TX == nil {
return errors.WithStack(ErrNoTransactionOpen)
}
return errors.WithStack(tx.TX.Commit())
}
// Rollback implements Transactional.
func (p *BasePersister) Rollback(ctx context.Context) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Rollback")
defer otelx.End(span, &err)
if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip {
return nil // we skipped BeginTX, so we also skip Rollback
}
fallback := &pop.Connection{TX: &pop.Tx{}}
tx := popx.GetConnection(ctx, fallback)
if tx.TX == fallback.TX || tx.TX == nil {
return errors.WithStack(ErrNoTransactionOpen)
}
return errors.WithStack(tx.TX.Rollback())
}

View File

@ -7,6 +7,7 @@ import (
"context"
)
// Authenticate implements ResourceOwnerPasswordCredentialsGrantStorage.
func (p *Persister) Authenticate(ctx context.Context, name, secret string) (subject string, err error) {
session, err := p.r.Kratos().Authenticate(ctx, name, secret)
if err != nil {

View File

@ -18,23 +18,51 @@ import (
"github.com/ory/x/sqlcon"
)
func (p *Persister) GetConcreteClient(ctx context.Context, id string) (c *client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteClient",
// AuthenticateClient implements client.Manager.
func (p *Persister) AuthenticateClient(ctx context.Context, id string, secret []byte) (_ *client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AuthenticateClient",
trace.WithAttributes(events.ClientID(id)),
)
defer otelx.End(span, &err)
var cl client.Client
if err := p.QueryWithNetwork(ctx).Where("id = ?", id).First(&cl); err != nil {
return nil, sqlcon.HandleError(err)
c, err := p.GetConcreteClient(ctx, id)
if err != nil {
return nil, err
}
return &cl, nil
if err := p.r.ClientHasher().Compare(ctx, c.GetHashedSecret(), secret); err != nil {
return nil, err
}
return c, nil
}
func (p *Persister) GetClient(ctx context.Context, id string) (fosite.Client, error) {
return p.GetConcreteClient(ctx, id)
// CreateClient implements client.Storage.
func (p *Persister) CreateClient(ctx context.Context, c *client.Client) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateClient")
defer otelx.End(span, &err)
h, err := p.r.ClientHasher().Hash(ctx, []byte(c.Secret))
if err != nil {
return err
}
c.Secret = string(h)
if c.ID == "" {
c.ID = uuid.Must(uuid.NewV4()).String()
}
if err := sqlcon.HandleError(p.CreateWithNetwork(ctx, c)); err != nil {
return err
}
events.Trace(ctx, events.ClientCreated,
events.WithClientID(c.ID),
events.WithClientName(c.Name))
return nil
}
// UpdateClient implements client.Storage.
func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateClient",
trace.WithAttributes(events.ClientID(cl.ID)),
@ -79,48 +107,7 @@ func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) (err er
})
}
func (p *Persister) AuthenticateClient(ctx context.Context, id string, secret []byte) (_ *client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AuthenticateClient",
trace.WithAttributes(events.ClientID(id)),
)
defer otelx.End(span, &err)
c, err := p.GetConcreteClient(ctx, id)
if err != nil {
return nil, err
}
if err := p.r.ClientHasher().Compare(ctx, c.GetHashedSecret(), secret); err != nil {
return nil, err
}
return c, nil
}
func (p *Persister) CreateClient(ctx context.Context, c *client.Client) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateClient")
defer otelx.End(span, &err)
h, err := p.r.ClientHasher().Hash(ctx, []byte(c.Secret))
if err != nil {
return err
}
c.Secret = string(h)
if c.ID == "" {
c.ID = uuid.Must(uuid.NewV4()).String()
}
if err := sqlcon.HandleError(p.CreateWithNetwork(ctx, c)); err != nil {
return err
}
events.Trace(ctx, events.ClientCreated,
events.WithClientID(c.ID),
events.WithClientName(c.Name))
return nil
}
// DeleteClient implements client.Storage.
func (p *Persister) DeleteClient(ctx context.Context, id string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteClient",
trace.WithAttributes(events.ClientID(id)),
@ -143,6 +130,7 @@ func (p *Persister) DeleteClient(ctx context.Context, id string) (err error) {
return nil
}
// GetClients implements client.Storage.
func (p *Persister) GetClients(ctx context.Context, filters client.Filter) (cs []client.Client, _ *keysetpagination.Paginator, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetClients")
defer otelx.End(span, &err)
@ -170,3 +158,22 @@ func (p *Persister) GetClients(ctx context.Context, filters client.Filter) (cs [
cs, nextPage := keysetpagination.Result(cs, paginator)
return cs, nextPage, nil
}
// GetConcreteClient implements client.Storage.
func (p *Persister) GetConcreteClient(ctx context.Context, id string) (c *client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteClient",
trace.WithAttributes(events.ClientID(id)),
)
defer otelx.End(span, &err)
var cl client.Client
if err := p.QueryWithNetwork(ctx).Where("id = ?", id).First(&cl); err != nil {
return nil, sqlcon.HandleError(err)
}
return &cl, nil
}
// GetClient implements fosite.ClientManager.
func (p *Persister) GetClient(ctx context.Context, id string) (fosite.Client, error) {
return p.GetConcreteClient(ctx, id)
}

View File

@ -27,7 +27,7 @@ import (
"github.com/ory/x/sqlxx"
)
var _ consent.Manager = &Persister{}
var _ consent.Manager = (*Persister)(nil)
func (p *Persister) RevokeSubjectConsentSession(ctx context.Context, user string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession")

View File

@ -17,7 +17,6 @@ import (
"github.com/tidwall/gjson"
"github.com/ory/hydra/v2/fosite"
"github.com/ory/hydra/v2/fosite/handler/rfc8628"
"github.com/ory/hydra/v2/oauth2"
"github.com/ory/x/otelx"
"github.com/ory/x/sqlcon"
@ -102,10 +101,6 @@ func (r *DeviceRequestSQL) toRequest(ctx context.Context, session fosite.Session
}, nil
}
func (p *Persister) DeviceAuthStorage() rfc8628.DeviceAuthStorage {
return p
}
func (p *Persister) sqlDeviceSchemaFromRequest(ctx context.Context, deviceCodeSignature, userCodeSignature string, r fosite.DeviceRequester, expiresAt time.Time) (*DeviceRequestSQL, error) {
subject := ""
if r.GetSession() == nil {
@ -157,7 +152,7 @@ func (p *Persister) sqlDeviceSchemaFromRequest(ctx context.Context, deviceCodeSi
}, nil
}
// CreateDeviceCodeSession creates a new device code session and stores it in the database
// CreateDeviceCodeSession creates a new device code session and stores it in the database. Implements DeviceAuthStorage.
func (p *Persister) CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, requester fosite.DeviceRequester) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateDeviceCodeSession")
defer otelx.End(span, &err)
@ -178,33 +173,7 @@ func (p *Persister) CreateDeviceAuthSession(ctx context.Context, deviceCodeSigna
return nil
}
// UpdateDeviceCodeSessionBySignature updates a device code session by the device_code signature
func (p *Persister) UpdateDeviceCodeSessionBySignature(ctx context.Context, signature string, requester fosite.DeviceRequester) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateDeviceCodeSessionBySignature")
defer otelx.End(span, &err)
req, err := p.sqlDeviceSchemaFromRequest(ctx, signature, "", requester, requester.GetSession().GetExpiresAt(fosite.DeviceCode).UTC())
if err != nil {
return err
}
stmt := fmt.Sprintf(
"UPDATE %s SET granted_scope=?, granted_audience=?, session_data=?, user_code_state=?, subject=?, challenge_id=? WHERE device_code_signature=? AND nid = ?",
sqlTableDeviceAuthCodes,
)
/* #nosec G201 table is static */
return sqlcon.HandleError(
p.Connection(ctx).RawQuery(stmt,
req.GrantedScope, req.GrantedAudience,
req.Session, req.UserCodeState,
req.Subject, req.ConsentChallenge,
signature, p.NetworkID(ctx),
).Exec(),
)
}
// GetDeviceCodeSession returns a device code session from the database
// GetDeviceCodeSession returns a device code session from the database. Implements DeviceAuthStorage.
func (p *Persister) GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.DeviceRequester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSession")
defer otelx.End(span, &err)
@ -227,34 +196,7 @@ func (p *Persister) GetDeviceCodeSession(ctx context.Context, signature string,
return r.toRequest(ctx, session, p)
}
// GetDeviceCodeSessionByRequestID returns a device code session from the database
func (p *Persister) GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, session fosite.Session) (_ fosite.DeviceRequester, deviceCodeSignature string, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSessionByRequestID")
defer otelx.End(span, &err)
r := DeviceRequestSQL{}
if err = p.QueryWithNetwork(ctx).Where("request_id = ?", requestID).First(&r); errors.Is(err, sql.ErrNoRows) {
return nil, "", errors.WithStack(fosite.ErrNotFound)
} else if err != nil {
return nil, "", sqlcon.HandleError(err)
}
if !r.DeviceCodeActive {
fr, err := r.toRequest(ctx, session, p)
if err != nil {
return nil, "", err
}
return fr, r.ID, errors.WithStack(fosite.ErrInactiveToken)
}
fr, err := r.toRequest(ctx, session, p)
if err != nil {
return nil, "", err
}
return fr, r.ID, nil
}
// InvalidateDeviceCodeSession invalidates a device code session
// InvalidateDeviceCodeSession invalidates a device code session. Implements DeviceAuthStorage.
func (p *Persister) InvalidateDeviceCodeSession(ctx context.Context, signature string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InvalidateDeviceCodeSession")
defer otelx.End(span, &err)
@ -266,7 +208,7 @@ func (p *Persister) InvalidateDeviceCodeSession(ctx context.Context, signature s
Delete(DeviceRequestSQL{}))
}
// GetUserCodeSession returns a user code session from the database
// GetUserCodeSession returns a user code session from the database. Implements FositeStorer.
func (p *Persister) GetUserCodeSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.DeviceRequester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUserCodeSession")
defer otelx.End(span, &err)
@ -293,3 +235,56 @@ func (p *Persister) GetUserCodeSession(ctx context.Context, signature string, se
return fr, err
}
// GetDeviceCodeSessionByRequestID returns a device code session from the database. Implements FositeStorer.
func (p *Persister) GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, session fosite.Session) (_ fosite.DeviceRequester, deviceCodeSignature string, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSessionByRequestID")
defer otelx.End(span, &err)
r := DeviceRequestSQL{}
if err = p.QueryWithNetwork(ctx).Where("request_id = ?", requestID).First(&r); errors.Is(err, sql.ErrNoRows) {
return nil, "", errors.WithStack(fosite.ErrNotFound)
} else if err != nil {
return nil, "", sqlcon.HandleError(err)
}
if !r.DeviceCodeActive {
fr, err := r.toRequest(ctx, session, p)
if err != nil {
return nil, "", err
}
return fr, r.ID, errors.WithStack(fosite.ErrInactiveToken)
}
fr, err := r.toRequest(ctx, session, p)
if err != nil {
return nil, "", err
}
return fr, r.ID, nil
}
// UpdateDeviceCodeSessionBySignature updates a device code session by the device_code signature. Implements FositeStorer.
func (p *Persister) UpdateDeviceCodeSessionBySignature(ctx context.Context, signature string, requester fosite.DeviceRequester) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateDeviceCodeSessionBySignature")
defer otelx.End(span, &err)
req, err := p.sqlDeviceSchemaFromRequest(ctx, signature, "", requester, requester.GetSession().GetExpiresAt(fosite.DeviceCode).UTC())
if err != nil {
return err
}
stmt := fmt.Sprintf(
"UPDATE %s SET granted_scope=?, granted_audience=?, session_data=?, user_code_state=?, subject=?, challenge_id=? WHERE device_code_signature=? AND nid = ?",
sqlTableDeviceAuthCodes,
)
/* #nosec G201 table is static */
return sqlcon.HandleError(
p.Connection(ctx).RawQuery(stmt,
req.GrantedScope, req.GrantedAudience,
req.Session, req.UserCodeState,
req.Subject, req.ConsentChallenge,
signature, p.NetworkID(ctx),
).Exec(),
)
}

View File

@ -12,7 +12,6 @@ import (
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/ory/hydra/v2/fosite/handler/rfc7523"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/oauth2/trust"
"github.com/ory/pop/v6"
@ -22,7 +21,7 @@ import (
"github.com/ory/x/sqlxx"
)
var _ trust.GrantManager = &Persister{}
var _ trust.GrantManager = (*Persister)(nil)
type SQLGrant struct {
ID uuid.UUID `db:"id"`
@ -41,10 +40,37 @@ func (SQLGrant) TableName() string {
return "hydra_oauth2_trusted_jwt_bearer_issuer"
}
func (p *Persister) RFC7523KeyStorage() rfc7523.RFC7523KeyStorage {
return p
func (SQLGrant) fromGrant(g trust.Grant) SQLGrant {
return SQLGrant{
ID: g.ID,
Issuer: g.Issuer,
Subject: g.Subject,
AllowAnySubject: g.AllowAnySubject,
Scope: g.Scope,
KeySet: g.PublicKey.Set,
KeyID: g.PublicKey.KeyID,
CreatedAt: g.CreatedAt,
ExpiresAt: g.ExpiresAt,
}
}
func (d SQLGrant) toGrant() trust.Grant {
return trust.Grant{
ID: d.ID,
Issuer: d.Issuer,
Subject: d.Subject,
AllowAnySubject: d.AllowAnySubject,
Scope: d.Scope,
PublicKey: trust.PublicKey{
Set: d.KeySet,
KeyID: d.KeyID,
},
CreatedAt: d.CreatedAt,
ExpiresAt: d.ExpiresAt,
}
}
// CreateGrant implements GrantManager
func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jose.JSONWebKey) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateGrant")
defer otelx.End(span, &err)
@ -66,6 +92,7 @@ func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jo
})
}
// GetConcreteGrant implements GrantManager
func (p *Persister) GetConcreteGrant(ctx context.Context, id uuid.UUID) (_ trust.Grant, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteGrant")
defer otelx.End(span, &err)
@ -78,6 +105,7 @@ func (p *Persister) GetConcreteGrant(ctx context.Context, id uuid.UUID) (_ trust
return data.toGrant(), nil
}
// DeleteGrant implements GrantManager
func (p *Persister) DeleteGrant(ctx context.Context, id uuid.UUID) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteGrant")
defer otelx.End(span, &err)
@ -96,6 +124,7 @@ func (p *Persister) DeleteGrant(ctx context.Context, id uuid.UUID) (err error) {
})
}
// GetGrants implements GrantManager
func (p *Persister) GetGrants(ctx context.Context, optionalIssuer string, pageOpts ...keysetpagination.Option) (_ []trust.Grant, _ *keysetpagination.Paginator, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetGrants")
defer otelx.End(span, &err)
@ -123,6 +152,19 @@ func (p *Persister) GetGrants(ctx context.Context, optionalIssuer string, pageOp
return grants, nextPage, nil
}
// FlushInactiveGrants implements GrantManager
func (p *Persister) FlushInactiveGrants(ctx context.Context, notAfter time.Time, _ int, _ int) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveGrants")
defer otelx.End(span, &err)
deleteUntil := time.Now().UTC()
if deleteUntil.After(notAfter) {
deleteUntil = notAfter
}
return sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("expires_at < ?", deleteUntil).Delete(&SQLGrant{}))
}
// GetPublicKey implements RFC7523KeyStorage
func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject string, keyId string) (_ *jose.JSONWebKey, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKey")
defer otelx.End(span, &err)
@ -151,6 +193,7 @@ func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject str
return &keySet.Keys[0], nil
}
// GetPublicKeys implements RFC7523KeyStorage
func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject string) (_ *jose.JSONWebKeySet, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeys")
defer otelx.End(span, &err)
@ -212,6 +255,7 @@ func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject st
return js.ToJWK(ctx, p.r.KeyCipher())
}
// GetPublicKeyScopes implements RFC7523KeyStorage
func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) (_ []string, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeyScopes")
defer otelx.End(span, &err)
@ -234,6 +278,7 @@ func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subje
return scopes, nil
}
// IsJWTUsed implements RFC7523KeyStorage
func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (ok bool, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.IsJWTUsed")
defer otelx.End(span, &err)
@ -246,50 +291,10 @@ func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (ok bool, err err
return false, nil
}
// MarkJWTUsedForTime implements RFC7523KeyStorage
func (p *Persister) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MarkJWTUsedForTime")
defer otelx.End(span, &err)
return p.SetClientAssertionJWT(ctx, jti, exp)
}
func (SQLGrant) fromGrant(g trust.Grant) SQLGrant {
return SQLGrant{
ID: g.ID,
Issuer: g.Issuer,
Subject: g.Subject,
AllowAnySubject: g.AllowAnySubject,
Scope: g.Scope,
KeySet: g.PublicKey.Set,
KeyID: g.PublicKey.KeyID,
CreatedAt: g.CreatedAt,
ExpiresAt: g.ExpiresAt,
}
}
func (d SQLGrant) toGrant() trust.Grant {
return trust.Grant{
ID: d.ID,
Issuer: d.Issuer,
Subject: d.Subject,
AllowAnySubject: d.AllowAnySubject,
Scope: d.Scope,
PublicKey: trust.PublicKey{
Set: d.KeySet,
KeyID: d.KeyID,
},
CreatedAt: d.CreatedAt,
ExpiresAt: d.ExpiresAt,
}
}
func (p *Persister) FlushInactiveGrants(ctx context.Context, notAfter time.Time, _ int, _ int) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveGrants")
defer otelx.End(span, &err)
deleteUntil := time.Now().UTC()
if deleteUntil.After(notAfter) {
deleteUntil = notAfter
}
return sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("expires_at < ?", deleteUntil).Delete(&SQLGrant{}))
}

View File

@ -26,6 +26,7 @@ type JWKPersister struct {
*BasePersister
}
// GenerateAndPersistKeySet implements jwk.Manager.
func (p *JWKPersister) GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, use string) (_ *jose.JSONWebKeySet, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GenerateAndPersistKeySet",
trace.WithAttributes(
@ -51,6 +52,7 @@ func (p *JWKPersister) GenerateAndPersistKeySet(ctx context.Context, set, kid, a
return keys, nil
}
// AddKey implements jwk.Manager.
func (p *JWKPersister) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AddKey",
trace.WithAttributes(
@ -76,6 +78,7 @@ func (p *JWKPersister) AddKey(ctx context.Context, set string, key *jose.JSONWeb
}))
}
// AddKeySet implements jwk.Manager.
func (p *JWKPersister) AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AddKeySet", trace.WithAttributes(attribute.String("set", set)))
defer otelx.End(span, &err)
@ -105,7 +108,7 @@ func (p *JWKPersister) AddKeySet(ctx context.Context, set string, keys *jose.JSO
})
}
// UpdateKey updates or creates the key.
// UpdateKey updates or creates the key. Implements jwk.Manager.
func (p *JWKPersister) UpdateKey(ctx context.Context, set string, key *jose.JSONWebKey) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKey",
trace.WithAttributes(
@ -124,7 +127,7 @@ func (p *JWKPersister) UpdateKey(ctx context.Context, set string, key *jose.JSON
})
}
// UpdateKeySet updates or creates the key set.
// UpdateKeySet updates or creates the key set. Implements jwk.Manager.
func (p *JWKPersister) UpdateKeySet(ctx context.Context, set string, keySet *jose.JSONWebKeySet) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKeySet", trace.WithAttributes(attribute.String("set", set)))
defer otelx.End(span, &err)
@ -140,6 +143,7 @@ func (p *JWKPersister) UpdateKeySet(ctx context.Context, set string, keySet *jos
})
}
// GetKey implements jwk.Manager.
func (p *JWKPersister) GetKey(ctx context.Context, set, kid string) (_ *jose.JSONWebKeySet, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetKey",
trace.WithAttributes(
@ -170,6 +174,7 @@ func (p *JWKPersister) GetKey(ctx context.Context, set, kid string) (_ *jose.JSO
}, nil
}
// GetKeySet implements jwk.Manager.
func (p *JWKPersister) GetKeySet(ctx context.Context, set string) (keys *jose.JSONWebKeySet, err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetKeySet", trace.WithAttributes(attribute.String("set", set)))
defer otelx.End(span, &err)
@ -185,6 +190,7 @@ func (p *JWKPersister) GetKeySet(ctx context.Context, set string) (keys *jose.JS
return js.ToJWK(ctx, aead.NewAESGCM(p.d.Config()))
}
// DeleteKey implements jwk.Manager.
func (p *JWKPersister) DeleteKey(ctx context.Context, set, kid string) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteKey",
trace.WithAttributes(
@ -196,6 +202,7 @@ func (p *JWKPersister) DeleteKey(ctx context.Context, set, kid string) (err erro
return sqlcon.HandleError(err)
}
// DeleteKeySet implements jwk.Manager.
func (p *JWKPersister) DeleteKeySet(ctx context.Context, set string) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteKeySet", trace.WithAttributes(attribute.String("set", set)))
defer otelx.End(span, &err)

View File

@ -163,10 +163,10 @@ func (s *PersisterTestSuite) TestClientAssertionJWTValid() {
for k, r := range s.registries {
s.T().Run(k, func(t *testing.T) {
jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(24*time.Hour))
require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry))
require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry))
require.NoError(t, r.Persister().ClientManager().ClientAssertionJWTValid(s.t2, jti.JTI))
require.Error(t, r.Persister().ClientManager().ClientAssertionJWTValid(s.t1, jti.JTI))
require.NoError(t, r.Persister().ClientAssertionJWTValid(s.t2, jti.JTI))
require.Error(t, r.Persister().ClientAssertionJWTValid(s.t1, jti.JTI))
})
}
}
@ -799,7 +799,7 @@ func (s *PersisterTestSuite) TestGetClientAssertionJWT() {
store, ok := r.OAuth2Storage().(oauth2.AssertionJWTReader)
require.True(t, ok)
expected := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(24*time.Hour))
require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, expected.JTI, expected.Expiry))
require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, expected.JTI, expected.Expiry))
_, err := store.GetClientAssertionJWT(s.t2, expected.JTI)
require.Error(t, err)
@ -1151,7 +1151,7 @@ func (s *PersisterTestSuite) TestIsJWTUsed() {
for k, r := range s.registries {
s.T().Run(k, func(t *testing.T) {
jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(24*time.Hour))
require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry))
require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry))
actual, err := r.Persister().IsJWTUsed(s.t2, jti.JTI)
require.NoError(t, err)
@ -1245,9 +1245,9 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog
func (s *PersisterTestSuite) TestMarkJWTUsedForTime() {
for k, r := range s.registries {
s.T().Run(k, func(t *testing.T) {
require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, "a", time.Now().Add(-24*time.Hour)))
require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t2, "a", time.Now().Add(-24*time.Hour)))
require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t2, "b", time.Now().Add(-24*time.Hour)))
require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, "a", time.Now().Add(-24*time.Hour)))
require.NoError(t, r.Persister().SetClientAssertionJWT(s.t2, "a", time.Now().Add(-24*time.Hour)))
require.NoError(t, r.Persister().SetClientAssertionJWT(s.t2, "b", time.Now().Add(-24*time.Hour)))
require.NoError(t, r.Persister().MarkJWTUsedForTime(s.t2, "a", time.Now().Add(48*time.Hour)))
@ -1469,7 +1469,7 @@ func (s *PersisterTestSuite) TestSetClientAssertionJWT() {
for k, r := range s.registries {
s.T().Run(k, func(t *testing.T) {
jti := oauth2.NewBlacklistedJTI(uuid.Must(uuid.NewV4()).String(), time.Now().Add(24*time.Hour))
require.NoError(t, r.Persister().ClientManager().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry))
require.NoError(t, r.Persister().SetClientAssertionJWT(s.t1, jti.JTI, jti.Expiry))
actual := &oauth2.BlacklistedJTI{}
require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, jti.ID))

View File

@ -10,13 +10,17 @@ import (
"github.com/pkg/errors"
"github.com/ory/hydra/v2/fosite"
"github.com/ory/hydra/v2/fosite/handler/verifiable"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/otelx"
)
var _ verifiable.NonceManager = (*Persister)(nil)
// Set the aadAccessTokenPrefix to something unique to avoid ciphertext confusion with other usages of the AEAD cipher.
var aadAccessTokenPrefix = "vc-nonce-at:" // nolint:gosec
// NewNonce implements NonceManager.
func (p *Persister) NewNonce(ctx context.Context, accessToken string, expiresIn time.Time) (res string, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.NewNonce")
defer otelx.End(span, &err)
@ -27,6 +31,7 @@ func (p *Persister) NewNonce(ctx context.Context, accessToken string, expiresIn
return p.r.FlowCipher().Encrypt(ctx, plaintext, aad)
}
// IsNonceValid implements NonceManager.
func (p *Persister) IsNonceValid(ctx context.Context, accessToken, nonce string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.IsNonceValid")
defer otelx.End(span, &err)

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,7 @@ import (
)
type FositeStorer interface {
fosite.Storage
fosite.ClientManager
oauth2.AuthorizeCodeStorage
oauth2.AccessTokenStorage
oauth2.RefreshTokenStorage
@ -29,6 +29,7 @@ type FositeStorer interface {
verifiable.NonceManager
oauth2.ResourceOwnerPasswordCredentialsGrantStorage
// Hydra-specific storage utilities
// flush the access token requests from the database.
// no data will be deleted after the 'notAfter' timeframe.
FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error
@ -44,8 +45,9 @@ type FositeStorer interface {
// DeleteOpenIDConnectSession deletes an OpenID Connect session.
// This is duplicated from Ory Fosite to help against deprecation linting errors.
DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error
// DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error
// Hydra-specific RFC8628 Device Auth capabilities
GetUserCodeSession(context.Context, string, fosite.Session) (fosite.DeviceRequester, error)
GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, requester fosite.Session) (fosite.DeviceRequester, string, error)
UpdateDeviceCodeSessionBySignature(ctx context.Context, requestID string, requester fosite.DeviceRequester) error