refactor: identity persistence (#3101)

This commit is contained in:
Arne Luenser 2023-03-03 15:34:16 +01:00 committed by GitHub
parent ea6ad2a8fe
commit ceb5cc2b8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 585 additions and 488 deletions

View File

@ -32,7 +32,7 @@ func NewMigrateHandler() *MigrateHandler {
return &MigrateHandler{}
}
func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) error {
func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string, opts ...driver.RegistryOption) error {
var d driver.Registry
var err error
@ -78,7 +78,7 @@ func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) error {
}
}
err = d.Init(cmd.Context(), &contextx.Default{}, driver.SkipNetworkInit)
err = d.Init(cmd.Context(), &contextx.Default{}, append(opts, driver.SkipNetworkInit)...)
if err != nil {
return errors.Wrap(err, "an error occurred initializing migrations")
}

View File

@ -4,7 +4,7 @@
package courier
import (
cx "context"
"context"
"net/http"
"github.com/spf13/cobra"
@ -22,7 +22,7 @@ import (
)
func NewWatchCmd(slOpts []servicelocatorx.Option, dOpts []driver.RegistryOption) *cobra.Command {
var c = &cobra.Command{
c := &cobra.Command{
Use: "watch",
Short: "Starts the Ory Kratos message courier",
RunE: func(cmd *cobra.Command, args []string) error {
@ -38,7 +38,7 @@ func NewWatchCmd(slOpts []servicelocatorx.Option, dOpts []driver.RegistryOption)
return c
}
func StartCourier(ctx cx.Context, r driver.Registry) error {
func StartCourier(ctx context.Context, r driver.Registry) error {
eg, ctx := errgroup.WithContext(ctx)
if r.Config().CourierExposeMetricsPort(ctx) != 0 {
@ -54,7 +54,7 @@ func StartCourier(ctx cx.Context, r driver.Registry) error {
return eg.Wait()
}
func ServeMetrics(ctx cx.Context, r driver.Registry) error {
func ServeMetrics(ctx context.Context, r driver.Registry) error {
c := r.Config()
l := r.Logger()
n := negroni.New()
@ -73,39 +73,23 @@ func ServeMetrics(ctx cx.Context, r driver.Registry) error {
handler = otelx.NewHandler(handler, "cmd.courier.ServeMetrics", otelhttp.WithTracerProvider(tp))
}
// #nosec G112 - the correct settings are set by graceful.WithDefaults
//#nosec G112 -- the correct settings are set by graceful.WithDefaults
server := graceful.WithDefaults(&http.Server{
Addr: c.MetricsListenOn(ctx),
Handler: handler,
})
l.Printf("Starting the metrics httpd on: %s", server.Addr)
if err := graceful.Graceful(func() error {
errChan := make(chan error, 1)
go func(errChan chan error) {
if err := server.ListenAndServe(); err != nil {
errChan <- err
}
}(errChan)
select {
case err := <-errChan:
l.Errorf("Failed to start the metrics httpd: %s\n", err)
return err
case <-ctx.Done():
l.Printf("Shutting down the metrics httpd: context closed: %s\n", ctx.Err())
return server.Shutdown(ctx)
}
}, server.Shutdown); err != nil {
if err := graceful.GracefulContext(ctx, server.ListenAndServe, server.Shutdown); err != nil {
l.Errorln("Failed to gracefully shutdown metrics httpd")
return err
} else {
l.Println("Metrics httpd was shutdown gracefully")
}
l.Println("Metrics httpd was shutdown gracefully")
return nil
}
func Watch(ctx cx.Context, r driver.Registry) error {
ctx, cancel := cx.WithCancel(ctx)
func Watch(ctx context.Context, r driver.Registry) error {
ctx, cancel := context.WithCancel(ctx)
r.Logger().Println("Courier worker started.")
if err := graceful.Graceful(func() error {
@ -115,7 +99,7 @@ func Watch(ctx cx.Context, r driver.Registry) error {
}
return c.Work(ctx)
}, func(_ cx.Context) error {
}, func(_ context.Context) error {
cancel()
return nil
}); err != nil {

View File

@ -120,7 +120,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
handler = otelx.TraceHandler(handler, otelhttp.WithTracerProvider(tracer.Provider()))
}
// #nosec G112 - the correct settings are set by graceful.WithDefaults
//#nosec G112 -- the correct settings are set by graceful.WithDefaults
server := graceful.WithDefaults(&http.Server{
Handler: handler,
TLSConfig: &tls.Config{GetCertificate: certs, MinVersion: tls.VersionTLS12},
@ -128,7 +128,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
addr := c.PublicListenOn(ctx)
l.Printf("Starting the public httpd on: %s", addr)
if err := graceful.Graceful(func() error {
if err := graceful.GracefulContext(ctx, func() error {
listener, err := networkx.MakeListener(addr, c.PublicSocketPermission(ctx))
if err != nil {
return err
@ -191,7 +191,7 @@ func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *se
)
}
// #nosec G112 - the correct settings are set by graceful.WithDefaults
//#nosec G112 -- the correct settings are set by graceful.WithDefaults
server := graceful.WithDefaults(&http.Server{
Handler: handler,
TLSConfig: &tls.Config{GetCertificate: certs, MinVersion: tls.VersionTLS12},
@ -200,7 +200,7 @@ func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *se
addr := c.AdminListenOn(ctx)
l.Printf("Starting the admin httpd on: %s", addr)
if err := graceful.Graceful(func() error {
if err := graceful.GracefulContext(ctx, func() error {
listener, err := networkx.MakeListener(addr, c.AdminSocketPermission(ctx))
if err != nil {
return err

View File

@ -188,7 +188,7 @@ func runLoadTest(cmd *cobra.Command, conf *argon2Config, reqPerMin int) (*result
eg.Go(func(i int) func() error {
return func() error {
// wait randomly before starting, max. sample time
// #nosec G404 - just a timeout to collect statistical data
//#nosec G404 -- just a timeout to collect statistical data
t := time.Duration(rand.Intn(int(sampleTime)))
timer := time.NewTimer(t)
defer timer.Stop()

View File

@ -47,7 +47,7 @@ Use -w or --write to write output back to files instead of stdout.
cmdx.Must(err, `JSONNet file "%s" could not be formatted: %s`, file, err)
if shouldWrite {
err := os.WriteFile(file, []byte(output), 0644) // #nosec
err := os.WriteFile(file, []byte(output), 0644) //#nosec
cmdx.Must(err, `Could not write to file "%s" because: %s`, file, err)
} else {
fmt.Println(output)

View File

@ -7,11 +7,12 @@ import (
"github.com/spf13/cobra"
"github.com/ory/kratos/cmd/cliclient"
"github.com/ory/kratos/driver"
"github.com/ory/x/configx"
)
// migrateSqlCmd represents the sql command
func NewMigrateSQLCmd() *cobra.Command {
func NewMigrateSQLCmd(opts ...driver.RegistryOption) *cobra.Command {
c := &cobra.Command{
Use: "sql <database-url>",
Short: "Create SQL schemas and apply migration plans",
@ -29,7 +30,7 @@ You can read in the database URL using the -e flag, for example:
Before running this command on an existing database, create a back up!
`,
RunE: func(cmd *cobra.Command, args []string) error {
return cliclient.NewMigrateHandler().MigrateSQL(cmd, args)
return cliclient.NewMigrateHandler().MigrateSQL(cmd, args, opts...)
},
}

View File

@ -82,13 +82,13 @@ func newSMTP(ctx context.Context, deps Dependencies) (*smtpClient, error) {
// Enforcing StartTLS by default for security best practices (config review, etc.)
skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls"))
if !skipStartTLS {
// #nosec G402 This is ok (and required!) because it is configurable and disabled by default.
//#nosec G402 -- This is ok (and required!) because it is configurable and disabled by default.
dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, Certificates: tlsCertificates, ServerName: serverName}
// Enforcing StartTLS
dialer.StartTLSPolicy = gomail.MandatoryStartTLS
}
case "smtps":
// #nosec G402 This is ok (and required!) because it is configurable and disabled by default.
//#nosec G402 -- This is ok (and required!) because it is configurable and disabled by default.
dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, Certificates: tlsCertificates, ServerName: serverName}
dialer.SSL = true
}

View File

@ -517,8 +517,7 @@ func (m *RegistryDefault) ContinuityCookieManager(ctx context.Context) sessions.
func (m *RegistryDefault) Tracer(ctx context.Context) *otelx.Tracer {
if m.trc == nil {
m.Logger().WithError(errors.WithStack(errors.New(""))).Warn("No tracer setup in RegistryDefault")
return otelx.NewNoop(m.l, m.Config().Tracing(ctx)) // should never happen
return otelx.NewNoop(m.l, m.Config().Tracing(ctx))
}
return m.trc
}

5
go.mod
View File

@ -61,6 +61,7 @@ require (
github.com/jteeuwen/go-bindata v3.0.7+incompatible
github.com/julienschmidt/httprouter v1.3.0
github.com/knadh/koanf v1.4.4
github.com/laher/mergefs v0.1.2-0.20230223191438-d16611b2f4e7
github.com/luna-duclos/instrumentedsql v1.1.3
github.com/mattn/goveralls v0.0.7
github.com/mikefarah/yq/v4 v4.19.1
@ -70,7 +71,7 @@ require (
github.com/ory/client-go v0.2.0-alpha.60
github.com/ory/dockertest/v3 v3.9.1
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe
github.com/ory/graceful v0.1.3
github.com/ory/graceful v0.1.4-0.20230301144740-e222150c51d0
github.com/ory/herodot v0.9.13
github.com/ory/hydra-client-go/v2 v2.0.3
github.com/ory/jsonschema/v3 v3.0.7
@ -87,7 +88,7 @@ require (
github.com/spf13/cobra v1.6.1
github.com/spf13/pflag v1.0.5
github.com/sqs/goreturns v0.0.0-20181028201513-538ac6014518
github.com/stretchr/testify v1.8.1
github.com/stretchr/testify v1.8.2
github.com/tidwall/gjson v1.14.3
github.com/tidwall/sjson v1.2.5
github.com/urfave/negroni v1.0.0

11
go.sum
View File

@ -913,6 +913,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/go-gypsy v1.0.0/go.mod h1:chkXM0zjdpXOiqkCW1XcCHDfjfk14PH2KKkQWxfJUcU=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/laher/mergefs v0.1.2-0.20230223191438-d16611b2f4e7 h1:PDeBswTUsSIT4QSrzLvlqKlGrANYa7TrXUwdBN9myU8=
github.com/laher/mergefs v0.1.2-0.20230223191438-d16611b2f4e7/go.mod h1:FSY1hYy94on4Tz60waRMGdO1awwS23BacqJlqf9lJ9Q=
github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/letsencrypt/pkcs11key/v4 v4.0.0/go.mod h1:EFUvBDay26dErnNb70Nd0/VW3tJiIbETBPTl9ATXQag=
@ -945,6 +947,8 @@ github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJ
github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE=
github.com/markbates/pkger v0.17.1 h1:/MKEtWqtc0mZvu9OinB9UzVN9iYCwLWuyUv4Bw+PCno=
github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0=
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
@ -1093,8 +1097,8 @@ github.com/ory/dockertest/v3 v3.9.1/go.mod h1:42Ir9hmvaAPm0Mgibk6mBPi7SFvTXxEcnz
github.com/ory/go-acc v0.2.6/go.mod h1:4Kb/UnPcT8qRAk3IAxta+hvVapdxTLWtrr7bFLlEgpw=
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc=
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI=
github.com/ory/graceful v0.1.3 h1:FaeXcHZh168WzS+bqruqWEw/HgXWLdNv2nJ+fbhxbhc=
github.com/ory/graceful v0.1.3/go.mod h1:4zFz687IAF7oNHHiB586U4iL+/4aV09o/PYLE34t2bA=
github.com/ory/graceful v0.1.4-0.20230301144740-e222150c51d0 h1:VMUeLRfQD14fOMvhpYZIIT4vtAqxYh+f3KnSqCeJ13o=
github.com/ory/graceful v0.1.4-0.20230301144740-e222150c51d0/go.mod h1:hg2iCy+LCWOXahBZ+NQa4dk8J2govyQD79rrqrgMyY8=
github.com/ory/herodot v0.9.13 h1:cN/Z4eOkErl/9W7hDIDLb79IO/bfsH+8yscBjRpB4IU=
github.com/ory/herodot v0.9.13/go.mod h1:IWDs9kSvFQqw/cQ8zi5ksyYvITiUU4dI7glUrhZcJYo=
github.com/ory/hydra-client-go/v2 v2.0.3 h1:jIx968J9RBnjRuaQ21QMLCwZoa28FPvzYWAQ+88XVLw=
@ -1342,8 +1346,9 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8=
github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=

View File

@ -7,8 +7,8 @@ import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/md5" // #nosec G501
"crypto/sha1" // #nosec G505 - compatibility for imported passwords
"crypto/md5" //#nosec G501 -- compatibility for imported passwords
"crypto/sha1" //#nosec G505 -- compatibility for imported passwords
"crypto/sha256"
"crypto/sha512"
"crypto/subtle"
@ -220,7 +220,7 @@ func CompareMD5(_ context.Context, password []byte, hash []byte) error {
r := strings.NewReplacer("{SALT}", string(salt), "{PASSWORD}", string(password))
arg = []byte(r.Replace(string(pf)))
}
// #nosec G401
//#nosec G401 -- compatibility for imported passwords
otherHash := md5.Sum(arg)
// Check that the contents of the hashed passwords are identical. Note
@ -412,7 +412,7 @@ func compareSHAHelper(hasher string, raw []byte, hash []byte) error {
switch hasher {
case "sha1":
sum := sha1.Sum(raw) // #nosec G401 - compatibility for imported passwords
sum := sha1.Sum(raw) //#nosec G401 -- compatibility for imported passwords
sha = sum[:]
case "sha256":
sum := sha256.Sum256(raw)

View File

@ -7,7 +7,7 @@ import (
"bytes"
"context"
"crypto/rand"
"crypto/sha1" // #nosec G505 - compatibility for imported passwords
"crypto/sha1" //#nosec G505 -- compatibility for imported passwords
"crypto/sha256"
"crypto/sha512"
"encoding/base64"

View File

@ -137,7 +137,7 @@ func (i *Identity) AfterEagerFind(tx *pop.Connection) error {
return err
}
if err := i.validate(); err != nil {
if err := i.Validate(); err != nil {
return err
}
@ -378,7 +378,7 @@ func (i WithCredentialsMetadataAndAdminMetadataInJSON) MarshalJSON() ([]byte, er
return json.Marshal(localIdentity(i))
}
func (i *Identity) validate() error {
func (i *Identity) Validate() error {
expected := i.NID
if expected == uuid.Nil {
return errors.WithStack(herodot.ErrInternalServerError.WithReason("Received empty nid."))

View File

@ -267,7 +267,7 @@ func TestValidateNID(t *testing.T) {
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
err := tc.i.validate()
err := tc.i.Validate()
if tc.expectedErr {
require.Error(t, err)
} else {

View File

@ -77,5 +77,8 @@ type (
// HydrateIdentityAssociations hydrates the associations of an identity.
HydrateIdentityAssociations(ctx context.Context, i *Identity, expandables Expandables) error
// InjectTraitsSchemaURL sets the identity's traits JSON schema URL from the schema's ID.
InjectTraitsSchemaURL(ctx context.Context, i *Identity) error
}
)

View File

@ -41,9 +41,7 @@ import (
"github.com/ory/kratos/x"
)
func TestPool(ctx context.Context, conf *config.Config, p interface {
persistence.Persister
}, m *identity.Manager) func(t *testing.T) {
func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, m *identity.Manager) func(t *testing.T) {
return func(t *testing.T) {
exampleServerURL := urlx.ParseOrPanic("http://example.com")
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, exampleServerURL.String())
@ -134,7 +132,6 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
assert.Empty(t, actual.RecoveryAddresses)
assert.Empty(t, actual.VerifiableAddresses)
require.Len(t, actual.InternalCredentials, 2)
require.Len(t, actual.Credentials, 2)
assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"})
@ -181,7 +178,6 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
t.Run("expand=everything", func(t *testing.T) {
runner(t, identity.ExpandEverything, func(t *testing.T, actual *identity.Identity) {
require.Len(t, actual.InternalCredentials, 2)
require.Len(t, actual.Credentials, 2)
assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"})
@ -199,7 +195,6 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
runner(t, identity.ExpandNothing, func(t *testing.T, actual *identity.Identity) {
require.NoError(t, p.HydrateIdentityAssociations(ctx, actual, identity.ExpandEverything))
require.Len(t, actual.InternalCredentials, 2)
require.Len(t, actual.Credentials, 2)
assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"})

View File

@ -275,7 +275,7 @@ func AssertRegistrationRespectsValidation(t *testing.T, reg *driver.RegistryDefa
})
}
func AssertCommonErrorCases(t *testing.T, reg *driver.RegistryDefault, flows []string) {
func AssertCommonErrorCases(t *testing.T, flows []string) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)
testhelpers.SetDefaultIdentitySchemaFromRaw(conf, basicSchema)

View File

@ -148,7 +148,7 @@ func CheckE2EServerOnHTTP(t *testing.T, publicPort, adminPort int) (publicUrl, a
func waitToComeAlive(t *testing.T, publicUrl, adminUrl string) {
require.NoError(t, retry.Do(func() error {
/* #nosec G402: TLS InsecureSkipVerify set true. */
//#nosec G402 -- TLS InsecureSkipVerify set true
tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
client := &http.Client{Transport: tr}

View File

@ -197,7 +197,7 @@ func NewSettingsAPIServer(t *testing.T, reg *driver.RegistryDefault, ids map[str
reg.Config().MustSet(ctx, config.ViperKeyPublicBaseURL, tsp.URL)
reg.Config().MustSet(ctx, config.ViperKeyAdminBaseURL, tsa.URL)
// #nosec G112
//#nosec G112
return tsp, tsa, AddAndLoginIdentities(t, reg, &httptest.Server{Config: &http.Server{Handler: public}, URL: tsp.URL}, ids)
}

View File

@ -57,6 +57,7 @@ type Persister interface {
MigrateDown(c context.Context, steps int) error
MigrateUp(c context.Context) error
Migrator() *popx.Migrator
MigrationBox() *popx.MigrationBox
GetConnection(ctx context.Context) *pop.Connection
Transaction(ctx context.Context, callback func(ctx context.Context, connection *pop.Connection) error) error
Networker

View File

@ -0,0 +1,45 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package devices
import (
"context"
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/ory/kratos/session"
"github.com/ory/x/contextx"
"github.com/ory/x/popx"
"github.com/ory/x/sqlcon"
)
var _ session.DevicePersister = (*DevicePersister)(nil)
type DevicePersister struct {
ctxer contextx.Provider
c *pop.Connection
nid uuid.UUID
}
func NewPersister(r contextx.Provider, c *pop.Connection) *DevicePersister {
return &DevicePersister{
ctxer: r,
c: c,
}
}
func (p *DevicePersister) NetworkID(ctx context.Context) uuid.UUID {
return p.ctxer.Contextualizer().Network(ctx, p.nid)
}
func (p DevicePersister) WithNetworkID(nid uuid.UUID) session.DevicePersister {
p.nid = nid
return &p
}
func (p *DevicePersister) CreateDevice(ctx context.Context, d *session.Device) error {
d.NID = p.NetworkID(ctx)
return sqlcon.HandleError(popx.GetConnection(ctx, p.c.WithContext(ctx)).Create(d))
}

View File

@ -1,7 +1,7 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package sql
package identity
import (
"context"
@ -10,7 +10,9 @@ import (
"strings"
"time"
"github.com/ory/x/contextx"
"github.com/ory/x/pointerx"
"github.com/ory/x/popx"
"golang.org/x/sync/errgroup"
@ -21,7 +23,11 @@ import (
"github.com/ory/jsonschema/v3"
"github.com/ory/x/sqlxx"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/otp"
"github.com/ory/kratos/persistence/sql/update"
"github.com/ory/kratos/schema"
"github.com/ory/kratos/x"
"github.com/gobuffalo/pop/v6"
@ -31,40 +37,81 @@ import (
"github.com/ory/herodot"
"github.com/ory/x/errorsx"
"github.com/ory/x/sqlcon"
"github.com/ory/kratos/identity"
)
var _ identity.Pool = new(Persister)
var _ identity.PrivilegedPool = new(Persister)
var _ identity.Pool = new(IdentityPersister)
var _ identity.PrivilegedPool = new(IdentityPersister)
func (p *Persister) ListVerifiableAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.VerifiableAddress, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListVerifiableAddresses")
defer span.End()
if err := p.GetConnection(ctx).Where("nid = ?", p.NetworkID(ctx)).Order("id DESC").Paginate(page, x.MaxItemsPerPage(itemsPerPage)).All(&a); err != nil {
return nil, sqlcon.HandleError(err)
}
return a, err
type dependencies interface {
schema.IdentityTraitsProvider
identity.ValidationProvider
x.LoggingProvider
config.Provider
contextx.Provider
x.TracingProvider
}
func (p *Persister) ListRecoveryAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.RecoveryAddress, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListRecoveryAddresses")
defer span.End()
type IdentityPersister struct {
r dependencies
c *pop.Connection
nid uuid.UUID
}
func NewPersister(r dependencies, c *pop.Connection) *IdentityPersister {
return &IdentityPersister{
c: c,
r: r,
}
}
func (p *IdentityPersister) NetworkID(ctx context.Context) uuid.UUID {
return p.r.Contextualizer().Network(ctx, p.nid)
}
func (p IdentityPersister) WithNetworkID(nid uuid.UUID) identity.PrivilegedPool {
p.nid = nid
return &p
}
func WithTransaction(ctx context.Context, tx *pop.Connection) context.Context {
return popx.WithTransaction(ctx, tx)
}
func (p *IdentityPersister) Transaction(ctx context.Context, callback func(ctx context.Context, connection *pop.Connection) error) error {
return popx.Transaction(ctx, p.c.WithContext(ctx), callback)
}
func (p *IdentityPersister) GetConnection(ctx context.Context) *pop.Connection {
return popx.GetConnection(ctx, p.c.WithContext(ctx))
}
func (p *IdentityPersister) ListVerifiableAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.VerifiableAddress, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListVerifiableAddresses")
otelx.End(span, &err)
if err := p.GetConnection(ctx).Where("nid = ?", p.NetworkID(ctx)).Order("id DESC").Paginate(page, x.MaxItemsPerPage(itemsPerPage)).All(&a); err != nil {
return nil, sqlcon.HandleError(err)
}
return a, err
return a, nil
}
func (p *IdentityPersister) ListRecoveryAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.RecoveryAddress, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListRecoveryAddresses")
otelx.End(span, &err)
if err := p.GetConnection(ctx).Where("nid = ?", p.NetworkID(ctx)).Order("id DESC").Paginate(page, x.MaxItemsPerPage(itemsPerPage)).All(&a); err != nil {
return nil, sqlcon.HandleError(err)
}
return a, nil
}
func stringToLowerTrim(match string) string {
return strings.ToLower(strings.TrimSpace(match))
}
func (p *Persister) normalizeIdentifier(ct identity.CredentialsType, match string) string {
func NormalizeIdentifier(ct identity.CredentialsType, match string) string {
switch ct {
case identity.CredentialsTypeLookup:
// lookup credentials are case-sensitive
@ -83,9 +130,9 @@ func (p *Persister) normalizeIdentifier(ct identity.CredentialsType, match strin
return match
}
func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, ct identity.CredentialsType, match string) (*identity.Identity, *identity.Credentials, error) {
func (p *IdentityPersister) FindByCredentialsIdentifier(ctx context.Context, ct identity.CredentialsType, match string) (_ *identity.Identity, _ *identity.Credentials, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindByCredentialsIdentifier")
defer span.End()
otelx.End(span, &err)
nid := p.NetworkID(ctx)
@ -94,22 +141,21 @@ func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, ct identity
}
// Force case-insensitivity and trimming for identifiers
match = p.normalizeIdentifier(ct, match)
match = NormalizeIdentifier(ct, match)
// #nosec G201
if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT
ic.identity_id
FROM %s ic
INNER JOIN %s ict on ic.identity_credential_type_id = ict.id
INNER JOIN %s ici on ic.id = ici.identity_credential_id AND ici.identity_credential_type_id = ict.id
WHERE ici.identifier = ?
AND ic.nid = ?
AND ici.nid = ?
AND ict.name = ?`,
"identity_credentials",
"identity_credential_types",
"identity_credential_identifiers",
),
if err := p.GetConnection(ctx).RawQuery(`
SELECT
ic.identity_id
FROM identity_credentials ic
INNER JOIN identity_credential_types ict
ON ic.identity_credential_type_id = ict.id
INNER JOIN identity_credential_identifiers ici
ON ic.id = ici.identity_credential_id AND ici.identity_credential_type_id = ict.id
WHERE ici.identifier = ?
AND ic.nid = ?
AND ici.nid = ?
AND ict.name = ?
LIMIT 1`, // pop doesn't understand how to add a limit clause to this query
match,
nid,
nid,
@ -135,9 +181,9 @@ WHERE ici.identifier = ?
return i.CopyWithoutCredentials(), creds, nil
}
func (p *Persister) findIdentityCredentialsType(ctx context.Context, ct identity.CredentialsType) (*identity.CredentialsTypeTable, error) {
func (p *IdentityPersister) findIdentityCredentialsType(ctx context.Context, ct identity.CredentialsType) (_ *identity.CredentialsTypeTable, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findIdentityCredentialsType")
defer span.End()
otelx.End(span, &err)
var m identity.CredentialsTypeTable
if err := p.GetConnection(ctx).Where("name = ?", ct).First(&m); err != nil {
@ -146,9 +192,9 @@ func (p *Persister) findIdentityCredentialsType(ctx context.Context, ct identity
return &m, nil
}
func (p *Persister) createIdentityCredentials(ctx context.Context, i *identity.Identity) error {
func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, i *identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createIdentityCredentials")
defer span.End()
otelx.End(span, &err)
c := p.GetConnection(ctx)
@ -174,7 +220,7 @@ func (p *Persister) createIdentityCredentials(ctx context.Context, i *identity.I
for _, ids := range cred.Identifiers {
// Force case-insensitivity and trimming for identifiers
ids = p.normalizeIdentifier(cred.Type, ids)
ids = NormalizeIdentifier(cred.Type, ids)
if len(ids) == 0 {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to create identity credentials with missing or empty identifier."))
@ -196,9 +242,9 @@ func (p *Persister) createIdentityCredentials(ctx context.Context, i *identity.I
return nil
}
func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.Identity) error {
func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, i *identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createVerifiableAddresses")
defer span.End()
otelx.End(span, &err)
for k := range i.VerifiableAddresses {
if err := p.GetConnection(ctx).Create(&i.VerifiableAddresses[k]); err != nil {
@ -210,7 +256,10 @@ func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.I
func updateAssociation[T interface {
Hash() string
}](ctx context.Context, p *Persister, i *identity.Identity, inID []T) error {
}](ctx context.Context, p *IdentityPersister, i *identity.Identity, inID []T) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.updateAssociation")
otelx.End(span, &err)
var inDB []T
if err := p.GetConnection(ctx).
Where("identity_id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).
@ -252,12 +301,12 @@ func updateAssociation[T interface {
return nil
}
func (p *Persister) normalizeAllAddressess(ctx context.Context, id *identity.Identity) {
func (p *IdentityPersister) normalizeAllAddressess(ctx context.Context, id *identity.Identity) {
p.normalizeRecoveryAddresses(ctx, id)
p.normalizeVerifiableAddresses(ctx, id)
}
func (p *Persister) normalizeVerifiableAddresses(ctx context.Context, id *identity.Identity) {
func (p *IdentityPersister) normalizeVerifiableAddresses(ctx context.Context, id *identity.Identity) {
for k := range id.VerifiableAddresses {
v := id.VerifiableAddresses[k]
@ -282,7 +331,7 @@ func (p *Persister) normalizeVerifiableAddresses(ctx context.Context, id *identi
}
}
func (p *Persister) normalizeRecoveryAddresses(ctx context.Context, id *identity.Identity) {
func (p *IdentityPersister) normalizeRecoveryAddresses(ctx context.Context, id *identity.Identity) {
for k := range id.RecoveryAddresses {
id.RecoveryAddresses[k].IdentityID = id.ID
id.RecoveryAddresses[k].NID = p.NetworkID(ctx)
@ -291,9 +340,9 @@ func (p *Persister) normalizeRecoveryAddresses(ctx context.Context, id *identity
}
}
func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Identity) error {
func (p *IdentityPersister) createRecoveryAddresses(ctx context.Context, i *identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createRecoveryAddresses")
defer span.End()
otelx.End(span, &err)
for k := range i.RecoveryAddresses {
if err := p.GetConnection(ctx).Create(&i.RecoveryAddresses[k]); err != nil {
@ -303,9 +352,9 @@ func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Ide
return nil
}
func (p *Persister) CountIdentities(ctx context.Context) (int64, error) {
func (p *IdentityPersister) CountIdentities(ctx context.Context) (n int64, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CountIdentities")
defer span.End()
otelx.End(span, &err)
count, err := p.c.WithContext(ctx).Where("nid = ?", p.NetworkID(ctx)).Count(new(identity.Identity))
if err != nil {
@ -314,9 +363,9 @@ func (p *Persister) CountIdentities(ctx context.Context) (int64, error) {
return int64(count), nil
}
func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) error {
func (p *IdentityPersister) CreateIdentity(ctx context.Context, i *identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateIdentity")
defer span.End()
otelx.End(span, &err)
i.NID = p.NetworkID(ctx)
@ -334,7 +383,7 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er
i.Traits = identity.Traits("{}")
}
if err := p.injectTraitsSchemaURL(ctx, i); err != nil {
if err := p.InjectTraitsSchemaURL(ctx, i); err != nil {
return err
}
@ -361,156 +410,18 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er
})
}
func (p *Persister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {
func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HydrateIdentityAssociations")
defer otelx.End(span, &err)
con := p.GetConnection(ctx)
if err := con.Load(i, expand.ToEager()...); err != nil {
return err
}
if err := i.AfterEagerFind(con); err != nil {
return err
}
return p.injectTraitsSchemaURL(ctx, i)
}
func (p *Persister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (res []identity.Identity, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListIdentities")
defer otelx.End(span, &err)
span.SetAttributes(
attribute.Int("page", params.Page),
attribute.Int("per_page", params.PerPage),
attribute.StringSlice("expand", params.Expand.ToEager()),
attribute.Bool("use:credential_identifier_filter", params.CredentialsIdentifier != ""),
attribute.String("network.id", p.NetworkID(ctx).String()),
)
is := make([]identity.Identity, 0)
con := p.GetConnection(ctx)
nid := p.NetworkID(ctx)
query := con.Where("identities.nid = ?", nid).Order("identities.id DESC")
if len(params.Expand) > 0 {
query = query.EagerPreload(params.Expand.ToEager()...)
}
if match := params.CredentialsIdentifier; len(match) > 0 {
// When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore
// important to normalize the identifier before querying the database.
match = p.normalizeIdentifier(identity.CredentialsTypePassword, match)
query = query.
InnerJoin("identity_credentials ic", "ic.identity_id = identities.id").
InnerJoin("identity_credential_types ict", "ict.id = ic.identity_credential_type_id").
InnerJoin("identity_credential_identifiers ici", "ici.identity_credential_id = ic.id").
Where("(ic.nid = ? AND ici.nid = ? AND ici.identifier = ?)", nid, nid, match).
Where("ict.name IN (?)", identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword).
Limit(1)
} else {
query = query.Paginate(params.Page, params.PerPage)
}
/* #nosec G201 TableName is static */
if err := sqlcon.HandleError(query.All(&is)); err != nil {
return nil, err
}
schemaCache := map[string]string{}
for k := range is {
i := &is[k]
if u, ok := schemaCache[i.SchemaID]; ok {
i.SchemaURL = u
} else {
if err := p.injectTraitsSchemaURL(ctx, i); err != nil {
return nil, err
}
schemaCache[i.SchemaID] = i.SchemaURL
}
is[k] = *i
}
return is, nil
}
func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateIdentity")
defer span.End()
if err := p.validateIdentity(ctx, i); err != nil {
return err
}
i.NID = p.NetworkID(ctx)
return sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
if count, err := tx.Where("id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).Count(i); err != nil {
return err
} else if count == 0 {
return sql.ErrNoRows
}
p.normalizeAllAddressess(ctx, i)
if err := updateAssociation(ctx, p, i, i.RecoveryAddresses); err != nil {
return err
}
if err := updateAssociation(ctx, p, i, i.VerifiableAddresses); err != nil {
return err
}
/* #nosec G201 TableName is static */
if err := tx.RawQuery(
fmt.Sprintf(
`DELETE FROM %s WHERE identity_id = ? AND nid = ?`,
new(identity.Credentials).TableName(ctx)),
i.ID, p.NetworkID(ctx)).Exec(); err != nil {
return sqlcon.HandleError(err)
}
if err := p.update(WithTransaction(ctx, tx), i); err != nil {
return err
}
return p.createIdentityCredentials(ctx, i)
}))
}
func (p *Persister) DeleteIdentity(ctx context.Context, id uuid.UUID) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteIdentity")
defer span.End()
return p.delete(ctx, new(identity.Identity), id)
}
func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID, expand identity.Expandables) (res *identity.Identity, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentity")
defer otelx.End(span, &err)
span.SetAttributes(
attribute.String("identity.id", id.String()),
attribute.StringSlice("expand", expand.ToEager()),
attribute.String("network.id", p.NetworkID(ctx).String()),
)
nid := p.NetworkID(ctx)
con := p.GetConnection(ctx)
query := con.Where("id = ? AND nid = ?", id, nid)
var (
i identity.Identity
con = p.GetConnection(ctx)
nid = p.NetworkID(ctx)
credentials []identity.Credentials
verifiableAddresses []identity.VerifiableAddress
recoveryAddresses []identity.RecoveryAddress
)
if err := query.First(&i); err != nil {
return nil, sqlcon.HandleError(err)
}
eg, ctx := errgroup.WithContext(ctx)
if expand.Has(identity.ExpandFieldRecoveryAddresses) {
eg.Go(func() error {
@ -560,7 +471,7 @@ func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID, expand identi
}
if err := eg.Wait(); err != nil {
return nil, err
return err
}
i.VerifiableAddresses = verifiableAddresses
@ -568,26 +479,164 @@ func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID, expand identi
i.InternalCredentials = credentials
if err := i.AfterEagerFind(con); err != nil {
return err
}
return p.InjectTraitsSchemaURL(ctx, i)
}
func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (res []identity.Identity, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListIdentities")
defer otelx.End(span, &err)
span.SetAttributes(
attribute.Int("page", params.Page),
attribute.Int("per_page", params.PerPage),
attribute.StringSlice("expand", params.Expand.ToEager()),
attribute.Bool("use:credential_identifier_filter", params.CredentialsIdentifier != ""),
attribute.String("network.id", p.NetworkID(ctx).String()),
)
is := make([]identity.Identity, 0)
con := p.GetConnection(ctx)
nid := p.NetworkID(ctx)
query := con.Where("identities.nid = ?", nid).Order("identities.id DESC")
if len(params.Expand) > 0 {
query = query.EagerPreload(params.Expand.ToEager()...)
}
if match := params.CredentialsIdentifier; len(match) > 0 {
// When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore
// important to normalize the identifier before querying the database.
match = NormalizeIdentifier(identity.CredentialsTypePassword, match)
query = query.
InnerJoin("identity_credentials ic", "ic.identity_id = identities.id").
InnerJoin("identity_credential_types ict", "ict.id = ic.identity_credential_type_id").
InnerJoin("identity_credential_identifiers ici", "ici.identity_credential_id = ic.id").
Where("(ic.nid = ? AND ici.nid = ? AND ici.identifier = ?)", nid, nid, match).
Where("ict.name IN (?)", identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword).
Limit(1)
} else {
query = query.Paginate(params.Page, params.PerPage)
}
if err := sqlcon.HandleError(query.All(&is)); err != nil {
return nil, err
}
if err := p.injectTraitsSchemaURL(ctx, &i); err != nil {
schemaCache := map[string]string{}
for k := range is {
i := &is[k]
if u, ok := schemaCache[i.SchemaID]; ok {
i.SchemaURL = u
} else {
if err := p.InjectTraitsSchemaURL(ctx, i); err != nil {
return nil, err
}
schemaCache[i.SchemaID] = i.SchemaURL
}
is[k] = *i
}
return is, nil
}
func (p *IdentityPersister) UpdateIdentity(ctx context.Context, i *identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateIdentity")
defer otelx.End(span, &err)
if err := p.validateIdentity(ctx, i); err != nil {
return err
}
i.NID = p.NetworkID(ctx)
return sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
if count, err := tx.Where("id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).Count(i); err != nil {
return err
} else if count == 0 {
return sql.ErrNoRows
}
p.normalizeAllAddressess(ctx, i)
if err := updateAssociation(ctx, p, i, i.RecoveryAddresses); err != nil {
return err
}
if err := updateAssociation(ctx, p, i, i.VerifiableAddresses); err != nil {
return err
}
//#nosec G201 -- TableName is static
if err := tx.RawQuery(
fmt.Sprintf(
`DELETE FROM %s WHERE identity_id = ? AND nid = ?`,
new(identity.Credentials).TableName(ctx)),
i.ID, p.NetworkID(ctx)).Exec(); err != nil {
return sqlcon.HandleError(err)
}
if err := update.Generic(WithTransaction(ctx, tx), tx, p.r.Tracer(ctx).Tracer(), i); err != nil {
return err
}
return p.createIdentityCredentials(ctx, i)
}))
}
func (p *IdentityPersister) DeleteIdentity(ctx context.Context, id uuid.UUID) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteIdentity")
defer otelx.End(span, &err)
nid := p.NetworkID(ctx)
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", new(identity.Identity).TableName(ctx)),
id,
nid,
).ExecWithCount()
if err != nil {
return sqlcon.HandleError(err)
}
if count == 0 {
return errors.WithStack(sqlcon.ErrNoRows)
}
return nil
}
func (p *IdentityPersister) GetIdentity(ctx context.Context, id uuid.UUID, expand identity.Expandables) (_ *identity.Identity, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentity")
defer otelx.End(span, &err)
span.SetAttributes(
attribute.String("identity.id", id.String()),
attribute.StringSlice("expand", expand.ToEager()),
attribute.String("network.id", p.NetworkID(ctx).String()),
)
var i identity.Identity
if err := p.GetConnection(ctx).Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).First(&i); err != nil {
return nil, sqlcon.HandleError(err)
}
if err := p.HydrateIdentityAssociations(ctx, &i, expand); err != nil {
return nil, err
}
return &i, nil
}
func (p *Persister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (res *identity.Identity, err error) {
func (p *IdentityPersister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (res *identity.Identity, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentityConfidential")
defer otelx.End(span, &err)
return p.GetIdentity(ctx, id, identity.ExpandEverything)
}
func (p *Persister) FindVerifiableAddressByValue(ctx context.Context, via identity.VerifiableAddressType, value string) (*identity.VerifiableAddress, error) {
func (p *IdentityPersister) FindVerifiableAddressByValue(ctx context.Context, via identity.VerifiableAddressType, value string) (_ *identity.VerifiableAddress, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindVerifiableAddressByValue")
defer span.End()
otelx.End(span, &err)
var address identity.VerifiableAddress
if err := p.GetConnection(ctx).Where("nid = ? AND via = ? AND value = ?", p.NetworkID(ctx), via, stringToLowerTrim(value)).First(&address); err != nil {
@ -597,9 +646,9 @@ func (p *Persister) FindVerifiableAddressByValue(ctx context.Context, via identi
return &address, nil
}
func (p *Persister) FindRecoveryAddressByValue(ctx context.Context, via identity.RecoveryAddressType, value string) (*identity.RecoveryAddress, error) {
func (p *IdentityPersister) FindRecoveryAddressByValue(ctx context.Context, via identity.RecoveryAddressType, value string) (_ *identity.RecoveryAddress, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindRecoveryAddressByValue")
defer span.End()
defer otelx.End(span, &err)
var address identity.RecoveryAddress
if err := p.GetConnection(ctx).Where("nid = ? AND via = ? AND value = ?", p.NetworkID(ctx), via, stringToLowerTrim(value)).First(&address); err != nil {
@ -609,16 +658,17 @@ func (p *Persister) FindRecoveryAddressByValue(ctx context.Context, via identity
return &address, nil
}
func (p *Persister) VerifyAddress(ctx context.Context, code string) error {
func (p *IdentityPersister) VerifyAddress(ctx context.Context, code string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAddress")
defer span.End()
otelx.End(span, &err)
newCode, err := otp.New()
if err != nil {
return err
}
count, err := p.GetConnection(ctx).RawQuery(
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
fmt.Sprintf(
"UPDATE %s SET status = ?, verified = true, verified_at = ?, code = ? WHERE nid = ? AND code = ? AND expires_at > ?",
new(identity.VerifiableAddress).TableName(ctx),
@ -641,18 +691,18 @@ func (p *Persister) VerifyAddress(ctx context.Context, code string) error {
return nil
}
func (p *Persister) UpdateVerifiableAddress(ctx context.Context, address *identity.VerifiableAddress) error {
func (p *IdentityPersister) UpdateVerifiableAddress(ctx context.Context, address *identity.VerifiableAddress) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateVerifiableAddress")
defer span.End()
otelx.End(span, &err)
address.NID = p.NetworkID(ctx)
address.Value = stringToLowerTrim(address.Value)
return p.update(ctx, address)
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), address)
}
func (p *Persister) validateIdentity(ctx context.Context, i *identity.Identity) error {
func (p *IdentityPersister) validateIdentity(ctx context.Context, i *identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.validateIdentity")
defer span.End()
otelx.End(span, &err)
if err := p.r.IdentityValidator().ValidateWithRunner(ctx, i); err != nil {
if _, ok := errorsx.Cause(err).(*jsonschema.ValidationError); ok {
@ -664,9 +714,9 @@ func (p *Persister) validateIdentity(ctx context.Context, i *identity.Identity)
return nil
}
func (p *Persister) injectTraitsSchemaURL(ctx context.Context, i *identity.Identity) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.injectTraitsSchemaURL")
defer span.End()
func (p *IdentityPersister) InjectTraitsSchemaURL(ctx context.Context, i *identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InjectTraitsSchemaURL")
otelx.End(span, &err)
ss, err := p.r.IdentityTraitsSchemas(ctx)
if err != nil {

View File

@ -15,8 +15,6 @@ import (
"github.com/ory/x/servicelocatorx"
"github.com/ory/x/fsx"
"github.com/ory/kratos/identity"
"github.com/bradleyjkemp/cupaloy/v2"
@ -124,7 +122,7 @@ func TestMigrations(t *testing.T) {
t.Run("suite=up", func(t *testing.T) {
tm, err := popx.NewMigrationBox(
fsx.Merge(os.DirFS("../migrations/sql")),
os.DirFS("../migrations/sql"),
popx.NewMigrator(c, logrusx.New("", "", logrusx.ForceLevel(logrus.DebugLevel)), nil, 1*time.Minute),
popx.WithTestdata(t, os.DirFS("./testdata")),
)

View File

@ -6,28 +6,24 @@ package sql
import (
"context"
"embed"
"fmt"
"time"
"github.com/ory/x/contextx"
"github.com/ory/x/fsx"
"github.com/gobuffalo/pop/v6"
"github.com/gobuffalo/pop/v6/columns"
"github.com/gofrs/uuid"
"github.com/laher/mergefs"
"github.com/pkg/errors"
"github.com/ory/x/networkx"
"github.com/ory/x/sqlcon"
"github.com/ory/x/popx"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/persistence"
"github.com/ory/kratos/persistence/sql/devices"
idpersistence "github.com/ory/kratos/persistence/sql/identity"
"github.com/ory/kratos/schema"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
"github.com/ory/x/contextx"
"github.com/ory/x/networkx"
"github.com/ory/x/popx"
)
var _ persistence.Persister = new(Persister)
@ -37,34 +33,40 @@ var migrations embed.FS
type (
persisterDependencies interface {
schema.IdentityTraitsProvider
identity.ValidationProvider
x.LoggingProvider
config.Provider
contextx.Provider
x.TracingProvider
schema.IdentityTraitsProvider
identity.ValidationProvider
}
Persister struct {
nid uuid.UUID
c *pop.Connection
mb *popx.MigrationBox
mbs popx.MigrationStatuses
r persisterDependencies
p *networkx.Manager
isSQLite bool
nid uuid.UUID
c *pop.Connection
mb *popx.MigrationBox
mbs popx.MigrationStatuses
r persisterDependencies
p *networkx.Manager
identity.PrivilegedPool
session.DevicePersister
}
)
func NewPersister(ctx context.Context, r persisterDependencies, c *pop.Connection) (*Persister, error) {
m, err := popx.NewMigrationBox(fsx.Merge(migrations, networkx.Migrations), popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0))
m, err := popx.NewMigrationBox(mergefs.Merge(migrations, networkx.Migrations), popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0))
if err != nil {
return nil, err
}
m.DumpMigrations = false
return &Persister{
c: c, mb: m, r: r, isSQLite: c.Dialect.Name() == "sqlite3",
p: networkx.NewManager(c, r.Logger(), r.Tracer(ctx)),
c: c,
mb: m,
r: r,
PrivilegedPool: idpersistence.NewPersister(r, c),
DevicePersister: devices.NewPersister(r, c),
p: networkx.NewManager(c, r.Logger(), r.Tracer(ctx)),
}, nil
}
@ -72,8 +74,18 @@ func (p *Persister) NetworkID(ctx context.Context) uuid.UUID {
return p.r.Contextualizer().Network(ctx, p.nid)
}
func (p Persister) WithNetworkID(sid uuid.UUID) persistence.Persister {
p.nid = sid
func (p Persister) WithNetworkID(nid uuid.UUID) persistence.Persister {
p.nid = nid
if pp, ok := p.PrivilegedPool.(interface {
WithNetworkID(uuid.UUID) identity.PrivilegedPool
}); ok {
p.PrivilegedPool = pp.WithNetworkID(nid)
}
if dp, ok := p.DevicePersister.(interface {
WithNetworkID(uuid.UUID) session.DevicePersister
}); ok {
p.DevicePersister = dp.WithNetworkID(nid)
}
return &p
}
@ -113,6 +125,10 @@ func (p *Persister) MigrateUp(ctx context.Context) error {
return p.mb.Up(ctx)
}
func (p *Persister) MigrationBox() *popx.MigrationBox {
return p.mb
}
func (p *Persister) Migrator() *popx.Migrator {
return p.mb.Migrator
}
@ -130,15 +146,6 @@ func (p *Persister) Ping() error {
return errors.WithStack(p.c.Store.(pinger).Ping())
}
type quotable interface {
Quote(key string) string
}
type node interface {
GetID() uuid.UUID
GetNID() uuid.UUID
}
func (p *Persister) CleanupDatabase(ctx context.Context, wait time.Duration, older time.Duration, batchSize int) error {
currentTime := time.Now().Add(-older)
p.r.Logger().Printf("Cleaning up records older than %s\n", currentTime)
@ -189,81 +196,3 @@ func (p *Persister) CleanupDatabase(ctx context.Context, wait time.Duration, old
"This should be re-run periodically, to be sure that all expired data is purged.")
return nil
}
func (p *Persister) update(ctx context.Context, v node, columnNames ...string) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.update")
defer span.End()
c := p.GetConnection(ctx)
quoter, ok := c.Dialect.(quotable)
if !ok {
return errors.Errorf("store is not a quoter: %T", p.c.Store)
}
model := pop.NewModel(v, ctx)
tn := model.TableName()
cols := columns.Columns{}
if len(columnNames) > 0 && tn == model.TableName() {
cols = columns.NewColumnsWithAlias(tn, model.As, model.IDField())
cols.Add(columnNames...)
} else {
cols = columns.ForStructWithAlias(v, tn, model.As, model.IDField())
}
// #nosec
stmt := fmt.Sprintf("SELECT COUNT(id) FROM %s AS %s WHERE %s.id = ? AND %s.nid = ?",
quoter.Quote(model.TableName()),
model.Alias(),
model.Alias(),
model.Alias(),
)
var count int
if err := c.Store.GetContext(ctx, &count, c.Dialect.TranslateSQL(stmt), v.GetID(), v.GetNID()); err != nil {
return sqlcon.HandleError(err)
} else if count == 0 {
return errors.WithStack(sqlcon.ErrNoRows)
}
// #nosec
stmt = fmt.Sprintf("UPDATE %s AS %s SET %s WHERE %s AND %s.nid = :nid",
quoter.Quote(model.TableName()),
model.Alias(),
cols.Writeable().QuotedUpdateString(quoter),
model.WhereNamedID(),
model.Alias(),
)
if _, err := c.Store.NamedExecContext(ctx, stmt, v); err != nil {
return sqlcon.HandleError(err)
}
return nil
}
func (p *Persister) delete(ctx context.Context, v interface{}, id uuid.UUID) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.delete")
defer span.End()
nid := p.NetworkID(ctx)
tabler, ok := v.(interface {
TableName(ctx context.Context) string
})
if !ok {
return errors.Errorf("expected model to have TableName signature but got: %T", v)
}
/* #nosec G201 TableName is static */
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", tabler.TableName(ctx)),
id,
nid,
).ExecWithCount()
if err != nil {
return sqlcon.HandleError(err)
}
if count == 0 {
return errors.WithStack(sqlcon.ErrNoRows)
}
return nil
}

View File

@ -43,7 +43,7 @@ func (p *Persister) DeleteContinuitySession(ctx context.Context, id uuid.UUID) e
defer span.End()
if count, err := p.GetConnection(ctx).RawQuery(
// #nosec
//#nosec G201 -- TableName is static
fmt.Sprintf("DELETE FROM %s WHERE id=? AND nid=?",
new(continuity.Container).TableName(ctx)), id, p.NetworkID(ctx)).ExecWithCount(); err != nil {
return sqlcon.HandleError(err)
@ -54,7 +54,7 @@ func (p *Persister) DeleteContinuitySession(ctx context.Context, id uuid.UUID) e
}
func (p *Persister) DeleteExpiredContinuitySessions(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
//#nosec G201 -- TableName is static
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(continuity.Container).TableName(ctx),

View File

@ -18,6 +18,7 @@ import (
"github.com/ory/x/uuidx"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/persistence/sql/update"
)
var _ courier.Persister = new(Persister)
@ -89,7 +90,7 @@ func (p *Persister) NextMessages(ctx context.Context, limit uint8) (messages []c
for i := range m {
message := &m[i]
message.Status = courier.MessageStatusProcessing
if err := p.update(ctx, message, "status"); err != nil {
if err := update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), message, "status"); err != nil {
return err
}
}

View File

@ -7,7 +7,6 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"time"
"github.com/gofrs/uuid"
@ -56,9 +55,8 @@ func (p *Persister) Read(ctx context.Context, id uuid.UUID) (*errorx.ErrorContai
return nil, sqlcon.HandleError(err)
}
// #nosec G201
if err := p.GetConnection(ctx).RawQuery(
fmt.Sprintf("UPDATE %s SET was_seen = true, seen_at = ? WHERE id = ? AND nid = ?", "selfservice_errors"),
"UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ? AND nid = ?",
time.Now().UTC(), id, p.NetworkID(ctx)).Exec(); err != nil {
return nil, sqlcon.HandleError(err)
}
@ -71,14 +69,12 @@ func (p *Persister) Clear(ctx context.Context, olderThan time.Duration, force bo
defer span.End()
if force {
// #nosec G201
err = p.GetConnection(ctx).RawQuery(
fmt.Sprintf("DELETE FROM %s WHERE nid = ? AND seen_at < ? AND seen_at IS NOT NULL", "selfservice_errors"),
"DELETE FROM selfservice_errors WHERE nid = ? AND seen_at < ? AND seen_at IS NOT NULL",
p.NetworkID(ctx), time.Now().UTC().Add(-olderThan)).Exec()
} else {
// #nosec G201
err = p.GetConnection(ctx).RawQuery(
fmt.Sprintf("DELETE FROM %s WHERE nid = ? AND was_seen=true AND seen_at < ? AND seen_at IS NOT NULL", "selfservice_errors"),
"DELETE FROM selfservice_errors WHERE nid = ? AND was_seen=true AND seen_at < ? AND seen_at IS NOT NULL",
p.NetworkID(ctx), time.Now().UTC().Add(-olderThan)).Exec()
}

View File

@ -13,9 +13,6 @@ import (
"github.com/ory/x/configx"
"github.com/ory/x/otelx"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/schema"
"github.com/gobuffalo/pop/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -23,6 +20,8 @@ import (
"github.com/ory/x/logrusx"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/schema"
)
type logRegistryOnly struct {
@ -39,14 +38,6 @@ func (l *logRegistryOnly) Contextualizer() contextx.Contextualizer {
panic("implement me")
}
func (l *logRegistryOnly) IdentityTraitsSchemas(ctx context.Context) (schema.Schemas, error) {
panic("implement me")
}
func (l *logRegistryOnly) IdentityValidator() *identity.Validator {
panic("implement me")
}
func (l *logRegistryOnly) Logger() *logrusx.Logger {
if l.l == nil {
l.l = logrusx.New("kratos", "testing")
@ -61,6 +52,13 @@ func (l *logRegistryOnly) Audit() *logrusx.Logger {
func (l *logRegistryOnly) Tracer(ctx context.Context) *otelx.Tracer {
return otelx.NewNoop(l.l, new(otelx.Config))
}
func (l *logRegistryOnly) IdentityTraitsSchemas(ctx context.Context) (schema.Schemas, error) {
panic("implement me")
}
func (l *logRegistryOnly) IdentityValidator() *identity.Validator {
panic("implement me")
}
var _ persisterDependencies = &logRegistryOnly{}

View File

@ -14,6 +14,7 @@ import (
"github.com/ory/x/sqlcon"
"github.com/ory/kratos/persistence/sql/update"
"github.com/ory/kratos/selfservice/flow/login"
)
@ -35,7 +36,7 @@ func (p *Persister) UpdateLoginFlow(ctx context.Context, r *login.Flow) error {
r.EnsureInternalContext()
cp := *r
cp.NID = p.NetworkID(ctx)
return p.update(ctx, cp)
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
}
func (p *Persister) GetLoginFlow(ctx context.Context, id uuid.UUID) (*login.Flow, error) {
@ -68,7 +69,7 @@ func (p *Persister) ForceLoginFlow(ctx context.Context, id uuid.UUID) error {
}
func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
//#nosec G201 -- TableName is static
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(login.Flow).TableName(ctx),

View File

@ -15,6 +15,7 @@ import (
"github.com/gofrs/uuid"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/persistence/sql/update"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/recovery"
"github.com/ory/kratos/selfservice/strategy/code"
@ -51,7 +52,7 @@ func (p *Persister) UpdateRecoveryFlow(ctx context.Context, r *recovery.Flow) er
cp := *r
cp.NID = p.NetworkID(ctx)
return p.update(ctx, cp)
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
}
func (p *Persister) CreateRecoveryToken(ctx context.Context, token *link.RecoveryToken) error {
@ -101,7 +102,7 @@ func (p *Persister) UseRecoveryToken(ctx context.Context, fID uuid.UUID, token s
}
rt.RecoveryAddress = &ra
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
return tx.RawQuery(fmt.Sprintf("UPDATE %s SET used=true, used_at=? WHERE id=? AND nid = ?", rt.TableName(ctx)), time.Now().UTC(), rt.ID, nid).Exec()
})); err != nil {
return nil, err
@ -114,12 +115,12 @@ func (p *Persister) DeleteRecoveryToken(ctx context.Context, token string) error
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRecoveryToken")
defer span.End()
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE token=? AND nid = ?", new(link.RecoveryToken).TableName(ctx)), token, p.NetworkID(ctx)).Exec()
}
func (p *Persister) DeleteExpiredRecoveryFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
//#nosec G201 -- TableName is static
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(recovery.Flow).TableName(ctx),
@ -186,14 +187,14 @@ func (p *Persister) UseRecoveryCode(ctx context.Context, fID uuid.UUID, codeVal
if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), fID, nid).Exec()); err != nil {
return err
}
var submitCount int
// Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on.
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), fID, nid).First(&submitCount)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction
@ -248,7 +249,7 @@ func (p *Persister) UseRecoveryCode(ctx context.Context, fID uuid.UUID, codeVal
}
recoveryCode.RecoveryAddress = &ra
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
return sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", recoveryCode.TableName(ctx)), time.Now().UTC(), recoveryCode.ID, nid).Exec())
})); err != nil {
return nil, err
@ -273,6 +274,6 @@ func (p *Persister) DeleteRecoveryCodesOfFlow(ctx context.Context, fID uuid.UUID
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRecoveryCodesOfFlow")
defer span.End()
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE selfservice_recovery_flow_id = ? AND nid = ?", new(code.RecoveryCode).TableName(ctx)), fID, p.NetworkID(ctx)).Exec()
}

View File

@ -12,6 +12,7 @@ import (
"github.com/ory/x/sqlcon"
"github.com/ory/kratos/persistence/sql/update"
"github.com/ory/kratos/selfservice/flow/registration"
)
@ -31,7 +32,7 @@ func (p *Persister) UpdateRegistrationFlow(ctx context.Context, r *registration.
r.EnsureInternalContext()
cp := *r
cp.NID = p.NetworkID(ctx)
return p.update(ctx, cp)
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
}
func (p *Persister) GetRegistrationFlow(ctx context.Context, id uuid.UUID) (*registration.Flow, error) {
@ -48,7 +49,7 @@ func (p *Persister) GetRegistrationFlow(ctx context.Context, id uuid.UUID) (*reg
}
func (p *Persister) DeleteExpiredRegistrationFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
//#nosec G201 -- TableName is static
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(registration.Flow).TableName(ctx),

View File

@ -8,25 +8,17 @@ import (
"fmt"
"time"
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
"github.com/ory/x/otelx"
"github.com/ory/kratos/identity"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/stringsx"
"github.com/gobuffalo/pop/v6"
"github.com/pkg/errors"
"github.com/gofrs/uuid"
"github.com/ory/x/sqlcon"
"github.com/ory/kratos/session"
"github.com/ory/x/otelx"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/sqlcon"
"github.com/ory/x/stringsx"
)
var _ session.Persister = new(Persister)
@ -57,7 +49,7 @@ func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables s
if expandables.Has(session.ExpandSessionIdentity) {
// This is needed because of how identities are fetched from the store (if we use eager not all fields are
// available!).
i, err := p.GetIdentity(ctx, s.IdentityID, identity.ExpandDefault)
i, err := p.PrivilegedPool.GetIdentity(ctx, s.IdentityID, identity.ExpandDefault)
if err != nil {
return nil, err
}
@ -116,7 +108,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt
if s[k].Identity == nil {
continue
}
if err := p.injectTraitsSchemaURL(ctx, s[k].Identity); err != nil {
if err := p.InjectTraitsSchemaURL(ctx, s[k].Identity); err != nil {
return nil, 0, nil, err
}
}
@ -212,7 +204,7 @@ func (p *Persister) UpsertSession(ctx context.Context, s *session.Session) (err
device.UserAgent = stringsx.GetPointer(stringsx.TruncateByteLen(*device.UserAgent, SessionDeviceUserAgentMaxLength))
}
if err := sqlcon.HandleError(tx.Create(device)); err != nil {
if err := p.DevicePersister.CreateDevice(ctx, device); err != nil {
return err
}
}
@ -225,17 +217,29 @@ func (p *Persister) DeleteSession(ctx context.Context, sid uuid.UUID) (err error
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteSession")
defer otelx.End(span, &err)
return p.delete(ctx, new(session.Session), sid)
nid := p.NetworkID(ctx)
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", new(session.Session).TableName(ctx)),
sid,
nid,
).ExecWithCount()
if err != nil {
return sqlcon.HandleError(err)
}
if count == 0 {
return errors.WithStack(sqlcon.ErrNoRows)
}
return nil
}
func (p *Persister) DeleteSessionsByIdentity(ctx context.Context, identityID uuid.UUID) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteSessionsByIdentity")
defer otelx.End(span, &err)
// #nosec G201
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE identity_id = ? AND nid = ?",
"sessions",
new(session.Session).TableName(ctx),
),
identityID,
p.NetworkID(ctx),
@ -279,7 +283,7 @@ func (p *Persister) GetSessionByToken(ctx context.Context, token string, expand
// available!).
if expand.Has(session.ExpandSessionIdentity) {
eg.Go(func() (err error) {
i, err = p.GetIdentity(ctx, s.IdentityID, identityExpand)
i, err = p.PrivilegedPool.GetIdentity(ctx, s.IdentityID, identityExpand)
return err
})
}
@ -298,10 +302,10 @@ func (p *Persister) DeleteSessionByToken(ctx context.Context, token string) (err
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteSessionByToken")
defer otelx.End(span, &err)
// #nosec G201
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE token = ? AND nid = ?",
"sessions",
new(session.Session).TableName(ctx),
),
token,
p.NetworkID(ctx),
@ -319,10 +323,10 @@ func (p *Persister) RevokeSessionByToken(ctx context.Context, token string) (err
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSessionByToken")
defer otelx.End(span, &err)
// #nosec G201
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"UPDATE %s SET active = false WHERE token = ? AND nid = ?",
"sessions",
new(session.Session).TableName(ctx),
),
token,
p.NetworkID(ctx),
@ -341,10 +345,10 @@ func (p *Persister) RevokeSessionById(ctx context.Context, sID uuid.UUID) (err e
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSessionById")
defer otelx.End(span, &err)
// #nosec G201
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"UPDATE %s SET active = false WHERE id = ? AND nid = ?",
"sessions",
new(session.Session).TableName(ctx),
),
sID,
p.NetworkID(ctx),
@ -364,10 +368,10 @@ func (p *Persister) RevokeSession(ctx context.Context, iID, sID uuid.UUID) (err
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSession")
defer otelx.End(span, &err)
// #nosec G201
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"UPDATE %s SET active = false WHERE id = ? AND identity_id = ? AND nid = ?",
"sessions",
new(session.Session).TableName(ctx),
),
sID,
iID,
@ -384,10 +388,10 @@ func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID u
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSessionsIdentityExcept")
defer otelx.End(span, &err)
// #nosec G201
//#nosec G201 -- TableName is static
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"UPDATE %s SET active = false WHERE identity_id = ? AND id != ? AND nid = ?",
"sessions",
new(session.Session).TableName(ctx),
),
iID,
sID,
@ -402,10 +406,12 @@ func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID u
func (p *Persister) DeleteExpiredSessions(ctx context.Context, expiresAt time.Time, limit int) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteExpiredSessions")
defer otelx.End(span, &err)
//#nosec G201 -- TableName is static
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
"sessions",
"sessions",
new(session.Session).TableName(ctx),
new(session.Session).TableName(ctx),
limit,
),
expiresAt,

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/persistence/sql/update"
"github.com/gofrs/uuid"
@ -39,7 +40,7 @@ func (p *Persister) GetSettingsFlow(ctx context.Context, id uuid.UUID) (*setting
return nil, sqlcon.HandleError(err)
}
r.Identity, err = p.GetIdentity(ctx, r.IdentityID, identity.ExpandDefault)
r.Identity, err = p.PrivilegedPool.GetIdentity(ctx, r.IdentityID, identity.ExpandDefault)
if err != nil {
return nil, err
}
@ -54,11 +55,11 @@ func (p *Persister) UpdateSettingsFlow(ctx context.Context, r *settings.Flow) er
r.EnsureInternalContext()
cp := *r
cp.NID = p.NetworkID(ctx)
return p.update(ctx, cp)
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
}
func (p *Persister) DeleteExpiredSettingsFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
//#nosec G201 -- TableName is static
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(settings.Flow).TableName(ctx),

View File

@ -12,6 +12,7 @@ import (
"github.com/pkg/errors"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/persistence/sql/update"
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
@ -53,7 +54,7 @@ func (p *Persister) UpdateVerificationFlow(ctx context.Context, r *verification.
cp := *r
cp.NID = p.NetworkID(ctx)
return p.update(ctx, cp)
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
}
func (p *Persister) CreateVerificationToken(ctx context.Context, token *link.VerificationToken) error {
@ -101,7 +102,7 @@ func (p *Persister) UseVerificationToken(ctx context.Context, fID uuid.UUID, tok
rt.VerifiableAddress = &va
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
return tx.RawQuery(fmt.Sprintf("UPDATE %s SET used=true, used_at=? WHERE id=? AND nid = ?", rt.TableName(ctx)), time.Now().UTC(), rt.ID, nid).Exec()
})); err != nil {
return nil, err
@ -115,12 +116,12 @@ func (p *Persister) DeleteVerificationToken(ctx context.Context, token string) e
defer span.End()
nid := p.NetworkID(ctx)
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE token=? AND nid = ?", new(link.VerificationToken).TableName(ctx)), token, nid).Exec()
}
func (p *Persister) DeleteExpiredVerificationFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
//#nosec G201 -- TableName is static
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(verification.Flow).TableName(ctx),
@ -149,7 +150,7 @@ func (p *Persister) UseVerificationCode(ctx context.Context, fID uuid.UUID, code
if err := sqlcon.HandleError(
tx.RawQuery(
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName),
fID,
nid,
@ -162,7 +163,7 @@ func (p *Persister) UseVerificationCode(ctx context.Context, fID uuid.UUID, code
// Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on.
if err := sqlcon.HandleError(
tx.RawQuery(
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName),
fID,
nid,
@ -222,7 +223,7 @@ func (p *Persister) UseVerificationCode(ctx context.Context, fID uuid.UUID, code
verificationCode.VerifiableAddress = &va
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
return tx.
RawQuery(
fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", verificationCode.TableName(ctx)),
@ -276,7 +277,7 @@ func (p *Persister) DeleteVerificationCodesOfFlow(ctx context.Context, fID uuid.
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteVerificationCodesOfFlow")
defer span.End()
/* #nosec G201 TableName is static */
//#nosec G201 -- TableName is static
return p.GetConnection(ctx).
RawQuery(
fmt.Sprintf("DELETE FROM %s WHERE selfservice_verification_flow_id = ? AND nid = ?", new(code.VerificationCode).TableName(ctx)),

View File

@ -0,0 +1,72 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package update
import (
"context"
"fmt"
"github.com/gobuffalo/pop/v6"
"github.com/gobuffalo/pop/v6/columns"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/trace"
"github.com/ory/x/sqlcon"
)
type Model interface {
GetID() uuid.UUID
GetNID() uuid.UUID
}
func Generic(ctx context.Context, c *pop.Connection, tracer trace.Tracer, v Model, columnNames ...string) error {
ctx, span := tracer.Start(ctx, "persistence.sql.update")
defer span.End()
quoter, ok := c.Dialect.(interface{ Quote(key string) string })
if !ok {
return errors.Errorf("store is not a quoter: %T", c.Store)
}
model := pop.NewModel(v, ctx)
tn := model.TableName()
cols := columns.Columns{}
if len(columnNames) > 0 && tn == model.TableName() {
cols = columns.NewColumnsWithAlias(tn, model.As, model.IDField())
cols.Add(columnNames...)
} else {
cols = columns.ForStructWithAlias(v, tn, model.As, model.IDField())
}
//#nosec G201 -- TableName is static
stmt := fmt.Sprintf("SELECT COUNT(id) FROM %s AS %s WHERE %s.id = ? AND %s.nid = ?",
quoter.Quote(model.TableName()),
model.Alias(),
model.Alias(),
model.Alias(),
)
var count int
if err := c.Store.GetContext(ctx, &count, c.Dialect.TranslateSQL(stmt), v.GetID(), v.GetNID()); err != nil {
return sqlcon.HandleError(err)
} else if count == 0 {
return errors.WithStack(sqlcon.ErrNoRows)
}
//#nosec G201 -- TableName is static
stmt = fmt.Sprintf("UPDATE %s AS %s SET %s WHERE %s AND %s.nid = :nid",
quoter.Quote(model.TableName()),
model.Alias(),
cols.Writeable().QuotedUpdateString(quoter),
model.WhereNamedID(),
model.Alias(),
)
if _, err := c.Store.NamedExecContext(ctx, stmt, v); err != nil {
return sqlcon.HandleError(err)
}
return nil
}

View File

@ -22,6 +22,7 @@ import (
"github.com/ory/kratos/ui/node"
"github.com/ory/kratos/x"
"github.com/ory/x/httpx"
"github.com/ory/x/otelx"
"github.com/ory/x/otelx/semconv"
)
@ -49,6 +50,7 @@ type (
x.CSRFTokenGeneratorProvider
x.WriterProvider
x.LoggingProvider
x.TracingProvider
HooksProvider
}
@ -110,7 +112,12 @@ func (e *HookExecutor) handleLoginError(_ http.ResponseWriter, r *http.Request,
return flowError
}
func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, i *identity.Identity, s *session.Session) error {
func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, i *identity.Identity, s *session.Session) (err error) {
ctx := r.Context()
ctx, span := e.d.Tracer(ctx).Tracer().Start(ctx, "HookExecutor.PostLoginHook")
r = r.WithContext(ctx)
defer otelx.End(span, &err)
if err := s.Activate(r, i, e.d.Config(), time.Now().UTC()); err != nil {
return err
}
@ -128,7 +135,8 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g n
return err
}
s = s.Declassify()
classified := s
s = s.Declassified()
e.d.Logger().
WithRequest(r).
@ -181,7 +189,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g n
)
response := &APIFlowResponse{Session: s, Token: s.Token}
if required, _ := e.requiresAAL2(r, s, a); required {
if required, _ := e.requiresAAL2(r, classified, a); required {
// If AAL is not satisfied, we omit the identity to preserve the user's privacy in case of a phishing attack.
response.Session.Identity = nil
}
@ -240,7 +248,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g n
finalReturnTo = rt
}
x.ContentNegotiationRedirection(w, r, s.Declassify(), e.d.Writer(), finalReturnTo)
x.ContentNegotiationRedirection(w, r, s, e.d.Writer(), finalReturnTo)
return nil
}

View File

@ -230,7 +230,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
finalReturnTo = cr
}
x.ContentNegotiationRedirection(w, r, s.Declassify(), e.d.Writer(), finalReturnTo)
x.ContentNegotiationRedirection(w, r, s.Declassified(), e.d.Writer(), finalReturnTo)
return nil
}

View File

@ -176,7 +176,7 @@ func newHydraIntegration(t *testing.T, remote *string, subject *string, claims *
parsed, err := url.ParseRequestURI(addr)
require.NoError(t, err)
// #nosec G112
//#nosec G112
server := &http.Server{Addr: ":" + parsed.Port(), Handler: router}
go func(t *testing.T) {
if err := server.ListenAndServe(); err != http.ErrServerClosed {

View File

@ -76,8 +76,7 @@ func TestRegistration(t *testing.T) {
apiClient := testhelpers.NewDebugClient(t)
t.Run("AssertCommonErrorCases", func(t *testing.T) {
reg := newRegistrationRegistry(t)
registrationhelpers.AssertCommonErrorCases(t, reg, flows)
registrationhelpers.AssertCommonErrorCases(t, flows)
})
t.Run("AssertRegistrationRespectsValidation", func(t *testing.T) {

View File

@ -10,8 +10,7 @@ import (
"github.com/hashicorp/go-retryablehttp"
/* #nosec G505 sha1 is used for k-anonymity */
"crypto/sha1"
"crypto/sha1" //#nosec G505 -- sha1 is used for k-anonymity
"fmt"
"net/http"
"strconv"
@ -188,7 +187,7 @@ func (s *DefaultPasswordValidator) validate(ctx context.Context, identifier, pas
return nil
}
/* #nosec G401 sha1 is used for k-anonymity */
//#nosec G401 -- sha1 is used for k-anonymity
h := sha1.New()
if _, err := h.Write([]byte(password)); err != nil {
return err

View File

@ -7,7 +7,7 @@ import (
"bytes"
"context"
"crypto/rand"
"crypto/sha1" // #nosec G505 - compatibility for imported passwords
"crypto/sha1" //#nosec G505 -- compatibility for imported passwords
"errors"
"fmt"
"io"
@ -144,7 +144,7 @@ func TestDefaultPasswordValidationStrategy(t *testing.T) {
s.Client = httpx.NewResilientClient(httpx.ResilientClientWithClient(&fakeClient.Client), httpx.ResilientClientWithMaxRetry(1), httpx.ResilientClientWithConnectionTimeout(time.Millisecond))
var hashPw = func(t *testing.T, pw string) string {
/* #nosec G401 sha1 is used for k-anonymity */
//#nosec G401 -- sha1 is used for k-anonymity
h := sha1.New()
_, err := h.Write([]byte(pw))
require.NoError(t, err)

View File

@ -88,8 +88,7 @@ func TestRegistration(t *testing.T) {
//}
t.Run("AssertCommonErrorCases", func(t *testing.T) {
reg := newRegistrationRegistry(t)
registrationhelpers.AssertCommonErrorCases(t, reg, flows)
registrationhelpers.AssertCommonErrorCases(t, flows)
})
t.Run("AssertRegistrationRespectsValidation", func(t *testing.T) {

View File

@ -68,3 +68,7 @@ type Persister interface {
// RevokeSessionsIdentityExcept marks all except the given session of an identity inactive. It returns the number of sessions that were revoked.
RevokeSessionsIdentityExcept(ctx context.Context, iID, sID uuid.UUID) (int, error)
}
type DevicePersister interface {
CreateDevice(ctx context.Context, d *Device) error
}

View File

@ -261,9 +261,9 @@ func (s *Session) SetSessionDeviceInformation(r *http.Request) {
s.Devices = append(s.Devices, device)
}
func (s *Session) Declassify() *Session {
func (s Session) Declassified() *Session {
s.Identity = s.Identity.CopyWithoutCredentials()
return s
return &s
}
func (s *Session) IsActive() bool {

View File

@ -176,7 +176,7 @@ func main() {
})
addr := ":" + osx.GetenvDefault("PORT", "4446")
// #nosec G112
//#nosec G112
server := &http.Server{Addr: addr, Handler: router}
fmt.Printf("Starting web server at %s\n", addr)
check(server.ListenAndServe())

View File

@ -13,8 +13,7 @@ import (
"github.com/stretchr/testify/require"
)
// #nosec G404
var rnd = rand.New(rand.NewSource(time.Now().Unix()))
var rnd = rand.New(rand.NewSource(time.Now().Unix())) //#nosec G404
func AssertEqualTime(t *testing.T, expected, actual time.Time) {
assert.EqualValues(t, expected.UTC().Round(time.Second), actual.UTC().Round(time.Second))