mirror of https://github.com/ory/kratos
refactor: identity persistence (#3101)
This commit is contained in:
parent
ea6ad2a8fe
commit
ceb5cc2b8a
|
|
@ -32,7 +32,7 @@ func NewMigrateHandler() *MigrateHandler {
|
|||
return &MigrateHandler{}
|
||||
}
|
||||
|
||||
func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) error {
|
||||
func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string, opts ...driver.RegistryOption) error {
|
||||
var d driver.Registry
|
||||
var err error
|
||||
|
||||
|
|
@ -78,7 +78,7 @@ func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
}
|
||||
|
||||
err = d.Init(cmd.Context(), &contextx.Default{}, driver.SkipNetworkInit)
|
||||
err = d.Init(cmd.Context(), &contextx.Default{}, append(opts, driver.SkipNetworkInit)...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "an error occurred initializing migrations")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
package courier
|
||||
|
||||
import (
|
||||
cx "context"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
|
@ -22,7 +22,7 @@ import (
|
|||
)
|
||||
|
||||
func NewWatchCmd(slOpts []servicelocatorx.Option, dOpts []driver.RegistryOption) *cobra.Command {
|
||||
var c = &cobra.Command{
|
||||
c := &cobra.Command{
|
||||
Use: "watch",
|
||||
Short: "Starts the Ory Kratos message courier",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
|
|
@ -38,7 +38,7 @@ func NewWatchCmd(slOpts []servicelocatorx.Option, dOpts []driver.RegistryOption)
|
|||
return c
|
||||
}
|
||||
|
||||
func StartCourier(ctx cx.Context, r driver.Registry) error {
|
||||
func StartCourier(ctx context.Context, r driver.Registry) error {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
if r.Config().CourierExposeMetricsPort(ctx) != 0 {
|
||||
|
|
@ -54,7 +54,7 @@ func StartCourier(ctx cx.Context, r driver.Registry) error {
|
|||
return eg.Wait()
|
||||
}
|
||||
|
||||
func ServeMetrics(ctx cx.Context, r driver.Registry) error {
|
||||
func ServeMetrics(ctx context.Context, r driver.Registry) error {
|
||||
c := r.Config()
|
||||
l := r.Logger()
|
||||
n := negroni.New()
|
||||
|
|
@ -73,39 +73,23 @@ func ServeMetrics(ctx cx.Context, r driver.Registry) error {
|
|||
handler = otelx.NewHandler(handler, "cmd.courier.ServeMetrics", otelhttp.WithTracerProvider(tp))
|
||||
}
|
||||
|
||||
// #nosec G112 - the correct settings are set by graceful.WithDefaults
|
||||
//#nosec G112 -- the correct settings are set by graceful.WithDefaults
|
||||
server := graceful.WithDefaults(&http.Server{
|
||||
Addr: c.MetricsListenOn(ctx),
|
||||
Handler: handler,
|
||||
})
|
||||
|
||||
l.Printf("Starting the metrics httpd on: %s", server.Addr)
|
||||
if err := graceful.Graceful(func() error {
|
||||
errChan := make(chan error, 1)
|
||||
go func(errChan chan error) {
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}(errChan)
|
||||
select {
|
||||
case err := <-errChan:
|
||||
l.Errorf("Failed to start the metrics httpd: %s\n", err)
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
l.Printf("Shutting down the metrics httpd: context closed: %s\n", ctx.Err())
|
||||
return server.Shutdown(ctx)
|
||||
}
|
||||
}, server.Shutdown); err != nil {
|
||||
if err := graceful.GracefulContext(ctx, server.ListenAndServe, server.Shutdown); err != nil {
|
||||
l.Errorln("Failed to gracefully shutdown metrics httpd")
|
||||
return err
|
||||
} else {
|
||||
l.Println("Metrics httpd was shutdown gracefully")
|
||||
}
|
||||
l.Println("Metrics httpd was shutdown gracefully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func Watch(ctx cx.Context, r driver.Registry) error {
|
||||
ctx, cancel := cx.WithCancel(ctx)
|
||||
func Watch(ctx context.Context, r driver.Registry) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
r.Logger().Println("Courier worker started.")
|
||||
if err := graceful.Graceful(func() error {
|
||||
|
|
@ -115,7 +99,7 @@ func Watch(ctx cx.Context, r driver.Registry) error {
|
|||
}
|
||||
|
||||
return c.Work(ctx)
|
||||
}, func(_ cx.Context) error {
|
||||
}, func(_ context.Context) error {
|
||||
cancel()
|
||||
return nil
|
||||
}); err != nil {
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
|
|||
handler = otelx.TraceHandler(handler, otelhttp.WithTracerProvider(tracer.Provider()))
|
||||
}
|
||||
|
||||
// #nosec G112 - the correct settings are set by graceful.WithDefaults
|
||||
//#nosec G112 -- the correct settings are set by graceful.WithDefaults
|
||||
server := graceful.WithDefaults(&http.Server{
|
||||
Handler: handler,
|
||||
TLSConfig: &tls.Config{GetCertificate: certs, MinVersion: tls.VersionTLS12},
|
||||
|
|
@ -128,7 +128,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
|
|||
addr := c.PublicListenOn(ctx)
|
||||
|
||||
l.Printf("Starting the public httpd on: %s", addr)
|
||||
if err := graceful.Graceful(func() error {
|
||||
if err := graceful.GracefulContext(ctx, func() error {
|
||||
listener, err := networkx.MakeListener(addr, c.PublicSocketPermission(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -191,7 +191,7 @@ func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *se
|
|||
)
|
||||
}
|
||||
|
||||
// #nosec G112 - the correct settings are set by graceful.WithDefaults
|
||||
//#nosec G112 -- the correct settings are set by graceful.WithDefaults
|
||||
server := graceful.WithDefaults(&http.Server{
|
||||
Handler: handler,
|
||||
TLSConfig: &tls.Config{GetCertificate: certs, MinVersion: tls.VersionTLS12},
|
||||
|
|
@ -200,7 +200,7 @@ func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *se
|
|||
addr := c.AdminListenOn(ctx)
|
||||
|
||||
l.Printf("Starting the admin httpd on: %s", addr)
|
||||
if err := graceful.Graceful(func() error {
|
||||
if err := graceful.GracefulContext(ctx, func() error {
|
||||
listener, err := networkx.MakeListener(addr, c.AdminSocketPermission(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ func runLoadTest(cmd *cobra.Command, conf *argon2Config, reqPerMin int) (*result
|
|||
eg.Go(func(i int) func() error {
|
||||
return func() error {
|
||||
// wait randomly before starting, max. sample time
|
||||
// #nosec G404 - just a timeout to collect statistical data
|
||||
//#nosec G404 -- just a timeout to collect statistical data
|
||||
t := time.Duration(rand.Intn(int(sampleTime)))
|
||||
timer := time.NewTimer(t)
|
||||
defer timer.Stop()
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ Use -w or --write to write output back to files instead of stdout.
|
|||
cmdx.Must(err, `JSONNet file "%s" could not be formatted: %s`, file, err)
|
||||
|
||||
if shouldWrite {
|
||||
err := os.WriteFile(file, []byte(output), 0644) // #nosec
|
||||
err := os.WriteFile(file, []byte(output), 0644) //#nosec
|
||||
cmdx.Must(err, `Could not write to file "%s" because: %s`, file, err)
|
||||
} else {
|
||||
fmt.Println(output)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ory/kratos/cmd/cliclient"
|
||||
"github.com/ory/kratos/driver"
|
||||
"github.com/ory/x/configx"
|
||||
)
|
||||
|
||||
// migrateSqlCmd represents the sql command
|
||||
func NewMigrateSQLCmd() *cobra.Command {
|
||||
func NewMigrateSQLCmd(opts ...driver.RegistryOption) *cobra.Command {
|
||||
c := &cobra.Command{
|
||||
Use: "sql <database-url>",
|
||||
Short: "Create SQL schemas and apply migration plans",
|
||||
|
|
@ -29,7 +30,7 @@ You can read in the database URL using the -e flag, for example:
|
|||
Before running this command on an existing database, create a back up!
|
||||
`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cliclient.NewMigrateHandler().MigrateSQL(cmd, args)
|
||||
return cliclient.NewMigrateHandler().MigrateSQL(cmd, args, opts...)
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -82,13 +82,13 @@ func newSMTP(ctx context.Context, deps Dependencies) (*smtpClient, error) {
|
|||
// Enforcing StartTLS by default for security best practices (config review, etc.)
|
||||
skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls"))
|
||||
if !skipStartTLS {
|
||||
// #nosec G402 This is ok (and required!) because it is configurable and disabled by default.
|
||||
//#nosec G402 -- This is ok (and required!) because it is configurable and disabled by default.
|
||||
dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, Certificates: tlsCertificates, ServerName: serverName}
|
||||
// Enforcing StartTLS
|
||||
dialer.StartTLSPolicy = gomail.MandatoryStartTLS
|
||||
}
|
||||
case "smtps":
|
||||
// #nosec G402 This is ok (and required!) because it is configurable and disabled by default.
|
||||
//#nosec G402 -- This is ok (and required!) because it is configurable and disabled by default.
|
||||
dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, Certificates: tlsCertificates, ServerName: serverName}
|
||||
dialer.SSL = true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -517,8 +517,7 @@ func (m *RegistryDefault) ContinuityCookieManager(ctx context.Context) sessions.
|
|||
|
||||
func (m *RegistryDefault) Tracer(ctx context.Context) *otelx.Tracer {
|
||||
if m.trc == nil {
|
||||
m.Logger().WithError(errors.WithStack(errors.New(""))).Warn("No tracer setup in RegistryDefault")
|
||||
return otelx.NewNoop(m.l, m.Config().Tracing(ctx)) // should never happen
|
||||
return otelx.NewNoop(m.l, m.Config().Tracing(ctx))
|
||||
}
|
||||
return m.trc
|
||||
}
|
||||
|
|
|
|||
5
go.mod
5
go.mod
|
|
@ -61,6 +61,7 @@ require (
|
|||
github.com/jteeuwen/go-bindata v3.0.7+incompatible
|
||||
github.com/julienschmidt/httprouter v1.3.0
|
||||
github.com/knadh/koanf v1.4.4
|
||||
github.com/laher/mergefs v0.1.2-0.20230223191438-d16611b2f4e7
|
||||
github.com/luna-duclos/instrumentedsql v1.1.3
|
||||
github.com/mattn/goveralls v0.0.7
|
||||
github.com/mikefarah/yq/v4 v4.19.1
|
||||
|
|
@ -70,7 +71,7 @@ require (
|
|||
github.com/ory/client-go v0.2.0-alpha.60
|
||||
github.com/ory/dockertest/v3 v3.9.1
|
||||
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe
|
||||
github.com/ory/graceful v0.1.3
|
||||
github.com/ory/graceful v0.1.4-0.20230301144740-e222150c51d0
|
||||
github.com/ory/herodot v0.9.13
|
||||
github.com/ory/hydra-client-go/v2 v2.0.3
|
||||
github.com/ory/jsonschema/v3 v3.0.7
|
||||
|
|
@ -87,7 +88,7 @@ require (
|
|||
github.com/spf13/cobra v1.6.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/sqs/goreturns v0.0.0-20181028201513-538ac6014518
|
||||
github.com/stretchr/testify v1.8.1
|
||||
github.com/stretchr/testify v1.8.2
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/urfave/negroni v1.0.0
|
||||
|
|
|
|||
11
go.sum
11
go.sum
|
|
@ -913,6 +913,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
|||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/go-gypsy v1.0.0/go.mod h1:chkXM0zjdpXOiqkCW1XcCHDfjfk14PH2KKkQWxfJUcU=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/laher/mergefs v0.1.2-0.20230223191438-d16611b2f4e7 h1:PDeBswTUsSIT4QSrzLvlqKlGrANYa7TrXUwdBN9myU8=
|
||||
github.com/laher/mergefs v0.1.2-0.20230223191438-d16611b2f4e7/go.mod h1:FSY1hYy94on4Tz60waRMGdO1awwS23BacqJlqf9lJ9Q=
|
||||
github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
|
||||
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
|
||||
github.com/letsencrypt/pkcs11key/v4 v4.0.0/go.mod h1:EFUvBDay26dErnNb70Nd0/VW3tJiIbETBPTl9ATXQag=
|
||||
|
|
@ -945,6 +947,8 @@ github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJ
|
|||
github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE=
|
||||
github.com/markbates/pkger v0.17.1 h1:/MKEtWqtc0mZvu9OinB9UzVN9iYCwLWuyUv4Bw+PCno=
|
||||
github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0=
|
||||
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
|
||||
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
||||
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
|
|
@ -1093,8 +1097,8 @@ github.com/ory/dockertest/v3 v3.9.1/go.mod h1:42Ir9hmvaAPm0Mgibk6mBPi7SFvTXxEcnz
|
|||
github.com/ory/go-acc v0.2.6/go.mod h1:4Kb/UnPcT8qRAk3IAxta+hvVapdxTLWtrr7bFLlEgpw=
|
||||
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc=
|
||||
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI=
|
||||
github.com/ory/graceful v0.1.3 h1:FaeXcHZh168WzS+bqruqWEw/HgXWLdNv2nJ+fbhxbhc=
|
||||
github.com/ory/graceful v0.1.3/go.mod h1:4zFz687IAF7oNHHiB586U4iL+/4aV09o/PYLE34t2bA=
|
||||
github.com/ory/graceful v0.1.4-0.20230301144740-e222150c51d0 h1:VMUeLRfQD14fOMvhpYZIIT4vtAqxYh+f3KnSqCeJ13o=
|
||||
github.com/ory/graceful v0.1.4-0.20230301144740-e222150c51d0/go.mod h1:hg2iCy+LCWOXahBZ+NQa4dk8J2govyQD79rrqrgMyY8=
|
||||
github.com/ory/herodot v0.9.13 h1:cN/Z4eOkErl/9W7hDIDLb79IO/bfsH+8yscBjRpB4IU=
|
||||
github.com/ory/herodot v0.9.13/go.mod h1:IWDs9kSvFQqw/cQ8zi5ksyYvITiUU4dI7glUrhZcJYo=
|
||||
github.com/ory/hydra-client-go/v2 v2.0.3 h1:jIx968J9RBnjRuaQ21QMLCwZoa28FPvzYWAQ+88XVLw=
|
||||
|
|
@ -1342,8 +1346,9 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
|
||||
github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8=
|
||||
github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ import (
|
|||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5" // #nosec G501
|
||||
"crypto/sha1" // #nosec G505 - compatibility for imported passwords
|
||||
"crypto/md5" //#nosec G501 -- compatibility for imported passwords
|
||||
"crypto/sha1" //#nosec G505 -- compatibility for imported passwords
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/subtle"
|
||||
|
|
@ -220,7 +220,7 @@ func CompareMD5(_ context.Context, password []byte, hash []byte) error {
|
|||
r := strings.NewReplacer("{SALT}", string(salt), "{PASSWORD}", string(password))
|
||||
arg = []byte(r.Replace(string(pf)))
|
||||
}
|
||||
// #nosec G401
|
||||
//#nosec G401 -- compatibility for imported passwords
|
||||
otherHash := md5.Sum(arg)
|
||||
|
||||
// Check that the contents of the hashed passwords are identical. Note
|
||||
|
|
@ -412,7 +412,7 @@ func compareSHAHelper(hasher string, raw []byte, hash []byte) error {
|
|||
|
||||
switch hasher {
|
||||
case "sha1":
|
||||
sum := sha1.Sum(raw) // #nosec G401 - compatibility for imported passwords
|
||||
sum := sha1.Sum(raw) //#nosec G401 -- compatibility for imported passwords
|
||||
sha = sum[:]
|
||||
case "sha256":
|
||||
sum := sha256.Sum256(raw)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha1" // #nosec G505 - compatibility for imported passwords
|
||||
"crypto/sha1" //#nosec G505 -- compatibility for imported passwords
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ func (i *Identity) AfterEagerFind(tx *pop.Connection) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err := i.validate(); err != nil {
|
||||
if err := i.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -378,7 +378,7 @@ func (i WithCredentialsMetadataAndAdminMetadataInJSON) MarshalJSON() ([]byte, er
|
|||
return json.Marshal(localIdentity(i))
|
||||
}
|
||||
|
||||
func (i *Identity) validate() error {
|
||||
func (i *Identity) Validate() error {
|
||||
expected := i.NID
|
||||
if expected == uuid.Nil {
|
||||
return errors.WithStack(herodot.ErrInternalServerError.WithReason("Received empty nid."))
|
||||
|
|
|
|||
|
|
@ -267,7 +267,7 @@ func TestValidateNID(t *testing.T) {
|
|||
},
|
||||
} {
|
||||
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
|
||||
err := tc.i.validate()
|
||||
err := tc.i.Validate()
|
||||
if tc.expectedErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -77,5 +77,8 @@ type (
|
|||
|
||||
// HydrateIdentityAssociations hydrates the associations of an identity.
|
||||
HydrateIdentityAssociations(ctx context.Context, i *Identity, expandables Expandables) error
|
||||
|
||||
// InjectTraitsSchemaURL sets the identity's traits JSON schema URL from the schema's ID.
|
||||
InjectTraitsSchemaURL(ctx context.Context, i *Identity) error
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,9 +41,7 @@ import (
|
|||
"github.com/ory/kratos/x"
|
||||
)
|
||||
|
||||
func TestPool(ctx context.Context, conf *config.Config, p interface {
|
||||
persistence.Persister
|
||||
}, m *identity.Manager) func(t *testing.T) {
|
||||
func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, m *identity.Manager) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
exampleServerURL := urlx.ParseOrPanic("http://example.com")
|
||||
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, exampleServerURL.String())
|
||||
|
|
@ -134,7 +132,6 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
|
|||
assert.Empty(t, actual.RecoveryAddresses)
|
||||
assert.Empty(t, actual.VerifiableAddresses)
|
||||
|
||||
require.Len(t, actual.InternalCredentials, 2)
|
||||
require.Len(t, actual.Credentials, 2)
|
||||
|
||||
assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"})
|
||||
|
|
@ -181,7 +178,6 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
|
|||
t.Run("expand=everything", func(t *testing.T) {
|
||||
runner(t, identity.ExpandEverything, func(t *testing.T, actual *identity.Identity) {
|
||||
|
||||
require.Len(t, actual.InternalCredentials, 2)
|
||||
require.Len(t, actual.Credentials, 2)
|
||||
|
||||
assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"})
|
||||
|
|
@ -199,7 +195,6 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
|
|||
runner(t, identity.ExpandNothing, func(t *testing.T, actual *identity.Identity) {
|
||||
require.NoError(t, p.HydrateIdentityAssociations(ctx, actual, identity.ExpandEverything))
|
||||
|
||||
require.Len(t, actual.InternalCredentials, 2)
|
||||
require.Len(t, actual.Credentials, 2)
|
||||
|
||||
assertx.EqualAsJSONExcept(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword], []string{"updated_at", "created_at"})
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ func AssertRegistrationRespectsValidation(t *testing.T, reg *driver.RegistryDefa
|
|||
})
|
||||
}
|
||||
|
||||
func AssertCommonErrorCases(t *testing.T, reg *driver.RegistryDefault, flows []string) {
|
||||
func AssertCommonErrorCases(t *testing.T, flows []string) {
|
||||
ctx := context.Background()
|
||||
conf, reg := internal.NewFastRegistryWithMocks(t)
|
||||
testhelpers.SetDefaultIdentitySchemaFromRaw(conf, basicSchema)
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ func CheckE2EServerOnHTTP(t *testing.T, publicPort, adminPort int) (publicUrl, a
|
|||
|
||||
func waitToComeAlive(t *testing.T, publicUrl, adminUrl string) {
|
||||
require.NoError(t, retry.Do(func() error {
|
||||
/* #nosec G402: TLS InsecureSkipVerify set true. */
|
||||
//#nosec G402 -- TLS InsecureSkipVerify set true
|
||||
tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
client := &http.Client{Transport: tr}
|
||||
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ func NewSettingsAPIServer(t *testing.T, reg *driver.RegistryDefault, ids map[str
|
|||
|
||||
reg.Config().MustSet(ctx, config.ViperKeyPublicBaseURL, tsp.URL)
|
||||
reg.Config().MustSet(ctx, config.ViperKeyAdminBaseURL, tsa.URL)
|
||||
// #nosec G112
|
||||
//#nosec G112
|
||||
return tsp, tsa, AddAndLoginIdentities(t, reg, &httptest.Server{Config: &http.Server{Handler: public}, URL: tsp.URL}, ids)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ type Persister interface {
|
|||
MigrateDown(c context.Context, steps int) error
|
||||
MigrateUp(c context.Context) error
|
||||
Migrator() *popx.Migrator
|
||||
MigrationBox() *popx.MigrationBox
|
||||
GetConnection(ctx context.Context) *pop.Connection
|
||||
Transaction(ctx context.Context, callback func(ctx context.Context, connection *pop.Connection) error) error
|
||||
Networker
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright © 2023 Ory Corp
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package devices
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/gofrs/uuid"
|
||||
|
||||
"github.com/ory/kratos/session"
|
||||
"github.com/ory/x/contextx"
|
||||
"github.com/ory/x/popx"
|
||||
"github.com/ory/x/sqlcon"
|
||||
)
|
||||
|
||||
var _ session.DevicePersister = (*DevicePersister)(nil)
|
||||
|
||||
type DevicePersister struct {
|
||||
ctxer contextx.Provider
|
||||
c *pop.Connection
|
||||
nid uuid.UUID
|
||||
}
|
||||
|
||||
func NewPersister(r contextx.Provider, c *pop.Connection) *DevicePersister {
|
||||
return &DevicePersister{
|
||||
ctxer: r,
|
||||
c: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DevicePersister) NetworkID(ctx context.Context) uuid.UUID {
|
||||
return p.ctxer.Contextualizer().Network(ctx, p.nid)
|
||||
}
|
||||
|
||||
func (p DevicePersister) WithNetworkID(nid uuid.UUID) session.DevicePersister {
|
||||
p.nid = nid
|
||||
return &p
|
||||
}
|
||||
|
||||
func (p *DevicePersister) CreateDevice(ctx context.Context, d *session.Device) error {
|
||||
d.NID = p.NetworkID(ctx)
|
||||
return sqlcon.HandleError(popx.GetConnection(ctx, p.c.WithContext(ctx)).Create(d))
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright © 2023 Ory Corp
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package sql
|
||||
package identity
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
|
@ -10,7 +10,9 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ory/x/contextx"
|
||||
"github.com/ory/x/pointerx"
|
||||
"github.com/ory/x/popx"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
|
|
@ -21,7 +23,11 @@ import (
|
|||
"github.com/ory/jsonschema/v3"
|
||||
"github.com/ory/x/sqlxx"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/otp"
|
||||
"github.com/ory/kratos/persistence/sql/update"
|
||||
"github.com/ory/kratos/schema"
|
||||
"github.com/ory/kratos/x"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
|
|
@ -31,40 +37,81 @@ import (
|
|||
"github.com/ory/herodot"
|
||||
"github.com/ory/x/errorsx"
|
||||
"github.com/ory/x/sqlcon"
|
||||
|
||||
"github.com/ory/kratos/identity"
|
||||
)
|
||||
|
||||
var _ identity.Pool = new(Persister)
|
||||
var _ identity.PrivilegedPool = new(Persister)
|
||||
var _ identity.Pool = new(IdentityPersister)
|
||||
var _ identity.PrivilegedPool = new(IdentityPersister)
|
||||
|
||||
func (p *Persister) ListVerifiableAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.VerifiableAddress, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListVerifiableAddresses")
|
||||
defer span.End()
|
||||
|
||||
if err := p.GetConnection(ctx).Where("nid = ?", p.NetworkID(ctx)).Order("id DESC").Paginate(page, x.MaxItemsPerPage(itemsPerPage)).All(&a); err != nil {
|
||||
return nil, sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
return a, err
|
||||
type dependencies interface {
|
||||
schema.IdentityTraitsProvider
|
||||
identity.ValidationProvider
|
||||
x.LoggingProvider
|
||||
config.Provider
|
||||
contextx.Provider
|
||||
x.TracingProvider
|
||||
}
|
||||
|
||||
func (p *Persister) ListRecoveryAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.RecoveryAddress, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListRecoveryAddresses")
|
||||
defer span.End()
|
||||
type IdentityPersister struct {
|
||||
r dependencies
|
||||
c *pop.Connection
|
||||
nid uuid.UUID
|
||||
}
|
||||
|
||||
func NewPersister(r dependencies, c *pop.Connection) *IdentityPersister {
|
||||
return &IdentityPersister{
|
||||
c: c,
|
||||
r: r,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *IdentityPersister) NetworkID(ctx context.Context) uuid.UUID {
|
||||
return p.r.Contextualizer().Network(ctx, p.nid)
|
||||
}
|
||||
|
||||
func (p IdentityPersister) WithNetworkID(nid uuid.UUID) identity.PrivilegedPool {
|
||||
p.nid = nid
|
||||
return &p
|
||||
}
|
||||
|
||||
func WithTransaction(ctx context.Context, tx *pop.Connection) context.Context {
|
||||
return popx.WithTransaction(ctx, tx)
|
||||
}
|
||||
|
||||
func (p *IdentityPersister) Transaction(ctx context.Context, callback func(ctx context.Context, connection *pop.Connection) error) error {
|
||||
return popx.Transaction(ctx, p.c.WithContext(ctx), callback)
|
||||
}
|
||||
|
||||
func (p *IdentityPersister) GetConnection(ctx context.Context) *pop.Connection {
|
||||
return popx.GetConnection(ctx, p.c.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (p *IdentityPersister) ListVerifiableAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.VerifiableAddress, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListVerifiableAddresses")
|
||||
otelx.End(span, &err)
|
||||
|
||||
if err := p.GetConnection(ctx).Where("nid = ?", p.NetworkID(ctx)).Order("id DESC").Paginate(page, x.MaxItemsPerPage(itemsPerPage)).All(&a); err != nil {
|
||||
return nil, sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
return a, err
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (p *IdentityPersister) ListRecoveryAddresses(ctx context.Context, page, itemsPerPage int) (a []identity.RecoveryAddress, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListRecoveryAddresses")
|
||||
otelx.End(span, &err)
|
||||
|
||||
if err := p.GetConnection(ctx).Where("nid = ?", p.NetworkID(ctx)).Order("id DESC").Paginate(page, x.MaxItemsPerPage(itemsPerPage)).All(&a); err != nil {
|
||||
return nil, sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func stringToLowerTrim(match string) string {
|
||||
return strings.ToLower(strings.TrimSpace(match))
|
||||
}
|
||||
|
||||
func (p *Persister) normalizeIdentifier(ct identity.CredentialsType, match string) string {
|
||||
func NormalizeIdentifier(ct identity.CredentialsType, match string) string {
|
||||
switch ct {
|
||||
case identity.CredentialsTypeLookup:
|
||||
// lookup credentials are case-sensitive
|
||||
|
|
@ -83,9 +130,9 @@ func (p *Persister) normalizeIdentifier(ct identity.CredentialsType, match strin
|
|||
return match
|
||||
}
|
||||
|
||||
func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, ct identity.CredentialsType, match string) (*identity.Identity, *identity.Credentials, error) {
|
||||
func (p *IdentityPersister) FindByCredentialsIdentifier(ctx context.Context, ct identity.CredentialsType, match string) (_ *identity.Identity, _ *identity.Credentials, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindByCredentialsIdentifier")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
nid := p.NetworkID(ctx)
|
||||
|
||||
|
|
@ -94,22 +141,21 @@ func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, ct identity
|
|||
}
|
||||
|
||||
// Force case-insensitivity and trimming for identifiers
|
||||
match = p.normalizeIdentifier(ct, match)
|
||||
match = NormalizeIdentifier(ct, match)
|
||||
|
||||
// #nosec G201
|
||||
if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT
|
||||
ic.identity_id
|
||||
FROM %s ic
|
||||
INNER JOIN %s ict on ic.identity_credential_type_id = ict.id
|
||||
INNER JOIN %s ici on ic.id = ici.identity_credential_id AND ici.identity_credential_type_id = ict.id
|
||||
WHERE ici.identifier = ?
|
||||
AND ic.nid = ?
|
||||
AND ici.nid = ?
|
||||
AND ict.name = ?`,
|
||||
"identity_credentials",
|
||||
"identity_credential_types",
|
||||
"identity_credential_identifiers",
|
||||
),
|
||||
if err := p.GetConnection(ctx).RawQuery(`
|
||||
SELECT
|
||||
ic.identity_id
|
||||
FROM identity_credentials ic
|
||||
INNER JOIN identity_credential_types ict
|
||||
ON ic.identity_credential_type_id = ict.id
|
||||
INNER JOIN identity_credential_identifiers ici
|
||||
ON ic.id = ici.identity_credential_id AND ici.identity_credential_type_id = ict.id
|
||||
WHERE ici.identifier = ?
|
||||
AND ic.nid = ?
|
||||
AND ici.nid = ?
|
||||
AND ict.name = ?
|
||||
LIMIT 1`, // pop doesn't understand how to add a limit clause to this query
|
||||
match,
|
||||
nid,
|
||||
nid,
|
||||
|
|
@ -135,9 +181,9 @@ WHERE ici.identifier = ?
|
|||
return i.CopyWithoutCredentials(), creds, nil
|
||||
}
|
||||
|
||||
func (p *Persister) findIdentityCredentialsType(ctx context.Context, ct identity.CredentialsType) (*identity.CredentialsTypeTable, error) {
|
||||
func (p *IdentityPersister) findIdentityCredentialsType(ctx context.Context, ct identity.CredentialsType) (_ *identity.CredentialsTypeTable, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findIdentityCredentialsType")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
var m identity.CredentialsTypeTable
|
||||
if err := p.GetConnection(ctx).Where("name = ?", ct).First(&m); err != nil {
|
||||
|
|
@ -146,9 +192,9 @@ func (p *Persister) findIdentityCredentialsType(ctx context.Context, ct identity
|
|||
return &m, nil
|
||||
}
|
||||
|
||||
func (p *Persister) createIdentityCredentials(ctx context.Context, i *identity.Identity) error {
|
||||
func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, i *identity.Identity) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createIdentityCredentials")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
c := p.GetConnection(ctx)
|
||||
|
||||
|
|
@ -174,7 +220,7 @@ func (p *Persister) createIdentityCredentials(ctx context.Context, i *identity.I
|
|||
|
||||
for _, ids := range cred.Identifiers {
|
||||
// Force case-insensitivity and trimming for identifiers
|
||||
ids = p.normalizeIdentifier(cred.Type, ids)
|
||||
ids = NormalizeIdentifier(cred.Type, ids)
|
||||
|
||||
if len(ids) == 0 {
|
||||
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to create identity credentials with missing or empty identifier."))
|
||||
|
|
@ -196,9 +242,9 @@ func (p *Persister) createIdentityCredentials(ctx context.Context, i *identity.I
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.Identity) error {
|
||||
func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, i *identity.Identity) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createVerifiableAddresses")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
for k := range i.VerifiableAddresses {
|
||||
if err := p.GetConnection(ctx).Create(&i.VerifiableAddresses[k]); err != nil {
|
||||
|
|
@ -210,7 +256,10 @@ func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.I
|
|||
|
||||
func updateAssociation[T interface {
|
||||
Hash() string
|
||||
}](ctx context.Context, p *Persister, i *identity.Identity, inID []T) error {
|
||||
}](ctx context.Context, p *IdentityPersister, i *identity.Identity, inID []T) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.updateAssociation")
|
||||
otelx.End(span, &err)
|
||||
|
||||
var inDB []T
|
||||
if err := p.GetConnection(ctx).
|
||||
Where("identity_id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).
|
||||
|
|
@ -252,12 +301,12 @@ func updateAssociation[T interface {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Persister) normalizeAllAddressess(ctx context.Context, id *identity.Identity) {
|
||||
func (p *IdentityPersister) normalizeAllAddressess(ctx context.Context, id *identity.Identity) {
|
||||
p.normalizeRecoveryAddresses(ctx, id)
|
||||
p.normalizeVerifiableAddresses(ctx, id)
|
||||
}
|
||||
|
||||
func (p *Persister) normalizeVerifiableAddresses(ctx context.Context, id *identity.Identity) {
|
||||
func (p *IdentityPersister) normalizeVerifiableAddresses(ctx context.Context, id *identity.Identity) {
|
||||
for k := range id.VerifiableAddresses {
|
||||
v := id.VerifiableAddresses[k]
|
||||
|
||||
|
|
@ -282,7 +331,7 @@ func (p *Persister) normalizeVerifiableAddresses(ctx context.Context, id *identi
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Persister) normalizeRecoveryAddresses(ctx context.Context, id *identity.Identity) {
|
||||
func (p *IdentityPersister) normalizeRecoveryAddresses(ctx context.Context, id *identity.Identity) {
|
||||
for k := range id.RecoveryAddresses {
|
||||
id.RecoveryAddresses[k].IdentityID = id.ID
|
||||
id.RecoveryAddresses[k].NID = p.NetworkID(ctx)
|
||||
|
|
@ -291,9 +340,9 @@ func (p *Persister) normalizeRecoveryAddresses(ctx context.Context, id *identity
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Identity) error {
|
||||
func (p *IdentityPersister) createRecoveryAddresses(ctx context.Context, i *identity.Identity) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createRecoveryAddresses")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
for k := range i.RecoveryAddresses {
|
||||
if err := p.GetConnection(ctx).Create(&i.RecoveryAddresses[k]); err != nil {
|
||||
|
|
@ -303,9 +352,9 @@ func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Ide
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Persister) CountIdentities(ctx context.Context) (int64, error) {
|
||||
func (p *IdentityPersister) CountIdentities(ctx context.Context) (n int64, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CountIdentities")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
count, err := p.c.WithContext(ctx).Where("nid = ?", p.NetworkID(ctx)).Count(new(identity.Identity))
|
||||
if err != nil {
|
||||
|
|
@ -314,9 +363,9 @@ func (p *Persister) CountIdentities(ctx context.Context) (int64, error) {
|
|||
return int64(count), nil
|
||||
}
|
||||
|
||||
func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) error {
|
||||
func (p *IdentityPersister) CreateIdentity(ctx context.Context, i *identity.Identity) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateIdentity")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
i.NID = p.NetworkID(ctx)
|
||||
|
||||
|
|
@ -334,7 +383,7 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er
|
|||
i.Traits = identity.Traits("{}")
|
||||
}
|
||||
|
||||
if err := p.injectTraitsSchemaURL(ctx, i); err != nil {
|
||||
if err := p.InjectTraitsSchemaURL(ctx, i); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -361,156 +410,18 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er
|
|||
})
|
||||
}
|
||||
|
||||
func (p *Persister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {
|
||||
func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HydrateIdentityAssociations")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
con := p.GetConnection(ctx)
|
||||
if err := con.Load(i, expand.ToEager()...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := i.AfterEagerFind(con); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.injectTraitsSchemaURL(ctx, i)
|
||||
}
|
||||
|
||||
func (p *Persister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (res []identity.Identity, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListIdentities")
|
||||
defer otelx.End(span, &err)
|
||||
span.SetAttributes(
|
||||
attribute.Int("page", params.Page),
|
||||
attribute.Int("per_page", params.PerPage),
|
||||
attribute.StringSlice("expand", params.Expand.ToEager()),
|
||||
attribute.Bool("use:credential_identifier_filter", params.CredentialsIdentifier != ""),
|
||||
attribute.String("network.id", p.NetworkID(ctx).String()),
|
||||
)
|
||||
|
||||
is := make([]identity.Identity, 0)
|
||||
|
||||
con := p.GetConnection(ctx)
|
||||
nid := p.NetworkID(ctx)
|
||||
query := con.Where("identities.nid = ?", nid).Order("identities.id DESC")
|
||||
|
||||
if len(params.Expand) > 0 {
|
||||
query = query.EagerPreload(params.Expand.ToEager()...)
|
||||
}
|
||||
|
||||
if match := params.CredentialsIdentifier; len(match) > 0 {
|
||||
// When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore
|
||||
// important to normalize the identifier before querying the database.
|
||||
match = p.normalizeIdentifier(identity.CredentialsTypePassword, match)
|
||||
query = query.
|
||||
InnerJoin("identity_credentials ic", "ic.identity_id = identities.id").
|
||||
InnerJoin("identity_credential_types ict", "ict.id = ic.identity_credential_type_id").
|
||||
InnerJoin("identity_credential_identifiers ici", "ici.identity_credential_id = ic.id").
|
||||
Where("(ic.nid = ? AND ici.nid = ? AND ici.identifier = ?)", nid, nid, match).
|
||||
Where("ict.name IN (?)", identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword).
|
||||
Limit(1)
|
||||
} else {
|
||||
query = query.Paginate(params.Page, params.PerPage)
|
||||
}
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
if err := sqlcon.HandleError(query.All(&is)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
schemaCache := map[string]string{}
|
||||
for k := range is {
|
||||
i := &is[k]
|
||||
|
||||
if u, ok := schemaCache[i.SchemaID]; ok {
|
||||
i.SchemaURL = u
|
||||
} else {
|
||||
if err := p.injectTraitsSchemaURL(ctx, i); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
schemaCache[i.SchemaID] = i.SchemaURL
|
||||
}
|
||||
|
||||
is[k] = *i
|
||||
}
|
||||
|
||||
return is, nil
|
||||
}
|
||||
|
||||
func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) error {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateIdentity")
|
||||
defer span.End()
|
||||
|
||||
if err := p.validateIdentity(ctx, i); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.NID = p.NetworkID(ctx)
|
||||
return sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
|
||||
if count, err := tx.Where("id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).Count(i); err != nil {
|
||||
return err
|
||||
} else if count == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
p.normalizeAllAddressess(ctx, i)
|
||||
if err := updateAssociation(ctx, p, i, i.RecoveryAddresses); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := updateAssociation(ctx, p, i, i.VerifiableAddresses); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
if err := tx.RawQuery(
|
||||
fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE identity_id = ? AND nid = ?`,
|
||||
new(identity.Credentials).TableName(ctx)),
|
||||
i.ID, p.NetworkID(ctx)).Exec(); err != nil {
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
if err := p.update(WithTransaction(ctx, tx), i); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.createIdentityCredentials(ctx, i)
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *Persister) DeleteIdentity(ctx context.Context, id uuid.UUID) error {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteIdentity")
|
||||
defer span.End()
|
||||
|
||||
return p.delete(ctx, new(identity.Identity), id)
|
||||
}
|
||||
|
||||
func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID, expand identity.Expandables) (res *identity.Identity, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentity")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
span.SetAttributes(
|
||||
attribute.String("identity.id", id.String()),
|
||||
attribute.StringSlice("expand", expand.ToEager()),
|
||||
attribute.String("network.id", p.NetworkID(ctx).String()),
|
||||
)
|
||||
|
||||
nid := p.NetworkID(ctx)
|
||||
con := p.GetConnection(ctx)
|
||||
query := con.Where("id = ? AND nid = ?", id, nid)
|
||||
|
||||
var (
|
||||
i identity.Identity
|
||||
con = p.GetConnection(ctx)
|
||||
nid = p.NetworkID(ctx)
|
||||
credentials []identity.Credentials
|
||||
verifiableAddresses []identity.VerifiableAddress
|
||||
recoveryAddresses []identity.RecoveryAddress
|
||||
)
|
||||
|
||||
if err := query.First(&i); err != nil {
|
||||
return nil, sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
if expand.Has(identity.ExpandFieldRecoveryAddresses) {
|
||||
eg.Go(func() error {
|
||||
|
|
@ -560,7 +471,7 @@ func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID, expand identi
|
|||
}
|
||||
|
||||
if err := eg.Wait(); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
i.VerifiableAddresses = verifiableAddresses
|
||||
|
|
@ -568,26 +479,164 @@ func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID, expand identi
|
|||
i.InternalCredentials = credentials
|
||||
|
||||
if err := i.AfterEagerFind(con); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.InjectTraitsSchemaURL(ctx, i)
|
||||
}
|
||||
|
||||
func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (res []identity.Identity, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListIdentities")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
span.SetAttributes(
|
||||
attribute.Int("page", params.Page),
|
||||
attribute.Int("per_page", params.PerPage),
|
||||
attribute.StringSlice("expand", params.Expand.ToEager()),
|
||||
attribute.Bool("use:credential_identifier_filter", params.CredentialsIdentifier != ""),
|
||||
attribute.String("network.id", p.NetworkID(ctx).String()),
|
||||
)
|
||||
|
||||
is := make([]identity.Identity, 0)
|
||||
|
||||
con := p.GetConnection(ctx)
|
||||
nid := p.NetworkID(ctx)
|
||||
query := con.Where("identities.nid = ?", nid).Order("identities.id DESC")
|
||||
|
||||
if len(params.Expand) > 0 {
|
||||
query = query.EagerPreload(params.Expand.ToEager()...)
|
||||
}
|
||||
|
||||
if match := params.CredentialsIdentifier; len(match) > 0 {
|
||||
// When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore
|
||||
// important to normalize the identifier before querying the database.
|
||||
match = NormalizeIdentifier(identity.CredentialsTypePassword, match)
|
||||
query = query.
|
||||
InnerJoin("identity_credentials ic", "ic.identity_id = identities.id").
|
||||
InnerJoin("identity_credential_types ict", "ict.id = ic.identity_credential_type_id").
|
||||
InnerJoin("identity_credential_identifiers ici", "ici.identity_credential_id = ic.id").
|
||||
Where("(ic.nid = ? AND ici.nid = ? AND ici.identifier = ?)", nid, nid, match).
|
||||
Where("ict.name IN (?)", identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword).
|
||||
Limit(1)
|
||||
} else {
|
||||
query = query.Paginate(params.Page, params.PerPage)
|
||||
}
|
||||
|
||||
if err := sqlcon.HandleError(query.All(&is)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.injectTraitsSchemaURL(ctx, &i); err != nil {
|
||||
schemaCache := map[string]string{}
|
||||
for k := range is {
|
||||
i := &is[k]
|
||||
|
||||
if u, ok := schemaCache[i.SchemaID]; ok {
|
||||
i.SchemaURL = u
|
||||
} else {
|
||||
if err := p.InjectTraitsSchemaURL(ctx, i); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
schemaCache[i.SchemaID] = i.SchemaURL
|
||||
}
|
||||
|
||||
is[k] = *i
|
||||
}
|
||||
|
||||
return is, nil
|
||||
}
|
||||
|
||||
func (p *IdentityPersister) UpdateIdentity(ctx context.Context, i *identity.Identity) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateIdentity")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
if err := p.validateIdentity(ctx, i); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.NID = p.NetworkID(ctx)
|
||||
return sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
|
||||
if count, err := tx.Where("id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).Count(i); err != nil {
|
||||
return err
|
||||
} else if count == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
p.normalizeAllAddressess(ctx, i)
|
||||
if err := updateAssociation(ctx, p, i, i.RecoveryAddresses); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := updateAssociation(ctx, p, i, i.VerifiableAddresses); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//#nosec G201 -- TableName is static
|
||||
if err := tx.RawQuery(
|
||||
fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE identity_id = ? AND nid = ?`,
|
||||
new(identity.Credentials).TableName(ctx)),
|
||||
i.ID, p.NetworkID(ctx)).Exec(); err != nil {
|
||||
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
if err := update.Generic(WithTransaction(ctx, tx), tx, p.r.Tracer(ctx).Tracer(), i); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.createIdentityCredentials(ctx, i)
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *IdentityPersister) DeleteIdentity(ctx context.Context, id uuid.UUID) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteIdentity")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
nid := p.NetworkID(ctx)
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", new(identity.Identity).TableName(ctx)),
|
||||
id,
|
||||
nid,
|
||||
).ExecWithCount()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
if count == 0 {
|
||||
return errors.WithStack(sqlcon.ErrNoRows)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *IdentityPersister) GetIdentity(ctx context.Context, id uuid.UUID, expand identity.Expandables) (_ *identity.Identity, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentity")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
span.SetAttributes(
|
||||
attribute.String("identity.id", id.String()),
|
||||
attribute.StringSlice("expand", expand.ToEager()),
|
||||
attribute.String("network.id", p.NetworkID(ctx).String()),
|
||||
)
|
||||
|
||||
var i identity.Identity
|
||||
if err := p.GetConnection(ctx).Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).First(&i); err != nil {
|
||||
return nil, sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
if err := p.HydrateIdentityAssociations(ctx, &i, expand); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &i, nil
|
||||
}
|
||||
|
||||
func (p *Persister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (res *identity.Identity, err error) {
|
||||
func (p *IdentityPersister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (res *identity.Identity, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetIdentityConfidential")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
return p.GetIdentity(ctx, id, identity.ExpandEverything)
|
||||
}
|
||||
|
||||
func (p *Persister) FindVerifiableAddressByValue(ctx context.Context, via identity.VerifiableAddressType, value string) (*identity.VerifiableAddress, error) {
|
||||
func (p *IdentityPersister) FindVerifiableAddressByValue(ctx context.Context, via identity.VerifiableAddressType, value string) (_ *identity.VerifiableAddress, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindVerifiableAddressByValue")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
var address identity.VerifiableAddress
|
||||
if err := p.GetConnection(ctx).Where("nid = ? AND via = ? AND value = ?", p.NetworkID(ctx), via, stringToLowerTrim(value)).First(&address); err != nil {
|
||||
|
|
@ -597,9 +646,9 @@ func (p *Persister) FindVerifiableAddressByValue(ctx context.Context, via identi
|
|||
return &address, nil
|
||||
}
|
||||
|
||||
func (p *Persister) FindRecoveryAddressByValue(ctx context.Context, via identity.RecoveryAddressType, value string) (*identity.RecoveryAddress, error) {
|
||||
func (p *IdentityPersister) FindRecoveryAddressByValue(ctx context.Context, via identity.RecoveryAddressType, value string) (_ *identity.RecoveryAddress, err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindRecoveryAddressByValue")
|
||||
defer span.End()
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
var address identity.RecoveryAddress
|
||||
if err := p.GetConnection(ctx).Where("nid = ? AND via = ? AND value = ?", p.NetworkID(ctx), via, stringToLowerTrim(value)).First(&address); err != nil {
|
||||
|
|
@ -609,16 +658,17 @@ func (p *Persister) FindRecoveryAddressByValue(ctx context.Context, via identity
|
|||
return &address, nil
|
||||
}
|
||||
|
||||
func (p *Persister) VerifyAddress(ctx context.Context, code string) error {
|
||||
func (p *IdentityPersister) VerifyAddress(ctx context.Context, code string) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAddress")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
newCode, err := otp.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := p.GetConnection(ctx).RawQuery(
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
fmt.Sprintf(
|
||||
"UPDATE %s SET status = ?, verified = true, verified_at = ?, code = ? WHERE nid = ? AND code = ? AND expires_at > ?",
|
||||
new(identity.VerifiableAddress).TableName(ctx),
|
||||
|
|
@ -641,18 +691,18 @@ func (p *Persister) VerifyAddress(ctx context.Context, code string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Persister) UpdateVerifiableAddress(ctx context.Context, address *identity.VerifiableAddress) error {
|
||||
func (p *IdentityPersister) UpdateVerifiableAddress(ctx context.Context, address *identity.VerifiableAddress) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateVerifiableAddress")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
address.NID = p.NetworkID(ctx)
|
||||
address.Value = stringToLowerTrim(address.Value)
|
||||
return p.update(ctx, address)
|
||||
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), address)
|
||||
}
|
||||
|
||||
func (p *Persister) validateIdentity(ctx context.Context, i *identity.Identity) error {
|
||||
func (p *IdentityPersister) validateIdentity(ctx context.Context, i *identity.Identity) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.validateIdentity")
|
||||
defer span.End()
|
||||
otelx.End(span, &err)
|
||||
|
||||
if err := p.r.IdentityValidator().ValidateWithRunner(ctx, i); err != nil {
|
||||
if _, ok := errorsx.Cause(err).(*jsonschema.ValidationError); ok {
|
||||
|
|
@ -664,9 +714,9 @@ func (p *Persister) validateIdentity(ctx context.Context, i *identity.Identity)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Persister) injectTraitsSchemaURL(ctx context.Context, i *identity.Identity) error {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.injectTraitsSchemaURL")
|
||||
defer span.End()
|
||||
func (p *IdentityPersister) InjectTraitsSchemaURL(ctx context.Context, i *identity.Identity) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InjectTraitsSchemaURL")
|
||||
otelx.End(span, &err)
|
||||
|
||||
ss, err := p.r.IdentityTraitsSchemas(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -15,8 +15,6 @@ import (
|
|||
|
||||
"github.com/ory/x/servicelocatorx"
|
||||
|
||||
"github.com/ory/x/fsx"
|
||||
|
||||
"github.com/ory/kratos/identity"
|
||||
|
||||
"github.com/bradleyjkemp/cupaloy/v2"
|
||||
|
|
@ -124,7 +122,7 @@ func TestMigrations(t *testing.T) {
|
|||
|
||||
t.Run("suite=up", func(t *testing.T) {
|
||||
tm, err := popx.NewMigrationBox(
|
||||
fsx.Merge(os.DirFS("../migrations/sql")),
|
||||
os.DirFS("../migrations/sql"),
|
||||
popx.NewMigrator(c, logrusx.New("", "", logrusx.ForceLevel(logrus.DebugLevel)), nil, 1*time.Minute),
|
||||
popx.WithTestdata(t, os.DirFS("./testdata")),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,28 +6,24 @@ package sql
|
|||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ory/x/contextx"
|
||||
|
||||
"github.com/ory/x/fsx"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/gobuffalo/pop/v6/columns"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/laher/mergefs"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/x/networkx"
|
||||
"github.com/ory/x/sqlcon"
|
||||
|
||||
"github.com/ory/x/popx"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/persistence"
|
||||
"github.com/ory/kratos/persistence/sql/devices"
|
||||
idpersistence "github.com/ory/kratos/persistence/sql/identity"
|
||||
"github.com/ory/kratos/schema"
|
||||
"github.com/ory/kratos/session"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/x/contextx"
|
||||
"github.com/ory/x/networkx"
|
||||
"github.com/ory/x/popx"
|
||||
)
|
||||
|
||||
var _ persistence.Persister = new(Persister)
|
||||
|
|
@ -37,34 +33,40 @@ var migrations embed.FS
|
|||
|
||||
type (
|
||||
persisterDependencies interface {
|
||||
schema.IdentityTraitsProvider
|
||||
identity.ValidationProvider
|
||||
x.LoggingProvider
|
||||
config.Provider
|
||||
contextx.Provider
|
||||
x.TracingProvider
|
||||
schema.IdentityTraitsProvider
|
||||
identity.ValidationProvider
|
||||
}
|
||||
Persister struct {
|
||||
nid uuid.UUID
|
||||
c *pop.Connection
|
||||
mb *popx.MigrationBox
|
||||
mbs popx.MigrationStatuses
|
||||
r persisterDependencies
|
||||
p *networkx.Manager
|
||||
isSQLite bool
|
||||
nid uuid.UUID
|
||||
c *pop.Connection
|
||||
mb *popx.MigrationBox
|
||||
mbs popx.MigrationStatuses
|
||||
r persisterDependencies
|
||||
p *networkx.Manager
|
||||
|
||||
identity.PrivilegedPool
|
||||
session.DevicePersister
|
||||
}
|
||||
)
|
||||
|
||||
func NewPersister(ctx context.Context, r persisterDependencies, c *pop.Connection) (*Persister, error) {
|
||||
m, err := popx.NewMigrationBox(fsx.Merge(migrations, networkx.Migrations), popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0))
|
||||
m, err := popx.NewMigrationBox(mergefs.Merge(migrations, networkx.Migrations), popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.DumpMigrations = false
|
||||
|
||||
return &Persister{
|
||||
c: c, mb: m, r: r, isSQLite: c.Dialect.Name() == "sqlite3",
|
||||
p: networkx.NewManager(c, r.Logger(), r.Tracer(ctx)),
|
||||
c: c,
|
||||
mb: m,
|
||||
r: r,
|
||||
PrivilegedPool: idpersistence.NewPersister(r, c),
|
||||
DevicePersister: devices.NewPersister(r, c),
|
||||
p: networkx.NewManager(c, r.Logger(), r.Tracer(ctx)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -72,8 +74,18 @@ func (p *Persister) NetworkID(ctx context.Context) uuid.UUID {
|
|||
return p.r.Contextualizer().Network(ctx, p.nid)
|
||||
}
|
||||
|
||||
func (p Persister) WithNetworkID(sid uuid.UUID) persistence.Persister {
|
||||
p.nid = sid
|
||||
func (p Persister) WithNetworkID(nid uuid.UUID) persistence.Persister {
|
||||
p.nid = nid
|
||||
if pp, ok := p.PrivilegedPool.(interface {
|
||||
WithNetworkID(uuid.UUID) identity.PrivilegedPool
|
||||
}); ok {
|
||||
p.PrivilegedPool = pp.WithNetworkID(nid)
|
||||
}
|
||||
if dp, ok := p.DevicePersister.(interface {
|
||||
WithNetworkID(uuid.UUID) session.DevicePersister
|
||||
}); ok {
|
||||
p.DevicePersister = dp.WithNetworkID(nid)
|
||||
}
|
||||
return &p
|
||||
}
|
||||
|
||||
|
|
@ -113,6 +125,10 @@ func (p *Persister) MigrateUp(ctx context.Context) error {
|
|||
return p.mb.Up(ctx)
|
||||
}
|
||||
|
||||
func (p *Persister) MigrationBox() *popx.MigrationBox {
|
||||
return p.mb
|
||||
}
|
||||
|
||||
func (p *Persister) Migrator() *popx.Migrator {
|
||||
return p.mb.Migrator
|
||||
}
|
||||
|
|
@ -130,15 +146,6 @@ func (p *Persister) Ping() error {
|
|||
return errors.WithStack(p.c.Store.(pinger).Ping())
|
||||
}
|
||||
|
||||
type quotable interface {
|
||||
Quote(key string) string
|
||||
}
|
||||
|
||||
type node interface {
|
||||
GetID() uuid.UUID
|
||||
GetNID() uuid.UUID
|
||||
}
|
||||
|
||||
func (p *Persister) CleanupDatabase(ctx context.Context, wait time.Duration, older time.Duration, batchSize int) error {
|
||||
currentTime := time.Now().Add(-older)
|
||||
p.r.Logger().Printf("Cleaning up records older than %s\n", currentTime)
|
||||
|
|
@ -189,81 +196,3 @@ func (p *Persister) CleanupDatabase(ctx context.Context, wait time.Duration, old
|
|||
"This should be re-run periodically, to be sure that all expired data is purged.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Persister) update(ctx context.Context, v node, columnNames ...string) error {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.update")
|
||||
defer span.End()
|
||||
|
||||
c := p.GetConnection(ctx)
|
||||
quoter, ok := c.Dialect.(quotable)
|
||||
if !ok {
|
||||
return errors.Errorf("store is not a quoter: %T", p.c.Store)
|
||||
}
|
||||
|
||||
model := pop.NewModel(v, ctx)
|
||||
tn := model.TableName()
|
||||
|
||||
cols := columns.Columns{}
|
||||
if len(columnNames) > 0 && tn == model.TableName() {
|
||||
cols = columns.NewColumnsWithAlias(tn, model.As, model.IDField())
|
||||
cols.Add(columnNames...)
|
||||
} else {
|
||||
cols = columns.ForStructWithAlias(v, tn, model.As, model.IDField())
|
||||
}
|
||||
|
||||
// #nosec
|
||||
stmt := fmt.Sprintf("SELECT COUNT(id) FROM %s AS %s WHERE %s.id = ? AND %s.nid = ?",
|
||||
quoter.Quote(model.TableName()),
|
||||
model.Alias(),
|
||||
model.Alias(),
|
||||
model.Alias(),
|
||||
)
|
||||
|
||||
var count int
|
||||
if err := c.Store.GetContext(ctx, &count, c.Dialect.TranslateSQL(stmt), v.GetID(), v.GetNID()); err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
} else if count == 0 {
|
||||
return errors.WithStack(sqlcon.ErrNoRows)
|
||||
}
|
||||
|
||||
// #nosec
|
||||
stmt = fmt.Sprintf("UPDATE %s AS %s SET %s WHERE %s AND %s.nid = :nid",
|
||||
quoter.Quote(model.TableName()),
|
||||
model.Alias(),
|
||||
cols.Writeable().QuotedUpdateString(quoter),
|
||||
model.WhereNamedID(),
|
||||
model.Alias(),
|
||||
)
|
||||
|
||||
if _, err := c.Store.NamedExecContext(ctx, stmt, v); err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Persister) delete(ctx context.Context, v interface{}, id uuid.UUID) error {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.delete")
|
||||
defer span.End()
|
||||
|
||||
nid := p.NetworkID(ctx)
|
||||
|
||||
tabler, ok := v.(interface {
|
||||
TableName(ctx context.Context) string
|
||||
})
|
||||
if !ok {
|
||||
return errors.Errorf("expected model to have TableName signature but got: %T", v)
|
||||
}
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", tabler.TableName(ctx)),
|
||||
id,
|
||||
nid,
|
||||
).ExecWithCount()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
if count == 0 {
|
||||
return errors.WithStack(sqlcon.ErrNoRows)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ func (p *Persister) DeleteContinuitySession(ctx context.Context, id uuid.UUID) e
|
|||
defer span.End()
|
||||
|
||||
if count, err := p.GetConnection(ctx).RawQuery(
|
||||
// #nosec
|
||||
//#nosec G201 -- TableName is static
|
||||
fmt.Sprintf("DELETE FROM %s WHERE id=? AND nid=?",
|
||||
new(continuity.Container).TableName(ctx)), id, p.NetworkID(ctx)).ExecWithCount(); err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
|
|
@ -54,7 +54,7 @@ func (p *Persister) DeleteContinuitySession(ctx context.Context, id uuid.UUID) e
|
|||
}
|
||||
|
||||
func (p *Persister) DeleteExpiredContinuitySessions(ctx context.Context, expiresAt time.Time, limit int) error {
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(continuity.Container).TableName(ctx),
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/ory/x/uuidx"
|
||||
|
||||
"github.com/ory/kratos/courier"
|
||||
"github.com/ory/kratos/persistence/sql/update"
|
||||
)
|
||||
|
||||
var _ courier.Persister = new(Persister)
|
||||
|
|
@ -89,7 +90,7 @@ func (p *Persister) NextMessages(ctx context.Context, limit uint8) (messages []c
|
|||
for i := range m {
|
||||
message := &m[i]
|
||||
message.Status = courier.MessageStatusProcessing
|
||||
if err := p.update(ctx, message, "status"); err != nil {
|
||||
if err := update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), message, "status"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
|
|
@ -56,9 +55,8 @@ func (p *Persister) Read(ctx context.Context, id uuid.UUID) (*errorx.ErrorContai
|
|||
return nil, sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
// #nosec G201
|
||||
if err := p.GetConnection(ctx).RawQuery(
|
||||
fmt.Sprintf("UPDATE %s SET was_seen = true, seen_at = ? WHERE id = ? AND nid = ?", "selfservice_errors"),
|
||||
"UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ? AND nid = ?",
|
||||
time.Now().UTC(), id, p.NetworkID(ctx)).Exec(); err != nil {
|
||||
return nil, sqlcon.HandleError(err)
|
||||
}
|
||||
|
|
@ -71,14 +69,12 @@ func (p *Persister) Clear(ctx context.Context, olderThan time.Duration, force bo
|
|||
defer span.End()
|
||||
|
||||
if force {
|
||||
// #nosec G201
|
||||
err = p.GetConnection(ctx).RawQuery(
|
||||
fmt.Sprintf("DELETE FROM %s WHERE nid = ? AND seen_at < ? AND seen_at IS NOT NULL", "selfservice_errors"),
|
||||
"DELETE FROM selfservice_errors WHERE nid = ? AND seen_at < ? AND seen_at IS NOT NULL",
|
||||
p.NetworkID(ctx), time.Now().UTC().Add(-olderThan)).Exec()
|
||||
} else {
|
||||
// #nosec G201
|
||||
err = p.GetConnection(ctx).RawQuery(
|
||||
fmt.Sprintf("DELETE FROM %s WHERE nid = ? AND was_seen=true AND seen_at < ? AND seen_at IS NOT NULL", "selfservice_errors"),
|
||||
"DELETE FROM selfservice_errors WHERE nid = ? AND was_seen=true AND seen_at < ? AND seen_at IS NOT NULL",
|
||||
p.NetworkID(ctx), time.Now().UTC().Add(-olderThan)).Exec()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,9 +13,6 @@ import (
|
|||
"github.com/ory/x/configx"
|
||||
"github.com/ory/x/otelx"
|
||||
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/schema"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -23,6 +20,8 @@ import (
|
|||
"github.com/ory/x/logrusx"
|
||||
|
||||
"github.com/ory/kratos/driver/config"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/schema"
|
||||
)
|
||||
|
||||
type logRegistryOnly struct {
|
||||
|
|
@ -39,14 +38,6 @@ func (l *logRegistryOnly) Contextualizer() contextx.Contextualizer {
|
|||
panic("implement me")
|
||||
}
|
||||
|
||||
func (l *logRegistryOnly) IdentityTraitsSchemas(ctx context.Context) (schema.Schemas, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (l *logRegistryOnly) IdentityValidator() *identity.Validator {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (l *logRegistryOnly) Logger() *logrusx.Logger {
|
||||
if l.l == nil {
|
||||
l.l = logrusx.New("kratos", "testing")
|
||||
|
|
@ -61,6 +52,13 @@ func (l *logRegistryOnly) Audit() *logrusx.Logger {
|
|||
func (l *logRegistryOnly) Tracer(ctx context.Context) *otelx.Tracer {
|
||||
return otelx.NewNoop(l.l, new(otelx.Config))
|
||||
}
|
||||
func (l *logRegistryOnly) IdentityTraitsSchemas(ctx context.Context) (schema.Schemas, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (l *logRegistryOnly) IdentityValidator() *identity.Validator {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
var _ persisterDependencies = &logRegistryOnly{}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/ory/x/sqlcon"
|
||||
|
||||
"github.com/ory/kratos/persistence/sql/update"
|
||||
"github.com/ory/kratos/selfservice/flow/login"
|
||||
)
|
||||
|
||||
|
|
@ -35,7 +36,7 @@ func (p *Persister) UpdateLoginFlow(ctx context.Context, r *login.Flow) error {
|
|||
r.EnsureInternalContext()
|
||||
cp := *r
|
||||
cp.NID = p.NetworkID(ctx)
|
||||
return p.update(ctx, cp)
|
||||
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
|
||||
}
|
||||
|
||||
func (p *Persister) GetLoginFlow(ctx context.Context, id uuid.UUID) (*login.Flow, error) {
|
||||
|
|
@ -68,7 +69,7 @@ func (p *Persister) ForceLoginFlow(ctx context.Context, id uuid.UUID) error {
|
|||
}
|
||||
|
||||
func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time.Time, limit int) error {
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(login.Flow).TableName(ctx),
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/gofrs/uuid"
|
||||
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/persistence/sql/update"
|
||||
"github.com/ory/kratos/selfservice/flow"
|
||||
"github.com/ory/kratos/selfservice/flow/recovery"
|
||||
"github.com/ory/kratos/selfservice/strategy/code"
|
||||
|
|
@ -51,7 +52,7 @@ func (p *Persister) UpdateRecoveryFlow(ctx context.Context, r *recovery.Flow) er
|
|||
|
||||
cp := *r
|
||||
cp.NID = p.NetworkID(ctx)
|
||||
return p.update(ctx, cp)
|
||||
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
|
||||
}
|
||||
|
||||
func (p *Persister) CreateRecoveryToken(ctx context.Context, token *link.RecoveryToken) error {
|
||||
|
|
@ -101,7 +102,7 @@ func (p *Persister) UseRecoveryToken(ctx context.Context, fID uuid.UUID, token s
|
|||
}
|
||||
rt.RecoveryAddress = &ra
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
return tx.RawQuery(fmt.Sprintf("UPDATE %s SET used=true, used_at=? WHERE id=? AND nid = ?", rt.TableName(ctx)), time.Now().UTC(), rt.ID, nid).Exec()
|
||||
})); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -114,12 +115,12 @@ func (p *Persister) DeleteRecoveryToken(ctx context.Context, token string) error
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRecoveryToken")
|
||||
defer span.End()
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE token=? AND nid = ?", new(link.RecoveryToken).TableName(ctx)), token, p.NetworkID(ctx)).Exec()
|
||||
}
|
||||
|
||||
func (p *Persister) DeleteExpiredRecoveryFlows(ctx context.Context, expiresAt time.Time, limit int) error {
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(recovery.Flow).TableName(ctx),
|
||||
|
|
@ -186,14 +187,14 @@ func (p *Persister) UseRecoveryCode(ctx context.Context, fID uuid.UUID, codeVal
|
|||
|
||||
if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), fID, nid).Exec()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var submitCount int
|
||||
// Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on.
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), fID, nid).First(&submitCount)); err != nil {
|
||||
if errors.Is(err, sqlcon.ErrNoRows) {
|
||||
// Return no error, as that would roll back the transaction
|
||||
|
|
@ -248,7 +249,7 @@ func (p *Persister) UseRecoveryCode(ctx context.Context, fID uuid.UUID, codeVal
|
|||
}
|
||||
recoveryCode.RecoveryAddress = &ra
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
return sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", recoveryCode.TableName(ctx)), time.Now().UTC(), recoveryCode.ID, nid).Exec())
|
||||
})); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -273,6 +274,6 @@ func (p *Persister) DeleteRecoveryCodesOfFlow(ctx context.Context, fID uuid.UUID
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRecoveryCodesOfFlow")
|
||||
defer span.End()
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE selfservice_recovery_flow_id = ? AND nid = ?", new(code.RecoveryCode).TableName(ctx)), fID, p.NetworkID(ctx)).Exec()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
"github.com/ory/x/sqlcon"
|
||||
|
||||
"github.com/ory/kratos/persistence/sql/update"
|
||||
"github.com/ory/kratos/selfservice/flow/registration"
|
||||
)
|
||||
|
||||
|
|
@ -31,7 +32,7 @@ func (p *Persister) UpdateRegistrationFlow(ctx context.Context, r *registration.
|
|||
r.EnsureInternalContext()
|
||||
cp := *r
|
||||
cp.NID = p.NetworkID(ctx)
|
||||
return p.update(ctx, cp)
|
||||
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
|
||||
}
|
||||
|
||||
func (p *Persister) GetRegistrationFlow(ctx context.Context, id uuid.UUID) (*registration.Flow, error) {
|
||||
|
|
@ -48,7 +49,7 @@ func (p *Persister) GetRegistrationFlow(ctx context.Context, id uuid.UUID) (*reg
|
|||
}
|
||||
|
||||
func (p *Persister) DeleteExpiredRegistrationFlows(ctx context.Context, expiresAt time.Time, limit int) error {
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(registration.Flow).TableName(ctx),
|
||||
|
|
|
|||
|
|
@ -8,25 +8,17 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ory/x/otelx"
|
||||
|
||||
"github.com/ory/kratos/identity"
|
||||
|
||||
"github.com/ory/x/pagination/keysetpagination"
|
||||
|
||||
"github.com/ory/x/stringsx"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
|
||||
"github.com/ory/x/sqlcon"
|
||||
|
||||
"github.com/ory/kratos/session"
|
||||
"github.com/ory/x/otelx"
|
||||
"github.com/ory/x/pagination/keysetpagination"
|
||||
"github.com/ory/x/sqlcon"
|
||||
"github.com/ory/x/stringsx"
|
||||
)
|
||||
|
||||
var _ session.Persister = new(Persister)
|
||||
|
|
@ -57,7 +49,7 @@ func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables s
|
|||
if expandables.Has(session.ExpandSessionIdentity) {
|
||||
// This is needed because of how identities are fetched from the store (if we use eager not all fields are
|
||||
// available!).
|
||||
i, err := p.GetIdentity(ctx, s.IdentityID, identity.ExpandDefault)
|
||||
i, err := p.PrivilegedPool.GetIdentity(ctx, s.IdentityID, identity.ExpandDefault)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -116,7 +108,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt
|
|||
if s[k].Identity == nil {
|
||||
continue
|
||||
}
|
||||
if err := p.injectTraitsSchemaURL(ctx, s[k].Identity); err != nil {
|
||||
if err := p.InjectTraitsSchemaURL(ctx, s[k].Identity); err != nil {
|
||||
return nil, 0, nil, err
|
||||
}
|
||||
}
|
||||
|
|
@ -212,7 +204,7 @@ func (p *Persister) UpsertSession(ctx context.Context, s *session.Session) (err
|
|||
device.UserAgent = stringsx.GetPointer(stringsx.TruncateByteLen(*device.UserAgent, SessionDeviceUserAgentMaxLength))
|
||||
}
|
||||
|
||||
if err := sqlcon.HandleError(tx.Create(device)); err != nil {
|
||||
if err := p.DevicePersister.CreateDevice(ctx, device); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -225,17 +217,29 @@ func (p *Persister) DeleteSession(ctx context.Context, sid uuid.UUID) (err error
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteSession")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
return p.delete(ctx, new(session.Session), sid)
|
||||
nid := p.NetworkID(ctx)
|
||||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ? AND nid = ?", new(session.Session).TableName(ctx)),
|
||||
sid,
|
||||
nid,
|
||||
).ExecWithCount()
|
||||
if err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
if count == 0 {
|
||||
return errors.WithStack(sqlcon.ErrNoRows)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Persister) DeleteSessionsByIdentity(ctx context.Context, identityID uuid.UUID) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteSessionsByIdentity")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE identity_id = ? AND nid = ?",
|
||||
"sessions",
|
||||
new(session.Session).TableName(ctx),
|
||||
),
|
||||
identityID,
|
||||
p.NetworkID(ctx),
|
||||
|
|
@ -279,7 +283,7 @@ func (p *Persister) GetSessionByToken(ctx context.Context, token string, expand
|
|||
// available!).
|
||||
if expand.Has(session.ExpandSessionIdentity) {
|
||||
eg.Go(func() (err error) {
|
||||
i, err = p.GetIdentity(ctx, s.IdentityID, identityExpand)
|
||||
i, err = p.PrivilegedPool.GetIdentity(ctx, s.IdentityID, identityExpand)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
|
@ -298,10 +302,10 @@ func (p *Persister) DeleteSessionByToken(ctx context.Context, token string) (err
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteSessionByToken")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE token = ? AND nid = ?",
|
||||
"sessions",
|
||||
new(session.Session).TableName(ctx),
|
||||
),
|
||||
token,
|
||||
p.NetworkID(ctx),
|
||||
|
|
@ -319,10 +323,10 @@ func (p *Persister) RevokeSessionByToken(ctx context.Context, token string) (err
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSessionByToken")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"UPDATE %s SET active = false WHERE token = ? AND nid = ?",
|
||||
"sessions",
|
||||
new(session.Session).TableName(ctx),
|
||||
),
|
||||
token,
|
||||
p.NetworkID(ctx),
|
||||
|
|
@ -341,10 +345,10 @@ func (p *Persister) RevokeSessionById(ctx context.Context, sID uuid.UUID) (err e
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSessionById")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"UPDATE %s SET active = false WHERE id = ? AND nid = ?",
|
||||
"sessions",
|
||||
new(session.Session).TableName(ctx),
|
||||
),
|
||||
sID,
|
||||
p.NetworkID(ctx),
|
||||
|
|
@ -364,10 +368,10 @@ func (p *Persister) RevokeSession(ctx context.Context, iID, sID uuid.UUID) (err
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSession")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"UPDATE %s SET active = false WHERE id = ? AND identity_id = ? AND nid = ?",
|
||||
"sessions",
|
||||
new(session.Session).TableName(ctx),
|
||||
),
|
||||
sID,
|
||||
iID,
|
||||
|
|
@ -384,10 +388,10 @@ func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID u
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSessionsIdentityExcept")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"UPDATE %s SET active = false WHERE identity_id = ? AND id != ? AND nid = ?",
|
||||
"sessions",
|
||||
new(session.Session).TableName(ctx),
|
||||
),
|
||||
iID,
|
||||
sID,
|
||||
|
|
@ -402,10 +406,12 @@ func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID u
|
|||
func (p *Persister) DeleteExpiredSessions(ctx context.Context, expiresAt time.Time, limit int) (err error) {
|
||||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteExpiredSessions")
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
//#nosec G201 -- TableName is static
|
||||
err = p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
"sessions",
|
||||
"sessions",
|
||||
new(session.Session).TableName(ctx),
|
||||
new(session.Session).TableName(ctx),
|
||||
limit,
|
||||
),
|
||||
expiresAt,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/persistence/sql/update"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
|
||||
|
|
@ -39,7 +40,7 @@ func (p *Persister) GetSettingsFlow(ctx context.Context, id uuid.UUID) (*setting
|
|||
return nil, sqlcon.HandleError(err)
|
||||
}
|
||||
|
||||
r.Identity, err = p.GetIdentity(ctx, r.IdentityID, identity.ExpandDefault)
|
||||
r.Identity, err = p.PrivilegedPool.GetIdentity(ctx, r.IdentityID, identity.ExpandDefault)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -54,11 +55,11 @@ func (p *Persister) UpdateSettingsFlow(ctx context.Context, r *settings.Flow) er
|
|||
r.EnsureInternalContext()
|
||||
cp := *r
|
||||
cp.NID = p.NetworkID(ctx)
|
||||
return p.update(ctx, cp)
|
||||
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
|
||||
}
|
||||
|
||||
func (p *Persister) DeleteExpiredSettingsFlows(ctx context.Context, expiresAt time.Time, limit int) error {
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(settings.Flow).TableName(ctx),
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/persistence/sql/update"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/gofrs/uuid"
|
||||
|
|
@ -53,7 +54,7 @@ func (p *Persister) UpdateVerificationFlow(ctx context.Context, r *verification.
|
|||
|
||||
cp := *r
|
||||
cp.NID = p.NetworkID(ctx)
|
||||
return p.update(ctx, cp)
|
||||
return update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), cp)
|
||||
}
|
||||
|
||||
func (p *Persister) CreateVerificationToken(ctx context.Context, token *link.VerificationToken) error {
|
||||
|
|
@ -101,7 +102,7 @@ func (p *Persister) UseVerificationToken(ctx context.Context, fID uuid.UUID, tok
|
|||
|
||||
rt.VerifiableAddress = &va
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
return tx.RawQuery(fmt.Sprintf("UPDATE %s SET used=true, used_at=? WHERE id=? AND nid = ?", rt.TableName(ctx)), time.Now().UTC(), rt.ID, nid).Exec()
|
||||
})); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -115,12 +116,12 @@ func (p *Persister) DeleteVerificationToken(ctx context.Context, token string) e
|
|||
defer span.End()
|
||||
|
||||
nid := p.NetworkID(ctx)
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE token=? AND nid = ?", new(link.VerificationToken).TableName(ctx)), token, nid).Exec()
|
||||
}
|
||||
|
||||
func (p *Persister) DeleteExpiredVerificationFlows(ctx context.Context, expiresAt time.Time, limit int) error {
|
||||
// #nosec G201
|
||||
//#nosec G201 -- TableName is static
|
||||
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
|
||||
new(verification.Flow).TableName(ctx),
|
||||
|
|
@ -149,7 +150,7 @@ func (p *Persister) UseVerificationCode(ctx context.Context, fID uuid.UUID, code
|
|||
|
||||
if err := sqlcon.HandleError(
|
||||
tx.RawQuery(
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName),
|
||||
fID,
|
||||
nid,
|
||||
|
|
@ -162,7 +163,7 @@ func (p *Persister) UseVerificationCode(ctx context.Context, fID uuid.UUID, code
|
|||
// Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on.
|
||||
if err := sqlcon.HandleError(
|
||||
tx.RawQuery(
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName),
|
||||
fID,
|
||||
nid,
|
||||
|
|
@ -222,7 +223,7 @@ func (p *Persister) UseVerificationCode(ctx context.Context, fID uuid.UUID, code
|
|||
|
||||
verificationCode.VerifiableAddress = &va
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
return tx.
|
||||
RawQuery(
|
||||
fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", verificationCode.TableName(ctx)),
|
||||
|
|
@ -276,7 +277,7 @@ func (p *Persister) DeleteVerificationCodesOfFlow(ctx context.Context, fID uuid.
|
|||
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteVerificationCodesOfFlow")
|
||||
defer span.End()
|
||||
|
||||
/* #nosec G201 TableName is static */
|
||||
//#nosec G201 -- TableName is static
|
||||
return p.GetConnection(ctx).
|
||||
RawQuery(
|
||||
fmt.Sprintf("DELETE FROM %s WHERE selfservice_verification_flow_id = ? AND nid = ?", new(code.VerificationCode).TableName(ctx)),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright © 2023 Ory Corp
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package update
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/gobuffalo/pop/v6/columns"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/ory/x/sqlcon"
|
||||
)
|
||||
|
||||
type Model interface {
|
||||
GetID() uuid.UUID
|
||||
GetNID() uuid.UUID
|
||||
}
|
||||
|
||||
func Generic(ctx context.Context, c *pop.Connection, tracer trace.Tracer, v Model, columnNames ...string) error {
|
||||
ctx, span := tracer.Start(ctx, "persistence.sql.update")
|
||||
defer span.End()
|
||||
|
||||
quoter, ok := c.Dialect.(interface{ Quote(key string) string })
|
||||
if !ok {
|
||||
return errors.Errorf("store is not a quoter: %T", c.Store)
|
||||
}
|
||||
|
||||
model := pop.NewModel(v, ctx)
|
||||
tn := model.TableName()
|
||||
|
||||
cols := columns.Columns{}
|
||||
if len(columnNames) > 0 && tn == model.TableName() {
|
||||
cols = columns.NewColumnsWithAlias(tn, model.As, model.IDField())
|
||||
cols.Add(columnNames...)
|
||||
} else {
|
||||
cols = columns.ForStructWithAlias(v, tn, model.As, model.IDField())
|
||||
}
|
||||
|
||||
//#nosec G201 -- TableName is static
|
||||
stmt := fmt.Sprintf("SELECT COUNT(id) FROM %s AS %s WHERE %s.id = ? AND %s.nid = ?",
|
||||
quoter.Quote(model.TableName()),
|
||||
model.Alias(),
|
||||
model.Alias(),
|
||||
model.Alias(),
|
||||
)
|
||||
|
||||
var count int
|
||||
if err := c.Store.GetContext(ctx, &count, c.Dialect.TranslateSQL(stmt), v.GetID(), v.GetNID()); err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
} else if count == 0 {
|
||||
return errors.WithStack(sqlcon.ErrNoRows)
|
||||
}
|
||||
|
||||
//#nosec G201 -- TableName is static
|
||||
stmt = fmt.Sprintf("UPDATE %s AS %s SET %s WHERE %s AND %s.nid = :nid",
|
||||
quoter.Quote(model.TableName()),
|
||||
model.Alias(),
|
||||
cols.Writeable().QuotedUpdateString(quoter),
|
||||
model.WhereNamedID(),
|
||||
model.Alias(),
|
||||
)
|
||||
|
||||
if _, err := c.Store.NamedExecContext(ctx, stmt, v); err != nil {
|
||||
return sqlcon.HandleError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/ory/kratos/ui/node"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/x/httpx"
|
||||
"github.com/ory/x/otelx"
|
||||
"github.com/ory/x/otelx/semconv"
|
||||
)
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ type (
|
|||
x.CSRFTokenGeneratorProvider
|
||||
x.WriterProvider
|
||||
x.LoggingProvider
|
||||
x.TracingProvider
|
||||
|
||||
HooksProvider
|
||||
}
|
||||
|
|
@ -110,7 +112,12 @@ func (e *HookExecutor) handleLoginError(_ http.ResponseWriter, r *http.Request,
|
|||
return flowError
|
||||
}
|
||||
|
||||
func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, i *identity.Identity, s *session.Session) error {
|
||||
func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, i *identity.Identity, s *session.Session) (err error) {
|
||||
ctx := r.Context()
|
||||
ctx, span := e.d.Tracer(ctx).Tracer().Start(ctx, "HookExecutor.PostLoginHook")
|
||||
r = r.WithContext(ctx)
|
||||
defer otelx.End(span, &err)
|
||||
|
||||
if err := s.Activate(r, i, e.d.Config(), time.Now().UTC()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -128,7 +135,8 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g n
|
|||
return err
|
||||
}
|
||||
|
||||
s = s.Declassify()
|
||||
classified := s
|
||||
s = s.Declassified()
|
||||
|
||||
e.d.Logger().
|
||||
WithRequest(r).
|
||||
|
|
@ -181,7 +189,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g n
|
|||
)
|
||||
|
||||
response := &APIFlowResponse{Session: s, Token: s.Token}
|
||||
if required, _ := e.requiresAAL2(r, s, a); required {
|
||||
if required, _ := e.requiresAAL2(r, classified, a); required {
|
||||
// If AAL is not satisfied, we omit the identity to preserve the user's privacy in case of a phishing attack.
|
||||
response.Session.Identity = nil
|
||||
}
|
||||
|
|
@ -240,7 +248,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g n
|
|||
finalReturnTo = rt
|
||||
}
|
||||
|
||||
x.ContentNegotiationRedirection(w, r, s.Declassify(), e.d.Writer(), finalReturnTo)
|
||||
x.ContentNegotiationRedirection(w, r, s, e.d.Writer(), finalReturnTo)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
|
|||
finalReturnTo = cr
|
||||
}
|
||||
|
||||
x.ContentNegotiationRedirection(w, r, s.Declassify(), e.d.Writer(), finalReturnTo)
|
||||
x.ContentNegotiationRedirection(w, r, s.Declassified(), e.d.Writer(), finalReturnTo)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ func newHydraIntegration(t *testing.T, remote *string, subject *string, claims *
|
|||
parsed, err := url.ParseRequestURI(addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// #nosec G112
|
||||
//#nosec G112
|
||||
server := &http.Server{Addr: ":" + parsed.Port(), Handler: router}
|
||||
go func(t *testing.T) {
|
||||
if err := server.ListenAndServe(); err != http.ErrServerClosed {
|
||||
|
|
|
|||
|
|
@ -76,8 +76,7 @@ func TestRegistration(t *testing.T) {
|
|||
apiClient := testhelpers.NewDebugClient(t)
|
||||
|
||||
t.Run("AssertCommonErrorCases", func(t *testing.T) {
|
||||
reg := newRegistrationRegistry(t)
|
||||
registrationhelpers.AssertCommonErrorCases(t, reg, flows)
|
||||
registrationhelpers.AssertCommonErrorCases(t, flows)
|
||||
})
|
||||
|
||||
t.Run("AssertRegistrationRespectsValidation", func(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
|
||||
/* #nosec G505 sha1 is used for k-anonymity */
|
||||
"crypto/sha1"
|
||||
"crypto/sha1" //#nosec G505 -- sha1 is used for k-anonymity
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
|
@ -188,7 +187,7 @@ func (s *DefaultPasswordValidator) validate(ctx context.Context, identifier, pas
|
|||
return nil
|
||||
}
|
||||
|
||||
/* #nosec G401 sha1 is used for k-anonymity */
|
||||
//#nosec G401 -- sha1 is used for k-anonymity
|
||||
h := sha1.New()
|
||||
if _, err := h.Write([]byte(password)); err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha1" // #nosec G505 - compatibility for imported passwords
|
||||
"crypto/sha1" //#nosec G505 -- compatibility for imported passwords
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -144,7 +144,7 @@ func TestDefaultPasswordValidationStrategy(t *testing.T) {
|
|||
s.Client = httpx.NewResilientClient(httpx.ResilientClientWithClient(&fakeClient.Client), httpx.ResilientClientWithMaxRetry(1), httpx.ResilientClientWithConnectionTimeout(time.Millisecond))
|
||||
|
||||
var hashPw = func(t *testing.T, pw string) string {
|
||||
/* #nosec G401 sha1 is used for k-anonymity */
|
||||
//#nosec G401 -- sha1 is used for k-anonymity
|
||||
h := sha1.New()
|
||||
_, err := h.Write([]byte(pw))
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
|
|
@ -88,8 +88,7 @@ func TestRegistration(t *testing.T) {
|
|||
//}
|
||||
|
||||
t.Run("AssertCommonErrorCases", func(t *testing.T) {
|
||||
reg := newRegistrationRegistry(t)
|
||||
registrationhelpers.AssertCommonErrorCases(t, reg, flows)
|
||||
registrationhelpers.AssertCommonErrorCases(t, flows)
|
||||
})
|
||||
|
||||
t.Run("AssertRegistrationRespectsValidation", func(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -68,3 +68,7 @@ type Persister interface {
|
|||
// RevokeSessionsIdentityExcept marks all except the given session of an identity inactive. It returns the number of sessions that were revoked.
|
||||
RevokeSessionsIdentityExcept(ctx context.Context, iID, sID uuid.UUID) (int, error)
|
||||
}
|
||||
|
||||
type DevicePersister interface {
|
||||
CreateDevice(ctx context.Context, d *Device) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -261,9 +261,9 @@ func (s *Session) SetSessionDeviceInformation(r *http.Request) {
|
|||
s.Devices = append(s.Devices, device)
|
||||
}
|
||||
|
||||
func (s *Session) Declassify() *Session {
|
||||
func (s Session) Declassified() *Session {
|
||||
s.Identity = s.Identity.CopyWithoutCredentials()
|
||||
return s
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *Session) IsActive() bool {
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ func main() {
|
|||
})
|
||||
|
||||
addr := ":" + osx.GetenvDefault("PORT", "4446")
|
||||
// #nosec G112
|
||||
//#nosec G112
|
||||
server := &http.Server{Addr: addr, Handler: router}
|
||||
fmt.Printf("Starting web server at %s\n", addr)
|
||||
check(server.ListenAndServe())
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// #nosec G404
|
||||
var rnd = rand.New(rand.NewSource(time.Now().Unix()))
|
||||
var rnd = rand.New(rand.NewSource(time.Now().Unix())) //#nosec G404
|
||||
|
||||
func AssertEqualTime(t *testing.T, expected, actual time.Time) {
|
||||
assert.EqualValues(t, expected.UTC().Round(time.Second), actual.UTC().Round(time.Second))
|
||||
|
|
|
|||
Loading…
Reference in New Issue