// Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 package driver import ( "context" "io/fs" "strings" "time" "github.com/gobuffalo/pop/v6" _ "github.com/jackc/pgx/v4/stdlib" "github.com/luna-duclos/instrumentedsql" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/hsm" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/oauth2/trust" "github.com/ory/hydra/v2/persistence/sql" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" "github.com/ory/x/dbal" "github.com/ory/x/errorsx" otelsql "github.com/ory/x/otelx/sql" "github.com/ory/x/popx" "github.com/ory/x/resilience" "github.com/ory/x/sqlcon" ) type RegistrySQL struct { *RegistryBase defaultKeyManager jwk.Manager initialPing func(r *RegistrySQL) error } var _ Registry = new(RegistrySQL) // defaultInitialPing is the default function that will be called within RegistrySQL.Init to make sure // the database is reachable. It can be injected for test purposes by changing the value // of RegistrySQL.initialPing. var defaultInitialPing = func(m *RegistrySQL) error { if err := resilience.Retry(m.l, 5*time.Second, 5*time.Minute, m.Ping); err != nil { m.Logger().Print("Could not ping database: ", err) return errorsx.WithStack(err) } return nil } func init() { dbal.RegisterDriver( func() dbal.Driver { return NewRegistrySQL() }, ) } func NewRegistrySQL() *RegistrySQL { r := &RegistrySQL{ RegistryBase: new(RegistryBase), initialPing: defaultInitialPing, } r.RegistryBase.with(r) return r } func (m *RegistrySQL) Init( ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer, extraMigrations []fs.FS, goMigrations []popx.Migration, ) error { if m.persister == nil { m.WithContextualizer(ctxer) var opts []instrumentedsql.Opt if m.Tracer(ctx).IsLoaded() { opts = []instrumentedsql.Opt{ instrumentedsql.WithTracer(otelsql.NewTracer()), instrumentedsql.WithOmitArgs(), // don't risk leaking PII or secrets instrumentedsql.WithOpsExcluded(instrumentedsql.OpSQLRowsNext), } } // new db connection pool, idlePool, connMaxLifetime, connMaxIdleTime, cleanedDSN := sqlcon.ParseConnectionOptions( m.l, m.Config().DSN(), ) c, err := pop.NewConnection( &pop.ConnectionDetails{ URL: sqlcon.FinalizeDSN(m.l, cleanedDSN), IdlePool: idlePool, ConnMaxLifetime: connMaxLifetime, ConnMaxIdleTime: connMaxIdleTime, Pool: pool, UseInstrumentedDriver: m.Tracer(ctx).IsLoaded(), InstrumentedDriverOptions: opts, Unsafe: m.Config().DbIgnoreUnknownTableColumns(), }, ) if err != nil { return errorsx.WithStack(err) } if err := resilience.Retry(m.l, 5*time.Second, 5*time.Minute, c.Open); err != nil { return errorsx.WithStack(err) } p, err := sql.NewPersister(ctx, c, m, m.Config(), extraMigrations, goMigrations) if err != nil { return err } m.persister = p if err := m.initialPing(m); err != nil { return err } if m.Config().HSMEnabled() { hardwareKeyManager := hsm.NewKeyManager(m.HSMContext(), m.Config()) m.defaultKeyManager = jwk.NewManagerStrategy(hardwareKeyManager, m.persister) } else { m.defaultKeyManager = m.persister } // if dsn is memory we have to run the migrations on every start // use case - such as // - just in memory // - shared connection // - shared but unique in the same process // see: https://sqlite.org/inmemorydb.html if dbal.IsMemorySQLite(m.Config().DSN()) { m.Logger().Print("Hydra is running migrations on every startup as DSN is memory.\n") m.Logger().Print("This means your data is lost when Hydra terminates.\n") if err := p.MigrateUp(context.Background()); err != nil { return err } } else if migrate { if err := p.MigrateUp(context.Background()); err != nil { return err } } if skipNetworkInit { m.persister = p } else { net, err := p.DetermineNetwork(ctx) if err != nil { m.Logger().WithError(err).Warnf("Unable to determine network, retrying.") return err } m.persister = p.WithFallbackNetworkID(net.ID) } if m.Config().HSMEnabled() { hardwareKeyManager := hsm.NewKeyManager(m.HSMContext(), m.Config()) m.defaultKeyManager = jwk.NewManagerStrategy(hardwareKeyManager, m.persister) } else { m.defaultKeyManager = m.persister } } return nil } func (m *RegistrySQL) alwaysCanHandle(dsn string) bool { scheme := strings.Split(dsn, "://")[0] s := dbal.Canonicalize(scheme) return s == dbal.DriverMySQL || s == dbal.DriverPostgreSQL || s == dbal.DriverCockroachDB } func (m *RegistrySQL) Ping() error { return m.Persister().Ping() } func (m *RegistrySQL) ClientManager() client.Manager { return m.Persister() } func (m *RegistrySQL) ConsentManager() consent.Manager { return m.Persister() } func (m *RegistrySQL) OAuth2Storage() x.FositeStorer { return m.Persister() } func (m *RegistrySQL) KeyManager() jwk.Manager { return m.defaultKeyManager } func (m *RegistrySQL) SoftwareKeyManager() jwk.Manager { return m.Persister() } func (m *RegistrySQL) GrantManager() trust.GrantManager { return m.Persister() }