hydra/jwk/manager_test_helpers.go

236 lines
6.9 KiB
Go

// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jwk
import (
"context"
"crypto/rand"
"io"
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/x/assertx"
)
func RandomBytes(n int) ([]byte, error) {
bytes := make([]byte, n)
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
return []byte{}, errors.WithStack(err)
}
return bytes, nil
}
func canonicalizeThumbprints(js []jose.JSONWebKey) []jose.JSONWebKey {
for k, v := range js {
js[k] = canonicalizeKeyThumbprints(&v)
}
return js
}
func canonicalizeKeyThumbprints(v *jose.JSONWebKey) jose.JSONWebKey {
if len(v.CertificateThumbprintSHA1) == 0 {
v.CertificateThumbprintSHA1 = nil
}
if len(v.CertificateThumbprintSHA256) == 0 {
v.CertificateThumbprintSHA256 = nil
}
return *v
}
func TestHelperManagerKey(m Manager, algo string, keys *jose.JSONWebKeySet, suffix string) func(t *testing.T) {
priv := canonicalizeThumbprints(keys.Key(suffix))
var pub []jose.JSONWebKey
for _, k := range priv {
pub = append(pub, canonicalizeThumbprints([]jose.JSONWebKey{k.Public()})...)
}
return func(t *testing.T) {
ctx := t.Context()
set := algo + uuid.Must(uuid.NewV4()).String()
_, err := m.GetKey(ctx, set, suffix)
assert.NotNil(t, err)
err = m.AddKey(ctx, set, First(priv))
require.NoError(t, err)
got, err := m.GetKey(ctx, set, suffix)
require.NoError(t, err)
assertx.EqualAsJSON(t, priv, canonicalizeThumbprints(got.Keys))
addKey := First(pub)
addKey.KeyID = uuid.Must(uuid.NewV4()).String()
err = m.AddKey(ctx, set, addKey)
require.NoError(t, err)
got, err = m.GetKey(ctx, set, suffix)
require.NoError(t, err)
assertx.EqualAsJSON(t, priv, canonicalizeThumbprints(got.Keys))
// Because MySQL
time.Sleep(time.Second * 2)
newKID := "new-key-id:" + suffix
pub[0].KeyID = newKID
pub[0].Use = "sig"
err = m.AddKey(ctx, set, First(pub))
require.NoError(t, err)
got, err = m.GetKey(ctx, set, newKID)
require.NoError(t, err)
newKey := First(got.Keys)
assert.EqualValues(t, "sig", newKey.Use)
newKey.Use = "enc"
err = m.UpdateKey(ctx, set, newKey)
require.NoError(t, err)
updated, err := m.GetKey(ctx, set, newKID)
require.NoError(t, err)
updatedKey := First(updated.Keys)
assert.EqualValues(t, "enc", updatedKey.Use)
keys, err = m.GetKeySet(ctx, set)
require.NoError(t, err)
var found bool
for _, k := range keys.Keys {
if k.KeyID == newKID {
found = true
break
}
}
assert.True(t, found, "Key not found in key set: %s / %s\n%+v", keys, newKID)
beforeDeleteKeysCount := len(keys.Keys)
err = m.DeleteKey(ctx, set, suffix)
require.NoError(t, err)
_, err = m.GetKey(ctx, set, suffix)
require.Error(t, err)
keys, err = m.GetKeySet(ctx, set)
require.NoError(t, err)
assert.EqualValues(t, beforeDeleteKeysCount-1, len(keys.Keys))
}
}
func TestHelperManagerKeySet(m Manager, algo string, keys *jose.JSONWebKeySet, suffix string, parallel bool) func(t *testing.T) {
return func(t *testing.T) {
ctx := t.Context()
if parallel {
t.Parallel()
}
set := uuid.Must(uuid.NewV4()).String()
_, err := m.GetKeySet(ctx, algo+set)
require.Error(t, err)
err = m.AddKeySet(ctx, algo+set, keys)
require.NoError(t, err)
got, err := m.GetKeySet(ctx, algo+set)
require.NoError(t, err)
assertx.EqualAsJSON(t, canonicalizeThumbprints(keys.Key(suffix)), canonicalizeThumbprints(got.Key(suffix)))
assertx.EqualAsJSON(t, canonicalizeThumbprints(keys.Key(suffix)), canonicalizeThumbprints(got.Key(suffix)))
for i := range got.Keys {
got.Keys[i].Use = "enc"
}
err = m.UpdateKeySet(ctx, algo+set, got)
require.NoError(t, err)
updated, err := m.GetKeySet(ctx, algo+set)
require.NoError(t, err)
assert.EqualValues(t, "enc", updated.Key(suffix)[0].Public().Use)
assert.EqualValues(t, "enc", updated.Key(suffix)[0].Use)
err = m.DeleteKeySet(ctx, algo+set)
require.NoError(t, err)
_, err = m.GetKeySet(ctx, algo+set)
require.Error(t, err)
}
}
func TestHelperManagerGenerateAndPersistKeySet(m Manager, alg string, parallel bool) func(t *testing.T) {
return func(t *testing.T) {
ctx := t.Context()
if parallel {
t.Parallel()
}
_, err := m.GetKeySet(ctx, "foo")
require.Error(t, err)
keys, err := m.GenerateAndPersistKeySet(ctx, "foo", "bar", alg, "sig")
require.NoError(t, err)
genPub, err := FindPublicKey(keys)
require.NoError(t, err)
require.NotEmpty(t, genPub)
genPriv, err := FindPrivateKey(keys)
require.NoError(t, err)
got, err := m.GetKeySet(ctx, "foo")
require.NoError(t, err)
gotPub, err := FindPublicKey(got)
require.NoError(t, err)
require.NotEmpty(t, gotPub)
gotPriv, err := FindPrivateKey(got)
require.NoError(t, err)
assertx.EqualAsJSON(t, canonicalizeKeyThumbprints(genPub), canonicalizeKeyThumbprints(gotPub))
assert.EqualValues(t, genPriv.KeyID, gotPriv.KeyID)
err = m.DeleteKeySet(ctx, "foo")
require.NoError(t, err)
_, err = m.GetKeySet(ctx, "foo")
require.Error(t, err)
}
}
func TestHelperNID(t1ValidNID, t2InvalidNID Manager) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.Background()
jwks, err := GenerateJWK(jose.RS256, "2022-03-11-ks-1-kid", "test")
require.NoError(t, err)
require.Error(t, t2InvalidNID.AddKey(ctx, "2022-03-11-k-1", &jwks.Keys[0]))
require.NoError(t, t1ValidNID.AddKey(ctx, "2022-03-11-k-1", &jwks.Keys[0]))
require.Error(t, t2InvalidNID.AddKeySet(ctx, "2022-03-11-ks-1", jwks))
require.NoError(t, t1ValidNID.AddKeySet(ctx, "2022-03-11-ks-1", jwks))
require.NoError(t, t2InvalidNID.DeleteKey(ctx, "2022-03-11-ks-1", jwks.Keys[0].KeyID)) // Delete doesn't report error if key doesn't exist
require.NoError(t, t1ValidNID.DeleteKey(ctx, "2022-03-11-ks-1", jwks.Keys[0].KeyID))
_, err = t2InvalidNID.GenerateAndPersistKeySet(ctx, "2022-03-11-ks-2", "2022-03-11-ks-2-kid", "RS256", "sig")
require.Error(t, err)
gks2, err := t1ValidNID.GenerateAndPersistKeySet(ctx, "2022-03-11-ks-2", "2022-03-11-ks-2-kid", "RS256", "sig")
require.NoError(t, err)
_, err = t1ValidNID.GetKey(ctx, "2022-03-11-ks-2", gks2.Keys[0].KeyID)
require.NoError(t, err)
_, err = t2InvalidNID.GetKey(ctx, "2022-03-11-ks-2", gks2.Keys[0].KeyID)
require.Error(t, err)
_, err = t1ValidNID.GetKeySet(ctx, "2022-03-11-ks-2")
require.NoError(t, err)
_, err = t2InvalidNID.GetKeySet(ctx, "2022-03-11-ks-2")
require.Error(t, err)
updatedKey := &gks2.Keys[0]
updatedKey.Use = "enc"
require.Error(t, t2InvalidNID.UpdateKey(ctx, "2022-03-11-ks-2", updatedKey))
require.NoError(t, t1ValidNID.UpdateKey(ctx, "2022-03-11-ks-2", updatedKey))
gks2.Keys[0].Use = "enc"
require.Error(t, t2InvalidNID.UpdateKeySet(ctx, "2022-03-11-ks-2", gks2))
require.NoError(t, t1ValidNID.UpdateKeySet(ctx, "2022-03-11-ks-2", gks2))
require.NoError(t, t2InvalidNID.DeleteKeySet(ctx, "2022-03-11-ks-2"))
require.NoError(t, t1ValidNID.DeleteKeySet(ctx, "2022-03-11-ks-2"))
}
}