mirror of https://github.com/ory/hydra
144 lines
4.5 KiB
Go
144 lines
4.5 KiB
Go
// Copyright © 2022 Ory Corp
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package server_test
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v3"
|
|
"github.com/gofrs/uuid"
|
|
"github.com/sirupsen/logrus/hooks/test"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/ory/x/configx"
|
|
"github.com/ory/x/logrusx"
|
|
"github.com/ory/x/servicelocatorx"
|
|
"github.com/ory/x/tlsx"
|
|
|
|
"github.com/ory/hydra/v2/cmd/server"
|
|
"github.com/ory/hydra/v2/driver"
|
|
"github.com/ory/hydra/v2/driver/config"
|
|
"github.com/ory/hydra/v2/internal/testhelpers"
|
|
"github.com/ory/hydra/v2/jwk"
|
|
)
|
|
|
|
func TestGetOrCreateTLSCertificate(t *testing.T) {
|
|
certPath, keyPath, cert, priv := testhelpers.GenerateTLSCertificateFilesForTests(t)
|
|
logger := logrusx.New("", "")
|
|
logger.Logger.ExitFunc = func(code int) { t.Fatalf("Logger called os.Exit(%v)", code) }
|
|
d, err := driver.New(t.Context(),
|
|
driver.WithConfigOptions(configx.WithValues(map[string]interface{}{
|
|
"dsn": config.DSNMemory,
|
|
"serve.tls.enabled": true,
|
|
"serve.tls.cert.path": certPath,
|
|
"serve.tls.key.path": keyPath,
|
|
})),
|
|
driver.WithServiceLocatorOptions(servicelocatorx.WithLogger(logger)),
|
|
)
|
|
require.NoError(t, err)
|
|
getCert := server.GetOrCreateTLSCertificate(t.Context(), d, d.Config().ServeAdmin(t.Context()).TLS, "admin")
|
|
require.NotNil(t, getCert)
|
|
tlsCert, err := getCert(nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, tlsCert)
|
|
if tlsCert.Leaf == nil {
|
|
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
|
|
require.NoError(t, err)
|
|
}
|
|
require.True(t, tlsCert.Leaf.Equal(cert))
|
|
require.True(t, priv.Equal(tlsCert.PrivateKey))
|
|
|
|
// generate new cert+key
|
|
newCertPath, newKeyPath, newCert, newPriv := testhelpers.GenerateTLSCertificateFilesForTests(t)
|
|
require.False(t, cert.Equal(newCert))
|
|
require.False(t, priv.Equal(newPriv))
|
|
require.NotEqual(t, certPath, newCertPath)
|
|
require.NotEqual(t, keyPath, newKeyPath)
|
|
|
|
hook := test.NewLocal(logger.Logger)
|
|
|
|
// move them into place
|
|
require.NoError(t, os.Rename(newKeyPath, keyPath))
|
|
require.NoError(t, os.Rename(newCertPath, certPath))
|
|
|
|
// give it some time and check we're reloaded
|
|
time.Sleep(150 * time.Millisecond)
|
|
require.Nil(t, hook.LastEntry())
|
|
|
|
// request another certificate: it should be the new one
|
|
tlsCert, err = getCert(nil)
|
|
require.NoError(t, err)
|
|
if tlsCert.Leaf == nil {
|
|
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
|
|
require.NoError(t, err)
|
|
}
|
|
require.True(t, tlsCert.Leaf.Equal(newCert))
|
|
require.True(t, newPriv.Equal(tlsCert.PrivateKey))
|
|
|
|
require.NoError(t, os.WriteFile(certPath, []byte{'j', 'u', 'n', 'k'}, 0))
|
|
|
|
timeout := time.After(500 * time.Millisecond)
|
|
for {
|
|
if hook.LastEntry() != nil {
|
|
break
|
|
}
|
|
select {
|
|
case <-timeout:
|
|
require.FailNow(t, "expected error log entry")
|
|
default:
|
|
}
|
|
}
|
|
require.Contains(t, hook.LastEntry().Message, "Failed to reload TLS certificates, using previous certificates")
|
|
}
|
|
|
|
func TestGetOrCreateTLSCertificateBase64(t *testing.T) {
|
|
certPath, keyPath, cert, priv := testhelpers.GenerateTLSCertificateFilesForTests(t)
|
|
certPEM, err := os.ReadFile(certPath)
|
|
require.NoError(t, err)
|
|
certBase64 := base64.StdEncoding.EncodeToString(certPEM)
|
|
keyPEM, err := os.ReadFile(keyPath)
|
|
require.NoError(t, err)
|
|
keyBase64 := base64.StdEncoding.EncodeToString(keyPEM)
|
|
|
|
d, err := driver.New(t.Context(), driver.WithConfigOptions(configx.WithValues(map[string]interface{}{
|
|
"dsn": config.DSNMemory,
|
|
"serve.tls.enabled": true,
|
|
"serve.tls.cert.base64": certBase64,
|
|
"serve.tls.key.base64": keyBase64,
|
|
})))
|
|
require.NoError(t, err)
|
|
getCert := server.GetOrCreateTLSCertificate(t.Context(), d, d.Config().ServeAdmin(t.Context()).TLS, "admin")
|
|
require.NotNil(t, getCert)
|
|
tlsCert, err := getCert(nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, tlsCert)
|
|
if tlsCert.Leaf == nil {
|
|
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
|
|
require.NoError(t, err)
|
|
}
|
|
require.True(t, tlsCert.Leaf.Equal(cert))
|
|
require.True(t, priv.Equal(tlsCert.PrivateKey))
|
|
}
|
|
|
|
func TestCreateSelfSignedCertificate(t *testing.T) {
|
|
keys, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig")
|
|
require.NoError(t, err)
|
|
|
|
private := keys.Keys[0]
|
|
cert, err := tlsx.CreateSelfSignedCertificate(private.Key)
|
|
require.NoError(t, err)
|
|
server.AttachCertificate(&private, cert)
|
|
|
|
var actual jose.JSONWebKeySet
|
|
var b bytes.Buffer
|
|
require.NoError(t, json.NewEncoder(&b).Encode(keys))
|
|
require.NoError(t, json.NewDecoder(&b).Decode(&actual))
|
|
}
|