test: fix data races

GitOrigin-RevId: 624aeee77f199dd8c67b67cefe2a61c2f9e7d759
This commit is contained in:
Philippe Gaultier 2025-12-03 15:15:50 +01:00 committed by ory-bot
parent c90675c65e
commit 8482dd5de0
4 changed files with 69 additions and 25 deletions

View File

@ -47,14 +47,17 @@ func TestPersister(ctx context.Context, newNetworkUnlessExisting NetworkWrapper,
messages := make([]courier.Message, 5)
t.Run("case=add messages to the queue", func(t *testing.T) {
t.Cleanup(func() { pop.SetNowFunc(func() time.Time { return time.Now().Round(time.Second) }) })
now := time.Now()
for k := range messages {
// We need to fake the time func to control the created_at column, which is the
// sort key for the messages.
pop.SetNowFunc(func() time.Time { return now.Add(time.Duration(k) * time.Hour).Round(time.Second) })
require.NoError(t, faker.FakeData(&messages[k]))
require.NoError(t, p.AddMessage(ctx, &messages[k]))
require.NoError(t, p.GetConnection(ctx).
RawQuery(
"UPDATE courier_messages SET created_at = ?, updated_at = ? WHERE id = ? AND nid = ?",
now.Add(time.Duration(k)*time.Hour).Round(time.Second),
now.Add(time.Duration(k)*time.Hour).Round(time.Second),
messages[k].ID, nid).
Exec())
}
})

View File

@ -623,7 +623,6 @@ func TestDriverDefault_Hooks(t *testing.T) {
func TestDriverDefault_Strategies(t *testing.T) {
t.Parallel()
ctx := context.Background()
_, reg := internal.NewVeryFastRegistryWithoutDB(t)
t.Run("case=registration", func(t *testing.T) {
t.Parallel()
@ -680,6 +679,7 @@ func TestDriverDefault_Strategies(t *testing.T) {
t.Parallel()
ctx := contextx.WithConfigValues(ctx, tc.config)
_, reg := internal.NewVeryFastRegistryWithoutDB(t)
s := reg.RegistrationStrategies(ctx)
require.Len(t, s, len(tc.expect))
for k, e := range tc.expect {
@ -753,6 +753,7 @@ func TestDriverDefault_Strategies(t *testing.T) {
t.Parallel()
ctx := contextx.WithConfigValues(ctx, tc.config)
_, reg := internal.NewVeryFastRegistryWithoutDB(t)
s := reg.LoginStrategies(ctx)
require.Len(t, s, len(tc.expect))
for k, e := range tc.expect {
@ -786,6 +787,7 @@ func TestDriverDefault_Strategies(t *testing.T) {
ctx := contextx.WithConfigValues(ctx, tc.config)
_, reg := internal.NewVeryFastRegistryWithoutDB(t)
s := reg.RecoveryStrategies(ctx)
require.Len(t, s, len(tc.expect))
for k, e := range tc.expect {

View File

@ -21,7 +21,6 @@ import (
continuity "github.com/ory/kratos/continuity/test"
"github.com/ory/kratos/corpx"
courier "github.com/ory/kratos/courier/test"
"github.com/ory/kratos/driver"
"github.com/ory/kratos/driver/config"
ri "github.com/ory/kratos/identity"
identity "github.com/ory/kratos/identity/test"
@ -93,10 +92,11 @@ func pl(t testing.TB) func(lvl logging.Level, s string, args ...interface{}) {
}
}
func createCleanDatabases(t testing.TB) map[string]*driver.RegistryDefault {
func createCleanDatabases(t testing.TB) map[string]string {
conns := map[string]string{
"sqlite": "sqlite://file:" + t.TempDir() + "/db.sqlite?_fk=true&max_conns=1&lock=false",
}
connsMtx := sync.Mutex{}
if !testing.Short() {
funcs := map[string]func(t testing.TB) string{
@ -116,14 +116,18 @@ func createCleanDatabases(t testing.TB) map[string]*driver.RegistryDefault {
go func(s string, f func(t testing.TB) string) {
defer wg.Done()
db := f(t)
connsMtx.Lock()
conns[s] = db
connsMtx.Unlock()
}(k, f)
}
wg.Wait()
}
ps := make(map[string]*driver.RegistryDefault, len(conns))
ps := make(map[string]string, len(conns))
psMtx := sync.Mutex{}
var wg sync.WaitGroup
wg.Add(len(conns))
for name, dsn := range conns {
@ -154,7 +158,9 @@ func createCleanDatabases(t testing.TB) map[string]*driver.RegistryDefault {
require.NoError(t, err)
require.False(t, status.HasPending())
ps[name] = reg
psMtx.Lock()
ps[name] = dsn
psMtx.Unlock()
t.Logf("Database %s initialized successfully", name)
}(name, dsn)
@ -170,26 +176,24 @@ func TestPersister(t *testing.T) {
conns := createCleanDatabases(t)
ctx := testhelpers.WithDefaultIdentitySchema(context.Background(), "file://./stub/identity.schema.json")
for name, reg := range conns {
for name, dsn := range conns {
t.Run(fmt.Sprintf("database=%s", name), func(t *testing.T) {
t.Parallel()
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
t.Logf("DSN: %s", reg.Config().DSN(ctx))
t.Logf("DSN: %s", dsn)
t.Run("racy identity creation", func(t *testing.T) {
t.Parallel()
var wg sync.WaitGroup
_, ps := testhelpers.NewNetwork(t, ctx, reg.Persister())
for i := range 10 {
wg.Add(1)
go func() {
defer wg.Done()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, ps := testhelpers.NewNetwork(t, ctx, reg.Persister())
id := ri.NewIdentity("")
id.SetCredentials(ri.CredentialsTypePassword, ri.Credentials{
Type: ri.CredentialsTypePassword,
@ -207,6 +211,8 @@ func TestPersister(t *testing.T) {
t.Run("case=credential types exist", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
for _, ct := range []ri.CredentialsType{ri.CredentialsTypeOIDC, ri.CredentialsTypePassword} {
require.NoError(t, p.(*sql.Persister).Connection(context.Background()).Where("name = ?", ct).First(&ri.CredentialsTypeTable{}))
}
@ -214,59 +220,88 @@ func TestPersister(t *testing.T) {
t.Run("contract=identity.TestPool", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
identity.TestPool(ctx, p, reg.IdentityManager(), name)(t)
})
t.Run("contract=registration.TestFlowPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
registration.TestFlowPersister(ctx, p)(t)
})
t.Run("contract=errorx.TestPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
errorx.TestPersister(ctx, p)(t)
})
t.Run("contract=login.TestFlowPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
login.TestFlowPersister(ctx, p)(t)
})
t.Run("contract=settings.TestFlowPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
settings.TestFlowPersister(ctx, p)(t)
})
t.Run("contract=session.TestPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
session.TestPersister(ctx, reg.Config(), p)(t)
})
t.Run("contract=sessiontokenexchange.TestPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
sessiontokenexchange.TestPersister(ctx, p)(t)
})
t.Run("contract=courier.TestPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
upsert, insert := sqltesthelpers.DefaultNetworkWrapper(p)
courier.TestPersister(ctx, upsert, insert)(t)
})
t.Run("contract=verification.TestFlowPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
verification.TestFlowPersister(ctx, p)(t)
})
t.Run("contract=recovery.TestFlowPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
recovery.TestFlowPersister(ctx, p)(t)
})
t.Run("contract=link.TestPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
link.TestPersister(ctx, p)(t)
})
t.Run("contract=code.TestPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
code.TestPersister(ctx, p)(t)
})
t.Run("contract=continuity.TestPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
continuity.TestPersister(ctx, p)(t)
})
t.Run("contract=batch.TestPersister", func(t *testing.T) {
t.Parallel()
_, reg := internal.NewRegistryDefaultWithDSN(t, dsn)
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
batch.TestPersister(ctx, reg.Tracer(ctx), p)(t)
})
})
@ -334,7 +369,8 @@ func Benchmark_BatchCreateIdentities(b *testing.B) {
batchSizes := []int{1, 10, 100, 500, 800, 900, 1000, 2000, 3000}
parallelRequests := []int{1, 4, 8, 16}
for name, reg := range conns {
for name, dsn := range conns {
_, reg := internal.NewRegistryDefaultWithDSN(b, dsn)
b.Run(fmt.Sprintf("database=%s", name), func(b *testing.B) {
conf := reg.Config()
_, p := testhelpers.NewNetwork(b, ctx, reg.Persister())

View File

@ -5,6 +5,7 @@ package test
import (
"context"
"sync/atomic"
"testing"
"time"
@ -21,7 +22,6 @@ import (
"github.com/ory/kratos/persistence"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
"github.com/ory/pop/v6"
"github.com/ory/x/contextx"
"github.com/ory/x/dbal"
"github.com/ory/x/pagination/keysetpagination"
@ -151,7 +151,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
seedSessionIDs := make([]uuid.UUID, 5)
seedSessionsList := make([]session.Session, 5)
start := time.Now()
now := time.Now()
for j := range seedSessionsList {
require.NoError(t, faker.FakeData(&seedSessionsList[j]))
seedSessionsList[j].Identity = &identity1
@ -168,13 +168,16 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
seedSessionsList[j].Devices = []session.Device{
device,
}
pop.SetNowFunc(func() time.Time {
return start.Add(time.Duration(j) * time.Minute).Round(time.Second)
})
require.NoError(t, l.UpsertSession(ctx, &seedSessionsList[j]))
require.NoError(t, p.GetConnection(ctx).
RawQuery(
"UPDATE sessions SET created_at = ?, updated_at = ? WHERE id = ?",
now.Add(time.Duration(j)*time.Minute).Round(time.Second),
now.Add(time.Duration(j)*time.Minute).Round(time.Second),
seedSessionsList[j].ID).
Exec())
seedSessionIDs[j] = seedSessionsList[j].ID
}
pop.SetNowFunc(time.Now)
identity2Session.Identity = &identity2
identity2Session.Active = true
@ -646,13 +649,13 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt
foundExpectedCockroachError := false
foundExpectedCockroachError := atomic.Bool{}
g := errgroup.Group{}
for range 10 {
g.Go(func() error {
err := p.ExtendSession(ctx, expected.ID)
if errors.Is(err, sqlcon.ErrNoRows) {
foundExpectedCockroachError = true
foundExpectedCockroachError.Store(true)
return nil
}
return err
@ -663,7 +666,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing)
require.NoError(t, err)
assert.LessOrEqual(t, expectedExpiry.Sub(actual.ExpiresAt).Abs(), 10*time.Second)
assert.True(t, foundExpectedCockroachError, "We expect to find a not found error caused by ... FOR UPDATE SKIP LOCKED")
assert.True(t, foundExpectedCockroachError.Load(), "We expect to find a not found error caused by ... FOR UPDATE SKIP LOCKED")
})
}
}