mirror of https://github.com/ory/hydra
205 lines
5.5 KiB
Go
205 lines
5.5 KiB
Go
// Copyright © 2022 Ory Corp
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package sql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"reflect"
|
|
|
|
"github.com/gofrs/uuid"
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/ory/hydra/v2/aead"
|
|
"github.com/ory/hydra/v2/driver/config"
|
|
"github.com/ory/hydra/v2/fosite"
|
|
"github.com/ory/hydra/v2/internal/kratos"
|
|
"github.com/ory/hydra/v2/jwk"
|
|
"github.com/ory/hydra/v2/oauth2"
|
|
"github.com/ory/hydra/v2/persistence"
|
|
"github.com/ory/hydra/v2/x"
|
|
"github.com/ory/pop/v6"
|
|
"github.com/ory/x/contextx"
|
|
"github.com/ory/x/logrusx"
|
|
"github.com/ory/x/networkx"
|
|
"github.com/ory/x/otelx"
|
|
"github.com/ory/x/popx"
|
|
)
|
|
|
|
var (
|
|
_ persistence.Persister = (*Persister)(nil)
|
|
_ fosite.Transactional = (*Persister)(nil)
|
|
_ fosite.ClientManager = (*Persister)(nil)
|
|
_ oauth2.AssertionJWTReader = (*Persister)(nil)
|
|
_ x.FositeStorer = (*Persister)(nil)
|
|
)
|
|
|
|
var ErrNoTransactionOpen = errors.New("There is no Transaction in this context.")
|
|
|
|
type skipCommitContextKey int
|
|
|
|
const skipCommitKey skipCommitContextKey = 0
|
|
|
|
type (
|
|
Persister struct {
|
|
*BasePersister
|
|
r Dependencies
|
|
l *logrusx.Logger
|
|
}
|
|
Dependencies interface {
|
|
ClientHasher() fosite.Hasher
|
|
KeyCipher() *aead.AESGCM
|
|
FlowCipher() *aead.XChaCha20Poly1305
|
|
Kratos() kratos.Client
|
|
contextx.Provider
|
|
x.RegistryLogger
|
|
x.TracingProvider
|
|
config.Provider
|
|
}
|
|
BasePersister struct {
|
|
c *pop.Connection
|
|
fallbackNID uuid.UUID
|
|
d baseDependencies
|
|
}
|
|
baseDependencies interface {
|
|
x.RegistryLogger
|
|
x.TracingProvider
|
|
contextx.Provider
|
|
config.Provider
|
|
jwk.ManagerProvider
|
|
}
|
|
BasePersisterProvider interface {
|
|
BasePersister() *BasePersister
|
|
}
|
|
)
|
|
|
|
func NewPersister(base *BasePersister, r Dependencies) *Persister {
|
|
return &Persister{
|
|
BasePersister: base,
|
|
r: r,
|
|
l: r.Logger(),
|
|
}
|
|
}
|
|
|
|
func NewBasePersister(c *pop.Connection, d baseDependencies) *BasePersister {
|
|
return &BasePersister{c: c, d: d}
|
|
}
|
|
|
|
func (p *BasePersister) DetermineNetwork(ctx context.Context) (*networkx.Network, error) {
|
|
return networkx.Determine(p.Connection(ctx))
|
|
}
|
|
|
|
func (p BasePersister) WithFallbackNetworkID(nid uuid.UUID) *BasePersister {
|
|
p.fallbackNID = nid
|
|
return &p
|
|
}
|
|
|
|
func (p *BasePersister) CreateWithNetwork(ctx context.Context, v interface{}) error {
|
|
p.mustSetNetwork(ctx, v)
|
|
return p.Connection(ctx).Create(v)
|
|
}
|
|
|
|
func (p *BasePersister) UpdateWithNetwork(ctx context.Context, v interface{}) (int64, error) {
|
|
p.mustSetNetwork(ctx, v)
|
|
|
|
m := pop.NewModel(v, ctx)
|
|
cols := m.Columns()
|
|
cs := make([]string, 0, len(cols.Cols))
|
|
for _, t := range m.Columns().Cols {
|
|
cs = append(cs, t.Name)
|
|
}
|
|
|
|
return p.Connection(ctx).Where(m.IDField()+" = ? AND nid = ?", m.ID(), p.NetworkID(ctx)).UpdateQuery(v, cs...)
|
|
}
|
|
|
|
func (p *BasePersister) NetworkID(ctx context.Context) uuid.UUID {
|
|
return p.d.Contextualizer().Network(ctx, p.fallbackNID)
|
|
}
|
|
|
|
func (p *BasePersister) QueryWithNetwork(ctx context.Context) *pop.Query {
|
|
return p.Connection(ctx).Where("nid = ?", p.NetworkID(ctx))
|
|
}
|
|
|
|
func (p *BasePersister) Connection(ctx context.Context) *pop.Connection {
|
|
return popx.GetConnection(ctx, p.c)
|
|
}
|
|
|
|
func (p *BasePersister) Ping(ctx context.Context) error { return p.c.Store.SQLDB().PingContext(ctx) }
|
|
|
|
func (p *BasePersister) mustSetNetwork(ctx context.Context, v interface{}) {
|
|
rv := reflect.ValueOf(v)
|
|
|
|
if rv.Kind() != reflect.Ptr || (rv.Kind() == reflect.Ptr && rv.Elem().Kind() != reflect.Struct) {
|
|
panic("v must be a pointer to a struct")
|
|
}
|
|
nf := rv.Elem().FieldByName("NID")
|
|
if !nf.IsValid() || !nf.CanSet() {
|
|
panic("v must have settable a field 'NID uuid.UUID'")
|
|
}
|
|
nf.Set(reflect.ValueOf(p.NetworkID(ctx)))
|
|
}
|
|
|
|
func (p *BasePersister) Transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error {
|
|
return popx.Transaction(ctx, p.c, f)
|
|
}
|
|
|
|
// BeginTX implements Transactional.
|
|
func (p *BasePersister) BeginTX(ctx context.Context) (_ context.Context, err error) {
|
|
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.BeginTX")
|
|
defer otelx.End(span, &err)
|
|
|
|
fallback := &pop.Connection{TX: &pop.Tx{}}
|
|
if popx.GetConnection(ctx, fallback).TX != fallback.TX {
|
|
return context.WithValue(ctx, skipCommitKey, true), nil // no-op
|
|
}
|
|
|
|
tx, err := p.c.Store.TransactionContextOptions(ctx, &sql.TxOptions{
|
|
Isolation: sql.LevelRepeatableRead,
|
|
ReadOnly: false,
|
|
})
|
|
c := &pop.Connection{
|
|
TX: tx,
|
|
Store: tx,
|
|
ID: uuid.Must(uuid.NewV4()).String(),
|
|
Dialect: p.c.Dialect,
|
|
}
|
|
return popx.WithTransaction(ctx, c), err
|
|
}
|
|
|
|
// Commit implements Transactional.
|
|
func (p *BasePersister) Commit(ctx context.Context) (err error) {
|
|
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Commit")
|
|
defer otelx.End(span, &err)
|
|
|
|
if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip {
|
|
return nil // we skipped BeginTX, so we also skip Commit
|
|
}
|
|
|
|
fallback := &pop.Connection{TX: &pop.Tx{}}
|
|
tx := popx.GetConnection(ctx, fallback)
|
|
if tx.TX == fallback.TX || tx.TX == nil {
|
|
return errors.WithStack(ErrNoTransactionOpen)
|
|
}
|
|
|
|
return errors.WithStack(tx.TX.Commit())
|
|
}
|
|
|
|
// Rollback implements Transactional.
|
|
func (p *BasePersister) Rollback(ctx context.Context) (err error) {
|
|
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Rollback")
|
|
defer otelx.End(span, &err)
|
|
|
|
if skip, ok := ctx.Value(skipCommitKey).(bool); ok && skip {
|
|
return nil // we skipped BeginTX, so we also skip Rollback
|
|
}
|
|
|
|
fallback := &pop.Connection{TX: &pop.Tx{}}
|
|
tx := popx.GetConnection(ctx, fallback)
|
|
if tx.TX == fallback.TX || tx.TX == nil {
|
|
return errors.WithStack(ErrNoTransactionOpen)
|
|
}
|
|
|
|
return errors.WithStack(tx.TX.Rollback())
|
|
}
|