hydra/oauth2/revocator_test.go

142 lines
4.5 KiB
Go

// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package oauth2_test
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
hydra "github.com/ory/hydra-client-go/v2"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/fosite"
"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/internal/testhelpers"
"github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/persistence/sql"
"github.com/ory/hydra/v2/x"
"github.com/ory/pop/v6"
"github.com/ory/x/httprouterx"
"github.com/ory/x/prometheusx"
)
func createAccessTokenSession(t testing.TB, subject, client, token string, expiresAt time.Time, fs x.FositeStorer, scopes fosite.Arguments) {
createAccessTokenSessionPairwise(t, subject, client, token, expiresAt, fs, scopes, "")
}
func createAccessTokenSessionPairwise(t testing.TB, subject, client, token string, expiresAt time.Time, fs x.FositeStorer, scopes fosite.Arguments, obfuscated string) {
ar := fosite.NewAccessRequest(oauth2.NewTestSession(t, subject))
ar.GrantedScope = fosite.Arguments{"core"}
if scopes != nil {
ar.GrantedScope = scopes
}
ar.RequestedAt = time.Now().UTC().Round(time.Minute)
ar.Client = &fosite.DefaultClient{ID: client}
ar.Session.SetExpiresAt(fosite.AccessToken, expiresAt)
ar.Session.(*oauth2.Session).Extra = map[string]interface{}{"foo": "bar"}
if obfuscated != "" {
ar.Session.(*oauth2.Session).Claims.Subject = obfuscated
}
if err := fs.CreateAccessTokenSession(context.Background(), token, ar); err != nil {
panic(err)
}
}
func countAccessTokens(t *testing.T, c *pop.Connection) int {
n, err := c.Count(&sql.OAuth2RequestSQL{Table: "access"})
require.NoError(t, err)
return n
}
func TestRevoke(t *testing.T) {
t.Parallel()
reg := testhelpers.NewRegistryMemory(t)
testhelpers.MustEnsureRegistryKeys(t, reg, x.OpenIDConnectKeyName)
internal.AddFositeExamples(t, reg)
tokens := Tokens(reg.OAuth2ProviderConfig(), 4)
now := time.Now().UTC().Round(time.Second)
metrics := prometheusx.NewMetricsManagerWithPrefix("hydra", prometheusx.HTTPMetrics, config.Version, config.Commit, config.Date)
handler := oauth2.NewHandler(reg)
router := httprouterx.NewRouterAdminWithPrefix(metrics)
handler.SetPublicRoutes(router.ToPublic(), func(h http.Handler) http.Handler { return h })
handler.SetAdminRoutes(router)
server := httptest.NewServer(router)
defer server.Close()
createAccessTokenSession(t, "alice", "my-client", tokens[0].sig, now.Add(time.Hour), reg.OAuth2Storage(), nil)
createAccessTokenSession(t, "siri", "my-client", tokens[1].sig, now.Add(time.Hour), reg.OAuth2Storage(), nil)
createAccessTokenSession(t, "siri", "my-client", tokens[2].sig, now.Add(-time.Hour), reg.OAuth2Storage(), nil)
createAccessTokenSession(t, "siri", "encoded:client", tokens[3].sig, now.Add(-time.Hour), reg.OAuth2Storage(), nil)
require.Equal(t, 4, countAccessTokens(t, reg.Persister().Connection(context.Background())))
client := hydra.NewAPIClient(hydra.NewConfiguration())
client.GetConfig().Servers = hydra.ServerConfigurations{{URL: server.URL}}
for k, c := range []struct {
token string
assert func(*testing.T)
}{
{
token: "invalid",
assert: func(t *testing.T) {
assert.Equal(t, 4, countAccessTokens(t, reg.Persister().Connection(context.Background())))
},
},
{
token: tokens[3].tok,
assert: func(t *testing.T) {
assert.Equal(t, 4, countAccessTokens(t, reg.Persister().Connection(context.Background())))
},
},
{
token: tokens[0].tok,
assert: func(t *testing.T) {
t.Logf("Tried to delete: %s %s", tokens[0].sig, tokens[0].tok)
assert.Equal(t, 3, countAccessTokens(t, reg.Persister().Connection(context.Background())))
},
},
{
token: tokens[0].tok,
},
{
token: tokens[2].tok,
assert: func(t *testing.T) {
assert.Equal(t, 2, countAccessTokens(t, reg.Persister().Connection(context.Background())))
},
},
{
token: tokens[1].tok,
assert: func(t *testing.T) {
assert.Equal(t, 1, countAccessTokens(t, reg.Persister().Connection(context.Background())))
},
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
resp, err := client.OAuth2API.RevokeOAuth2Token(
context.WithValue(
context.Background(),
hydra.ContextBasicAuth,
hydra.BasicAuth{UserName: "my-client", Password: "foobar"},
)).Token(c.token).Execute()
body, _ := io.ReadAll(resp.Body)
require.NoErrorf(t, err, "body: %s", body)
if c.assert != nil {
c.assert(t)
}
})
}
}