mirror of https://github.com/ollama/ollama
974 lines
24 KiB
Go
974 lines
24 KiB
Go
package kvcache
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"slices"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
type testCase struct {
|
|
name string
|
|
in []float32
|
|
inShape []int
|
|
seqs []int
|
|
pos []int32
|
|
expected []float32
|
|
expectedShape []int
|
|
expectedMask []float32
|
|
}
|
|
|
|
func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
|
|
t.Helper()
|
|
for _, permuted := range []bool{false, true} {
|
|
t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
|
|
fn(t, &testBackend{permutedV: permuted})
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStore(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
cache := NewCausalCache(nil)
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "FirstBatch",
|
|
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
|
inShape: []int{2, 3, 4},
|
|
seqs: []int{0, 0, 0, 0},
|
|
pos: []int32{0, 1, 2, 3},
|
|
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
|
expectedShape: []int{2, 3, 4},
|
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
|
},
|
|
{
|
|
name: "SecondBatch",
|
|
in: []float32{115, 215, 125, 225, 135, 235},
|
|
inShape: []int{2, 3, 1},
|
|
seqs: []int{0},
|
|
pos: []int32{4},
|
|
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
|
expectedShape: []int{2, 3, 5},
|
|
expectedMask: []float32{0, 0, 0, 0, 0},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
})
|
|
}
|
|
|
|
func TestSWA(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
cache := NewSWACache(1, nil)
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
|
|
|
x := float32(math.Inf(-1))
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "FirstBatch",
|
|
in: []float32{1, 2, 3, 4},
|
|
inShape: []int{1, 1, 4},
|
|
seqs: []int{0, 0, 0, 0},
|
|
pos: []int32{0, 1, 2, 3},
|
|
expected: []float32{1, 2, 3, 4},
|
|
expectedShape: []int{1, 1, 4},
|
|
expectedMask: []float32{
|
|
0, x, x, x,
|
|
0, 0, x, x,
|
|
x, 0, 0, x,
|
|
x, x, 0, 0,
|
|
},
|
|
},
|
|
{
|
|
name: "SecondBatch",
|
|
in: []float32{5, 6},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{0, 0},
|
|
pos: []int32{4, 5},
|
|
expected: []float32{5, 6, 3, 4},
|
|
expectedShape: []int{1, 1, 4},
|
|
expectedMask: []float32{
|
|
0, x, x, 0,
|
|
0, 0, x, x,
|
|
},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
})
|
|
}
|
|
|
|
func TestSWASeparateBatches(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
cache := NewSWACache(1, nil)
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 2, 16, 2)
|
|
|
|
x := float32(math.Inf(-1))
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "First seq 0",
|
|
in: []float32{1, 2},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{0, 0},
|
|
pos: []int32{0, 1},
|
|
expected: []float32{1, 2},
|
|
expectedShape: []int{1, 1, 2},
|
|
expectedMask: []float32{
|
|
0, x,
|
|
0, 0,
|
|
},
|
|
},
|
|
{
|
|
name: "Second seq 0",
|
|
in: []float32{3, 4},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{0, 0},
|
|
pos: []int32{2, 3},
|
|
expected: []float32{2, 3, 4},
|
|
expectedShape: []int{1, 1, 3},
|
|
expectedMask: []float32{
|
|
0, 0, x,
|
|
x, 0, 0,
|
|
},
|
|
},
|
|
{
|
|
name: "First seq 1",
|
|
in: []float32{5, 6},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{1, 1},
|
|
pos: []int32{0, 1},
|
|
expected: []float32{5, 6},
|
|
expectedShape: []int{1, 1, 2},
|
|
expectedMask: []float32{
|
|
0, x,
|
|
0, 0,
|
|
},
|
|
},
|
|
{
|
|
name: "Second seq 1",
|
|
in: []float32{7, 8},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{1, 1},
|
|
pos: []int32{2, 3},
|
|
expected: []float32{6, 3, 4, 7, 8},
|
|
expectedShape: []int{1, 1, 5},
|
|
expectedMask: []float32{
|
|
0, x, x, 0, x,
|
|
x, x, x, 0, 0,
|
|
},
|
|
},
|
|
{
|
|
name: "Third seq 0",
|
|
in: []float32{9, 10},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{0, 0},
|
|
pos: []int32{4, 5},
|
|
expected: []float32{9, 10, 3, 4},
|
|
expectedShape: []int{1, 1, 4},
|
|
expectedMask: []float32{
|
|
0, x, x, 0,
|
|
0, 0, x, x,
|
|
},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
})
|
|
}
|
|
|
|
func TestSWAMem(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
cache := NewSWAMemCache(1, 3, nil)
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
|
|
|
x := float32(math.Inf(-1))
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "FirstBatch",
|
|
in: []float32{1, 2, 3, 4},
|
|
inShape: []int{1, 1, 4},
|
|
seqs: []int{0, 0, 0, 0},
|
|
pos: []int32{0, 1, 2, 3},
|
|
expected: []float32{1, 2, 3, 4},
|
|
expectedShape: []int{1, 1, 4},
|
|
expectedMask: []float32{
|
|
0, x, x, x,
|
|
0, 0, x, x,
|
|
x, 0, 0, x,
|
|
x, x, 0, 0,
|
|
},
|
|
},
|
|
{
|
|
name: "SecondBatch",
|
|
in: []float32{5, 6},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{0, 0},
|
|
pos: []int32{4, 5},
|
|
expected: []float32{5, 2, 3, 4, 6},
|
|
expectedShape: []int{1, 1, 5},
|
|
expectedMask: []float32{
|
|
0, x, x, 0, x,
|
|
0, x, x, x, 0,
|
|
},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
})
|
|
}
|
|
|
|
func TestChunkedAttention(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
cache := NewChunkedAttentionCache(2, nil)
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
|
|
|
x := float32(math.Inf(-1))
|
|
|
|
testCache(
|
|
t, backend, cache,
|
|
[]testCase{
|
|
{
|
|
name: "FirstBatch",
|
|
in: []float32{1, 2, 3, 4},
|
|
inShape: []int{1, 1, 4},
|
|
seqs: []int{0, 0, 0, 0},
|
|
pos: []int32{0, 1, 2, 3},
|
|
expected: []float32{1, 2, 3, 4},
|
|
expectedShape: []int{1, 1, 4},
|
|
expectedMask: []float32{
|
|
0, x, x, x,
|
|
0, 0, x, x,
|
|
x, x, 0, x,
|
|
x, x, 0, 0,
|
|
},
|
|
},
|
|
{
|
|
name: "SecondBatch",
|
|
in: []float32{5, 6, 7},
|
|
inShape: []int{1, 1, 3},
|
|
seqs: []int{0, 0, 0},
|
|
pos: []int32{4, 5, 6},
|
|
expected: []float32{1, 2, 3, 4, 5, 6, 7},
|
|
expectedShape: []int{1, 1, 7},
|
|
expectedMask: []float32{
|
|
x, x, x, x, 0, x, x,
|
|
x, x, x, x, 0, 0, x,
|
|
x, x, x, x, x, x, 0,
|
|
},
|
|
},
|
|
{
|
|
name: "ThirdBatch",
|
|
in: []float32{8, 9},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{0, 0},
|
|
pos: []int32{7, 8},
|
|
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
|
expectedShape: []int{1, 1, 9},
|
|
expectedMask: []float32{
|
|
x, x, x, x, x, x, 0, 0, x,
|
|
x, x, x, x, x, x, x, x, 0,
|
|
},
|
|
},
|
|
},
|
|
)
|
|
})
|
|
}
|
|
|
|
func TestSequences(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
cache := NewCausalCache(nil)
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "FirstBatch",
|
|
in: []float32{1, 2, 3, 4},
|
|
inShape: []int{1, 1, 4},
|
|
seqs: []int{0, 0, 1, 1},
|
|
pos: []int32{0, 1, 0, 1},
|
|
expected: []float32{1, 2, 3, 4},
|
|
expectedShape: []int{1, 1, 4},
|
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
|
},
|
|
{
|
|
name: "SecondBatch",
|
|
in: []float32{5, 6},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{0, 1},
|
|
pos: []int32{2, 2},
|
|
expected: []float32{1, 2, 3, 4, 5, 6},
|
|
expectedShape: []int{1, 1, 6},
|
|
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
})
|
|
}
|
|
|
|
func TestRemove(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
|
return key.Add(ctx, shift), nil
|
|
})
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
|
|
|
x := float32(math.Inf(-1))
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "FirstBatch",
|
|
in: []float32{1, 2, 3, 4},
|
|
inShape: []int{1, 1, 4},
|
|
seqs: []int{0, 0, 1, 1},
|
|
pos: []int32{0, 1, 0, 1},
|
|
expected: []float32{1, 2, 3, 4},
|
|
expectedShape: []int{1, 1, 4},
|
|
expectedMask: []float32{
|
|
0, x, x, x,
|
|
0, 0, x, x,
|
|
x, x, 0, x,
|
|
x, x, 0, 0,
|
|
},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
|
|
err := cache.Remove(0, 1, math.MaxInt32)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
tests = []testCase{
|
|
{
|
|
name: "RemoveEnd",
|
|
in: []float32{5, 6},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{0, 1},
|
|
pos: []int32{1, 2},
|
|
expected: []float32{1, 5, 3, 4, 6},
|
|
expectedShape: []int{1, 1, 5},
|
|
expectedMask: []float32{
|
|
0, 0, x, x, x,
|
|
x, x, 0, 0, 0,
|
|
},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
|
|
err = cache.Remove(0, 0, 1)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
tests = []testCase{
|
|
{
|
|
name: "RemoveMiddle",
|
|
in: []float32{7, 8},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{0, 0},
|
|
pos: []int32{1, 2},
|
|
expected: []float32{7, 4, 3, 4, 6, 8},
|
|
expectedShape: []int{1, 1, 6},
|
|
expectedMask: []float32{
|
|
0, 0, x, x, x, x,
|
|
0, 0, x, x, x, 0,
|
|
},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
})
|
|
}
|
|
|
|
func TestCopy(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "FirstBatch",
|
|
in: []float32{1, 2, 3, 4},
|
|
inShape: []int{1, 1, 4},
|
|
seqs: []int{0, 0, 0, 0},
|
|
pos: []int32{0, 1, 2, 3},
|
|
expected: []float32{1, 2, 3, 4},
|
|
expectedShape: []int{1, 1, 4},
|
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
|
|
cache.CopyPrefix(0, 1, 2)
|
|
|
|
tests = []testCase{
|
|
{
|
|
name: "Copy",
|
|
in: []float32{5, 6},
|
|
inShape: []int{1, 1, 2},
|
|
seqs: []int{1, 1},
|
|
pos: []int32{3, 4},
|
|
expected: []float32{1, 2, 3, 4, 5, 6},
|
|
expectedShape: []int{1, 1, 6},
|
|
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
|
},
|
|
}
|
|
|
|
testCache(t, backend, cache, tests)
|
|
})
|
|
}
|
|
|
|
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
context := backend.NewContext()
|
|
defer context.Close()
|
|
|
|
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
cache.SetLayer(0)
|
|
tensor := context.FromFloats(test.in, test.inShape...)
|
|
cache.Put(context, tensor, tensor)
|
|
|
|
out, _, mask := cache.Get(context)
|
|
|
|
context.Forward(out, mask).Compute(out, mask)
|
|
|
|
if !slices.Equal(out.Floats(), test.expected) {
|
|
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
|
|
}
|
|
|
|
if !slices.Equal(out.Shape(), test.expectedShape) {
|
|
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
|
|
}
|
|
|
|
if !slices.Equal(mask.Floats(), test.expectedMask) {
|
|
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCanResume(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
windowSize := int32(4)
|
|
cache := NewSWACache(windowSize, nil)
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
|
|
|
context := backend.NewContext()
|
|
defer context.Close()
|
|
|
|
err := cache.StartForward(context, input.Batch{
|
|
Positions: []int32{0, 1, 2, 3, 4},
|
|
Sequences: []int{0, 0, 0, 0, 0},
|
|
}, false)
|
|
if err != nil {
|
|
t.Fatalf("StartForward failed: %v", err)
|
|
}
|
|
|
|
cache.SetLayer(0)
|
|
tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
|
|
cache.Put(context, tensor, tensor)
|
|
|
|
// with window size 4, nothing has slid out of the window yet
|
|
if !cache.CanResume(0, 0) {
|
|
t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
|
}
|
|
if !cache.CanResume(0, 1) {
|
|
t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
|
}
|
|
if !cache.CanResume(0, 2) {
|
|
t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
|
}
|
|
if !cache.CanResume(0, 3) {
|
|
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
|
}
|
|
if !cache.CanResume(0, 4) {
|
|
t.Errorf("CanResume(0, 4) = false, want true (latest position)")
|
|
}
|
|
|
|
// shift window by adding position 5
|
|
err = cache.StartForward(context, input.Batch{
|
|
Positions: []int32{5},
|
|
Sequences: []int{0},
|
|
}, false)
|
|
if err != nil {
|
|
t.Fatalf("StartForward failed: %v", err)
|
|
}
|
|
|
|
cache.SetLayer(0)
|
|
tensor = context.FromFloats([]float32{6}, 1, 1, 1)
|
|
cache.Put(context, tensor, tensor)
|
|
|
|
// only the latest position has overlapping windows
|
|
if cache.CanResume(0, 0) {
|
|
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
|
}
|
|
if cache.CanResume(0, 1) {
|
|
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
|
}
|
|
if cache.CanResume(0, 2) {
|
|
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
|
}
|
|
if cache.CanResume(0, 3) {
|
|
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
|
}
|
|
if cache.CanResume(0, 4) {
|
|
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
|
}
|
|
if !cache.CanResume(0, 5) {
|
|
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCanResumeSWAMem(t *testing.T) {
|
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
|
windowSize := int32(4)
|
|
memSize := int32(5)
|
|
cache := NewSWAMemCache(windowSize, memSize, nil)
|
|
defer cache.Close()
|
|
|
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
|
|
|
context := backend.NewContext()
|
|
defer context.Close()
|
|
|
|
err := cache.StartForward(context, input.Batch{
|
|
Positions: []int32{0, 1, 2, 3, 4, 5, 6},
|
|
Sequences: []int{0, 0, 0, 0, 0, 0, 0},
|
|
}, false)
|
|
if err != nil {
|
|
t.Fatalf("StartForward failed: %v", err)
|
|
}
|
|
|
|
cache.SetLayer(0)
|
|
tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
|
|
cache.Put(context, tensor, tensor)
|
|
|
|
// shift window by adding position 7
|
|
err = cache.StartForward(context, input.Batch{
|
|
Positions: []int32{7},
|
|
Sequences: []int{0},
|
|
}, false)
|
|
if err != nil {
|
|
t.Fatalf("StartForward failed: %v", err)
|
|
}
|
|
|
|
cache.SetLayer(0)
|
|
tensor = context.FromFloats([]float32{8}, 1, 1, 1)
|
|
cache.Put(context, tensor, tensor)
|
|
|
|
// only the latest position has overlapping windows
|
|
if cache.CanResume(0, 0) {
|
|
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
|
}
|
|
if cache.CanResume(0, 1) {
|
|
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
|
}
|
|
if cache.CanResume(0, 2) {
|
|
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
|
}
|
|
if cache.CanResume(0, 3) {
|
|
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
|
}
|
|
if cache.CanResume(0, 4) {
|
|
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
|
}
|
|
if cache.CanResume(0, 5) {
|
|
t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
|
|
}
|
|
if !cache.CanResume(0, 6) {
|
|
t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
|
|
}
|
|
if !cache.CanResume(0, 7) {
|
|
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
|
}
|
|
})
|
|
}
|
|
|
|
type testBackend struct {
|
|
ml.Backend
|
|
permutedV bool
|
|
}
|
|
|
|
func (b *testBackend) NewContext() ml.Context {
|
|
return &testContext{}
|
|
}
|
|
|
|
func (b *testBackend) NewContextSize(int) ml.Context {
|
|
return &testContext{}
|
|
}
|
|
|
|
func (b *testBackend) CacheConfig() ml.CacheConfig {
|
|
return ml.CacheConfig{PermutedV: b.permutedV}
|
|
}
|
|
|
|
type testContext struct {
|
|
ml.Context
|
|
}
|
|
|
|
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
|
total := 0
|
|
|
|
if len(shape) > 0 {
|
|
total = 1
|
|
for _, s := range shape {
|
|
total *= s
|
|
}
|
|
}
|
|
|
|
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
|
}
|
|
|
|
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
|
return c.Empty(dtype, shape...)
|
|
}
|
|
|
|
func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
|
|
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
|
|
|
copy(t.data, s)
|
|
|
|
return t
|
|
}
|
|
|
|
func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
|
|
f := make([]float32, len(s))
|
|
for i := range f {
|
|
f[i] = float32(s[i])
|
|
}
|
|
|
|
out := c.FromFloats(f, shape...)
|
|
out.(*testTensor).dtype = ml.DTypeI32
|
|
|
|
return out
|
|
}
|
|
|
|
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
|
s := make([]float32, 0, int((stop-start)/step))
|
|
for i := start; i < stop; i += step {
|
|
s = append(s, i)
|
|
}
|
|
|
|
out := c.FromFloats(s, len(s))
|
|
out.(*testTensor).dtype = dtype
|
|
return out
|
|
}
|
|
|
|
func (c *testContext) Input() ml.Context { return c }
|
|
func (c *testContext) Layer(int) ml.Context { return c }
|
|
|
|
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
|
|
|
func (c *testContext) Compute(...ml.Tensor) {}
|
|
|
|
func (c *testContext) Reserve() {}
|
|
|
|
func (c *testContext) MaxGraphNodes() int {
|
|
return 10
|
|
}
|
|
|
|
func (c *testContext) Close() {}
|
|
|
|
type testTensor struct {
|
|
ml.Tensor
|
|
|
|
dtype ml.DType
|
|
elementSize int
|
|
data []float32
|
|
shape []int
|
|
}
|
|
|
|
func (t *testTensor) Dim(n int) int {
|
|
return t.shape[n]
|
|
}
|
|
|
|
func (t *testTensor) Stride(n int) int {
|
|
stride := t.elementSize
|
|
for i := range n {
|
|
stride *= t.shape[i]
|
|
}
|
|
|
|
return stride
|
|
}
|
|
|
|
func (t *testTensor) Shape() []int {
|
|
return t.shape
|
|
}
|
|
|
|
func (t *testTensor) DType() ml.DType {
|
|
return t.dtype
|
|
}
|
|
|
|
func (t *testTensor) Floats() []float32 {
|
|
out := make([]float32, len(t.data))
|
|
copy(out, t.data)
|
|
return out
|
|
}
|
|
|
|
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
|
|
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
|
for i := range out.data {
|
|
out.data[i] = -t.data[i]
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
|
|
|
for i := range out.data {
|
|
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
|
return &testTensor{
|
|
dtype: t.dtype,
|
|
elementSize: t.elementSize,
|
|
data: t.data,
|
|
shape: shape,
|
|
}
|
|
}
|
|
|
|
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|
offset /= t.elementSize
|
|
|
|
var s []int
|
|
|
|
switch len(shape) {
|
|
case 1:
|
|
s = []int{shape[0]}
|
|
case 3:
|
|
s = []int{shape[0], shape[2]}
|
|
case 5:
|
|
s = []int{shape[0], shape[2], shape[4]}
|
|
default:
|
|
panic("unsupported number of dimensions")
|
|
}
|
|
|
|
context := &testContext{}
|
|
|
|
view := context.Empty(t.dtype, s...).(*testTensor)
|
|
view.data = t.data[offset : offset+len(view.data)]
|
|
|
|
return view
|
|
}
|
|
|
|
func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
|
if len(t.shape) > 4 || len(order) > 4 {
|
|
panic("permute only supports up to 4 dimensions")
|
|
}
|
|
|
|
if len(order) != len(t.shape) && len(order) != 4 {
|
|
panic("invalid number of dimensions for permute")
|
|
}
|
|
|
|
// ggml_permute expects 4 axes, so fill in any missing dimensions.
|
|
orderFull := append(make([]int, 0, 4), order...)
|
|
for len(orderFull) < 4 {
|
|
orderFull = append(orderFull, len(orderFull))
|
|
}
|
|
|
|
seen := [4]bool{}
|
|
|
|
shape4 := [4]int{1, 1, 1, 1}
|
|
for i := 0; i < len(t.shape) && i < 4; i++ {
|
|
shape4[i] = t.shape[i]
|
|
}
|
|
|
|
newShape4 := [4]int{1, 1, 1, 1}
|
|
for axis := range 4 {
|
|
dst := orderFull[axis]
|
|
if dst < 0 || dst >= 4 {
|
|
panic("invalid axis for permute")
|
|
}
|
|
if seen[dst] {
|
|
panic("duplicate axis for permute")
|
|
}
|
|
seen[dst] = true
|
|
newShape4[dst] = shape4[axis]
|
|
}
|
|
|
|
total := len(t.data)
|
|
newData := make([]float32, total)
|
|
|
|
if total > 0 {
|
|
oldDims := shape4
|
|
newDims := newShape4
|
|
|
|
oldStride := [4]int{1, 1, 1, 1}
|
|
newStride := [4]int{1, 1, 1, 1}
|
|
for i := 1; i < 4; i++ {
|
|
oldStride[i] = oldStride[i-1] * oldDims[i-1]
|
|
newStride[i] = newStride[i-1] * newDims[i-1]
|
|
}
|
|
|
|
var coords [4]int
|
|
var newCoords [4]int
|
|
|
|
for idx := range total {
|
|
remainder := idx
|
|
for axis := range 4 {
|
|
dim := oldDims[axis]
|
|
if dim == 0 {
|
|
coords[axis] = 0
|
|
continue
|
|
}
|
|
coords[axis] = remainder % dim
|
|
remainder /= dim
|
|
}
|
|
|
|
for axis := range 4 {
|
|
newCoords[orderFull[axis]] = coords[axis]
|
|
}
|
|
|
|
newIndex := 0
|
|
for axis := range 4 {
|
|
if newDims[axis] == 0 {
|
|
continue
|
|
}
|
|
newIndex += newCoords[axis] * newStride[axis]
|
|
}
|
|
|
|
newData[newIndex] = t.data[idx]
|
|
}
|
|
}
|
|
|
|
numDims := 4
|
|
for numDims > 1 && newShape4[numDims-1] <= 1 {
|
|
numDims--
|
|
}
|
|
|
|
newShape := make([]int, numDims)
|
|
copy(newShape, newShape4[:numDims])
|
|
|
|
return &testTensor{
|
|
dtype: t.dtype,
|
|
elementSize: t.elementSize,
|
|
data: newData,
|
|
shape: newShape,
|
|
}
|
|
}
|
|
|
|
func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
|
dst := t
|
|
srcTensor := src.(*testTensor)
|
|
idxTensor := idxs.(*testTensor)
|
|
|
|
shapeTo4D := func(shape []int) [4]int {
|
|
out := [4]int{1, 1, 1, 1}
|
|
for i := 0; i < len(shape) && i < 4; i++ {
|
|
out[i] = shape[i]
|
|
}
|
|
return out
|
|
}
|
|
|
|
computeStrides := func(shape [4]int) [4]int {
|
|
out := [4]int{1, 1, 1, 1}
|
|
for i := 1; i < 4; i++ {
|
|
out[i] = out[i-1] * shape[i-1]
|
|
}
|
|
return out
|
|
}
|
|
|
|
dstShape4D := shapeTo4D(dst.shape)
|
|
srcShape4D := shapeTo4D(srcTensor.shape)
|
|
idxShape4D := shapeTo4D(idxTensor.shape)
|
|
|
|
if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
|
|
panic("SetRows requires matching tensor shapes")
|
|
}
|
|
|
|
if srcShape4D[1] != idxShape4D[0] {
|
|
panic("SetRows rows/index mismatch")
|
|
}
|
|
|
|
if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
|
|
panic("SetRows cannot broadcast indices")
|
|
}
|
|
|
|
if idxShape4D[3] != 1 {
|
|
panic("SetRows expects 1D or 2D index tensors")
|
|
}
|
|
|
|
dstStride := computeStrides(dstShape4D)
|
|
srcStride := computeStrides(srcShape4D)
|
|
idxStride := computeStrides(idxShape4D)
|
|
|
|
numColumns := srcShape4D[0]
|
|
numRows := srcShape4D[1]
|
|
|
|
for dim3Index := range dstShape4D[3] {
|
|
for dim2Index := range dstShape4D[2] {
|
|
idxDim2 := 0
|
|
idxDim3 := 0
|
|
if idxShape4D[1] > 0 {
|
|
idxDim2 = dim2Index % idxShape4D[1]
|
|
}
|
|
if idxShape4D[2] > 0 {
|
|
idxDim3 = dim3Index % idxShape4D[2]
|
|
}
|
|
|
|
idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
|
|
srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
|
|
dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
|
|
|
|
for row := range numRows {
|
|
idx := int(idxTensor.data[idxBase+row*idxStride[0]])
|
|
if idx < 0 || idx >= dstShape4D[1] {
|
|
panic("SetRows index out of range")
|
|
}
|
|
|
|
srcOffset := srcBase + row*srcStride[1]
|
|
dstOffset := dstBase + idx*dstStride[1]
|
|
|
|
copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
|
|
}
|
|
}
|
|
}
|
|
|
|
return dst
|
|
}
|
|
|
|
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|
copy(t2.(*testTensor).data, t.data)
|
|
return nil
|
|
}
|