ollama/kvcache/causal_test.go

974 lines
24 KiB
Go

package kvcache
import (
"fmt"
"math"
"slices"
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
type testCase struct {
name string
in []float32
inShape []int
seqs []int
pos []int32
expected []float32
expectedShape []int
expectedMask []float32
}
func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
t.Helper()
for _, permuted := range []bool{false, true} {
t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
fn(t, &testBackend{permutedV: permuted})
})
}
}
func TestStore(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
inShape: []int{2, 3, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
expectedShape: []int{2, 3, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
{
name: "SecondBatch",
in: []float32{115, 215, 125, 225, 135, 235},
inShape: []int{2, 3, 1},
seqs: []int{0},
pos: []int32{4},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
expectedShape: []int{2, 3, 5},
expectedMask: []float32{0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
})
}
func TestSWA(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, 0, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, 0,
0, 0, x, x,
},
},
}
testCache(t, backend, cache, tests)
})
}
func TestSWASeparateBatches(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 2, 16, 2)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "First seq 0",
in: []float32{1, 2},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{0, 1},
expected: []float32{1, 2},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 0",
in: []float32{3, 4},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{2, 3},
expected: []float32{2, 3, 4},
expectedShape: []int{1, 1, 3},
expectedMask: []float32{
0, 0, x,
x, 0, 0,
},
},
{
name: "First seq 1",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{0, 1},
expected: []float32{5, 6},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 1",
in: []float32{7, 8},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{2, 3},
expected: []float32{6, 3, 4, 7, 8},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, x, x, 0, x,
x, x, x, 0, 0,
},
},
{
name: "Third seq 0",
in: []float32{9, 10},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{9, 10, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, 0,
0, 0, x, x,
},
},
}
testCache(t, backend, cache, tests)
})
}
func TestSWAMem(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewSWAMemCache(1, 3, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, 0, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 2, 3, 4, 6},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, x, x, 0, x,
0, x, x, x, 0,
},
},
}
testCache(t, backend, cache, tests)
})
}
func TestChunkedAttention(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
testCache(
t, backend, cache,
[]testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6, 7},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{4, 5, 6},
expected: []float32{1, 2, 3, 4, 5, 6, 7},
expectedShape: []int{1, 1, 7},
expectedMask: []float32{
x, x, x, x, 0, x, x,
x, x, x, x, 0, 0, x,
x, x, x, x, x, x, 0,
},
},
{
name: "ThirdBatch",
in: []float32{8, 9},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{7, 8},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedShape: []int{1, 1, 9},
expectedMask: []float32{
x, x, x, x, x, x, 0, 0, x,
x, x, x, x, x, x, x, x, 0,
},
},
},
)
})
}
func TestSequences(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 1},
pos: []int32{2, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests)
})
}
func TestRemove(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.Add(ctx, shift), nil
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
}
testCache(t, backend, cache, tests)
err := cache.Remove(0, 1, math.MaxInt32)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "RemoveEnd",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 1},
pos: []int32{1, 2},
expected: []float32{1, 5, 3, 4, 6},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, 0, x, x, x,
x, x, 0, 0, 0,
},
},
}
testCache(t, backend, cache, tests)
err = cache.Remove(0, 0, 1)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "RemoveMiddle",
in: []float32{7, 8},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{1, 2},
expected: []float32{7, 4, 3, 4, 6, 8},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{
0, 0, x, x, x, x,
0, 0, x, x, x, 0,
},
},
}
testCache(t, backend, cache, tests)
})
}
func TestCopy(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
cache.CopyPrefix(0, 1, 2)
tests = []testCase{
{
name: "Copy",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{3, 4},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
testCache(t, backend, cache, tests)
})
}
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
if err != nil {
panic(err)
}
cache.SetLayer(0)
tensor := context.FromFloats(test.in, test.inShape...)
cache.Put(context, tensor, tensor)
out, _, mask := cache.Get(context)
context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) {
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
}
if !slices.Equal(out.Shape(), test.expectedShape) {
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
}
if !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
}
})
}
}
func TestCanResume(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
windowSize := int32(4)
cache := NewSWACache(windowSize, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4},
Sequences: []int{0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
if !cache.CanResume(0, 0) {
t.Errorf("CanResume(0, 0) = false, want true (within window)")
}
if !cache.CanResume(0, 1) {
t.Errorf("CanResume(0, 1) = false, want true (within window)")
}
if !cache.CanResume(0, 2) {
t.Errorf("CanResume(0, 2) = false, want true (within window)")
}
if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
}
if !cache.CanResume(0, 4) {
t.Errorf("CanResume(0, 4) = false, want true (latest position)")
}
// shift window by adding position 5
err = cache.StartForward(context, input.Batch{
Positions: []int32{5},
Sequences: []int{0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloats([]float32{6}, 1, 1, 1)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
}
if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
}
if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
}
if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
}
if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
}
if !cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
}
})
}
func TestCanResumeSWAMem(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
windowSize := int32(4)
memSize := int32(5)
cache := NewSWAMemCache(windowSize, memSize, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4, 5, 6},
Sequences: []int{0, 0, 0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
cache.Put(context, tensor, tensor)
// shift window by adding position 7
err = cache.StartForward(context, input.Batch{
Positions: []int32{7},
Sequences: []int{0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloats([]float32{8}, 1, 1, 1)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
}
if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
}
if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
}
if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
}
if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
}
if cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
}
if !cache.CanResume(0, 6) {
t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
}
if !cache.CanResume(0, 7) {
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
}
})
}
type testBackend struct {
ml.Backend
permutedV bool
}
func (b *testBackend) NewContext() ml.Context {
return &testContext{}
}
func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{}
}
func (b *testBackend) CacheConfig() ml.CacheConfig {
return ml.CacheConfig{PermutedV: b.permutedV}
}
type testContext struct {
ml.Context
}
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
if len(shape) > 0 {
total = 1
for _, s := range shape {
total *= s
}
}
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...)
}
func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s)
return t
}
func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
f := make([]float32, len(s))
for i := range f {
f[i] = float32(s[i])
}
out := c.FromFloats(f, shape...)
out.(*testTensor).dtype = ml.DTypeI32
return out
}
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
s := make([]float32, 0, int((stop-start)/step))
for i := start; i < stop; i += step {
s = append(s, i)
}
out := c.FromFloats(s, len(s))
out.(*testTensor).dtype = dtype
return out
}
func (c *testContext) Input() ml.Context { return c }
func (c *testContext) Layer(int) ml.Context { return c }
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() {}
func (c *testContext) MaxGraphNodes() int {
return 10
}
func (c *testContext) Close() {}
type testTensor struct {
ml.Tensor
dtype ml.DType
elementSize int
data []float32
shape []int
}
func (t *testTensor) Dim(n int) int {
return t.shape[n]
}
func (t *testTensor) Stride(n int) int {
stride := t.elementSize
for i := range n {
stride *= t.shape[i]
}
return stride
}
func (t *testTensor) Shape() []int {
return t.shape
}
func (t *testTensor) DType() ml.DType {
return t.dtype
}
func (t *testTensor) Floats() []float32 {
out := make([]float32, len(t.data))
copy(out, t.data)
return out
}
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = -t.data[i]
}
return out
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
}
return out
}
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
return &testTensor{
dtype: t.dtype,
elementSize: t.elementSize,
data: t.data,
shape: shape,
}
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize
var s []int
switch len(shape) {
case 1:
s = []int{shape[0]}
case 3:
s = []int{shape[0], shape[2]}
case 5:
s = []int{shape[0], shape[2], shape[4]}
default:
panic("unsupported number of dimensions")
}
context := &testContext{}
view := context.Empty(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)]
return view
}
func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
if len(t.shape) > 4 || len(order) > 4 {
panic("permute only supports up to 4 dimensions")
}
if len(order) != len(t.shape) && len(order) != 4 {
panic("invalid number of dimensions for permute")
}
// ggml_permute expects 4 axes, so fill in any missing dimensions.
orderFull := append(make([]int, 0, 4), order...)
for len(orderFull) < 4 {
orderFull = append(orderFull, len(orderFull))
}
seen := [4]bool{}
shape4 := [4]int{1, 1, 1, 1}
for i := 0; i < len(t.shape) && i < 4; i++ {
shape4[i] = t.shape[i]
}
newShape4 := [4]int{1, 1, 1, 1}
for axis := range 4 {
dst := orderFull[axis]
if dst < 0 || dst >= 4 {
panic("invalid axis for permute")
}
if seen[dst] {
panic("duplicate axis for permute")
}
seen[dst] = true
newShape4[dst] = shape4[axis]
}
total := len(t.data)
newData := make([]float32, total)
if total > 0 {
oldDims := shape4
newDims := newShape4
oldStride := [4]int{1, 1, 1, 1}
newStride := [4]int{1, 1, 1, 1}
for i := 1; i < 4; i++ {
oldStride[i] = oldStride[i-1] * oldDims[i-1]
newStride[i] = newStride[i-1] * newDims[i-1]
}
var coords [4]int
var newCoords [4]int
for idx := range total {
remainder := idx
for axis := range 4 {
dim := oldDims[axis]
if dim == 0 {
coords[axis] = 0
continue
}
coords[axis] = remainder % dim
remainder /= dim
}
for axis := range 4 {
newCoords[orderFull[axis]] = coords[axis]
}
newIndex := 0
for axis := range 4 {
if newDims[axis] == 0 {
continue
}
newIndex += newCoords[axis] * newStride[axis]
}
newData[newIndex] = t.data[idx]
}
}
numDims := 4
for numDims > 1 && newShape4[numDims-1] <= 1 {
numDims--
}
newShape := make([]int, numDims)
copy(newShape, newShape4[:numDims])
return &testTensor{
dtype: t.dtype,
elementSize: t.elementSize,
data: newData,
shape: newShape,
}
}
func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
dst := t
srcTensor := src.(*testTensor)
idxTensor := idxs.(*testTensor)
shapeTo4D := func(shape []int) [4]int {
out := [4]int{1, 1, 1, 1}
for i := 0; i < len(shape) && i < 4; i++ {
out[i] = shape[i]
}
return out
}
computeStrides := func(shape [4]int) [4]int {
out := [4]int{1, 1, 1, 1}
for i := 1; i < 4; i++ {
out[i] = out[i-1] * shape[i-1]
}
return out
}
dstShape4D := shapeTo4D(dst.shape)
srcShape4D := shapeTo4D(srcTensor.shape)
idxShape4D := shapeTo4D(idxTensor.shape)
if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
panic("SetRows requires matching tensor shapes")
}
if srcShape4D[1] != idxShape4D[0] {
panic("SetRows rows/index mismatch")
}
if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
panic("SetRows cannot broadcast indices")
}
if idxShape4D[3] != 1 {
panic("SetRows expects 1D or 2D index tensors")
}
dstStride := computeStrides(dstShape4D)
srcStride := computeStrides(srcShape4D)
idxStride := computeStrides(idxShape4D)
numColumns := srcShape4D[0]
numRows := srcShape4D[1]
for dim3Index := range dstShape4D[3] {
for dim2Index := range dstShape4D[2] {
idxDim2 := 0
idxDim3 := 0
if idxShape4D[1] > 0 {
idxDim2 = dim2Index % idxShape4D[1]
}
if idxShape4D[2] > 0 {
idxDim3 = dim3Index % idxShape4D[2]
}
idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
for row := range numRows {
idx := int(idxTensor.data[idxBase+row*idxStride[0]])
if idx < 0 || idx >= dstShape4D[1] {
panic("SetRows index out of range")
}
srcOffset := srcBase + row*srcStride[1]
dstOffset := dstBase + idx*dstStride[1]
copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
}
}
}
return dst
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data)
return nil
}