mirror of https://github.com/ory/hydra
236 lines
6.9 KiB
Go
236 lines
6.9 KiB
Go
// Copyright © 2022 Ory Corp
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package jwk
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"io"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v3"
|
|
"github.com/gofrs/uuid"
|
|
"github.com/pkg/errors"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/ory/x/assertx"
|
|
)
|
|
|
|
func RandomBytes(n int) ([]byte, error) {
|
|
bytes := make([]byte, n)
|
|
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
|
|
return []byte{}, errors.WithStack(err)
|
|
}
|
|
return bytes, nil
|
|
}
|
|
|
|
func canonicalizeThumbprints(js []jose.JSONWebKey) []jose.JSONWebKey {
|
|
for k, v := range js {
|
|
js[k] = canonicalizeKeyThumbprints(&v)
|
|
}
|
|
return js
|
|
}
|
|
|
|
func canonicalizeKeyThumbprints(v *jose.JSONWebKey) jose.JSONWebKey {
|
|
if len(v.CertificateThumbprintSHA1) == 0 {
|
|
v.CertificateThumbprintSHA1 = nil
|
|
}
|
|
if len(v.CertificateThumbprintSHA256) == 0 {
|
|
v.CertificateThumbprintSHA256 = nil
|
|
}
|
|
return *v
|
|
}
|
|
|
|
func TestHelperManagerKey(m Manager, algo string, keys *jose.JSONWebKeySet, suffix string) func(t *testing.T) {
|
|
priv := canonicalizeThumbprints(keys.Key(suffix))
|
|
var pub []jose.JSONWebKey
|
|
for _, k := range priv {
|
|
pub = append(pub, canonicalizeThumbprints([]jose.JSONWebKey{k.Public()})...)
|
|
}
|
|
|
|
return func(t *testing.T) {
|
|
ctx := t.Context()
|
|
|
|
set := algo + uuid.Must(uuid.NewV4()).String()
|
|
|
|
_, err := m.GetKey(ctx, set, suffix)
|
|
assert.NotNil(t, err)
|
|
|
|
err = m.AddKey(ctx, set, First(priv))
|
|
require.NoError(t, err)
|
|
|
|
got, err := m.GetKey(ctx, set, suffix)
|
|
require.NoError(t, err)
|
|
assertx.EqualAsJSON(t, priv, canonicalizeThumbprints(got.Keys))
|
|
|
|
addKey := First(pub)
|
|
addKey.KeyID = uuid.Must(uuid.NewV4()).String()
|
|
err = m.AddKey(ctx, set, addKey)
|
|
require.NoError(t, err)
|
|
|
|
got, err = m.GetKey(ctx, set, suffix)
|
|
require.NoError(t, err)
|
|
assertx.EqualAsJSON(t, priv, canonicalizeThumbprints(got.Keys))
|
|
|
|
// Because MySQL
|
|
time.Sleep(time.Second * 2)
|
|
|
|
newKID := "new-key-id:" + suffix
|
|
pub[0].KeyID = newKID
|
|
pub[0].Use = "sig"
|
|
err = m.AddKey(ctx, set, First(pub))
|
|
require.NoError(t, err)
|
|
|
|
got, err = m.GetKey(ctx, set, newKID)
|
|
require.NoError(t, err)
|
|
newKey := First(got.Keys)
|
|
assert.EqualValues(t, "sig", newKey.Use)
|
|
|
|
newKey.Use = "enc"
|
|
err = m.UpdateKey(ctx, set, newKey)
|
|
require.NoError(t, err)
|
|
updated, err := m.GetKey(ctx, set, newKID)
|
|
require.NoError(t, err)
|
|
updatedKey := First(updated.Keys)
|
|
assert.EqualValues(t, "enc", updatedKey.Use)
|
|
|
|
keys, err = m.GetKeySet(ctx, set)
|
|
require.NoError(t, err)
|
|
var found bool
|
|
for _, k := range keys.Keys {
|
|
if k.KeyID == newKID {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
assert.True(t, found, "Key not found in key set: %s / %s\n%+v", keys, newKID)
|
|
|
|
beforeDeleteKeysCount := len(keys.Keys)
|
|
err = m.DeleteKey(ctx, set, suffix)
|
|
require.NoError(t, err)
|
|
|
|
_, err = m.GetKey(ctx, set, suffix)
|
|
require.Error(t, err)
|
|
|
|
keys, err = m.GetKeySet(ctx, set)
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, beforeDeleteKeysCount-1, len(keys.Keys))
|
|
}
|
|
}
|
|
|
|
func TestHelperManagerKeySet(m Manager, algo string, keys *jose.JSONWebKeySet, suffix string, parallel bool) func(t *testing.T) {
|
|
return func(t *testing.T) {
|
|
ctx := t.Context()
|
|
|
|
if parallel {
|
|
t.Parallel()
|
|
}
|
|
set := uuid.Must(uuid.NewV4()).String()
|
|
_, err := m.GetKeySet(ctx, algo+set)
|
|
require.Error(t, err)
|
|
|
|
err = m.AddKeySet(ctx, algo+set, keys)
|
|
require.NoError(t, err)
|
|
|
|
got, err := m.GetKeySet(ctx, algo+set)
|
|
require.NoError(t, err)
|
|
assertx.EqualAsJSON(t, canonicalizeThumbprints(keys.Key(suffix)), canonicalizeThumbprints(got.Key(suffix)))
|
|
assertx.EqualAsJSON(t, canonicalizeThumbprints(keys.Key(suffix)), canonicalizeThumbprints(got.Key(suffix)))
|
|
|
|
for i := range got.Keys {
|
|
got.Keys[i].Use = "enc"
|
|
}
|
|
err = m.UpdateKeySet(ctx, algo+set, got)
|
|
require.NoError(t, err)
|
|
|
|
updated, err := m.GetKeySet(ctx, algo+set)
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, "enc", updated.Key(suffix)[0].Public().Use)
|
|
assert.EqualValues(t, "enc", updated.Key(suffix)[0].Use)
|
|
|
|
err = m.DeleteKeySet(ctx, algo+set)
|
|
require.NoError(t, err)
|
|
|
|
_, err = m.GetKeySet(ctx, algo+set)
|
|
require.Error(t, err)
|
|
}
|
|
}
|
|
|
|
func TestHelperManagerGenerateAndPersistKeySet(m Manager, alg string, parallel bool) func(t *testing.T) {
|
|
return func(t *testing.T) {
|
|
ctx := t.Context()
|
|
|
|
if parallel {
|
|
t.Parallel()
|
|
}
|
|
_, err := m.GetKeySet(ctx, "foo")
|
|
require.Error(t, err)
|
|
|
|
keys, err := m.GenerateAndPersistKeySet(ctx, "foo", "bar", alg, "sig")
|
|
require.NoError(t, err)
|
|
genPub, err := FindPublicKey(keys)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, genPub)
|
|
genPriv, err := FindPrivateKey(keys)
|
|
require.NoError(t, err)
|
|
|
|
got, err := m.GetKeySet(ctx, "foo")
|
|
require.NoError(t, err)
|
|
gotPub, err := FindPublicKey(got)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, gotPub)
|
|
gotPriv, err := FindPrivateKey(got)
|
|
require.NoError(t, err)
|
|
|
|
assertx.EqualAsJSON(t, canonicalizeKeyThumbprints(genPub), canonicalizeKeyThumbprints(gotPub))
|
|
|
|
assert.EqualValues(t, genPriv.KeyID, gotPriv.KeyID)
|
|
|
|
err = m.DeleteKeySet(ctx, "foo")
|
|
require.NoError(t, err)
|
|
|
|
_, err = m.GetKeySet(ctx, "foo")
|
|
require.Error(t, err)
|
|
}
|
|
}
|
|
|
|
func TestHelperNID(t1ValidNID, t2InvalidNID Manager) func(t *testing.T) {
|
|
return func(t *testing.T) {
|
|
ctx := context.Background()
|
|
jwks, err := GenerateJWK(jose.RS256, "2022-03-11-ks-1-kid", "test")
|
|
require.NoError(t, err)
|
|
require.Error(t, t2InvalidNID.AddKey(ctx, "2022-03-11-k-1", &jwks.Keys[0]))
|
|
require.NoError(t, t1ValidNID.AddKey(ctx, "2022-03-11-k-1", &jwks.Keys[0]))
|
|
require.Error(t, t2InvalidNID.AddKeySet(ctx, "2022-03-11-ks-1", jwks))
|
|
require.NoError(t, t1ValidNID.AddKeySet(ctx, "2022-03-11-ks-1", jwks))
|
|
require.NoError(t, t2InvalidNID.DeleteKey(ctx, "2022-03-11-ks-1", jwks.Keys[0].KeyID)) // Delete doesn't report error if key doesn't exist
|
|
require.NoError(t, t1ValidNID.DeleteKey(ctx, "2022-03-11-ks-1", jwks.Keys[0].KeyID))
|
|
_, err = t2InvalidNID.GenerateAndPersistKeySet(ctx, "2022-03-11-ks-2", "2022-03-11-ks-2-kid", "RS256", "sig")
|
|
require.Error(t, err)
|
|
gks2, err := t1ValidNID.GenerateAndPersistKeySet(ctx, "2022-03-11-ks-2", "2022-03-11-ks-2-kid", "RS256", "sig")
|
|
require.NoError(t, err)
|
|
|
|
_, err = t1ValidNID.GetKey(ctx, "2022-03-11-ks-2", gks2.Keys[0].KeyID)
|
|
require.NoError(t, err)
|
|
_, err = t2InvalidNID.GetKey(ctx, "2022-03-11-ks-2", gks2.Keys[0].KeyID)
|
|
require.Error(t, err)
|
|
|
|
_, err = t1ValidNID.GetKeySet(ctx, "2022-03-11-ks-2")
|
|
require.NoError(t, err)
|
|
_, err = t2InvalidNID.GetKeySet(ctx, "2022-03-11-ks-2")
|
|
require.Error(t, err)
|
|
updatedKey := &gks2.Keys[0]
|
|
updatedKey.Use = "enc"
|
|
require.Error(t, t2InvalidNID.UpdateKey(ctx, "2022-03-11-ks-2", updatedKey))
|
|
require.NoError(t, t1ValidNID.UpdateKey(ctx, "2022-03-11-ks-2", updatedKey))
|
|
gks2.Keys[0].Use = "enc"
|
|
require.Error(t, t2InvalidNID.UpdateKeySet(ctx, "2022-03-11-ks-2", gks2))
|
|
require.NoError(t, t1ValidNID.UpdateKeySet(ctx, "2022-03-11-ks-2", gks2))
|
|
require.NoError(t, t2InvalidNID.DeleteKeySet(ctx, "2022-03-11-ks-2"))
|
|
require.NoError(t, t1ValidNID.DeleteKeySet(ctx, "2022-03-11-ks-2"))
|
|
}
|
|
}
|