mirror of https://github.com/ory/kratos
test: fix data races
GitOrigin-RevId: 624aeee77f199dd8c67b67cefe2a61c2f9e7d759
This commit is contained in:
parent
c90675c65e
commit
8482dd5de0
|
|
@ -47,14 +47,17 @@ func TestPersister(ctx context.Context, newNetworkUnlessExisting NetworkWrapper,
|
|||
|
||||
messages := make([]courier.Message, 5)
|
||||
t.Run("case=add messages to the queue", func(t *testing.T) {
|
||||
t.Cleanup(func() { pop.SetNowFunc(func() time.Time { return time.Now().Round(time.Second) }) })
|
||||
now := time.Now()
|
||||
for k := range messages {
|
||||
// We need to fake the time func to control the created_at column, which is the
|
||||
// sort key for the messages.
|
||||
pop.SetNowFunc(func() time.Time { return now.Add(time.Duration(k) * time.Hour).Round(time.Second) })
|
||||
require.NoError(t, faker.FakeData(&messages[k]))
|
||||
require.NoError(t, p.AddMessage(ctx, &messages[k]))
|
||||
require.NoError(t, p.GetConnection(ctx).
|
||||
RawQuery(
|
||||
"UPDATE courier_messages SET created_at = ?, updated_at = ? WHERE id = ? AND nid = ?",
|
||||
now.Add(time.Duration(k)*time.Hour).Round(time.Second),
|
||||
now.Add(time.Duration(k)*time.Hour).Round(time.Second),
|
||||
messages[k].ID, nid).
|
||||
Exec())
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -623,7 +623,6 @@ func TestDriverDefault_Hooks(t *testing.T) {
|
|||
func TestDriverDefault_Strategies(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
_, reg := internal.NewVeryFastRegistryWithoutDB(t)
|
||||
|
||||
t.Run("case=registration", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
@ -680,6 +679,7 @@ func TestDriverDefault_Strategies(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
ctx := contextx.WithConfigValues(ctx, tc.config)
|
||||
_, reg := internal.NewVeryFastRegistryWithoutDB(t)
|
||||
s := reg.RegistrationStrategies(ctx)
|
||||
require.Len(t, s, len(tc.expect))
|
||||
for k, e := range tc.expect {
|
||||
|
|
@ -753,6 +753,7 @@ func TestDriverDefault_Strategies(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
ctx := contextx.WithConfigValues(ctx, tc.config)
|
||||
_, reg := internal.NewVeryFastRegistryWithoutDB(t)
|
||||
s := reg.LoginStrategies(ctx)
|
||||
require.Len(t, s, len(tc.expect))
|
||||
for k, e := range tc.expect {
|
||||
|
|
@ -786,6 +787,7 @@ func TestDriverDefault_Strategies(t *testing.T) {
|
|||
|
||||
ctx := contextx.WithConfigValues(ctx, tc.config)
|
||||
|
||||
_, reg := internal.NewVeryFastRegistryWithoutDB(t)
|
||||
s := reg.RecoveryStrategies(ctx)
|
||||
require.Len(t, s, len(tc.expect))
|
||||
for k, e := range tc.expect {
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import (
|
|||
continuity "github.com/ory/kratos/continuity/test"
|
||||
"github.com/ory/kratos/corpx"
|
||||
courier "github.com/ory/kratos/courier/test"
|
||||
"github.com/ory/kratos/driver"
|
||||
"github.com/ory/kratos/driver/config"
|
||||
ri "github.com/ory/kratos/identity"
|
||||
identity "github.com/ory/kratos/identity/test"
|
||||
|
|
@ -93,10 +92,11 @@ func pl(t testing.TB) func(lvl logging.Level, s string, args ...interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
func createCleanDatabases(t testing.TB) map[string]*driver.RegistryDefault {
|
||||
func createCleanDatabases(t testing.TB) map[string]string {
|
||||
conns := map[string]string{
|
||||
"sqlite": "sqlite://file:" + t.TempDir() + "/db.sqlite?_fk=true&max_conns=1&lock=false",
|
||||
}
|
||||
connsMtx := sync.Mutex{}
|
||||
|
||||
if !testing.Short() {
|
||||
funcs := map[string]func(t testing.TB) string{
|
||||
|
|
@ -116,14 +116,18 @@ func createCleanDatabases(t testing.TB) map[string]*driver.RegistryDefault {
|
|||
go func(s string, f func(t testing.TB) string) {
|
||||
defer wg.Done()
|
||||
db := f(t)
|
||||
connsMtx.Lock()
|
||||
conns[s] = db
|
||||
connsMtx.Unlock()
|
||||
}(k, f)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
ps := make(map[string]*driver.RegistryDefault, len(conns))
|
||||
ps := make(map[string]string, len(conns))
|
||||
psMtx := sync.Mutex{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(conns))
|
||||
for name, dsn := range conns {
|
||||
|
|
@ -154,7 +158,9 @@ func createCleanDatabases(t testing.TB) map[string]*driver.RegistryDefault {
|
|||
require.NoError(t, err)
|
||||
require.False(t, status.HasPending())
|
||||
|
||||
ps[name] = reg
|
||||
psMtx.Lock()
|
||||
ps[name] = dsn
|
||||
psMtx.Unlock()
|
||||
|
||||
t.Logf("Database %s initialized successfully", name)
|
||||
}(name, dsn)
|
||||
|
|
@ -170,26 +176,24 @@ func TestPersister(t *testing.T) {
|
|||
conns := createCleanDatabases(t)
|
||||
ctx := testhelpers.WithDefaultIdentitySchema(context.Background(), "file://./stub/identity.schema.json")
|
||||
|
||||
for name, reg := range conns {
|
||||
for name, dsn := range conns {
|
||||
t.Run(fmt.Sprintf("database=%s", name), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
|
||||
t.Logf("DSN: %s", reg.Config().DSN(ctx))
|
||||
t.Logf("DSN: %s", dsn)
|
||||
|
||||
t.Run("racy identity creation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
_, ps := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
|
||||
for i := range 10 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, ps := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
id := ri.NewIdentity("")
|
||||
id.SetCredentials(ri.CredentialsTypePassword, ri.Credentials{
|
||||
Type: ri.CredentialsTypePassword,
|
||||
|
|
@ -207,6 +211,8 @@ func TestPersister(t *testing.T) {
|
|||
|
||||
t.Run("case=credential types exist", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
for _, ct := range []ri.CredentialsType{ri.CredentialsTypeOIDC, ri.CredentialsTypePassword} {
|
||||
require.NoError(t, p.(*sql.Persister).Connection(context.Background()).Where("name = ?", ct).First(&ri.CredentialsTypeTable{}))
|
||||
}
|
||||
|
|
@ -214,59 +220,88 @@ func TestPersister(t *testing.T) {
|
|||
|
||||
t.Run("contract=identity.TestPool", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
identity.TestPool(ctx, p, reg.IdentityManager(), name)(t)
|
||||
})
|
||||
t.Run("contract=registration.TestFlowPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
registration.TestFlowPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=errorx.TestPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
errorx.TestPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=login.TestFlowPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
login.TestFlowPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=settings.TestFlowPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
settings.TestFlowPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=session.TestPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
session.TestPersister(ctx, reg.Config(), p)(t)
|
||||
})
|
||||
t.Run("contract=sessiontokenexchange.TestPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
sessiontokenexchange.TestPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=courier.TestPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
upsert, insert := sqltesthelpers.DefaultNetworkWrapper(p)
|
||||
courier.TestPersister(ctx, upsert, insert)(t)
|
||||
})
|
||||
t.Run("contract=verification.TestFlowPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
verification.TestFlowPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=recovery.TestFlowPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
recovery.TestFlowPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=link.TestPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
link.TestPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=code.TestPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
code.TestPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=continuity.TestPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
continuity.TestPersister(ctx, p)(t)
|
||||
})
|
||||
t.Run("contract=batch.TestPersister", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
|
||||
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
|
||||
batch.TestPersister(ctx, reg.Tracer(ctx), p)(t)
|
||||
})
|
||||
})
|
||||
|
|
@ -334,7 +369,8 @@ func Benchmark_BatchCreateIdentities(b *testing.B) {
|
|||
batchSizes := []int{1, 10, 100, 500, 800, 900, 1000, 2000, 3000}
|
||||
parallelRequests := []int{1, 4, 8, 16}
|
||||
|
||||
for name, reg := range conns {
|
||||
for name, dsn := range conns {
|
||||
_, reg := internal.NewRegistryDefaultWithDSN(b, dsn)
|
||||
b.Run(fmt.Sprintf("database=%s", name), func(b *testing.B) {
|
||||
conf := reg.Config()
|
||||
_, p := testhelpers.NewNetwork(b, ctx, reg.Persister())
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ package test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -21,7 +22,6 @@ import (
|
|||
"github.com/ory/kratos/persistence"
|
||||
"github.com/ory/kratos/session"
|
||||
"github.com/ory/kratos/x"
|
||||
"github.com/ory/pop/v6"
|
||||
"github.com/ory/x/contextx"
|
||||
"github.com/ory/x/dbal"
|
||||
"github.com/ory/x/pagination/keysetpagination"
|
||||
|
|
@ -151,7 +151,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
|
|||
|
||||
seedSessionIDs := make([]uuid.UUID, 5)
|
||||
seedSessionsList := make([]session.Session, 5)
|
||||
start := time.Now()
|
||||
now := time.Now()
|
||||
for j := range seedSessionsList {
|
||||
require.NoError(t, faker.FakeData(&seedSessionsList[j]))
|
||||
seedSessionsList[j].Identity = &identity1
|
||||
|
|
@ -168,13 +168,16 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
|
|||
seedSessionsList[j].Devices = []session.Device{
|
||||
device,
|
||||
}
|
||||
pop.SetNowFunc(func() time.Time {
|
||||
return start.Add(time.Duration(j) * time.Minute).Round(time.Second)
|
||||
})
|
||||
require.NoError(t, l.UpsertSession(ctx, &seedSessionsList[j]))
|
||||
require.NoError(t, p.GetConnection(ctx).
|
||||
RawQuery(
|
||||
"UPDATE sessions SET created_at = ?, updated_at = ? WHERE id = ?",
|
||||
now.Add(time.Duration(j)*time.Minute).Round(time.Second),
|
||||
now.Add(time.Duration(j)*time.Minute).Round(time.Second),
|
||||
seedSessionsList[j].ID).
|
||||
Exec())
|
||||
seedSessionIDs[j] = seedSessionsList[j].ID
|
||||
}
|
||||
pop.SetNowFunc(time.Now)
|
||||
|
||||
identity2Session.Identity = &identity2
|
||||
identity2Session.Active = true
|
||||
|
|
@ -646,13 +649,13 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
|
|||
|
||||
expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt
|
||||
|
||||
foundExpectedCockroachError := false
|
||||
foundExpectedCockroachError := atomic.Bool{}
|
||||
g := errgroup.Group{}
|
||||
for range 10 {
|
||||
g.Go(func() error {
|
||||
err := p.ExtendSession(ctx, expected.ID)
|
||||
if errors.Is(err, sqlcon.ErrNoRows) {
|
||||
foundExpectedCockroachError = true
|
||||
foundExpectedCockroachError.Store(true)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
|
@ -663,7 +666,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
|
|||
actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, expectedExpiry.Sub(actual.ExpiresAt).Abs(), 10*time.Second)
|
||||
assert.True(t, foundExpectedCockroachError, "We expect to find a not found error caused by ... FOR UPDATE SKIP LOCKED")
|
||||
assert.True(t, foundExpectedCockroachError.Load(), "We expect to find a not found error caused by ... FOR UPDATE SKIP LOCKED")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue