hydra/driver/registry_sql_test.go

202 lines
5.8 KiB
Go

// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package driver
import (
"context"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/gorilla/sessions"
pkgerr "github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/persistence/sql"
"github.com/ory/x/configx"
"github.com/ory/x/dbal"
"github.com/ory/x/httpx"
"github.com/ory/x/logrusx"
"github.com/ory/x/popx"
"github.com/ory/x/randx"
)
func init() {
sql.SilenceMigrations = true
}
func TestGetJWKSFetcherStrategyHostEnforcement(t *testing.T) {
t.Parallel()
r, err := New(t.Context(), WithConfigOptions(
configx.WithValues(map[string]any{
config.KeyDSN: "memory",
config.HSMEnabled: "false",
config.KeyClientHTTPNoPrivateIPRanges: true,
}),
configx.WithConfigFiles("../internal/.hydra.yaml"),
))
require.NoError(t, err)
_, err = r.OAuth2Config().GetJWKSFetcherStrategy(t.Context()).Resolve(t.Context(), "http://localhost:8080", true)
require.ErrorAs(t, err, new(httpx.ErrPrivateIPAddressDisallowed))
}
func TestRegistrySQL_newKeyStrategy_handlesNetworkError(t *testing.T) {
t.Parallel()
// Test ensures any network specific error is logged with a
// specific message when attempting to create a new key strategy: issue #2338
hook := test.Hook{} // Test hook for asserting log messages
l := logrusx.New("", "", logrusx.WithHook(&hook))
l.Logrus().SetOutput(io.Discard)
// Create a config and set a valid but unresolvable DSN
c := config.MustNew(t, l,
configx.WithConfigFiles("../internal/.hydra.yaml"),
configx.WithValues(map[string]any{
config.KeyDSN: "postgres://user:password@127.0.0.1:9999/postgres",
config.HSMEnabled: false,
}),
)
r, err := newRegistryWithoutInit(c, l)
if err != nil {
t.Errorf("Failed to create registry: %s", err)
return
}
r.initialPing = failedPing(errors.New("snizzles"))
assert.ErrorContains(t,
r.Init(t.Context(), true, false, nil, nil),
"snizzles",
)
assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level)
assert.Contains(t, hook.LastEntry().Message, "snizzles")
}
func TestRegistrySQL_CookieStore_MaxAgeZero(t *testing.T) {
t.Parallel()
// Test ensures that CookieStore MaxAge option is equal to zero after initialization
r, err := New(t.Context(), SkipNetworkInit(), DisableValidation(), WithConfigOptions(
configx.WithValues(map[string]any{
config.KeyDSN: dbal.NewSQLiteTestDatabase(t),
config.KeyGetSystemSecret: []string{randx.MustString(32, randx.AlphaNum)},
}),
))
require.NoError(t, err)
s, err := r.CookieStore(t.Context())
require.NoError(t, err)
cs := s.(*sessions.CookieStore)
assert.Equal(t, cs.Options.MaxAge, 0)
}
func TestRegistrySQL_HTTPClient(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, _ *http.Request) {
writer.WriteHeader(http.StatusOK)
}))
defer ts.Close()
r, err := New(t.Context(), SkipNetworkInit(), DisableValidation(), WithConfigOptions(configx.WithValues(map[string]interface{}{
config.KeyDSN: dbal.NewSQLiteTestDatabase(t),
config.KeyClientHTTPNoPrivateIPRanges: true,
config.KeyClientHTTPPrivateIPExceptionURLs: []string{ts.URL + "/exception/*"},
})))
require.NoError(t, err)
t.Run("case=matches exception glob", func(t *testing.T) {
res, err := r.HTTPClient(t.Context()).Get(ts.URL + "/exception/foo")
require.NoError(t, err)
assert.Equal(t, 200, res.StatusCode)
})
t.Run("case=does not match exception glob", func(t *testing.T) {
_, err := r.HTTPClient(t.Context()).Get(ts.URL + "/foo")
assert.ErrorContains(t, err, "prohibited IP address")
})
}
func TestDefaultKeyManager_HsmDisabled(t *testing.T) {
t.Parallel()
r, err := New(t.Context(),
SkipNetworkInit(),
DisableValidation(),
WithConfigOptions(
configx.SkipValidation(),
configx.WithValues(map[string]any{
config.KeyDSN: "memory",
config.HSMEnabled: false,
}),
),
)
require.NoError(t, err)
assert.IsType(t, &sql.JWKPersister{}, r.KeyManager())
}
func TestDbUnknownTableColumns(t *testing.T) {
t.Parallel()
dsn := dbal.NewSQLiteTestDatabase(t)
reg, err := New(t.Context(), WithConfigOptions(configx.WithValue("dsn", dsn)), WithAutoMigrate())
require.NoError(t, err)
statement := `ALTER TABLE "hydra_client" ADD COLUMN "temp_column" VARCHAR(128) NOT NULL DEFAULT '';`
require.NoError(t, reg.Persister().Connection(t.Context()).RawQuery(statement).Exec())
cl := &client.Client{
ID: strconv.Itoa(rand.Int()),
}
require.NoError(t, reg.Persister().CreateClient(t.Context(), cl))
getClients := func(ctx context.Context, reg *RegistrySQL) ([]client.Client, error) {
readClients := make([]client.Client, 0)
conn := reg.Persister().Connection(ctx)
cols := popx.DBColumns[client.Client](conn.Dialect)
return readClients, conn.RawQuery(fmt.Sprintf(`SELECT %s, temp_column FROM "hydra_client"`, cols)).All(&readClients)
}
t.Run("with ignore disabled (default behavior)", func(t *testing.T) {
_, err := getClients(t.Context(), reg)
assert.ErrorContains(t, err, "missing destination name temp_column")
})
t.Run("with ignore enabled", func(t *testing.T) {
reg, err := New(t.Context(), WithConfigOptions(
configx.WithValue("dsn", dsn),
configx.WithValue(config.KeyDBIgnoreUnknownTableColumns, true),
))
require.NoError(t, err)
actual, err := getClients(t.Context(), reg)
require.NoError(t, err)
assert.Len(t, actual, 1)
})
}
func failedPing(err error) func(context.Context, *logrusx.Logger, *sql.BasePersister) error {
return func(context.Context, *logrusx.Logger, *sql.BasePersister) error {
return pkgerr.WithStack(err)
}
}