mirror of https://github.com/ory/hydra
382 lines
13 KiB
Go
382 lines
13 KiB
Go
// Copyright © 2022 Ory Corp
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package client
|
|
|
|
import (
|
|
"context"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-faker/faker/v4"
|
|
"github.com/go-jose/go-jose/v3"
|
|
"github.com/gobuffalo/pop/v6"
|
|
"github.com/gofrs/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/ory/fosite"
|
|
testhelpersuuid "github.com/ory/hydra/v2/internal/testhelpers/uuid"
|
|
"github.com/ory/hydra/v2/x"
|
|
"github.com/ory/x/assertx"
|
|
"github.com/ory/x/contextx"
|
|
"github.com/ory/x/sqlcon"
|
|
)
|
|
|
|
func TestHelperClientAutoGenerateKey(k string, m Storage) func(t *testing.T) {
|
|
return func(t *testing.T) {
|
|
ctx := context.TODO()
|
|
c := &Client{
|
|
Secret: "secret",
|
|
RedirectURIs: []string{"http://redirect"},
|
|
TermsOfServiceURI: "foo",
|
|
}
|
|
require.NoError(t, m.CreateClient(ctx, c))
|
|
dbClient, err := m.GetClient(ctx, c.GetID())
|
|
require.NoError(t, err)
|
|
dbClientConcrete, ok := dbClient.(*Client)
|
|
require.True(t, ok)
|
|
testhelpersuuid.AssertUUID(t, dbClientConcrete.ID)
|
|
assert.NoError(t, m.DeleteClient(ctx, c.GetID()))
|
|
}
|
|
}
|
|
|
|
func TestHelperClientAuthenticate(k string, m Manager) func(t *testing.T) {
|
|
return func(t *testing.T) {
|
|
ctx := context.TODO()
|
|
require.NoError(t, m.CreateClient(ctx, &Client{
|
|
ID: "1234321",
|
|
Secret: "secret",
|
|
RedirectURIs: []string{"http://redirect"},
|
|
}))
|
|
|
|
c, err := m.AuthenticateClient(ctx, "1234321", []byte("secret1"))
|
|
require.Error(t, err)
|
|
require.Nil(t, c)
|
|
|
|
c, err = m.AuthenticateClient(ctx, "1234321", []byte("secret"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "1234321", c.GetID())
|
|
}
|
|
}
|
|
|
|
func TestHelperUpdateTwoClients(_ string, m Manager) func(t *testing.T) {
|
|
return func(t *testing.T) {
|
|
c1, c2 := &Client{Name: "test client 1"}, &Client{Name: "test client 2"}
|
|
|
|
require.NoError(t, m.CreateClient(context.Background(), c1))
|
|
require.NoError(t, m.CreateClient(context.Background(), c2))
|
|
|
|
c1.Name, c2.Name = "updated client 1", "updated client 2"
|
|
|
|
assert.NoError(t, m.UpdateClient(context.Background(), c1))
|
|
assert.NoError(t, m.UpdateClient(context.Background(), c2))
|
|
}
|
|
}
|
|
|
|
func testHelperUpdateClient(t *testing.T, ctx context.Context, network Storage, k string) {
|
|
d, err := network.GetClient(ctx, "1234")
|
|
assert.NoError(t, err)
|
|
err = network.UpdateClient(ctx, &Client{
|
|
ID: "2-1234",
|
|
Name: "name-new",
|
|
Secret: "secret-new",
|
|
RedirectURIs: []string{"http://redirect/new"},
|
|
TermsOfServiceURI: "bar",
|
|
JSONWebKeys: new(x.JoseJSONWebKeySet),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
nc, err := network.GetConcreteClient(ctx, "2-1234")
|
|
require.NoError(t, err)
|
|
|
|
if k != "http" {
|
|
// http always returns an empty secret
|
|
assert.NotEqual(t, d.GetHashedSecret(), nc.GetHashedSecret())
|
|
}
|
|
assert.Equal(t, "bar", nc.TermsOfServiceURI)
|
|
assert.Equal(t, "name-new", nc.Name)
|
|
assert.EqualValues(t, []string{"http://redirect/new"}, nc.GetRedirectURIs())
|
|
assert.Zero(t, len(nc.Contacts))
|
|
}
|
|
|
|
func TestHelperCreateGetUpdateDeleteClientNext(t *testing.T, m Storage, networks []uuid.UUID) {
|
|
ctx := context.Background()
|
|
|
|
resources := map[uuid.UUID][]Client{}
|
|
for k := range networks {
|
|
nid := networks[k]
|
|
resources[nid] = []Client{}
|
|
|
|
ctx := contextx.SetNIDContext(ctx, nid)
|
|
t.Run(fmt.Sprintf("nid=%s", nid), func(t *testing.T) {
|
|
var client Client
|
|
require.NoError(t, faker.FakeData(&client))
|
|
client.CreatedAt = time.Now().Truncate(time.Second).UTC()
|
|
|
|
t.Run("lifecycle=does not exist", func(t *testing.T) {
|
|
_, err := m.GetClient(ctx, "1234")
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("lifecycle=exists", func(t *testing.T) {
|
|
require.NoError(t, m.CreateClient(ctx, &client))
|
|
c, err := m.GetClient(ctx, client.GetID())
|
|
require.NoError(t, err)
|
|
assertx.EqualAsJSONExcept(t, &client, c, []string{
|
|
"registration_access_token",
|
|
"registration_client_uri",
|
|
"updated_at",
|
|
})
|
|
|
|
n, err := m.CountClients(ctx)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 1, n)
|
|
copy := client
|
|
require.Error(t, m.CreateClient(ctx, ©))
|
|
})
|
|
|
|
t.Run("lifecycle=update", func(t *testing.T) {
|
|
client.Name = "updated" + nid.String()
|
|
require.NoError(t, m.UpdateClient(ctx, &client))
|
|
c, err := m.GetClient(ctx, client.GetID())
|
|
require.NoError(t, err)
|
|
assertx.EqualAsJSONExcept(t, &client, c, []string{
|
|
"registration_access_token",
|
|
"registration_client_uri",
|
|
"updated_at",
|
|
})
|
|
resources[nid] = append(resources[nid], client)
|
|
})
|
|
})
|
|
}
|
|
|
|
for k := range resources {
|
|
original := k
|
|
clients := resources[k]
|
|
for i := range networks {
|
|
check := networks[i]
|
|
|
|
t.Run("network="+original.String(), func(t *testing.T) {
|
|
ctx := contextx.SetNIDContext(ctx, check)
|
|
for _, expected := range clients {
|
|
c, err := m.GetClient(ctx, expected.GetID())
|
|
if check != original {
|
|
t.Run(fmt.Sprintf("case=must not find client %s", expected.GetID()), func(t *testing.T) {
|
|
require.ErrorIs(t, err, sqlcon.ErrNoRows)
|
|
})
|
|
} else {
|
|
t.Run("case=updates must not override each other", func(t *testing.T) {
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "updated"+original.String(), c.(*Client).Name)
|
|
})
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
for k := range resources {
|
|
clients := resources[k]
|
|
ctx := contextx.SetNIDContext(ctx, k)
|
|
t.Run("network="+k.String(), func(t *testing.T) {
|
|
for _, client := range clients {
|
|
t.Run("lifecycle=cleanup", func(t *testing.T) {
|
|
assert.NoError(t, m.DeleteClient(ctx, client.GetID()))
|
|
|
|
_, err := m.GetClient(ctx, client.GetID())
|
|
assert.ErrorIs(t, err, sqlcon.ErrNoRows)
|
|
|
|
n, err := m.CountClients(ctx)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 0, n)
|
|
assert.Error(t, m.DeleteClient(ctx, client.GetID()))
|
|
})
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHelperCreateGetUpdateDeleteClient(k string, connection *pop.Connection, t1 Storage, t2 Storage) func(t *testing.T) {
|
|
return func(t *testing.T) {
|
|
ctx := context.Background()
|
|
_, err := t1.GetClient(ctx, "1234")
|
|
require.Error(t, err)
|
|
|
|
t1c1 := &Client{
|
|
ID: "1234",
|
|
Name: "name",
|
|
Secret: "secret",
|
|
RedirectURIs: []string{"http://redirect", "http://redirect1"},
|
|
GrantTypes: []string{"implicit", "refresh_token"},
|
|
ResponseTypes: []string{"code token", "token id_token", "code"},
|
|
Scope: "scope-a scope-b",
|
|
Owner: "aeneas",
|
|
PolicyURI: "http://policy",
|
|
TermsOfServiceURI: "http://tos",
|
|
ClientURI: "http://client",
|
|
LogoURI: "http://logo",
|
|
Contacts: []string{"aeneas1", "aeneas2"},
|
|
SecretExpiresAt: 0,
|
|
SectorIdentifierURI: "https://sector",
|
|
JSONWebKeys: &x.JoseJSONWebKeySet{JSONWebKeySet: &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{KeyID: "foo", Key: []byte("asdf"), Certificates: []*x509.Certificate{}, CertificateThumbprintSHA1: []uint8{}, CertificateThumbprintSHA256: []uint8{}}}}},
|
|
JSONWebKeysURI: "https://...",
|
|
TokenEndpointAuthMethod: "none",
|
|
TokenEndpointAuthSigningAlgorithm: "RS256",
|
|
RequestURIs: []string{"foo", "bar"},
|
|
AllowedCORSOrigins: []string{"foo", "bar"},
|
|
RequestObjectSigningAlgorithm: "rs256",
|
|
UserinfoSignedResponseAlg: "RS256",
|
|
CreatedAt: time.Now().Add(-time.Hour).Round(time.Second).UTC(),
|
|
UpdatedAt: time.Now().Add(-time.Minute).Round(time.Second).UTC(),
|
|
FrontChannelLogoutURI: "http://fc-logout",
|
|
FrontChannelLogoutSessionRequired: true,
|
|
PostLogoutRedirectURIs: []string{"hello", "mister"},
|
|
BackChannelLogoutURI: "http://bc-logout",
|
|
BackChannelLogoutSessionRequired: true,
|
|
}
|
|
|
|
require.NoError(t, t1.CreateClient(ctx, t1c1))
|
|
{
|
|
t2c1 := *t1c1
|
|
require.Error(t, connection.Create(&t2c1), "should not be able to create the same client in other manager/network; are they backed by the same database?")
|
|
require.NoError(t, t2.CreateClient(ctx, &t2c1), "we should be able to create a client with the same ID in other network")
|
|
}
|
|
|
|
t2c3 := *t1c1
|
|
{
|
|
t2c3.ID = "t2c2-1234"
|
|
require.NoError(t, t2.CreateClient(ctx, &t2c3))
|
|
require.Error(t, t2.CreateClient(ctx, &t2c3))
|
|
}
|
|
assert.Equal(t, t1c1.GetID(), "1234")
|
|
if k != "http" {
|
|
assert.NotEmpty(t, t1c1.GetHashedSecret())
|
|
}
|
|
|
|
c2Template := &Client{
|
|
ID: "2-1234",
|
|
Name: "name2",
|
|
Secret: "secret",
|
|
RedirectURIs: []string{"http://redirect"},
|
|
TermsOfServiceURI: "foo",
|
|
SecretExpiresAt: 1,
|
|
}
|
|
assert.NoError(t, t1.CreateClient(ctx, c2Template))
|
|
assert.NoError(t, t2.CreateClient(ctx, c2Template))
|
|
|
|
d, err := t1.GetClient(ctx, "1234")
|
|
require.NoError(t, err)
|
|
|
|
cc := d.(*Client)
|
|
testhelpersuuid.AssertUUID(t, cc.NID)
|
|
|
|
compare(t, t1c1, d, k)
|
|
|
|
ds, err := t1.GetClients(ctx, Filter{Limit: 100, Offset: 0})
|
|
assert.NoError(t, err)
|
|
assert.Len(t, ds, 2)
|
|
assert.NotEqual(t, ds[0].GetID(), ds[1].GetID())
|
|
assert.NotEqual(t, ds[0].GetID(), ds[1].GetID())
|
|
// test if SecretExpiresAt was set properly
|
|
assert.Equal(t, ds[0].SecretExpiresAt, 0)
|
|
assert.Equal(t, ds[1].SecretExpiresAt, 1)
|
|
|
|
ds, err = t1.GetClients(ctx, Filter{Limit: 1, Offset: 0})
|
|
assert.NoError(t, err)
|
|
assert.Len(t, ds, 1)
|
|
|
|
ds, err = t1.GetClients(ctx, Filter{Limit: 100, Offset: 100})
|
|
assert.NoError(t, err)
|
|
assert.Empty(t, ds)
|
|
|
|
// get by name
|
|
ds, err = t1.GetClients(ctx, Filter{Limit: 100, Offset: 0, Name: "name"})
|
|
assert.NoError(t, err)
|
|
assert.Len(t, ds, 1)
|
|
assert.Equal(t, ds[0].Name, "name")
|
|
|
|
// get by name not exist
|
|
ds, err = t1.GetClients(ctx, Filter{Limit: 100, Offset: 0, Name: "bad name"})
|
|
assert.NoError(t, err)
|
|
assert.Len(t, ds, 0)
|
|
|
|
// get by owner
|
|
ds, err = t1.GetClients(ctx, Filter{Limit: 100, Offset: 0, Owner: "aeneas"})
|
|
assert.NoError(t, err)
|
|
assert.Len(t, ds, 1)
|
|
assert.Equal(t, ds[0].Owner, "aeneas")
|
|
|
|
testHelperUpdateClient(t, ctx, t1, k)
|
|
testHelperUpdateClient(t, ctx, t2, k)
|
|
|
|
err = t1.DeleteClient(ctx, "1234")
|
|
assert.NoError(t, err)
|
|
err = t1.DeleteClient(ctx, t2c3.GetID())
|
|
assert.Error(t, err)
|
|
|
|
_, err = t1.GetClient(ctx, "1234")
|
|
assert.NotNil(t, err)
|
|
|
|
n, err := t1.CountClients(ctx)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 1, n)
|
|
}
|
|
}
|
|
|
|
func compare(t *testing.T, expected *Client, actual fosite.Client, k string) {
|
|
assert.EqualValues(t, expected.GetID(), actual.GetID())
|
|
if k != "http" {
|
|
assert.EqualValues(t, expected.GetHashedSecret(), actual.GetHashedSecret())
|
|
}
|
|
assert.EqualValues(t, expected.GetRedirectURIs(), actual.GetRedirectURIs())
|
|
assert.EqualValues(t, expected.GetGrantTypes(), actual.GetGrantTypes())
|
|
|
|
assert.EqualValues(t, expected.GetResponseTypes(), actual.GetResponseTypes())
|
|
assert.EqualValues(t, expected.GetScopes(), actual.GetScopes())
|
|
assert.EqualValues(t, expected.IsPublic(), actual.IsPublic())
|
|
|
|
if actual, ok := actual.(*Client); ok {
|
|
assert.EqualValues(t, expected.Owner, actual.Owner)
|
|
assert.EqualValues(t, expected.Name, actual.Name)
|
|
assert.EqualValues(t, expected.PolicyURI, actual.PolicyURI)
|
|
assert.EqualValues(t, expected.TermsOfServiceURI, actual.TermsOfServiceURI)
|
|
assert.EqualValues(t, expected.ClientURI, actual.ClientURI)
|
|
assert.EqualValues(t, expected.LogoURI, actual.LogoURI)
|
|
assert.EqualValues(t, expected.Contacts, actual.Contacts)
|
|
assert.EqualValues(t, expected.SecretExpiresAt, actual.SecretExpiresAt)
|
|
assert.EqualValues(t, expected.SectorIdentifierURI, actual.SectorIdentifierURI)
|
|
assert.EqualValues(t, expected.UserinfoSignedResponseAlg, actual.UserinfoSignedResponseAlg)
|
|
assert.EqualValues(t, expected.CreatedAt.UTC().Unix(), actual.CreatedAt.UTC().Unix())
|
|
// these values are not the same because of https://github.com/gobuffalo/pop/issues/591
|
|
//assert.EqualValues(t, expected.UpdatedAt.UTC().Unix(), actual.UpdatedAt.UTC().Unix(), "%s\n%s", expected.UpdatedAt.String(), actual.UpdatedAt.String())
|
|
assert.EqualValues(t, expected.FrontChannelLogoutURI, actual.FrontChannelLogoutURI)
|
|
assert.EqualValues(t, expected.FrontChannelLogoutSessionRequired, actual.FrontChannelLogoutSessionRequired)
|
|
assert.EqualValues(t, expected.PostLogoutRedirectURIs, actual.PostLogoutRedirectURIs)
|
|
assert.EqualValues(t, expected.BackChannelLogoutURI, actual.BackChannelLogoutURI)
|
|
assert.EqualValues(t, expected.BackChannelLogoutSessionRequired, actual.BackChannelLogoutSessionRequired)
|
|
}
|
|
|
|
if actual, ok := actual.(fosite.OpenIDConnectClient); ok {
|
|
require.NotNil(t, expected.JSONWebKeys)
|
|
|
|
for k, v := range expected.JSONWebKeys.JSONWebKeySet.Keys {
|
|
if v.CertificateThumbprintSHA1 == nil {
|
|
v.CertificateThumbprintSHA1 = make([]byte, 0)
|
|
}
|
|
if v.CertificateThumbprintSHA256 == nil {
|
|
v.CertificateThumbprintSHA256 = make([]byte, 0)
|
|
}
|
|
expected.JSONWebKeys.JSONWebKeySet.Keys[k] = v
|
|
}
|
|
|
|
assert.EqualValues(t, expected.JSONWebKeys.JSONWebKeySet, actual.GetJSONWebKeys())
|
|
assert.EqualValues(t, expected.JSONWebKeysURI, actual.GetJSONWebKeysURI())
|
|
assert.EqualValues(t, expected.TokenEndpointAuthMethod, actual.GetTokenEndpointAuthMethod())
|
|
assert.EqualValues(t, expected.RequestURIs, actual.GetRequestURIs())
|
|
assert.EqualValues(t, expected.RequestObjectSigningAlgorithm, actual.GetRequestObjectSigningAlgorithm())
|
|
}
|
|
}
|