mirror of https://github.com/ory/hydra
202 lines
5.8 KiB
Go
202 lines
5.8 KiB
Go
// Copyright © 2022 Ory Corp
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package driver
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"github.com/gorilla/sessions"
|
|
pkgerr "github.com/pkg/errors"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/sirupsen/logrus/hooks/test"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/ory/hydra/v2/client"
|
|
"github.com/ory/hydra/v2/driver/config"
|
|
"github.com/ory/hydra/v2/persistence/sql"
|
|
"github.com/ory/x/configx"
|
|
"github.com/ory/x/dbal"
|
|
"github.com/ory/x/httpx"
|
|
"github.com/ory/x/logrusx"
|
|
"github.com/ory/x/popx"
|
|
"github.com/ory/x/randx"
|
|
)
|
|
|
|
func init() {
|
|
sql.SilenceMigrations = true
|
|
}
|
|
|
|
func TestGetJWKSFetcherStrategyHostEnforcement(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r, err := New(t.Context(), WithConfigOptions(
|
|
configx.WithValues(map[string]any{
|
|
config.KeyDSN: "memory",
|
|
config.HSMEnabled: "false",
|
|
config.KeyClientHTTPNoPrivateIPRanges: true,
|
|
}),
|
|
configx.WithConfigFiles("../internal/.hydra.yaml"),
|
|
))
|
|
require.NoError(t, err)
|
|
|
|
_, err = r.OAuth2Config().GetJWKSFetcherStrategy(t.Context()).Resolve(t.Context(), "http://localhost:8080", true)
|
|
require.ErrorAs(t, err, new(httpx.ErrPrivateIPAddressDisallowed))
|
|
}
|
|
|
|
func TestRegistrySQL_newKeyStrategy_handlesNetworkError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Test ensures any network specific error is logged with a
|
|
// specific message when attempting to create a new key strategy: issue #2338
|
|
|
|
hook := test.Hook{} // Test hook for asserting log messages
|
|
|
|
l := logrusx.New("", "", logrusx.WithHook(&hook))
|
|
l.Logrus().SetOutput(io.Discard)
|
|
|
|
// Create a config and set a valid but unresolvable DSN
|
|
c := config.MustNew(t, l,
|
|
configx.WithConfigFiles("../internal/.hydra.yaml"),
|
|
configx.WithValues(map[string]any{
|
|
config.KeyDSN: "postgres://user:password@127.0.0.1:9999/postgres",
|
|
config.HSMEnabled: false,
|
|
}),
|
|
)
|
|
|
|
r, err := newRegistryWithoutInit(c, l)
|
|
if err != nil {
|
|
t.Errorf("Failed to create registry: %s", err)
|
|
return
|
|
}
|
|
|
|
r.initialPing = failedPing(errors.New("snizzles"))
|
|
|
|
assert.ErrorContains(t,
|
|
r.Init(t.Context(), true, false, nil, nil),
|
|
"snizzles",
|
|
)
|
|
|
|
assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level)
|
|
assert.Contains(t, hook.LastEntry().Message, "snizzles")
|
|
}
|
|
|
|
func TestRegistrySQL_CookieStore_MaxAgeZero(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Test ensures that CookieStore MaxAge option is equal to zero after initialization
|
|
|
|
r, err := New(t.Context(), SkipNetworkInit(), DisableValidation(), WithConfigOptions(
|
|
configx.WithValues(map[string]any{
|
|
config.KeyDSN: dbal.NewSQLiteTestDatabase(t),
|
|
config.KeyGetSystemSecret: []string{randx.MustString(32, randx.AlphaNum)},
|
|
}),
|
|
))
|
|
require.NoError(t, err)
|
|
|
|
s, err := r.CookieStore(t.Context())
|
|
require.NoError(t, err)
|
|
cs := s.(*sessions.CookieStore)
|
|
|
|
assert.Equal(t, cs.Options.MaxAge, 0)
|
|
}
|
|
|
|
func TestRegistrySQL_HTTPClient(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ts := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, _ *http.Request) {
|
|
writer.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
r, err := New(t.Context(), SkipNetworkInit(), DisableValidation(), WithConfigOptions(configx.WithValues(map[string]interface{}{
|
|
config.KeyDSN: dbal.NewSQLiteTestDatabase(t),
|
|
config.KeyClientHTTPNoPrivateIPRanges: true,
|
|
config.KeyClientHTTPPrivateIPExceptionURLs: []string{ts.URL + "/exception/*"},
|
|
})))
|
|
require.NoError(t, err)
|
|
|
|
t.Run("case=matches exception glob", func(t *testing.T) {
|
|
res, err := r.HTTPClient(t.Context()).Get(ts.URL + "/exception/foo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 200, res.StatusCode)
|
|
})
|
|
|
|
t.Run("case=does not match exception glob", func(t *testing.T) {
|
|
_, err := r.HTTPClient(t.Context()).Get(ts.URL + "/foo")
|
|
assert.ErrorContains(t, err, "prohibited IP address")
|
|
})
|
|
}
|
|
|
|
func TestDefaultKeyManager_HsmDisabled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r, err := New(t.Context(),
|
|
SkipNetworkInit(),
|
|
DisableValidation(),
|
|
WithConfigOptions(
|
|
configx.SkipValidation(),
|
|
configx.WithValues(map[string]any{
|
|
config.KeyDSN: "memory",
|
|
config.HSMEnabled: false,
|
|
}),
|
|
),
|
|
)
|
|
require.NoError(t, err)
|
|
assert.IsType(t, &sql.JWKPersister{}, r.KeyManager())
|
|
}
|
|
|
|
func TestDbUnknownTableColumns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dsn := dbal.NewSQLiteTestDatabase(t)
|
|
reg, err := New(t.Context(), WithConfigOptions(configx.WithValue("dsn", dsn)), WithAutoMigrate())
|
|
require.NoError(t, err)
|
|
|
|
statement := `ALTER TABLE "hydra_client" ADD COLUMN "temp_column" VARCHAR(128) NOT NULL DEFAULT '';`
|
|
require.NoError(t, reg.Persister().Connection(t.Context()).RawQuery(statement).Exec())
|
|
|
|
cl := &client.Client{
|
|
ID: strconv.Itoa(rand.Int()),
|
|
}
|
|
require.NoError(t, reg.Persister().CreateClient(t.Context(), cl))
|
|
getClients := func(ctx context.Context, reg *RegistrySQL) ([]client.Client, error) {
|
|
readClients := make([]client.Client, 0)
|
|
conn := reg.Persister().Connection(ctx)
|
|
cols := popx.DBColumns[client.Client](conn.Dialect)
|
|
return readClients, conn.RawQuery(fmt.Sprintf(`SELECT %s, temp_column FROM "hydra_client"`, cols)).All(&readClients)
|
|
}
|
|
|
|
t.Run("with ignore disabled (default behavior)", func(t *testing.T) {
|
|
_, err := getClients(t.Context(), reg)
|
|
assert.ErrorContains(t, err, "missing destination name temp_column")
|
|
})
|
|
|
|
t.Run("with ignore enabled", func(t *testing.T) {
|
|
reg, err := New(t.Context(), WithConfigOptions(
|
|
configx.WithValue("dsn", dsn),
|
|
configx.WithValue(config.KeyDBIgnoreUnknownTableColumns, true),
|
|
))
|
|
require.NoError(t, err)
|
|
|
|
actual, err := getClients(t.Context(), reg)
|
|
require.NoError(t, err)
|
|
assert.Len(t, actual, 1)
|
|
})
|
|
}
|
|
|
|
func failedPing(err error) func(context.Context, *logrusx.Logger, *sql.BasePersister) error {
|
|
return func(context.Context, *logrusx.Logger, *sql.BasePersister) error {
|
|
return pkgerr.WithStack(err)
|
|
}
|
|
}
|