hydra/client/manager_test_helpers.go

403 lines
14 KiB
Go

// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package client
import (
"context"
"crypto/x509"
"fmt"
"testing"
"time"
"github.com/go-faker/faker/v4"
"github.com/go-jose/go-jose/v3"
"github.com/gofrs/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/x/uuidx"
"github.com/ory/hydra/v2/fosite"
testhelpersuuid "github.com/ory/hydra/v2/internal/testhelpers/uuid"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/assertx"
"github.com/ory/x/contextx"
keysetpagination "github.com/ory/x/pagination/keysetpagination_v2"
"github.com/ory/x/sqlcon"
)
func TestHelperClientAutoGenerateKey(m Storage) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.TODO()
c := &Client{
Secret: "secret",
RedirectURIs: []string{"http://redirect"},
TermsOfServiceURI: "foo",
}
require.NoError(t, m.CreateClient(ctx, c))
dbClient, err := m.GetClient(ctx, c.GetID())
require.NoError(t, err)
dbClientConcrete, ok := dbClient.(*Client)
require.True(t, ok)
testhelpersuuid.AssertUUID(t, dbClientConcrete.ID)
assert.NoError(t, m.DeleteClient(ctx, c.GetID()))
}
}
func TestHelperClientAuthenticate(m Manager) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.TODO()
require.NoError(t, m.CreateClient(ctx, &Client{
ID: "1234321",
Secret: "secret",
RedirectURIs: []string{"http://redirect"},
}))
c, err := m.AuthenticateClient(ctx, "1234321", []byte("secret1"))
require.Error(t, err)
require.Nil(t, c)
c, err = m.AuthenticateClient(ctx, "1234321", []byte("secret"))
require.NoError(t, err)
assert.Equal(t, "1234321", c.GetID())
}
}
func TestHelperUpdateTwoClients(m Manager) func(t *testing.T) {
return func(t *testing.T) {
c1, c2 := &Client{Name: "test client 1"}, &Client{Name: "test client 2"}
require.NoError(t, m.CreateClient(context.Background(), c1))
require.NoError(t, m.CreateClient(context.Background(), c2))
c1.Name, c2.Name = "updated client 1", "updated client 2"
assert.NoError(t, m.UpdateClient(context.Background(), c1))
assert.NoError(t, m.UpdateClient(context.Background(), c2))
}
}
func testHelperUpdateClient(t *testing.T, ctx context.Context, store Storage, toUpdate *Client) {
require.NoError(t, store.UpdateClient(ctx, &Client{
ID: toUpdate.ID,
Name: "name-new",
Secret: "secret-new",
RedirectURIs: []string{"http://redirect/new"},
TermsOfServiceURI: "bar",
JSONWebKeys: new(x.JoseJSONWebKeySet),
}))
nc, err := store.GetConcreteClient(ctx, toUpdate.ID)
require.NoError(t, err)
require.NotZero(t, toUpdate.GetHashedSecret())
require.NotZero(t, nc.GetHashedSecret())
assert.NotEqual(t, toUpdate.GetHashedSecret(), nc.GetHashedSecret(), "ensure that the secret hash was updated")
assert.Equal(t, "bar", nc.TermsOfServiceURI)
assert.Equal(t, "name-new", nc.Name)
assert.Equal(t, []string{"http://redirect/new"}, nc.GetRedirectURIs())
assert.Len(t, nc.Contacts, 0)
}
func TestHelperCreateGetUpdateDeleteClientNext(t *testing.T, m Storage, networks []uuid.UUID) {
ctx := context.Background()
resources := map[uuid.UUID][]Client{}
for k := range networks {
nid := networks[k]
resources[nid] = []Client{}
ctx := contextx.SetNIDContext(ctx, nid)
t.Run(fmt.Sprintf("nid=%s", nid), func(t *testing.T) {
var client Client
require.NoError(t, faker.FakeData(&client))
client.CreatedAt = time.Now().Truncate(time.Second).UTC()
t.Run("lifecycle=does not exist", func(t *testing.T) {
_, err := m.GetClient(ctx, uuidx.NewV4().String())
require.Error(t, err)
})
t.Run("lifecycle=exists", func(t *testing.T) {
require.NoError(t, m.CreateClient(ctx, &client))
c, err := m.GetClient(ctx, client.GetID())
require.NoError(t, err)
assertx.EqualAsJSONExcept(t, &client, c, []string{
"registration_access_token",
"registration_client_uri",
"updated_at",
})
require.ErrorIs(t, m.CreateClient(ctx, &client), sqlcon.ErrUniqueViolation)
})
t.Run("lifecycle=update", func(t *testing.T) {
client.Name = "updated" + nid.String()
require.NoError(t, m.UpdateClient(ctx, &client))
c, err := m.GetClient(ctx, client.GetID())
require.NoError(t, err)
assertx.EqualAsJSONExcept(t, &client, c, []string{
"registration_access_token",
"registration_client_uri",
"updated_at",
})
resources[nid] = append(resources[nid], client)
})
})
}
for k := range resources {
original := k
clients := resources[k]
for i := range networks {
check := networks[i]
t.Run("network="+original.String(), func(t *testing.T) {
ctx := contextx.SetNIDContext(ctx, check)
for _, expected := range clients {
c, err := m.GetClient(ctx, expected.GetID())
if check != original {
t.Run(fmt.Sprintf("case=must not find client %s", expected.GetID()), func(t *testing.T) {
require.ErrorIs(t, err, sqlcon.ErrNoRows)
})
} else {
t.Run("case=updates must not override each other", func(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "updated"+original.String(), c.(*Client).Name)
})
}
}
})
}
}
for k := range resources {
clients := resources[k]
ctx := contextx.SetNIDContext(ctx, k)
t.Run("network="+k.String(), func(t *testing.T) {
for _, client := range clients {
t.Run("lifecycle=cleanup", func(t *testing.T) {
assert.NoError(t, m.DeleteClient(ctx, client.GetID()))
_, err := m.GetClient(ctx, client.GetID())
assert.ErrorIs(t, err, sqlcon.ErrNoRows)
assert.ErrorIs(t, m.DeleteClient(ctx, client.GetID()), sqlcon.ErrNoRows)
})
}
})
}
}
func TestHelperCreateGetUpdateDeleteClient(t1, t2 Storage) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.Background()
_, err := t1.GetClient(ctx, "1234")
require.Error(t, err)
t1c1 := &Client{
ID: uuidx.NewV4().String(),
Name: "name",
Secret: "secret",
RedirectURIs: []string{"http://redirect", "http://redirect1"},
GrantTypes: []string{"implicit", "refresh_token"},
ResponseTypes: []string{"code token", "token id_token", "code"},
Scope: "scope-a scope-b",
Owner: "aeneas",
PolicyURI: "http://policy",
TermsOfServiceURI: "http://tos",
ClientURI: "http://client",
LogoURI: "http://logo",
Contacts: []string{"aeneas1", "aeneas2"},
SecretExpiresAt: 0,
SectorIdentifierURI: "https://sector",
JSONWebKeys: &x.JoseJSONWebKeySet{JSONWebKeySet: &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{KeyID: "foo", Key: []byte("asdf"), Certificates: []*x509.Certificate{}, CertificateThumbprintSHA1: []uint8{}, CertificateThumbprintSHA256: []uint8{}}}}},
JSONWebKeysURI: "https://...",
TokenEndpointAuthMethod: "none",
TokenEndpointAuthSigningAlgorithm: "RS256",
RequestURIs: []string{"foo", "bar"},
AllowedCORSOrigins: []string{"foo", "bar"},
RequestObjectSigningAlgorithm: "rs256",
UserinfoSignedResponseAlg: "RS256",
CreatedAt: time.Now().Add(-time.Hour).Round(time.Second).UTC(),
UpdatedAt: time.Now().Add(-time.Minute).Round(time.Second).UTC(),
FrontChannelLogoutURI: "http://fc-logout",
FrontChannelLogoutSessionRequired: true,
PostLogoutRedirectURIs: []string{"hello", "mister"},
BackChannelLogoutURI: "http://bc-logout",
BackChannelLogoutSessionRequired: true,
}
require.NoError(t, t1.CreateClient(ctx, t1c1))
assert.NotEmpty(t, t1c1.GetHashedSecret())
t2c1 := *t1c1
require.NoError(t, t2.CreateClient(ctx, &t2c1), "we should be able to create a client with the same ID in other network")
t2c3 := *t1c1
{
t2c3.ID = uuidx.NewV4().String()
require.NoError(t, t2.CreateClient(ctx, &t2c3))
require.ErrorIs(t, t2.CreateClient(ctx, &t2c3), sqlcon.ErrUniqueViolation)
}
c2Template := &Client{
ID: uuidx.NewV4().String(),
Name: "name2",
Secret: "secret",
RedirectURIs: []string{"http://redirect"},
TermsOfServiceURI: "foo",
SecretExpiresAt: 1,
}
assert.NoError(t, t1.CreateClient(ctx, c2Template))
assert.NoError(t, t2.CreateClient(ctx, c2Template))
d, err := t1.GetClient(ctx, t1c1.ID)
require.NoError(t, err)
cc := d.(*Client)
testhelpersuuid.AssertUUID(t, cc.NID)
compare(t, t1c1, d)
t.Run("list all", func(t *testing.T) {
cs, nextPage, err := t1.GetClients(ctx, Filter{})
require.NoError(t, err)
require.Len(t, cs, 2)
assert.ElementsMatch(t, []string{t1c1.ID, c2Template.ID}, []string{cs[0].GetID(), cs[1].GetID()})
// test if SecretExpiresAt was set properly
assert.ElementsMatch(t, []int{0, 1}, []int{cs[0].SecretExpiresAt, cs[1].SecretExpiresAt})
assert.True(t, nextPage.IsLast())
})
t.Run("pagination", func(t *testing.T) {
observedIDs := make([]string, 2)
cs, nextPage, err := t1.GetClients(ctx, Filter{PageOpts: []keysetpagination.Option{keysetpagination.WithSize(1)}})
require.NoError(t, err)
require.Len(t, cs, 1)
assert.False(t, nextPage.IsLast())
observedIDs[0] = cs[0].ID
cs, nextPage, err = t1.GetClients(ctx, Filter{PageOpts: nextPage.ToOptions()})
require.NoError(t, err)
require.Len(t, cs, 1)
assert.True(t, nextPage.IsLast())
observedIDs[1] = cs[0].ID
assert.ElementsMatch(t, []string{t1c1.ID, c2Template.ID}, observedIDs)
})
t.Run("list by name", func(t *testing.T) {
cs, nextPage, err := t1.GetClients(ctx, Filter{Name: "name"})
require.NoError(t, err)
require.Len(t, cs, 1)
assert.Equal(t, cs[0].Name, "name")
assert.True(t, nextPage.IsLast())
})
t.Run("list by unknown name", func(t *testing.T) {
cs, nextPage, err := t1.GetClients(ctx, Filter{Name: "bad name"})
require.NoError(t, err)
assert.Len(t, cs, 0)
assert.True(t, nextPage.IsLast())
})
t.Run("list by owner", func(t *testing.T) {
cs, nextPage, err := t1.GetClients(ctx, Filter{Owner: "aeneas"})
require.NoError(t, err)
require.Len(t, cs, 1)
assert.Equal(t, cs[0].Owner, "aeneas")
assert.True(t, nextPage.IsLast())
})
t.Run("list by ids", func(t *testing.T) {
cs, nextPage, err := t1.GetClients(ctx, Filter{IDs: []string{t1c1.ID, c2Template.ID}})
require.NoError(t, err)
require.Len(t, cs, 2)
assert.ElementsMatch(t, []string{t1c1.ID, c2Template.ID}, []string{cs[0].GetID(), cs[1].GetID()})
assert.True(t, nextPage.IsLast())
cs, nextPage, err = t1.GetClients(ctx, Filter{IDs: []string{t1c1.ID}})
require.NoError(t, err)
require.Len(t, cs, 1)
assert.Equal(t, t1c1.ID, cs[0].GetID())
assert.True(t, nextPage.IsLast())
cs, nextPage, err = t1.GetClients(ctx, Filter{IDs: []string{c2Template.ID}})
require.NoError(t, err)
require.Len(t, cs, 1)
assert.Equal(t, c2Template.ID, cs[0].GetID())
assert.True(t, nextPage.IsLast())
cs, nextPage, err = t1.GetClients(ctx, Filter{IDs: []string{uuidx.NewV4().String()}})
require.NoError(t, err)
require.Len(t, cs, 0)
assert.True(t, nextPage.IsLast())
})
testHelperUpdateClient(t, ctx, t1, t1c1)
testHelperUpdateClient(t, ctx, t2, &t2c1)
assert.ErrorIs(t, t1.DeleteClient(ctx, t2c3.ID), sqlcon.ErrNoRows)
assert.NoError(t, t1.DeleteClient(ctx, t1c1.ID))
_, err = t1.GetClient(ctx, t1c1.ID)
assert.ErrorIs(t, err, sqlcon.ErrNoRows)
}
}
func compare(t *testing.T, expected *Client, actual fosite.Client) {
assert.EqualValues(t, expected.GetID(), actual.GetID())
assert.EqualValues(t, expected.GetHashedSecret(), actual.GetHashedSecret())
assert.EqualValues(t, expected.GetRedirectURIs(), actual.GetRedirectURIs())
assert.EqualValues(t, expected.GetGrantTypes(), actual.GetGrantTypes())
assert.EqualValues(t, expected.GetResponseTypes(), actual.GetResponseTypes())
assert.EqualValues(t, expected.GetScopes(), actual.GetScopes())
assert.EqualValues(t, expected.IsPublic(), actual.IsPublic())
if actual, ok := actual.(*Client); ok {
assert.EqualValues(t, expected.Owner, actual.Owner)
assert.EqualValues(t, expected.Name, actual.Name)
assert.EqualValues(t, expected.PolicyURI, actual.PolicyURI)
assert.EqualValues(t, expected.TermsOfServiceURI, actual.TermsOfServiceURI)
assert.EqualValues(t, expected.ClientURI, actual.ClientURI)
assert.EqualValues(t, expected.LogoURI, actual.LogoURI)
assert.EqualValues(t, expected.Contacts, actual.Contacts)
assert.EqualValues(t, expected.SecretExpiresAt, actual.SecretExpiresAt)
assert.EqualValues(t, expected.SectorIdentifierURI, actual.SectorIdentifierURI)
assert.EqualValues(t, expected.UserinfoSignedResponseAlg, actual.UserinfoSignedResponseAlg)
assert.EqualValues(t, expected.CreatedAt.UTC().Unix(), actual.CreatedAt.UTC().Unix())
// these values are not the same because of https://github.com/ory/pop/issues/591
//assert.EqualValues(t, expected.UpdatedAt.UTC().Unix(), actual.UpdatedAt.UTC().Unix(), "%s\n%s", expected.UpdatedAt.String(), actual.UpdatedAt.String())
assert.EqualValues(t, expected.FrontChannelLogoutURI, actual.FrontChannelLogoutURI)
assert.EqualValues(t, expected.FrontChannelLogoutSessionRequired, actual.FrontChannelLogoutSessionRequired)
assert.EqualValues(t, expected.PostLogoutRedirectURIs, actual.PostLogoutRedirectURIs)
assert.EqualValues(t, expected.BackChannelLogoutURI, actual.BackChannelLogoutURI)
assert.EqualValues(t, expected.BackChannelLogoutSessionRequired, actual.BackChannelLogoutSessionRequired)
}
if actual, ok := actual.(fosite.OpenIDConnectClient); ok {
require.NotNil(t, expected.JSONWebKeys)
for k, v := range expected.JSONWebKeys.JSONWebKeySet.Keys {
if v.CertificateThumbprintSHA1 == nil {
v.CertificateThumbprintSHA1 = make([]byte, 0)
}
if v.CertificateThumbprintSHA256 == nil {
v.CertificateThumbprintSHA256 = make([]byte, 0)
}
expected.JSONWebKeys.JSONWebKeySet.Keys[k] = v
}
assert.EqualValues(t, expected.JSONWebKeys.JSONWebKeySet, actual.GetJSONWebKeys())
assert.EqualValues(t, expected.JSONWebKeysURI, actual.GetJSONWebKeysURI())
assert.EqualValues(t, expected.TokenEndpointAuthMethod, actual.GetTokenEndpointAuthMethod())
assert.EqualValues(t, expected.RequestURIs, actual.GetRequestURIs())
assert.EqualValues(t, expected.RequestObjectSigningAlgorithm, actual.GetRequestObjectSigningAlgorithm())
}
}