test: resturcture and improve integration tests

GitOrigin-RevId: 83dfe53cfc33f0a974d7b2f7eeed81d017d2518c
This commit is contained in:
Patrik 2025-09-25 14:28:45 +02:00 committed by ory-bot
parent b49cba50e8
commit 2769a75d0c
20 changed files with 221 additions and 206 deletions

View File

@ -3,5 +3,4 @@
"github.com/ory/x","Apache-2.0"
"github.com/stretchr/testify","MIT"
"go.opentelemetry.io/otel/sdk","Apache-2.0"
"go.opentelemetry.io/otel/sdk","BSD-3-Clause"

1 module name licenses
3 github.com/stretchr/testify MIT
4 go.opentelemetry.io/otel/sdk Apache-2.0
5
6

View File

@ -5,7 +5,6 @@ package server_test
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
@ -129,7 +128,7 @@ func TestGetOrCreateTLSCertificateBase64(t *testing.T) {
}
func TestCreateSelfSignedCertificate(t *testing.T) {
keys, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
keys, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
private := keys.Keys[0]

View File

@ -8,11 +8,8 @@ import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"strings"
@ -24,11 +21,9 @@ import (
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/urfave/negroni"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"golang.org/x/oauth2"
"github.com/ory/fosite/token/jwt"
hydra "github.com/ory/hydra-client-go/v2"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver"
"github.com/ory/hydra/v2/driver/config"
@ -37,7 +32,6 @@ import (
"github.com/ory/x/httpx"
"github.com/ory/x/ioutilx"
"github.com/ory/x/prometheusx"
"github.com/ory/x/uuidx"
)
func NewIDToken(t *testing.T, reg *driver.RegistrySQL, subject string) string {
@ -230,39 +224,6 @@ func NewCallbackURL(t testing.TB, prefix string, h http.HandlerFunc) string {
return ts.URL + "/" + prefix
}
func NewEmptyCookieJar(t testing.TB) *cookiejar.Jar {
c, err := cookiejar.New(&cookiejar.Options{})
require.NoError(t, err)
return c
}
func NewEmptyJarClient(t testing.TB) *http.Client {
return &http.Client{
Jar: NewEmptyCookieJar(t),
Transport: &loggingTransport{t},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
//t.Logf("Redirect to %s", req.URL.String())
if len(via) >= 20 {
for k, v := range via {
t.Logf("Failed with redirect (%d): %s", k, v.URL.String())
}
return errors.New("stopped after 20 redirects")
}
return nil
},
}
}
type loggingTransport struct{ t testing.TB }
func (s *loggingTransport) RoundTrip(r *http.Request) (*http.Response, error) {
//s.t.Logf("%s %s", r.Method, r.URL.String())
//s.t.Logf("%s %s\nWith Cookies: %v", r.Method, r.URL.String(), r.Cookies())
return otelhttp.DefaultClient.Transport.RoundTrip(r)
}
// InsecureDecodeJWT decodes a JWT payload without checking the signature.
func InsecureDecodeJWT(t require.TestingT, token string) []byte {
parts := strings.Split(token, ".")
@ -272,66 +233,11 @@ func InsecureDecodeJWT(t require.TestingT, token string) []byte {
return dec
}
const (
ClientCallbackURL = "https://client.ory/callback"
LoginURL = "https://ui.ory/login"
ConsentURL = "https://ui.ory/consent"
var (
NewEmptyCookieJar = x.NewEmptyCookieJar
NewEmptyJarClient = x.NewEmptyJarClient
)
func GetExpectRedirect(t *testing.T, cl *http.Client, uri string) *url.URL {
resp, err := cl.Get(uri)
require.NoError(t, err)
require.Equalf(t, 3, resp.StatusCode/100, "status: %d\nresponse: %s", resp.StatusCode, ioutilx.MustReadAll(resp.Body))
loc, err := resp.Location()
require.NoError(t, err)
return loc
}
func PerformAuthCodeFlow(t *testing.T, cfg *oauth2.Config, admin *hydra.APIClient, lr func(*testing.T, *hydra.OAuth2LoginRequest) hydra.AcceptOAuth2LoginRequest, cr func(*testing.T, *hydra.OAuth2ConsentRequest) hydra.AcceptOAuth2ConsentRequest, authCodeOpts ...oauth2.AuthCodeOption) *oauth2.Token {
cl := NewEmptyJarClient(t)
cl.CheckRedirect = func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }
// start the auth code flow
state := uuidx.NewV4().String()
loc := GetExpectRedirect(t, cl, cfg.AuthCodeURL(state, authCodeOpts...))
require.Equal(t, LoginURL, fmt.Sprintf("%s://%s%s", loc.Scheme, loc.Host, loc.Path))
// get & submit the login request
lReq, _, err := admin.OAuth2API.GetOAuth2LoginRequest(t.Context()).LoginChallenge(loc.Query().Get("login_challenge")).Execute()
require.NoError(t, err)
v, _, err := admin.OAuth2API.AcceptOAuth2LoginRequest(t.Context()).
LoginChallenge(lReq.Challenge).
AcceptOAuth2LoginRequest(lr(t, lReq)).
Execute()
require.NoError(t, err)
loc = GetExpectRedirect(t, cl, v.RedirectTo)
require.Equal(t, ConsentURL, fmt.Sprintf("%s://%s%s", loc.Scheme, loc.Host, loc.Path))
// get & submit the consent request
cReq, _, err := admin.OAuth2API.GetOAuth2ConsentRequest(t.Context()).ConsentChallenge(loc.Query().Get("consent_challenge")).Execute()
require.NoError(t, err)
v, _, err = admin.OAuth2API.AcceptOAuth2ConsentRequest(t.Context()).
ConsentChallenge(cReq.Challenge).
AcceptOAuth2ConsentRequest(cr(t, cReq)).
Execute()
require.NoError(t, err)
loc = GetExpectRedirect(t, cl, v.RedirectTo)
// ensure we got redirected to the client callback URL
require.Equal(t, ClientCallbackURL, fmt.Sprintf("%s://%s%s", loc.Scheme, loc.Host, loc.Path))
require.Equal(t, state, loc.Query().Get("state"))
// exchange the code for a token
code := loc.Query().Get("code")
require.NotEmpty(t, code)
token, err := cfg.Exchange(t.Context(), code)
require.NoError(t, err)
return token
}
func AssertTokenValid(t *testing.T, accessOrIDToken gjson.Result, sub string) {
assert.Equal(t, sub, accessOrIDToken.Get("sub").Str)
assert.WithinDurationf(t, time.Now(), time.Unix(accessOrIDToken.Get("iat").Int(), 0), time.Minute, "%s", accessOrIDToken.Raw)

View File

@ -4,7 +4,6 @@
package jwk
import (
"context"
"testing"
"github.com/go-jose/go-jose/v3"
@ -15,7 +14,7 @@ import (
func TestMustRSAPrivate(t *testing.T) {
t.Parallel()
keys, err := GenerateJWK(context.Background(), jose.RS256, "foo", "sig")
keys, err := GenerateJWK(jose.RS256, "foo", "sig")
require.NoError(t, err)
priv := keys.Key("foo")[0]

View File

@ -4,18 +4,16 @@
package jwk
import (
"context"
"crypto/x509"
"github.com/gofrs/uuid"
"github.com/go-jose/go-jose/v3"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/ory/x/josex"
)
func GenerateJWK(ctx context.Context, alg jose.SignatureAlgorithm, kid, use string) (*jose.JSONWebKeySet, error) {
func GenerateJWK(alg jose.SignatureAlgorithm, kid, use string) (*jose.JSONWebKeySet, error) {
bits := 0
if alg == jose.RS256 || alg == jose.RS384 || alg == jose.RS512 {
bits = 4096

View File

@ -4,7 +4,6 @@
package jwk
import (
"context"
"testing"
"github.com/go-jose/go-jose/v3"
@ -14,7 +13,7 @@ import (
func TestGenerateJWK(t *testing.T) {
t.Parallel()
jwks, err := GenerateJWK(context.Background(), jose.RS256, "", "")
jwks, err := GenerateJWK(jose.RS256, "", "")
require.NoError(t, err)
assert.NotEmpty(t, jwks.Keys[0].KeyID)
assert.EqualValues(t, jose.RS256, jwks.Keys[0].Algorithm)

View File

@ -41,7 +41,7 @@ func TestHandlerWellKnown(t *testing.T) {
if reg.Config().HSMEnabled() {
t.Skip("Skipping test. Not applicable when Hardware Security Module is enabled. Public/private keys on HSM are generated with equal key id's and are not using prefixes")
}
IDKS, _ := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
IDKS, _ := jwk.GenerateJWK(jose.RS256, "test-id-1", "sig")
require.NoError(t, reg.KeyManager().AddKeySet(context.TODO(), x.OpenIDConnectKeyName, IDKS))
res, err := http.Get(urlx.MustJoin(testServer.URL, JWKPath))
require.NoError(t, err, "problem in http request")
@ -72,7 +72,7 @@ func TestHandlerWellKnown(t *testing.T) {
require.NoError(t, err, "problem in generating keys")
} else {
var err error
IDKS, err = jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-2", "sig")
IDKS, err = jwk.GenerateJWK(jose.RS256, "test-id-2", "sig")
require.NoError(t, err, "problem in generating keys")
IDKS.Keys[0].KeyID = "test-id-2"
require.NoError(t, reg.KeyManager().AddKeySet(context.TODO(), x.OpenIDConnectKeyName, IDKS))

View File

@ -48,7 +48,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
t.Run("Test_Helper/Run_FindPublicKey_With_RSA", func(t *testing.T) {
t.Parallel()
RSIDKS, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
RSIDKS, err := jwk.GenerateJWK(jose.RS256, "test-id-1", "sig")
require.NoError(t, err)
keys, err := jwk.FindPublicKey(RSIDKS)
require.NoError(t, err)
@ -58,7 +58,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
t.Run("Test_Helper/Run_FindPublicKey_With_Opaque", func(t *testing.T) {
t.Parallel()
key, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
key, err := jwk.GenerateJWK(jose.RS256, "test-id-1", "sig")
RSIDKS := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{
Algorithm: "RS256",
Use: "sig",
@ -85,7 +85,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
t.Run("Test_Helper/Run_FindPublicKey_With_ECDSA", func(t *testing.T) {
t.Parallel()
ECDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.ES256, "test-id-2", "sig")
ECDSAIDKS, err := jwk.GenerateJWK(jose.ES256, "test-id-2", "sig")
require.NoError(t, err)
keys, err := jwk.FindPublicKey(ECDSAIDKS)
require.NoError(t, err)
@ -95,7 +95,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
t.Run("Test_Helper/Run_FindPublicKey_With_EdDSA", func(t *testing.T) {
t.Parallel()
EdDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.EdDSA, "test-id-3", "sig")
EdDSAIDKS, err := jwk.GenerateJWK(jose.EdDSA, "test-id-3", "sig")
require.NoError(t, err)
keys, err := jwk.FindPublicKey(EdDSAIDKS)
require.NoError(t, err)
@ -115,7 +115,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
func TestHandlerFindPrivateKey(t *testing.T) {
t.Parallel()
t.Run("Test_Helper/Run_FindPrivateKey_With_RSA", func(t *testing.T) {
RSIDKS, _ := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
RSIDKS, _ := jwk.GenerateJWK(jose.RS256, "test-id-1", "sig")
keys, err := jwk.FindPrivateKey(RSIDKS)
require.NoError(t, err)
assert.Equal(t, keys.KeyID, "test-id-1")
@ -123,7 +123,7 @@ func TestHandlerFindPrivateKey(t *testing.T) {
})
t.Run("Test_Helper/Run_FindPrivateKey_With_ECDSA", func(t *testing.T) {
ECDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.ES256, "test-id-2", "sig")
ECDSAIDKS, err := jwk.GenerateJWK(jose.ES256, "test-id-2", "sig")
require.NoError(t, err)
keys, err := jwk.FindPrivateKey(ECDSAIDKS)
require.NoError(t, err)
@ -132,7 +132,7 @@ func TestHandlerFindPrivateKey(t *testing.T) {
})
t.Run("Test_Helper/Run_FindPrivateKey_With_EdDSA", func(t *testing.T) {
EdDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.EdDSA, "test-id-3", "sig")
EdDSAIDKS, err := jwk.GenerateJWK(jose.EdDSA, "test-id-3", "sig")
require.NoError(t, err)
keys, err := jwk.FindPrivateKey(EdDSAIDKS)
require.NoError(t, err)
@ -151,7 +151,7 @@ func TestHandlerFindPrivateKey(t *testing.T) {
func TestPEMBlockForKey(t *testing.T) {
t.Parallel()
t.Run("Test_Helper/Run_PEMBlockForKey_With_RSA", func(t *testing.T) {
RSIDKS, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
RSIDKS, err := jwk.GenerateJWK(jose.RS256, "test-id-1", "sig")
require.NoError(t, err)
key, err := jwk.FindPrivateKey(RSIDKS)
require.NoError(t, err)
@ -162,7 +162,7 @@ func TestPEMBlockForKey(t *testing.T) {
})
t.Run("Test_Helper/Run_PEMBlockForKey_With_ECDSA", func(t *testing.T) {
ECDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.ES256, "test-id-2", "sig")
ECDSAIDKS, err := jwk.GenerateJWK(jose.ES256, "test-id-2", "sig")
require.NoError(t, err)
key, err := jwk.FindPrivateKey(ECDSAIDKS)
require.NoError(t, err)
@ -173,7 +173,7 @@ func TestPEMBlockForKey(t *testing.T) {
})
t.Run("Test_Helper/Run_PEMBlockForKey_With_EdDSA", func(t *testing.T) {
EdDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.EdDSA, "test-id-3", "sig")
EdDSAIDKS, err := jwk.GenerateJWK(jose.EdDSA, "test-id-3", "sig")
require.NoError(t, err)
key, err := jwk.FindPrivateKey(EdDSAIDKS)
require.NoError(t, err)
@ -193,7 +193,7 @@ func TestPEMBlockForKey(t *testing.T) {
func TestExcludeOpaquePrivateKeys(t *testing.T) {
t.Parallel()
opaqueKeys, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
opaqueKeys, err := jwk.GenerateJWK(jose.RS256, "test-id-1", "sig")
assert.NoError(t, err)
require.Len(t, opaqueKeys.Keys, 1)
opaqueKeys.Keys[0].Key = cryptosigner.Opaque(opaqueKeys.Keys[0].Key.(*rsa.PrivateKey))
@ -213,7 +213,7 @@ func TestGetOrGenerateKeys(t *testing.T) {
setID := uuid.NewUUID().String()
keyID := uuid.NewUUID().String()
keySet, err := jwk.GenerateJWK(t.Context(), jose.RS256, keyID, "sig")
keySet, err := jwk.GenerateJWK(jose.RS256, keyID, "sig")
require.NoError(t, err)
require.Len(t, keySet.Keys, 1)
keySetWithoutPrivateKey := &jose.JSONWebKeySet{
@ -272,7 +272,7 @@ func TestGetOrGenerateKeys(t *testing.T) {
}
func TestOnlyPublicSDKKeys(t *testing.T) {
set, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
set, err := jwk.GenerateJWK(jose.RS256, "test-id-1", "sig")
require.NoError(t, err)
out, err := json.Marshal(set.Keys)

View File

@ -200,7 +200,7 @@ func TestHelperManagerGenerateAndPersistKeySet(m Manager, alg string, parallel b
func TestHelperNID(t1ValidNID, t2InvalidNID Manager) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.Background()
jwks, err := GenerateJWK(ctx, jose.RS256, "2022-03-11-ks-1-kid", "test")
jwks, err := GenerateJWK(jose.RS256, "2022-03-11-ks-1-kid", "test")
require.NoError(t, err)
require.Error(t, t2InvalidNID.AddKey(ctx, "2022-03-11-k-1", &jwks.Keys[0]))
require.NoError(t, t1ValidNID.AddKey(ctx, "2022-03-11-k-1", &jwks.Keys[0]))

View File

@ -36,8 +36,8 @@ func TestAuthCodeFlowE2E(t *testing.T) {
reg := testhelpers.NewRegistryMemory(t, driver.WithConfigOptions(configx.WithValues(map[string]any{
config.KeyAccessTokenStrategy: "opaque",
config.KeyRefreshTokenHook: "",
config.KeyLoginURL: testhelpers.LoginURL,
config.KeyConsentURL: testhelpers.ConsentURL,
config.KeyLoginURL: x.LoginURL,
config.KeyConsentURL: x.ConsentURL,
config.KeyAccessTokenLifespan: 10 * time.Minute, // allow to debug
config.KeyRefreshTokenLifespan: 20 * time.Minute, // allow to debug
config.KeyScopeStrategy: "exact",
@ -55,18 +55,18 @@ func TestAuthCodeFlowE2E(t *testing.T) {
t.Run("auth code flow", func(t *testing.T) {
t.Run("rejects invalid audience", func(t *testing.T) {
cl := testhelpers.NewEmptyJarClient(t)
cl := x.NewEmptyJarClient(t)
cl.CheckRedirect = func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }
_, conf := newOAuth2Client(t, reg, testhelpers.ClientCallbackURL)
loc := testhelpers.GetExpectRedirect(t, cl, conf.AuthCodeURL(uuidx.NewV4().String(), oauth2.SetAuthURLParam("audience", "invalid-audience")))
require.Equal(t, testhelpers.ClientCallbackURL, fmt.Sprintf("%s://%s%s", loc.Scheme, loc.Host, loc.Path))
_, conf := newOAuth2Client(t, reg, x.ClientCallbackURL)
loc := x.GetExpectRedirect(t, cl, conf.AuthCodeURL(uuidx.NewV4().String(), oauth2.SetAuthURLParam("audience", "invalid-audience")))
require.Equal(t, x.ClientCallbackURL, fmt.Sprintf("%s://%s%s", loc.Scheme, loc.Host, loc.Path))
assert.Equal(t, "invalid_request", loc.Query().Get("error"))
assert.Contains(t, loc.Query().Get("error_description"), "Requested audience 'invalid-audience' has not been whitelisted by the OAuth 2.0 Client.")
})
for _, accessTokenStrategy := range []string{"opaque", "jwt"} {
t.Run("strategy="+accessTokenStrategy, func(t *testing.T) {
cl, conf := newOAuth2Client(t, reg, testhelpers.ClientCallbackURL, func(c *client.Client) {
cl, conf := newOAuth2Client(t, reg, x.ClientCallbackURL, func(c *client.Client) {
c.AccessTokenStrategy = accessTokenStrategy
c.Audience = []string{"audience-1", "audience-2"}
c.ID = "64f78bf1-f388-4eeb-9fee-e7207226c6be-" + accessTokenStrategy
@ -74,29 +74,27 @@ func TestAuthCodeFlowE2E(t *testing.T) {
sub := "c6a8ee1c-e0c4-404c-bba7-6a5b8702a2e9"
t.Run("access and id tokens with extra claims", func(t *testing.T) {
token := testhelpers.PerformAuthCodeFlow(t, conf, adminClient,
func(t *testing.T, req *hydra.OAuth2LoginRequest) hydra.AcceptOAuth2LoginRequest {
snapshotx.SnapshotT(t, req,
snapshotx.ExceptPaths("challenge", "client.created_at", "client.updated_at", "session_id", "request_url"),
snapshotx.WithName("login_request"))
return hydra.AcceptOAuth2LoginRequest{
Amr: []string{"amr1", "amr2"},
Acr: pointerx.Ptr("acr-value"),
Subject: sub,
}
},
func(t *testing.T, req *hydra.OAuth2ConsentRequest) hydra.AcceptOAuth2ConsentRequest {
snapshotx.SnapshotT(t, req,
snapshotx.ExceptPaths("challenge", "client.created_at", "client.updated_at", "consent_request_id", "login_challenge", "login_session_id", "request_url"),
snapshotx.WithName("consent_request"))
return hydra.AcceptOAuth2ConsentRequest{
GrantScope: []string{"openid"},
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]any{"key_access": "extra access token value"},
IdToken: map[string]any{"key_id": "extra id token value"},
},
}
})
token := x.PerformAuthCodeFlow(t.Context(), t, nil, conf, adminClient, func(t *testing.T, req *hydra.OAuth2LoginRequest) hydra.AcceptOAuth2LoginRequest {
snapshotx.SnapshotT(t, req,
snapshotx.ExceptPaths("challenge", "client.created_at", "client.updated_at", "session_id", "request_url"),
snapshotx.WithName("login_request"))
return hydra.AcceptOAuth2LoginRequest{
Amr: []string{"amr1", "amr2"},
Acr: pointerx.Ptr("acr-value"),
Subject: sub,
}
}, func(t *testing.T, req *hydra.OAuth2ConsentRequest) hydra.AcceptOAuth2ConsentRequest {
snapshotx.SnapshotT(t, req,
snapshotx.ExceptPaths("challenge", "client.created_at", "client.updated_at", "consent_request_id", "login_challenge", "login_session_id", "request_url"),
snapshotx.WithName("consent_request"))
return hydra.AcceptOAuth2ConsentRequest{
GrantScope: []string{"openid"},
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]any{"key_access": "extra access token value"},
IdToken: map[string]any{"key_id": "extra id token value"},
},
}
})
// check access token
introspected := testhelpers.IntrospectToken(t, token.AccessToken, adminTS)
@ -120,19 +118,17 @@ func TestAuthCodeFlowE2E(t *testing.T) {
})
t.Run("refreshed access and id tokens with extra claims", func(t *testing.T) {
token := testhelpers.PerformAuthCodeFlow(t, conf, adminClient,
func(*testing.T, *hydra.OAuth2LoginRequest) hydra.AcceptOAuth2LoginRequest {
return hydra.AcceptOAuth2LoginRequest{Subject: sub}
},
func(*testing.T, *hydra.OAuth2ConsentRequest) hydra.AcceptOAuth2ConsentRequest {
return hydra.AcceptOAuth2ConsentRequest{
GrantScope: []string{"openid", "offline"},
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]any{"key_access": "extra access token value"},
IdToken: map[string]any{"key_id": "extra id token value"},
},
}
})
token := x.PerformAuthCodeFlow(t.Context(), t, nil, conf, adminClient, func(*testing.T, *hydra.OAuth2LoginRequest) hydra.AcceptOAuth2LoginRequest {
return hydra.AcceptOAuth2LoginRequest{Subject: sub}
}, func(*testing.T, *hydra.OAuth2ConsentRequest) hydra.AcceptOAuth2ConsentRequest {
return hydra.AcceptOAuth2ConsentRequest{
GrantScope: []string{"openid", "offline"},
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]any{"key_access": "extra access token value"},
IdToken: map[string]any{"key_id": "extra id token value"},
},
}
})
token.Expiry = time.Now().Add(-time.Hour)
refreshed, err := conf.TokenSource(t.Context(), token).Token()
@ -169,19 +165,16 @@ func TestAuthCodeFlowE2E(t *testing.T) {
})
t.Run("audience is forwarded to access token", func(t *testing.T) {
token := testhelpers.PerformAuthCodeFlow(t, conf, adminClient,
func(t *testing.T, req *hydra.OAuth2LoginRequest) hydra.AcceptOAuth2LoginRequest {
assert.EqualValues(t, cl.Audience, req.RequestedAccessTokenAudience)
return hydra.AcceptOAuth2LoginRequest{Subject: sub}
},
func(t *testing.T, req *hydra.OAuth2ConsentRequest) hydra.AcceptOAuth2ConsentRequest {
assert.EqualValues(t, cl.Audience, req.RequestedAccessTokenAudience)
return hydra.AcceptOAuth2ConsentRequest{
GrantScope: []string{"openid"},
GrantAccessTokenAudience: req.RequestedAccessTokenAudience,
}
},
oauth2.SetAuthURLParam("audience", strings.Join(cl.Audience, " ")))
token := x.PerformAuthCodeFlow(t.Context(), t, nil, conf, adminClient, func(t *testing.T, req *hydra.OAuth2LoginRequest) hydra.AcceptOAuth2LoginRequest {
assert.EqualValues(t, cl.Audience, req.RequestedAccessTokenAudience)
return hydra.AcceptOAuth2LoginRequest{Subject: sub}
}, func(t *testing.T, req *hydra.OAuth2ConsentRequest) hydra.AcceptOAuth2ConsentRequest {
assert.EqualValues(t, cl.Audience, req.RequestedAccessTokenAudience)
return hydra.AcceptOAuth2ConsentRequest{
GrantScope: []string{"openid"},
GrantAccessTokenAudience: req.RequestedAccessTokenAudience,
}
}, oauth2.SetAuthURLParam("audience", strings.Join(cl.Audience, " ")))
expectedAud, err := json.Marshal(cl.Audience)
require.NoError(t, err)

View File

@ -1000,7 +1000,7 @@ func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.
grantStorage := x.OAuth2Storage().(rfc7523.RFC7523KeyStorage)
t.Run("case=associated key added with grant", func(t *testing.T) {
keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
publicKey := keySet.Keys[0].Public()
@ -1045,14 +1045,14 @@ func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.
})
t.Run("case=only associated key returns", func(t *testing.T) {
keySetToNotReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, uuid.Must(uuid.NewV4()).String(), "sig")
keySetToNotReturn, err := jwk.GenerateJWK(jose.ES256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
require.NoError(t, keyManager.AddKeySet(context.Background(), uuid.Must(uuid.NewV4()).String(), keySetToNotReturn), "adding a random key should not fail")
issuer := uuid.Must(uuid.NewV4()).String()
subject := "maria+" + uuid.Must(uuid.NewV4()).String() + "@example.com"
keySet1ToReturn, err := jwk.GenerateJWK(t.Context(), jose.ES256, uuid.Must(uuid.NewV4()).String(), "sig")
keySet1ToReturn, err := jwk.GenerateJWK(jose.ES256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
require.NoError(t, grantManager.CreateGrant(t.Context(), trust.Grant{
ID: uuid.Must(uuid.NewV4()),
@ -1065,7 +1065,7 @@ func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.
ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
}, keySet1ToReturn.Keys[0].Public()))
keySet2ToReturn, err := jwk.GenerateJWK(t.Context(), jose.ES256, uuid.Must(uuid.NewV4()).String(), "sig")
keySet2ToReturn, err := jwk.GenerateJWK(jose.ES256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
require.NoError(t, grantManager.CreateGrant(ctx, trust.Grant{
ID: uuid.Must(uuid.NewV4()),
@ -1108,7 +1108,7 @@ func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.
})
t.Run("case=associated key is deleted, when granted is deleted", func(t *testing.T) {
keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
publicKey := keySet.Keys[0].Public()
@ -1144,7 +1144,7 @@ func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.
})
t.Run("case=associated grant is deleted, when key is deleted", func(t *testing.T) {
keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
publicKey := keySet.Keys[0].Public()
@ -1180,7 +1180,7 @@ func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.
})
t.Run("case=only returns the key when subject matches", func(t *testing.T) {
keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
publicKey := keySet.Keys[0].Public()
@ -1221,7 +1221,7 @@ func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.
})
t.Run("case=returns the key when any subject is allowed", func(t *testing.T) {
keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
publicKey := keySet.Keys[0].Public()
@ -1253,7 +1253,7 @@ func testFositeJWTBearerGrantStorage(x oauth2.InternalRegistry) func(t *testing.
})
t.Run("case=does not return expired values", func(t *testing.T) {
keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
require.NoError(t, err)
publicKey := keySet.Keys[0].Public()

View File

@ -90,9 +90,9 @@ func BenchmarkAuthCode(b *testing.B) {
})),
driver.WithTracerWrapper(func(t *otelx.Tracer) *otelx.Tracer { return new(otelx.Tracer).WithOTLP(tracer) }),
)
oauth2Keys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OAuth2JWTKeyName, "sig")
oauth2Keys, err := jwk.GenerateJWK(jose.ES256, x.OAuth2JWTKeyName, "sig")
require.NoError(b, err)
oidcKeys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OpenIDConnectKeyName, "sig")
oidcKeys, err := jwk.GenerateJWK(jose.ES256, x.OpenIDConnectKeyName, "sig")
require.NoError(b, err)
_, _ = oauth2Keys, oidcKeys
require.NoError(b, reg.KeyManager().UpdateKeySet(ctx, x.OAuth2JWTKeyName, oauth2Keys))

View File

@ -135,7 +135,7 @@ func TestJWTBearer(t *testing.T) {
})
set, kid := uuid.Must(uuid.NewV4()).String(), uuid.Must(uuid.NewV4()).String()
keys, err := jwk.GenerateJWK(ctx, jose.RS256, kid, "sig")
keys, err := jwk.GenerateJWK(jose.RS256, kid, "sig")
require.NoError(t, err)
trustGrant := trust.Grant{
ID: uuid.Must(uuid.NewV4()),
@ -208,7 +208,7 @@ func TestJWTBearer(t *testing.T) {
})
t.Run("case=unable to exchange token with an invalid key", func(t *testing.T) {
keys, err := jwk.GenerateJWK(ctx, jose.RS256, kid, "sig")
keys, err := jwk.GenerateJWK(jose.RS256, kid, "sig")
require.NoError(t, err)
signer := jwk.NewDefaultJWTSigner(reg, set)
signer.GetPrivateKey = func(ctx context.Context) (interface{}, error) {

View File

@ -28,15 +28,15 @@ func TestHelperGrantManagerCreateGetDeleteGrant(t1 GrantManager, km jwk.Manager,
kid3 := uuid.Must(uuid.NewV4()).String()
set := uuid.Must(uuid.NewV4()).String()
key1, err := jwk.GenerateJWK(t.Context(), jose.RS256, kid1, "sig")
key1, err := jwk.GenerateJWK(jose.RS256, kid1, "sig")
require.NoError(t, err)
tokenServicePubKey1 := josex.ToPublicKey(&key1.Keys[0])
key2, err := jwk.GenerateJWK(t.Context(), jose.RS256, kid2, "sig")
key2, err := jwk.GenerateJWK(jose.RS256, kid2, "sig")
require.NoError(t, err)
tokenServicePubKey2 := josex.ToPublicKey(&key2.Keys[0])
key3, err := jwk.GenerateJWK(t.Context(), jose.RS256, kid3, "sig")
key3, err := jwk.GenerateJWK(jose.RS256, kid3, "sig")
require.NoError(t, err)
mikePubKey := josex.ToPublicKey(&key3.Keys[0])
@ -157,11 +157,11 @@ func TestHelperGrantManagerErrors(m GrantManager, km jwk.Manager) func(t *testin
t.Parallel()
key1, err := jwk.GenerateJWK(t.Context(), jose.RS256, kid1, "sig")
key1, err := jwk.GenerateJWK(jose.RS256, kid1, "sig")
require.NoError(t, err)
pubKey1 := josex.ToPublicKey(&key1.Keys[0])
key2, err := jwk.GenerateJWK(t.Context(), jose.RS256, kid2, "sig")
key2, err := jwk.GenerateJWK(jose.RS256, kid2, "sig")
require.NoError(t, err)
pubKey2 := josex.ToPublicKey(&key2.Keys[0])

View File

@ -146,8 +146,16 @@ var _ io.Writer = (*CallbackWriter)(nil)
func prepareCmd(cmd *cobra.Command, stdIn io.Reader, stdOut, stdErr io.Writer, args []string) {
cmd.SetIn(stdIn)
cmd.SetOut(io.MultiWriter(stdOut, debugStdout))
cmd.SetErr(io.MultiWriter(stdErr, debugStderr))
outs := []io.Writer{debugStdout}
if stdOut != nil {
outs = append(outs, stdOut)
}
cmd.SetOut(io.MultiWriter(outs...))
errs := []io.Writer{debugStderr}
if stdErr != nil {
errs = append(errs, stdErr)
}
cmd.SetErr(io.MultiWriter(errs...))
if args == nil {
args = []string{}

View File

@ -3,7 +3,10 @@
package contextx
import "context"
import (
"context"
"testing"
)
type ContextKey int
@ -13,6 +16,10 @@ const (
var RootContext = context.WithValue(context.Background(), ValidContextKey, true)
func TestRootContext(t *testing.T) context.Context {
return context.WithValue(t.Context(), ValidContextKey, true)
}
func IsRootContext(ctx context.Context) bool {
is, ok := ctx.Value(ValidContextKey).(bool)
return is && ok

View File

@ -38,7 +38,7 @@ func (p *JWKPersister) GenerateAndPersistKeySet(ctx context.Context, set, kid, a
kid = uuid.Must(uuid.NewV4()).String()
}
keys, err := jwk.GenerateJWK(ctx, jose.SignatureAlgorithm(alg), kid, use)
keys, err := jwk.GenerateJWK(jose.SignatureAlgorithm(alg), kid, use)
if err != nil {
return nil, errors.Wrapf(jwk.ErrUnsupportedKeyAlgorithm, "%s", err)
}

View File

@ -765,7 +765,7 @@ func (s *PersisterTestSuite) TestGenerateAndPersistKeySet() {
s.T().Run(k, func(t *testing.T) {
actual := &jwk.SQLData{}
key, err := jwk.GenerateJWK(s.t1, "RS256", "kid", "use")
key, err := jwk.GenerateJWK("RS256", "kid", "use")
require.NoError(t, err)
require.NoError(t, r.KeyManager().AddKey(s.t1, "ks", pointerx.Ptr(key.Keys[0].Public())))
@ -1908,7 +1908,7 @@ func newLogoutRequest() *flow.LogoutRequest {
}
func newKey(ksID string, use string) jose.JSONWebKey {
ks, err := jwk.GenerateJWK(context.Background(), jose.RS256, ksID, use)
ks, err := jwk.GenerateJWK(jose.RS256, ksID, use)
if err != nil {
panic(err)
}
@ -1916,7 +1916,7 @@ func newKey(ksID string, use string) jose.JSONWebKey {
}
func newKeySet(id string, use string) *jose.JSONWebKeySet {
return x.Must(jwk.GenerateJWK(context.Background(), jose.RS256, id, use))
return x.Must(jwk.GenerateJWK(jose.RS256, id, use))
}
func newLoginSession() *flow.LoginSession {

View File

@ -75,7 +75,7 @@ func testRegistry(t *testing.T, k string, t1, t2 *driver.RegistrySQL) {
} else {
kid, err := uuid.NewV4()
require.NoError(t, err)
ks, err := jwk.GenerateJWK(context.Background(), jose.SignatureAlgorithm(tc.alg), kid.String(), "sig")
ks, err := jwk.GenerateJWK(jose.SignatureAlgorithm(tc.alg), kid.String(), "sig")
require.NoError(t, err)
t.Run("TestManagerKey", jwk.TestHelperManagerKey(t1.KeyManager(), tc.alg, ks, kid.String()))
t.Run("Parallel", func(t *testing.T) {

107
x/oauth2_test_client.go Normal file
View File

@ -0,0 +1,107 @@
package x
import (
"context"
"fmt"
"net/http"
"net/http/cookiejar"
"net/url"
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
hydra "github.com/ory/hydra-client-go/v2"
"github.com/ory/x/ioutilx"
"github.com/ory/x/uuidx"
)
func NewEmptyCookieJar(t testing.TB) *cookiejar.Jar {
c, err := cookiejar.New(&cookiejar.Options{})
require.NoError(t, err)
return c
}
func NewEmptyJarClient(t testing.TB) *http.Client {
return &http.Client{
Jar: NewEmptyCookieJar(t),
CheckRedirect: func(req *http.Request, via []*http.Request) error {
//t.Logf("Redirect to %s", req.URL.String())
if len(via) >= 20 {
for k, v := range via {
t.Logf("Failed with redirect (%d): %s", k, v.URL.String())
}
return errors.New("stopped after 20 redirects")
}
return nil
},
}
}
func GetExpectRedirect(t *testing.T, cl *http.Client, uri string) *url.URL {
resp, err := cl.Get(uri)
require.NoError(t, err)
require.Equalf(t, 3, resp.StatusCode/100, "status: %d\nresponse: %s", resp.StatusCode, ioutilx.MustReadAll(resp.Body))
loc, err := resp.Location()
require.NoError(t, err)
return loc
}
const (
ClientCallbackURL = "https://client.ory/callback"
LoginURL = "https://ui.ory/login"
ConsentURL = "https://ui.ory/consent"
)
func PerformAuthCodeFlow(ctx context.Context, t *testing.T, baseClient *http.Client, cfg *oauth2.Config, admin *hydra.APIClient, lr func(*testing.T, *hydra.OAuth2LoginRequest) hydra.AcceptOAuth2LoginRequest, cr func(*testing.T, *hydra.OAuth2ConsentRequest) hydra.AcceptOAuth2ConsentRequest, authCodeOpts ...oauth2.AuthCodeOption) *oauth2.Token {
var cl http.Client
if baseClient != nil {
cl = *baseClient
}
if cl.Jar == nil {
cl.Jar = NewEmptyCookieJar(t)
}
cl.CheckRedirect = func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }
// start the auth code flow
state := uuidx.NewV4().String()
loc := GetExpectRedirect(t, &cl, cfg.AuthCodeURL(state, authCodeOpts...))
require.Equal(t, LoginURL, fmt.Sprintf("%s://%s%s", loc.Scheme, loc.Host, loc.Path))
// get & submit the login request
lReq, _, err := admin.OAuth2API.GetOAuth2LoginRequest(ctx).LoginChallenge(loc.Query().Get("login_challenge")).Execute()
require.NoError(t, err)
v, _, err := admin.OAuth2API.AcceptOAuth2LoginRequest(ctx).
LoginChallenge(lReq.Challenge).
AcceptOAuth2LoginRequest(lr(t, lReq)).
Execute()
require.NoError(t, err)
loc = GetExpectRedirect(t, &cl, v.RedirectTo)
require.Equal(t, ConsentURL, fmt.Sprintf("%s://%s%s", loc.Scheme, loc.Host, loc.Path))
// get & submit the consent request
cReq, _, err := admin.OAuth2API.GetOAuth2ConsentRequest(ctx).ConsentChallenge(loc.Query().Get("consent_challenge")).Execute()
require.NoError(t, err)
v, _, err = admin.OAuth2API.AcceptOAuth2ConsentRequest(ctx).
ConsentChallenge(cReq.Challenge).
AcceptOAuth2ConsentRequest(cr(t, cReq)).
Execute()
require.NoError(t, err)
loc = GetExpectRedirect(t, &cl, v.RedirectTo)
// ensure we got redirected to the client callback URL
require.Equal(t, ClientCallbackURL, fmt.Sprintf("%s://%s%s", loc.Scheme, loc.Host, loc.Path))
require.Equal(t, state, loc.Query().Get("state"))
// exchange the code for a token
code := loc.Query().Get("code")
require.NotEmpty(t, code)
token, err := cfg.Exchange(ctx, code)
require.NoError(t, err)
return token
}