diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 4f441d0fd..aeda93bc6 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -1,6 +1,7 @@ package kvcache import ( + "fmt" "math" "slices" "testing" @@ -20,217 +21,59 @@ type testCase struct { expectedMask []float32 } -func TestStore(t *testing.T) { - backend := &testBackend{} - cache := NewCausalCache(nil) - defer cache.Close() - - cache.Init(backend, ml.DTypeF16, 1, 16, 16) - - tests := []testCase{ - { - name: "FirstBatch", - in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, - inShape: []int{2, 3, 4}, - seqs: []int{0, 0, 0, 0}, - pos: []int32{0, 1, 2, 3}, - expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, - expectedShape: []int{2, 3, 4}, - expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, - }, - { - name: "SecondBatch", - in: []float32{115, 215, 125, 225, 135, 235}, - inShape: []int{2, 3, 1}, - seqs: []int{0}, - pos: []int32{4}, - expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235}, - expectedShape: []int{2, 3, 5}, - expectedMask: []float32{0, 0, 0, 0, 0}, - }, +func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) { + t.Helper() + for _, permuted := range []bool{false, true} { + t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) { + fn(t, &testBackend{permutedV: permuted}) + }) } +} - testCache(t, backend, cache, tests) +func TestStore(t *testing.T) { + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + cache := NewCausalCache(nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 1, 16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, + inShape: []int{2, 3, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, + expectedShape: []int{2, 3, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, + }, + { + name: "SecondBatch", + in: []float32{115, 215, 125, 225, 135, 235}, + inShape: []int{2, 3, 1}, + seqs: []int{0}, + pos: []int32{4}, + expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235}, + expectedShape: []int{2, 3, 5}, + expectedMask: []float32{0, 0, 0, 0, 0}, + }, + } + + testCache(t, backend, cache, tests) + }) } func TestSWA(t *testing.T) { - backend := &testBackend{} - cache := NewSWACache(1, nil) - defer cache.Close() + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + cache := NewSWACache(1, nil) + defer cache.Close() - cache.Init(backend, ml.DTypeF16, 1, 16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) - x := float32(math.Inf(-1)) + x := float32(math.Inf(-1)) - tests := []testCase{ - { - name: "FirstBatch", - in: []float32{1, 2, 3, 4}, - inShape: []int{1, 1, 4}, - seqs: []int{0, 0, 0, 0}, - pos: []int32{0, 1, 2, 3}, - expected: []float32{1, 2, 3, 4}, - expectedShape: []int{1, 1, 4}, - expectedMask: []float32{ - 0, x, x, x, - 0, 0, x, x, - x, 0, 0, x, - x, x, 0, 0, - }, - }, - { - name: "SecondBatch", - in: []float32{5, 6}, - inShape: []int{1, 1, 2}, - seqs: []int{0, 0}, - pos: []int32{4, 5}, - expected: []float32{5, 6, 3, 4}, - expectedShape: []int{1, 1, 4}, - expectedMask: []float32{ - 0, x, x, 0, - 0, 0, x, x, - }, - }, - } - - testCache(t, backend, cache, tests) -} - -func TestSWASeparateBatches(t *testing.T) { - backend := &testBackend{} - cache := NewSWACache(1, nil) - defer cache.Close() - - cache.Init(backend, ml.DTypeF16, 2, 16, 2) - - x := float32(math.Inf(-1)) - - tests := []testCase{ - { - name: "First seq 0", - in: []float32{1, 2}, - inShape: []int{1, 1, 2}, - seqs: []int{0, 0}, - pos: []int32{0, 1}, - expected: []float32{1, 2}, - expectedShape: []int{1, 1, 2}, - expectedMask: []float32{ - 0, x, - 0, 0, - }, - }, - { - name: "Second seq 0", - in: []float32{3, 4}, - inShape: []int{1, 1, 2}, - seqs: []int{0, 0}, - pos: []int32{2, 3}, - expected: []float32{2, 3, 4}, - expectedShape: []int{1, 1, 3}, - expectedMask: []float32{ - 0, 0, x, - x, 0, 0, - }, - }, - { - name: "First seq 1", - in: []float32{5, 6}, - inShape: []int{1, 1, 2}, - seqs: []int{1, 1}, - pos: []int32{0, 1}, - expected: []float32{5, 6}, - expectedShape: []int{1, 1, 2}, - expectedMask: []float32{ - 0, x, - 0, 0, - }, - }, - { - name: "Second seq 1", - in: []float32{7, 8}, - inShape: []int{1, 1, 2}, - seqs: []int{1, 1}, - pos: []int32{2, 3}, - expected: []float32{6, 3, 4, 7, 8}, - expectedShape: []int{1, 1, 5}, - expectedMask: []float32{ - 0, x, x, 0, x, - x, x, x, 0, 0, - }, - }, - { - name: "Third seq 0", - in: []float32{9, 10}, - inShape: []int{1, 1, 2}, - seqs: []int{0, 0}, - pos: []int32{4, 5}, - expected: []float32{9, 10, 3, 4}, - expectedShape: []int{1, 1, 4}, - expectedMask: []float32{ - 0, x, x, 0, - 0, 0, x, x, - }, - }, - } - - testCache(t, backend, cache, tests) -} - -func TestSWAMem(t *testing.T) { - backend := &testBackend{} - cache := NewSWAMemCache(1, 3, nil) - defer cache.Close() - - cache.Init(backend, ml.DTypeF16, 1, 16, 16) - - x := float32(math.Inf(-1)) - - tests := []testCase{ - { - name: "FirstBatch", - in: []float32{1, 2, 3, 4}, - inShape: []int{1, 1, 4}, - seqs: []int{0, 0, 0, 0}, - pos: []int32{0, 1, 2, 3}, - expected: []float32{1, 2, 3, 4}, - expectedShape: []int{1, 1, 4}, - expectedMask: []float32{ - 0, x, x, x, - 0, 0, x, x, - x, 0, 0, x, - x, x, 0, 0, - }, - }, - { - name: "SecondBatch", - in: []float32{5, 6}, - inShape: []int{1, 1, 2}, - seqs: []int{0, 0}, - pos: []int32{4, 5}, - expected: []float32{5, 2, 3, 4, 6}, - expectedShape: []int{1, 1, 5}, - expectedMask: []float32{ - 0, x, x, 0, x, - 0, x, x, x, 0, - }, - }, - } - - testCache(t, backend, cache, tests) -} - -func TestChunkedAttention(t *testing.T) { - cache := NewChunkedAttentionCache(2, nil) - defer cache.Close() - - var b testBackend - cache.Init(&b, ml.DTypeF16, 1, 16, 16) - - x := float32(math.Inf(-1)) - - testCache( - t, &b, cache, - []testCase{ + tests := []testCase{ { name: "FirstBatch", in: []float32{1, 2, 3, 4}, @@ -242,190 +85,365 @@ func TestChunkedAttention(t *testing.T) { expectedMask: []float32{ 0, x, x, x, 0, 0, x, x, - x, x, 0, x, + x, 0, 0, x, x, x, 0, 0, }, }, { name: "SecondBatch", - in: []float32{5, 6, 7}, - inShape: []int{1, 1, 3}, - seqs: []int{0, 0, 0}, - pos: []int32{4, 5, 6}, - expected: []float32{1, 2, 3, 4, 5, 6, 7}, - expectedShape: []int{1, 1, 7}, + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{4, 5}, + expected: []float32{5, 6, 3, 4}, + expectedShape: []int{1, 1, 4}, expectedMask: []float32{ - x, x, x, x, 0, x, x, - x, x, x, x, 0, 0, x, - x, x, x, x, x, x, 0, + 0, x, x, 0, + 0, 0, x, x, + }, + }, + } + + testCache(t, backend, cache, tests) + }) +} + +func TestSWASeparateBatches(t *testing.T) { + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + cache := NewSWACache(1, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 2, 16, 2) + + x := float32(math.Inf(-1)) + + tests := []testCase{ + { + name: "First seq 0", + in: []float32{1, 2}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{0, 1}, + expected: []float32{1, 2}, + expectedShape: []int{1, 1, 2}, + expectedMask: []float32{ + 0, x, + 0, 0, }, }, { - name: "ThirdBatch", - in: []float32{8, 9}, + name: "Second seq 0", + in: []float32{3, 4}, inShape: []int{1, 1, 2}, seqs: []int{0, 0}, - pos: []int32{7, 8}, - expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, - expectedShape: []int{1, 1, 9}, + pos: []int32{2, 3}, + expected: []float32{2, 3, 4}, + expectedShape: []int{1, 1, 3}, expectedMask: []float32{ - x, x, x, x, x, x, 0, 0, x, - x, x, x, x, x, x, x, x, 0, + 0, 0, x, + x, 0, 0, }, }, - }, - ) + { + name: "First seq 1", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{1, 1}, + pos: []int32{0, 1}, + expected: []float32{5, 6}, + expectedShape: []int{1, 1, 2}, + expectedMask: []float32{ + 0, x, + 0, 0, + }, + }, + { + name: "Second seq 1", + in: []float32{7, 8}, + inShape: []int{1, 1, 2}, + seqs: []int{1, 1}, + pos: []int32{2, 3}, + expected: []float32{6, 3, 4, 7, 8}, + expectedShape: []int{1, 1, 5}, + expectedMask: []float32{ + 0, x, x, 0, x, + x, x, x, 0, 0, + }, + }, + { + name: "Third seq 0", + in: []float32{9, 10}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{4, 5}, + expected: []float32{9, 10, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{ + 0, x, x, 0, + 0, 0, x, x, + }, + }, + } + + testCache(t, backend, cache, tests) + }) +} + +func TestSWAMem(t *testing.T) { + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + cache := NewSWAMemCache(1, 3, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 1, 16, 16) + + x := float32(math.Inf(-1)) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{ + 0, x, x, x, + 0, 0, x, x, + x, 0, 0, x, + x, x, 0, 0, + }, + }, + { + name: "SecondBatch", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{4, 5}, + expected: []float32{5, 2, 3, 4, 6}, + expectedShape: []int{1, 1, 5}, + expectedMask: []float32{ + 0, x, x, 0, x, + 0, x, x, x, 0, + }, + }, + } + + testCache(t, backend, cache, tests) + }) +} + +func TestChunkedAttention(t *testing.T) { + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + cache := NewChunkedAttentionCache(2, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 1, 16, 16) + + x := float32(math.Inf(-1)) + + testCache( + t, backend, cache, + []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{ + 0, x, x, x, + 0, 0, x, x, + x, x, 0, x, + x, x, 0, 0, + }, + }, + { + name: "SecondBatch", + in: []float32{5, 6, 7}, + inShape: []int{1, 1, 3}, + seqs: []int{0, 0, 0}, + pos: []int32{4, 5, 6}, + expected: []float32{1, 2, 3, 4, 5, 6, 7}, + expectedShape: []int{1, 1, 7}, + expectedMask: []float32{ + x, x, x, x, 0, x, x, + x, x, x, x, 0, 0, x, + x, x, x, x, x, x, 0, + }, + }, + { + name: "ThirdBatch", + in: []float32{8, 9}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{7, 8}, + expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, + expectedShape: []int{1, 1, 9}, + expectedMask: []float32{ + x, x, x, x, x, x, 0, 0, x, + x, x, x, x, x, x, x, x, 0, + }, + }, + }, + ) + }) } func TestSequences(t *testing.T) { - backend := &testBackend{} - cache := NewCausalCache(nil) - defer cache.Close() + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + cache := NewCausalCache(nil) + defer cache.Close() - cache.Init(backend, ml.DTypeF16, 1, 16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) - tests := []testCase{ - { - name: "FirstBatch", - in: []float32{1, 2, 3, 4}, - inShape: []int{1, 1, 4}, - seqs: []int{0, 0, 1, 1}, - pos: []int32{0, 1, 0, 1}, - expected: []float32{1, 2, 3, 4}, - expectedShape: []int{1, 1, 4}, - expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, - }, - { - name: "SecondBatch", - in: []float32{5, 6}, - inShape: []int{1, 1, 2}, - seqs: []int{0, 1}, - pos: []int32{2, 2}, - expected: []float32{1, 2, 3, 4, 5, 6}, - expectedShape: []int{1, 1, 6}, - expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, - }, - } + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 1, 1}, + pos: []int32{0, 1, 0, 1}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + { + name: "SecondBatch", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 1}, + pos: []int32{2, 2}, + expected: []float32{1, 2, 3, 4, 5, 6}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, + }, + } - testCache(t, backend, cache, tests) + testCache(t, backend, cache, tests) + }) } func TestRemove(t *testing.T) { - backend := &testBackend{} - cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.Add(ctx, shift), nil + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.Add(ctx, shift), nil + }) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 1, 16, 16) + + x := float32(math.Inf(-1)) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 1, 1}, + pos: []int32{0, 1, 0, 1}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{ + 0, x, x, x, + 0, 0, x, x, + x, x, 0, x, + x, x, 0, 0, + }, + }, + } + + testCache(t, backend, cache, tests) + + err := cache.Remove(0, 1, math.MaxInt32) + if err != nil { + panic(err) + } + + tests = []testCase{ + { + name: "RemoveEnd", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 1}, + pos: []int32{1, 2}, + expected: []float32{1, 5, 3, 4, 6}, + expectedShape: []int{1, 1, 5}, + expectedMask: []float32{ + 0, 0, x, x, x, + x, x, 0, 0, 0, + }, + }, + } + + testCache(t, backend, cache, tests) + + err = cache.Remove(0, 0, 1) + if err != nil { + panic(err) + } + + tests = []testCase{ + { + name: "RemoveMiddle", + in: []float32{7, 8}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{1, 2}, + expected: []float32{7, 4, 3, 4, 6, 8}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{ + 0, 0, x, x, x, x, + 0, 0, x, x, x, 0, + }, + }, + } + + testCache(t, backend, cache, tests) }) - defer cache.Close() - - cache.Init(backend, ml.DTypeF16, 1, 16, 16) - - x := float32(math.Inf(-1)) - - tests := []testCase{ - { - name: "FirstBatch", - in: []float32{1, 2, 3, 4}, - inShape: []int{1, 1, 4}, - seqs: []int{0, 0, 1, 1}, - pos: []int32{0, 1, 0, 1}, - expected: []float32{1, 2, 3, 4}, - expectedShape: []int{1, 1, 4}, - expectedMask: []float32{ - 0, x, x, x, - 0, 0, x, x, - x, x, 0, x, - x, x, 0, 0, - }, - }, - } - - testCache(t, backend, cache, tests) - - err := cache.Remove(0, 1, math.MaxInt32) - if err != nil { - panic(err) - } - - tests = []testCase{ - { - name: "RemoveEnd", - in: []float32{5, 6}, - inShape: []int{1, 1, 2}, - seqs: []int{0, 1}, - pos: []int32{1, 2}, - expected: []float32{1, 5, 3, 4, 6}, - expectedShape: []int{1, 1, 5}, - expectedMask: []float32{ - 0, 0, x, x, x, - x, x, 0, 0, 0, - }, - }, - } - - testCache(t, backend, cache, tests) - - err = cache.Remove(0, 0, 1) - if err != nil { - panic(err) - } - - tests = []testCase{ - { - name: "RemoveMiddle", - in: []float32{7, 8}, - inShape: []int{1, 1, 2}, - seqs: []int{0, 0}, - pos: []int32{1, 2}, - expected: []float32{7, 4, 3, 4, 6, 8}, - expectedShape: []int{1, 1, 6}, - expectedMask: []float32{ - 0, 0, x, x, x, x, - 0, 0, x, x, x, 0, - }, - }, - } - - testCache(t, backend, cache, tests) } func TestCopy(t *testing.T) { - backend := &testBackend{} - cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) - defer cache.Close() + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) + defer cache.Close() - cache.Init(backend, ml.DTypeF16, 1, 16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) - tests := []testCase{ - { - name: "FirstBatch", - in: []float32{1, 2, 3, 4}, - inShape: []int{1, 1, 4}, - seqs: []int{0, 0, 0, 0}, - pos: []int32{0, 1, 2, 3}, - expected: []float32{1, 2, 3, 4}, - expectedShape: []int{1, 1, 4}, - expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, - }, - } + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, + }, + } - testCache(t, backend, cache, tests) + testCache(t, backend, cache, tests) - cache.CopyPrefix(0, 1, 2) + cache.CopyPrefix(0, 1, 2) - tests = []testCase{ - { - name: "Copy", - in: []float32{5, 6}, - inShape: []int{1, 1, 2}, - seqs: []int{1, 1}, - pos: []int32{3, 4}, - expected: []float32{1, 2, 3, 4, 5, 6}, - expectedShape: []int{1, 1, 6}, - expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, - }, - } + tests = []testCase{ + { + name: "Copy", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{1, 1}, + pos: []int32{3, 4}, + expected: []float32{1, 2, 3, 4, 5, 6}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + } - testCache(t, backend, cache, tests) + testCache(t, backend, cache, tests) + }) } func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) { @@ -463,145 +481,148 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) } func TestCanResume(t *testing.T) { - backend := &testBackend{} - windowSize := int32(4) - cache := NewSWACache(windowSize, nil) - defer cache.Close() + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + windowSize := int32(4) + cache := NewSWACache(windowSize, nil) + defer cache.Close() - cache.Init(backend, ml.DTypeF16, 1, 16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) - context := backend.NewContext() - defer context.Close() + context := backend.NewContext() + defer context.Close() - err := cache.StartForward(context, input.Batch{ - Positions: []int32{0, 1, 2, 3, 4}, - Sequences: []int{0, 0, 0, 0, 0}, - }, false) - if err != nil { - t.Fatalf("StartForward failed: %v", err) - } + err := cache.StartForward(context, input.Batch{ + Positions: []int32{0, 1, 2, 3, 4}, + Sequences: []int{0, 0, 0, 0, 0}, + }, false) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } - cache.SetLayer(0) - tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5) - cache.Put(context, tensor, tensor) + cache.SetLayer(0) + tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5) + cache.Put(context, tensor, tensor) - // with window size 4, nothing has slid out of the window yet - if !cache.CanResume(0, 0) { - t.Errorf("CanResume(0, 0) = false, want true (within window)") - } - if !cache.CanResume(0, 1) { - t.Errorf("CanResume(0, 1) = false, want true (within window)") - } - if !cache.CanResume(0, 2) { - t.Errorf("CanResume(0, 2) = false, want true (within window)") - } - if !cache.CanResume(0, 3) { - t.Errorf("CanResume(0, 3) = false, want true (latest position)") - } - if !cache.CanResume(0, 4) { - t.Errorf("CanResume(0, 4) = false, want true (latest position)") - } + // with window size 4, nothing has slid out of the window yet + if !cache.CanResume(0, 0) { + t.Errorf("CanResume(0, 0) = false, want true (within window)") + } + if !cache.CanResume(0, 1) { + t.Errorf("CanResume(0, 1) = false, want true (within window)") + } + if !cache.CanResume(0, 2) { + t.Errorf("CanResume(0, 2) = false, want true (within window)") + } + if !cache.CanResume(0, 3) { + t.Errorf("CanResume(0, 3) = false, want true (latest position)") + } + if !cache.CanResume(0, 4) { + t.Errorf("CanResume(0, 4) = false, want true (latest position)") + } - // shift window by adding position 5 - err = cache.StartForward(context, input.Batch{ - Positions: []int32{5}, - Sequences: []int{0}, - }, false) - if err != nil { - t.Fatalf("StartForward failed: %v", err) - } + // shift window by adding position 5 + err = cache.StartForward(context, input.Batch{ + Positions: []int32{5}, + Sequences: []int{0}, + }, false) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } - cache.SetLayer(0) - tensor = context.FromFloats([]float32{6}, 1, 1, 1) - cache.Put(context, tensor, tensor) + cache.SetLayer(0) + tensor = context.FromFloats([]float32{6}, 1, 1, 1) + cache.Put(context, tensor, tensor) - // only the latest position has overlapping windows - if cache.CanResume(0, 0) { - t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") - } - if cache.CanResume(0, 1) { - t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") - } - if cache.CanResume(0, 2) { - t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") - } - if cache.CanResume(0, 3) { - t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") - } - if cache.CanResume(0, 4) { - t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") - } - if !cache.CanResume(0, 5) { - t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") - } + // only the latest position has overlapping windows + if cache.CanResume(0, 0) { + t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") + } + if cache.CanResume(0, 1) { + t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") + } + if cache.CanResume(0, 2) { + t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") + } + if cache.CanResume(0, 3) { + t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") + } + if cache.CanResume(0, 4) { + t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") + } + if !cache.CanResume(0, 5) { + t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") + } + }) } func TestCanResumeSWAMem(t *testing.T) { - backend := &testBackend{} - windowSize := int32(4) - memSize := int32(5) - cache := NewSWAMemCache(windowSize, memSize, nil) - defer cache.Close() + runPermutedVariants(t, func(t *testing.T, backend *testBackend) { + windowSize := int32(4) + memSize := int32(5) + cache := NewSWAMemCache(windowSize, memSize, nil) + defer cache.Close() - cache.Init(backend, ml.DTypeF16, 1, 16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) - context := backend.NewContext() - defer context.Close() + context := backend.NewContext() + defer context.Close() - err := cache.StartForward(context, input.Batch{ - Positions: []int32{0, 1, 2, 3, 4, 5, 6}, - Sequences: []int{0, 0, 0, 0, 0, 0, 0}, - }, false) - if err != nil { - t.Fatalf("StartForward failed: %v", err) - } + err := cache.StartForward(context, input.Batch{ + Positions: []int32{0, 1, 2, 3, 4, 5, 6}, + Sequences: []int{0, 0, 0, 0, 0, 0, 0}, + }, false) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } - cache.SetLayer(0) - tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) - cache.Put(context, tensor, tensor) + cache.SetLayer(0) + tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) + cache.Put(context, tensor, tensor) - // shift window by adding position 7 - err = cache.StartForward(context, input.Batch{ - Positions: []int32{7}, - Sequences: []int{0}, - }, false) - if err != nil { - t.Fatalf("StartForward failed: %v", err) - } + // shift window by adding position 7 + err = cache.StartForward(context, input.Batch{ + Positions: []int32{7}, + Sequences: []int{0}, + }, false) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } - cache.SetLayer(0) - tensor = context.FromFloats([]float32{8}, 1, 1, 1) - cache.Put(context, tensor, tensor) + cache.SetLayer(0) + tensor = context.FromFloats([]float32{8}, 1, 1, 1) + cache.Put(context, tensor, tensor) - // only the latest position has overlapping windows - if cache.CanResume(0, 0) { - t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") - } - if cache.CanResume(0, 1) { - t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") - } - if cache.CanResume(0, 2) { - t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") - } - if cache.CanResume(0, 3) { - t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") - } - if cache.CanResume(0, 4) { - t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") - } - if cache.CanResume(0, 5) { - t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)") - } - if !cache.CanResume(0, 6) { - t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)") - } - if !cache.CanResume(0, 7) { - t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)") - } + // only the latest position has overlapping windows + if cache.CanResume(0, 0) { + t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") + } + if cache.CanResume(0, 1) { + t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") + } + if cache.CanResume(0, 2) { + t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") + } + if cache.CanResume(0, 3) { + t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") + } + if cache.CanResume(0, 4) { + t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") + } + if cache.CanResume(0, 5) { + t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)") + } + if !cache.CanResume(0, 6) { + t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)") + } + if !cache.CanResume(0, 7) { + t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)") + } + }) } type testBackend struct { ml.Backend + permutedV bool } func (b *testBackend) NewContext() ml.Context { @@ -612,6 +633,10 @@ func (b *testBackend) NewContextSize(int) ml.Context { return &testContext{} } +func (b *testBackend) CacheConfig() ml.CacheConfig { + return ml.CacheConfig{PermutedV: b.permutedV} +} + type testContext struct { ml.Context } @@ -766,6 +791,102 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { return view } +func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor { + if len(t.shape) > 4 || len(order) > 4 { + panic("permute only supports up to 4 dimensions") + } + + if len(order) != len(t.shape) && len(order) != 4 { + panic("invalid number of dimensions for permute") + } + + // ggml_permute expects 4 axes, so fill in any missing dimensions. + orderFull := append(make([]int, 0, 4), order...) + for len(orderFull) < 4 { + orderFull = append(orderFull, len(orderFull)) + } + + seen := [4]bool{} + + shape4 := [4]int{1, 1, 1, 1} + for i := 0; i < len(t.shape) && i < 4; i++ { + shape4[i] = t.shape[i] + } + + newShape4 := [4]int{1, 1, 1, 1} + for axis := range 4 { + dst := orderFull[axis] + if dst < 0 || dst >= 4 { + panic("invalid axis for permute") + } + if seen[dst] { + panic("duplicate axis for permute") + } + seen[dst] = true + newShape4[dst] = shape4[axis] + } + + total := len(t.data) + newData := make([]float32, total) + + if total > 0 { + oldDims := shape4 + newDims := newShape4 + + oldStride := [4]int{1, 1, 1, 1} + newStride := [4]int{1, 1, 1, 1} + for i := 1; i < 4; i++ { + oldStride[i] = oldStride[i-1] * oldDims[i-1] + newStride[i] = newStride[i-1] * newDims[i-1] + } + + var coords [4]int + var newCoords [4]int + + for idx := range total { + remainder := idx + for axis := range 4 { + dim := oldDims[axis] + if dim == 0 { + coords[axis] = 0 + continue + } + coords[axis] = remainder % dim + remainder /= dim + } + + for axis := range 4 { + newCoords[orderFull[axis]] = coords[axis] + } + + newIndex := 0 + for axis := range 4 { + if newDims[axis] == 0 { + continue + } + newIndex += newCoords[axis] * newStride[axis] + } + + newData[newIndex] = t.data[idx] + } + } + + numDims := 4 + for numDims > 1 && newShape4[numDims-1] <= 1 { + numDims-- + } + + newShape := make([]int, numDims) + copy(newShape, newShape4[:numDims]) + + return &testTensor{ + dtype: t.dtype, + elementSize: t.elementSize, + data: newData, + shape: newShape, + } +} + func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor { dst := t srcTensor := src.(*testTensor)