mirror of https://github.com/ory/hydra
306 lines
11 KiB
Go
306 lines
11 KiB
Go
// Copyright © 2022 Ory Corp
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package oauth2_test
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"net/http"
|
|
"os"
|
|
"runtime"
|
|
"runtime/pprof"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v3"
|
|
"github.com/pborman/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
|
|
"go.opentelemetry.io/otel/propagation"
|
|
"go.opentelemetry.io/otel/sdk/resource"
|
|
"go.opentelemetry.io/otel/sdk/trace"
|
|
"go.opentelemetry.io/otel/sdk/trace/tracetest"
|
|
semconv "go.opentelemetry.io/otel/semconv/v1.12.0"
|
|
"golang.org/x/oauth2"
|
|
|
|
hydra "github.com/ory/hydra-client-go/v2"
|
|
hc "github.com/ory/hydra/v2/client"
|
|
"github.com/ory/hydra/v2/driver/config"
|
|
"github.com/ory/hydra/v2/internal"
|
|
"github.com/ory/hydra/v2/internal/testhelpers"
|
|
"github.com/ory/hydra/v2/jwk"
|
|
"github.com/ory/hydra/v2/x"
|
|
"github.com/ory/x/contextx"
|
|
"github.com/ory/x/pointerx"
|
|
"github.com/ory/x/stringsx"
|
|
)
|
|
|
|
var (
|
|
prof = flag.String("profile", "", "write a CPU profile to this filename")
|
|
conc = flag.Int("conc", 100, "dispatch this many requests concurrently")
|
|
tracing = flag.Bool("tracing", false, "send OpenTelemetry traces to localhost:4318")
|
|
)
|
|
|
|
func BenchmarkAuthCode(b *testing.B) {
|
|
flag.Parse()
|
|
|
|
ctx := context.Background()
|
|
|
|
spans := tracetest.NewSpanRecorder()
|
|
opts := []trace.TracerProviderOption{
|
|
trace.WithSpanProcessor(spans),
|
|
trace.WithResource(resource.NewWithAttributes(
|
|
semconv.SchemaURL, attribute.String(string(semconv.ServiceNameKey), "BenchmarkAuthCode"),
|
|
)),
|
|
}
|
|
if *tracing {
|
|
exporter, err := otlptracehttp.New(ctx, otlptracehttp.WithInsecure(), otlptracehttp.WithEndpoint("localhost:4318"))
|
|
require.NoError(b, err)
|
|
opts = append(opts, trace.WithSpanProcessor(trace.NewSimpleSpanProcessor(exporter)))
|
|
}
|
|
provider := trace.NewTracerProvider(opts...)
|
|
|
|
tracer := provider.Tracer("BenchmarkAuthCode")
|
|
otel.SetTextMapPropagator(propagation.TraceContext{})
|
|
otel.SetTracerProvider(provider)
|
|
|
|
ctx, span := tracer.Start(ctx, "BenchmarkAuthCode")
|
|
defer span.End()
|
|
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, otelhttp.DefaultClient)
|
|
|
|
dsn := stringsx.Coalesce(os.Getenv("DSN"), "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable&max_conns=20&max_idle_conns=20")
|
|
// dsn := "mysql://root:secret@tcp(localhost:3444)/mysql?max_conns=16&max_idle_conns=16"
|
|
// dsn := "cockroach://root@localhost:3446/defaultdb?sslmode=disable&max_conns=16&max_idle_conns=16"
|
|
reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer)
|
|
reg.Config().MustSet(ctx, config.KeyLogLevel, "error")
|
|
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
|
|
reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "")
|
|
oauth2Keys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OAuth2JWTKeyName, "sig")
|
|
require.NoError(b, err)
|
|
oidcKeys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OpenIDConnectKeyName, "sig")
|
|
require.NoError(b, err)
|
|
_, _ = oauth2Keys, oidcKeys
|
|
require.NoError(b, reg.KeyManager().UpdateKeySet(ctx, x.OAuth2JWTKeyName, oauth2Keys))
|
|
require.NoError(b, reg.KeyManager().UpdateKeySet(ctx, x.OpenIDConnectKeyName, oidcKeys))
|
|
_, adminTS := testhelpers.NewOAuth2Server(ctx, b, reg)
|
|
var (
|
|
authURL = reg.Config().OAuth2AuthURL(ctx).String()
|
|
tokenURL = reg.Config().OAuth2TokenURL(ctx).String()
|
|
nonce = uuid.New()
|
|
)
|
|
|
|
newOAuth2Client := func(b *testing.B, cb string) (*hc.Client, *oauth2.Config) {
|
|
secret := uuid.New()
|
|
c := &hc.Client{
|
|
Secret: secret,
|
|
RedirectURIs: []string{cb},
|
|
ResponseTypes: []string{"id_token", "code", "token"},
|
|
GrantTypes: []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"},
|
|
Scope: "hydra offline openid",
|
|
Audience: []string{"https://api.ory.sh/"},
|
|
}
|
|
require.NoError(b, reg.ClientManager().CreateClient(ctx, c))
|
|
return c, &oauth2.Config{
|
|
ClientID: c.GetID(),
|
|
ClientSecret: secret,
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: authURL,
|
|
TokenURL: tokenURL,
|
|
AuthStyle: oauth2.AuthStyleInHeader,
|
|
},
|
|
Scopes: strings.Split(c.Scope, " "),
|
|
}
|
|
}
|
|
|
|
cfg := hydra.NewConfiguration()
|
|
cfg.HTTPClient = otelhttp.DefaultClient
|
|
adminClient := hydra.NewAPIClient(cfg)
|
|
adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}}
|
|
|
|
getAuthorizeCode := func(ctx context.Context, b *testing.B, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) {
|
|
if c == nil {
|
|
c = testhelpers.NewEmptyJarClient(b)
|
|
}
|
|
|
|
state := uuid.New()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", conf.AuthCodeURL(state, params...), nil)
|
|
require.NoError(b, err)
|
|
resp, err := c.Do(req)
|
|
require.NoError(b, err)
|
|
defer resp.Body.Close()
|
|
|
|
q := resp.Request.URL.Query()
|
|
require.EqualValues(b, state, q.Get("state"))
|
|
return q.Get("code"), resp
|
|
}
|
|
|
|
acceptLoginHandler := func(b *testing.B, c *hc.Client, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc {
|
|
return otelhttp.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
rr, _, err := adminClient.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute()
|
|
require.NoError(b, err)
|
|
|
|
assert.EqualValues(b, c.GetID(), pointerx.Deref(rr.Client.ClientId))
|
|
assert.Empty(b, pointerx.Deref(rr.Client.ClientSecret))
|
|
assert.EqualValues(b, c.GrantTypes, rr.Client.GrantTypes)
|
|
assert.EqualValues(b, c.LogoURI, pointerx.Deref(rr.Client.LogoUri))
|
|
assert.EqualValues(b, c.RedirectURIs, rr.Client.RedirectUris)
|
|
assert.EqualValues(b, r.URL.Query().Get("login_challenge"), rr.Challenge)
|
|
assert.EqualValues(b, []string{"hydra", "offline", "openid"}, rr.RequestedScope)
|
|
assert.Contains(b, rr.RequestUrl, authURL)
|
|
|
|
acceptBody := hydra.AcceptOAuth2LoginRequest{
|
|
Subject: uuid.New(),
|
|
Remember: pointerx.Ptr(!rr.Skip),
|
|
Acr: pointerx.Ptr("1"),
|
|
Amr: []string{"pwd"},
|
|
Context: map[string]interface{}{"context": "bar"},
|
|
}
|
|
if checkRequestPayload != nil {
|
|
if b := checkRequestPayload(rr); b != nil {
|
|
acceptBody = *b
|
|
}
|
|
}
|
|
|
|
v, _, err := adminClient.OAuth2Api.AcceptOAuth2LoginRequest(ctx).
|
|
LoginChallenge(r.URL.Query().Get("login_challenge")).
|
|
AcceptOAuth2LoginRequest(acceptBody).
|
|
Execute()
|
|
require.NoError(b, err)
|
|
require.NotEmpty(b, v.RedirectTo)
|
|
http.Redirect(w, r, v.RedirectTo, http.StatusFound)
|
|
}), "acceptLoginHandler").ServeHTTP
|
|
}
|
|
|
|
acceptConsentHandler := func(b *testing.B, c *hc.Client, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc {
|
|
return otelhttp.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
rr, _, err := adminClient.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute()
|
|
require.NoError(b, err)
|
|
|
|
assert.EqualValues(b, c.GetID(), pointerx.Deref(rr.Client.ClientId))
|
|
assert.Empty(b, pointerx.Deref(rr.Client.ClientSecret))
|
|
assert.EqualValues(b, c.GrantTypes, rr.Client.GrantTypes)
|
|
assert.EqualValues(b, c.LogoURI, pointerx.Deref(rr.Client.LogoUri))
|
|
assert.EqualValues(b, c.RedirectURIs, rr.Client.RedirectUris)
|
|
// assert.EqualValues(b, subject, pointerx.Deref(rr.Subject))
|
|
assert.EqualValues(b, []string{"hydra", "offline", "openid"}, rr.RequestedScope)
|
|
assert.EqualValues(b, r.URL.Query().Get("consent_challenge"), rr.Challenge)
|
|
assert.Contains(b, *rr.RequestUrl, authURL)
|
|
if checkRequestPayload != nil {
|
|
checkRequestPayload(rr)
|
|
}
|
|
|
|
assert.Equal(b, map[string]interface{}{"context": "bar"}, rr.Context)
|
|
v, _, err := adminClient.OAuth2Api.AcceptOAuth2ConsentRequest(ctx).
|
|
ConsentChallenge(r.URL.Query().Get("consent_challenge")).
|
|
AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{
|
|
GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0),
|
|
GrantAccessTokenAudience: rr.RequestedAccessTokenAudience,
|
|
Session: &hydra.AcceptOAuth2ConsentRequestSession{
|
|
AccessToken: map[string]interface{}{"foo": "bar"},
|
|
IdToken: map[string]interface{}{"bar": "baz"},
|
|
},
|
|
}).
|
|
Execute()
|
|
require.NoError(b, err)
|
|
require.NotEmpty(b, v.RedirectTo)
|
|
http.Redirect(w, r, v.RedirectTo, http.StatusFound)
|
|
}), "acceptConsentHandler").ServeHTTP
|
|
}
|
|
|
|
run := func(b *testing.B, strategy string) func(*testing.B) {
|
|
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
|
|
c, conf := newOAuth2Client(b, testhelpers.NewCallbackURL(b, "callback", testhelpers.HTTPServerNotImplementedHandler))
|
|
testhelpers.NewLoginConsentUI(b, reg.Config(),
|
|
acceptLoginHandler(b, c, nil),
|
|
acceptConsentHandler(b, c, nil),
|
|
)
|
|
|
|
return func(b *testing.B) {
|
|
//pop.Debug = true
|
|
code, _ := getAuthorizeCode(ctx, b, conf, nil, oauth2.SetAuthURLParam("nonce", nonce))
|
|
require.NotEmpty(b, code)
|
|
|
|
_, err := conf.Exchange(ctx, code)
|
|
//pop.Debug = false
|
|
require.NoError(b, err)
|
|
}
|
|
}
|
|
|
|
b.ResetTimer()
|
|
|
|
b.SetParallelism(*conc / runtime.GOMAXPROCS(0))
|
|
|
|
b.Run("strategy=jwt", func(b *testing.B) {
|
|
initialDBSpans := dbSpans(spans)
|
|
B := run(b, "jwt")
|
|
|
|
stop := profile(b)
|
|
defer stop()
|
|
|
|
var totalMS int64 = 0
|
|
b.RunParallel(func(p *testing.PB) {
|
|
defer func(t0 time.Time) {
|
|
atomic.AddInt64(&totalMS, int64(time.Since(t0).Milliseconds()))
|
|
}(time.Now())
|
|
for p.Next() {
|
|
B(b)
|
|
}
|
|
})
|
|
|
|
b.ReportMetric(0, "ns/op")
|
|
b.ReportMetric(float64(atomic.LoadInt64(&totalMS))/float64(b.N), "ms/op")
|
|
b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op")
|
|
b.ReportMetric(float64(b.N)/b.Elapsed().Seconds(), "ops/s")
|
|
})
|
|
|
|
b.Run("strategy=opaque", func(b *testing.B) {
|
|
initialDBSpans := dbSpans(spans)
|
|
B := run(b, "opaque")
|
|
|
|
stop := profile(b)
|
|
defer stop()
|
|
|
|
var totalMS int64 = 0
|
|
b.RunParallel(func(p *testing.PB) {
|
|
defer func(t0 time.Time) {
|
|
atomic.AddInt64(&totalMS, int64(time.Since(t0).Milliseconds()))
|
|
}(time.Now())
|
|
for p.Next() {
|
|
B(b)
|
|
}
|
|
})
|
|
|
|
b.ReportMetric(0, "ns/op")
|
|
b.ReportMetric(float64(atomic.LoadInt64(&totalMS))/float64(b.N), "ms/op")
|
|
b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op")
|
|
b.ReportMetric(float64(b.N)/b.Elapsed().Seconds(), "ops/s")
|
|
})
|
|
|
|
}
|
|
|
|
func profile(t testing.TB) (stop func()) {
|
|
t.Helper()
|
|
if *prof == "" {
|
|
return func() {} // noop
|
|
}
|
|
f, err := os.Create(*prof)
|
|
require.NoError(t, err)
|
|
require.NoError(t, pprof.StartCPUProfile(f))
|
|
return func() {
|
|
pprof.StopCPUProfile()
|
|
require.NoError(t, f.Close())
|
|
t.Log("Wrote profile to", f.Name())
|
|
}
|
|
}
|