package kvcache import ( "fmt" "math" "slices" "testing" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model/input" ) type testCase struct { name string in []float32 inShape []int seqs []int pos []int32 expected []float32 expectedShape []int expectedMask []float32 } func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) { t.Helper() for _, permuted := range []bool{false, true} { t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) { fn(t, &testBackend{permutedV: permuted}) }) } } func TestStore(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { cache := NewCausalCache(nil) defer cache.Close() cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { name: "FirstBatch", in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, inShape: []int{2, 3, 4}, seqs: []int{0, 0, 0, 0}, pos: []int32{0, 1, 2, 3}, expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, expectedShape: []int{2, 3, 4}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, }, { name: "SecondBatch", in: []float32{115, 215, 125, 225, 135, 235}, inShape: []int{2, 3, 1}, seqs: []int{0}, pos: []int32{4}, expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235}, expectedShape: []int{2, 3, 5}, expectedMask: []float32{0, 0, 0, 0, 0}, }, } testCache(t, backend, cache, tests) }) } func TestSWA(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { cache := NewSWACache(1, nil) defer cache.Close() cache.Init(backend, ml.DTypeF16, 1, 16, 16) x := float32(math.Inf(-1)) tests := []testCase{ { name: "FirstBatch", in: []float32{1, 2, 3, 4}, inShape: []int{1, 1, 4}, seqs: []int{0, 0, 0, 0}, pos: []int32{0, 1, 2, 3}, expected: []float32{1, 2, 3, 4}, expectedShape: []int{1, 1, 4}, expectedMask: []float32{ 0, x, x, x, 0, 0, x, x, x, 0, 0, x, x, x, 0, 0, }, }, { name: "SecondBatch", in: []float32{5, 6}, inShape: []int{1, 1, 2}, seqs: []int{0, 0}, pos: []int32{4, 5}, expected: []float32{5, 6, 3, 4}, expectedShape: []int{1, 1, 4}, expectedMask: []float32{ 0, x, x, 0, 0, 0, x, x, }, }, } testCache(t, backend, cache, tests) }) } func TestSWASeparateBatches(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { cache := NewSWACache(1, nil) defer cache.Close() cache.Init(backend, ml.DTypeF16, 2, 16, 2) x := float32(math.Inf(-1)) tests := []testCase{ { name: "First seq 0", in: []float32{1, 2}, inShape: []int{1, 1, 2}, seqs: []int{0, 0}, pos: []int32{0, 1}, expected: []float32{1, 2}, expectedShape: []int{1, 1, 2}, expectedMask: []float32{ 0, x, 0, 0, }, }, { name: "Second seq 0", in: []float32{3, 4}, inShape: []int{1, 1, 2}, seqs: []int{0, 0}, pos: []int32{2, 3}, expected: []float32{2, 3, 4}, expectedShape: []int{1, 1, 3}, expectedMask: []float32{ 0, 0, x, x, 0, 0, }, }, { name: "First seq 1", in: []float32{5, 6}, inShape: []int{1, 1, 2}, seqs: []int{1, 1}, pos: []int32{0, 1}, expected: []float32{5, 6}, expectedShape: []int{1, 1, 2}, expectedMask: []float32{ 0, x, 0, 0, }, }, { name: "Second seq 1", in: []float32{7, 8}, inShape: []int{1, 1, 2}, seqs: []int{1, 1}, pos: []int32{2, 3}, expected: []float32{6, 3, 4, 7, 8}, expectedShape: []int{1, 1, 5}, expectedMask: []float32{ 0, x, x, 0, x, x, x, x, 0, 0, }, }, { name: "Third seq 0", in: []float32{9, 10}, inShape: []int{1, 1, 2}, seqs: []int{0, 0}, pos: []int32{4, 5}, expected: []float32{9, 10, 3, 4}, expectedShape: []int{1, 1, 4}, expectedMask: []float32{ 0, x, x, 0, 0, 0, x, x, }, }, } testCache(t, backend, cache, tests) }) } func TestSWAMem(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { cache := NewSWAMemCache(1, 3, nil) defer cache.Close() cache.Init(backend, ml.DTypeF16, 1, 16, 16) x := float32(math.Inf(-1)) tests := []testCase{ { name: "FirstBatch", in: []float32{1, 2, 3, 4}, inShape: []int{1, 1, 4}, seqs: []int{0, 0, 0, 0}, pos: []int32{0, 1, 2, 3}, expected: []float32{1, 2, 3, 4}, expectedShape: []int{1, 1, 4}, expectedMask: []float32{ 0, x, x, x, 0, 0, x, x, x, 0, 0, x, x, x, 0, 0, }, }, { name: "SecondBatch", in: []float32{5, 6}, inShape: []int{1, 1, 2}, seqs: []int{0, 0}, pos: []int32{4, 5}, expected: []float32{5, 2, 3, 4, 6}, expectedShape: []int{1, 1, 5}, expectedMask: []float32{ 0, x, x, 0, x, 0, x, x, x, 0, }, }, } testCache(t, backend, cache, tests) }) } func TestChunkedAttention(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { cache := NewChunkedAttentionCache(2, nil) defer cache.Close() cache.Init(backend, ml.DTypeF16, 1, 16, 16) x := float32(math.Inf(-1)) testCache( t, backend, cache, []testCase{ { name: "FirstBatch", in: []float32{1, 2, 3, 4}, inShape: []int{1, 1, 4}, seqs: []int{0, 0, 0, 0}, pos: []int32{0, 1, 2, 3}, expected: []float32{1, 2, 3, 4}, expectedShape: []int{1, 1, 4}, expectedMask: []float32{ 0, x, x, x, 0, 0, x, x, x, x, 0, x, x, x, 0, 0, }, }, { name: "SecondBatch", in: []float32{5, 6, 7}, inShape: []int{1, 1, 3}, seqs: []int{0, 0, 0}, pos: []int32{4, 5, 6}, expected: []float32{1, 2, 3, 4, 5, 6, 7}, expectedShape: []int{1, 1, 7}, expectedMask: []float32{ x, x, x, x, 0, x, x, x, x, x, x, 0, 0, x, x, x, x, x, x, x, 0, }, }, { name: "ThirdBatch", in: []float32{8, 9}, inShape: []int{1, 1, 2}, seqs: []int{0, 0}, pos: []int32{7, 8}, expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, expectedShape: []int{1, 1, 9}, expectedMask: []float32{ x, x, x, x, x, x, 0, 0, x, x, x, x, x, x, x, x, x, 0, }, }, }, ) }) } func TestSequences(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { cache := NewCausalCache(nil) defer cache.Close() cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { name: "FirstBatch", in: []float32{1, 2, 3, 4}, inShape: []int{1, 1, 4}, seqs: []int{0, 0, 1, 1}, pos: []int32{0, 1, 0, 1}, expected: []float32{1, 2, 3, 4}, expectedShape: []int{1, 1, 4}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, }, { name: "SecondBatch", in: []float32{5, 6}, inShape: []int{1, 1, 2}, seqs: []int{0, 1}, pos: []int32{2, 2}, expected: []float32{1, 2, 3, 4, 5, 6}, expectedShape: []int{1, 1, 6}, expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, }, } testCache(t, backend, cache, tests) }) } func TestRemove(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key.Add(ctx, shift), nil }) defer cache.Close() cache.Init(backend, ml.DTypeF16, 1, 16, 16) x := float32(math.Inf(-1)) tests := []testCase{ { name: "FirstBatch", in: []float32{1, 2, 3, 4}, inShape: []int{1, 1, 4}, seqs: []int{0, 0, 1, 1}, pos: []int32{0, 1, 0, 1}, expected: []float32{1, 2, 3, 4}, expectedShape: []int{1, 1, 4}, expectedMask: []float32{ 0, x, x, x, 0, 0, x, x, x, x, 0, x, x, x, 0, 0, }, }, } testCache(t, backend, cache, tests) err := cache.Remove(0, 1, math.MaxInt32) if err != nil { panic(err) } tests = []testCase{ { name: "RemoveEnd", in: []float32{5, 6}, inShape: []int{1, 1, 2}, seqs: []int{0, 1}, pos: []int32{1, 2}, expected: []float32{1, 5, 3, 4, 6}, expectedShape: []int{1, 1, 5}, expectedMask: []float32{ 0, 0, x, x, x, x, x, 0, 0, 0, }, }, } testCache(t, backend, cache, tests) err = cache.Remove(0, 0, 1) if err != nil { panic(err) } tests = []testCase{ { name: "RemoveMiddle", in: []float32{7, 8}, inShape: []int{1, 1, 2}, seqs: []int{0, 0}, pos: []int32{1, 2}, expected: []float32{7, 4, 3, 4, 6, 8}, expectedShape: []int{1, 1, 6}, expectedMask: []float32{ 0, 0, x, x, x, x, 0, 0, x, x, x, 0, }, }, } testCache(t, backend, cache, tests) }) } func TestCopy(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) defer cache.Close() cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { name: "FirstBatch", in: []float32{1, 2, 3, 4}, inShape: []int{1, 1, 4}, seqs: []int{0, 0, 0, 0}, pos: []int32{0, 1, 2, 3}, expected: []float32{1, 2, 3, 4}, expectedShape: []int{1, 1, 4}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, }, } testCache(t, backend, cache, tests) cache.CopyPrefix(0, 1, 2) tests = []testCase{ { name: "Copy", in: []float32{5, 6}, inShape: []int{1, 1, 2}, seqs: []int{1, 1}, pos: []int32{3, 4}, expected: []float32{1, 2, 3, 4, 5, 6}, expectedShape: []int{1, 1, 6}, expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, }, } testCache(t, backend, cache, tests) }) } func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { context := backend.NewContext() defer context.Close() err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false) if err != nil { panic(err) } cache.SetLayer(0) tensor := context.FromFloats(test.in, test.inShape...) cache.Put(context, tensor, tensor) out, _, mask := cache.Get(context) context.Forward(out, mask).Compute(out, mask) if !slices.Equal(out.Floats(), test.expected) { t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected) } if !slices.Equal(out.Shape(), test.expectedShape) { t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape) } if !slices.Equal(mask.Floats(), test.expectedMask) { t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask) } }) } } func TestCanResume(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { windowSize := int32(4) cache := NewSWACache(windowSize, nil) defer cache.Close() cache.Init(backend, ml.DTypeF16, 1, 16, 16) context := backend.NewContext() defer context.Close() err := cache.StartForward(context, input.Batch{ Positions: []int32{0, 1, 2, 3, 4}, Sequences: []int{0, 0, 0, 0, 0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5) cache.Put(context, tensor, tensor) // with window size 4, nothing has slid out of the window yet if !cache.CanResume(0, 0) { t.Errorf("CanResume(0, 0) = false, want true (within window)") } if !cache.CanResume(0, 1) { t.Errorf("CanResume(0, 1) = false, want true (within window)") } if !cache.CanResume(0, 2) { t.Errorf("CanResume(0, 2) = false, want true (within window)") } if !cache.CanResume(0, 3) { t.Errorf("CanResume(0, 3) = false, want true (latest position)") } if !cache.CanResume(0, 4) { t.Errorf("CanResume(0, 4) = false, want true (latest position)") } // shift window by adding position 5 err = cache.StartForward(context, input.Batch{ Positions: []int32{5}, Sequences: []int{0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) tensor = context.FromFloats([]float32{6}, 1, 1, 1) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows if cache.CanResume(0, 0) { t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") } if cache.CanResume(0, 1) { t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") } if cache.CanResume(0, 2) { t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") } if cache.CanResume(0, 3) { t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") } if cache.CanResume(0, 4) { t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") } if !cache.CanResume(0, 5) { t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") } }) } func TestCanResumeSWAMem(t *testing.T) { runPermutedVariants(t, func(t *testing.T, backend *testBackend) { windowSize := int32(4) memSize := int32(5) cache := NewSWAMemCache(windowSize, memSize, nil) defer cache.Close() cache.Init(backend, ml.DTypeF16, 1, 16, 16) context := backend.NewContext() defer context.Close() err := cache.StartForward(context, input.Batch{ Positions: []int32{0, 1, 2, 3, 4, 5, 6}, Sequences: []int{0, 0, 0, 0, 0, 0, 0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) cache.Put(context, tensor, tensor) // shift window by adding position 7 err = cache.StartForward(context, input.Batch{ Positions: []int32{7}, Sequences: []int{0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) tensor = context.FromFloats([]float32{8}, 1, 1, 1) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows if cache.CanResume(0, 0) { t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") } if cache.CanResume(0, 1) { t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") } if cache.CanResume(0, 2) { t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") } if cache.CanResume(0, 3) { t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") } if cache.CanResume(0, 4) { t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") } if cache.CanResume(0, 5) { t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)") } if !cache.CanResume(0, 6) { t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)") } if !cache.CanResume(0, 7) { t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)") } }) } type testBackend struct { ml.Backend permutedV bool } func (b *testBackend) NewContext() ml.Context { return &testContext{} } func (b *testBackend) NewContextSize(int) ml.Context { return &testContext{} } func (b *testBackend) CacheConfig() ml.CacheConfig { return ml.CacheConfig{PermutedV: b.permutedV} } type testContext struct { ml.Context } func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { total := 0 if len(shape) > 0 { total = 1 for _, s := range shape { total *= s } } return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} } func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return c.Empty(dtype, shape...) } func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor { t := c.Empty(ml.DTypeF32, shape...).(*testTensor) copy(t.data, s) return t } func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor { f := make([]float32, len(s)) for i := range f { f[i] = float32(s[i]) } out := c.FromFloats(f, shape...) out.(*testTensor).dtype = ml.DTypeI32 return out } func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { s := make([]float32, 0, int((stop-start)/step)) for i := start; i < stop; i += step { s = append(s, i) } out := c.FromFloats(s, len(s)) out.(*testTensor).dtype = dtype return out } func (c *testContext) Input() ml.Context { return c } func (c *testContext) Layer(int) ml.Context { return c } func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Compute(...ml.Tensor) {} func (c *testContext) Reserve() {} func (c *testContext) MaxGraphNodes() int { return 10 } func (c *testContext) Close() {} type testTensor struct { ml.Tensor dtype ml.DType elementSize int data []float32 shape []int } func (t *testTensor) Dim(n int) int { return t.shape[n] } func (t *testTensor) Stride(n int) int { stride := t.elementSize for i := range n { stride *= t.shape[i] } return stride } func (t *testTensor) Shape() []int { return t.shape } func (t *testTensor) DType() ml.DType { return t.dtype } func (t *testTensor) Floats() []float32 { out := make([]float32, len(t.data)) copy(out, t.data) return out } func (t *testTensor) Neg(ctx ml.Context) ml.Tensor { out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) for i := range out.data { out.data[i] = -t.data[i] } return out } func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) for i := range out.data { out.data[i] = t.data[i] + t2.(*testTensor).data[i] } return out } func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { return &testTensor{ dtype: t.dtype, elementSize: t.elementSize, data: t.data, shape: shape, } } func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { offset /= t.elementSize var s []int switch len(shape) { case 1: s = []int{shape[0]} case 3: s = []int{shape[0], shape[2]} case 5: s = []int{shape[0], shape[2], shape[4]} default: panic("unsupported number of dimensions") } context := &testContext{} view := context.Empty(t.dtype, s...).(*testTensor) view.data = t.data[offset : offset+len(view.data)] return view } func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor { if len(t.shape) > 4 || len(order) > 4 { panic("permute only supports up to 4 dimensions") } if len(order) != len(t.shape) && len(order) != 4 { panic("invalid number of dimensions for permute") } // ggml_permute expects 4 axes, so fill in any missing dimensions. orderFull := append(make([]int, 0, 4), order...) for len(orderFull) < 4 { orderFull = append(orderFull, len(orderFull)) } seen := [4]bool{} shape4 := [4]int{1, 1, 1, 1} for i := 0; i < len(t.shape) && i < 4; i++ { shape4[i] = t.shape[i] } newShape4 := [4]int{1, 1, 1, 1} for axis := range 4 { dst := orderFull[axis] if dst < 0 || dst >= 4 { panic("invalid axis for permute") } if seen[dst] { panic("duplicate axis for permute") } seen[dst] = true newShape4[dst] = shape4[axis] } total := len(t.data) newData := make([]float32, total) if total > 0 { oldDims := shape4 newDims := newShape4 oldStride := [4]int{1, 1, 1, 1} newStride := [4]int{1, 1, 1, 1} for i := 1; i < 4; i++ { oldStride[i] = oldStride[i-1] * oldDims[i-1] newStride[i] = newStride[i-1] * newDims[i-1] } var coords [4]int var newCoords [4]int for idx := range total { remainder := idx for axis := range 4 { dim := oldDims[axis] if dim == 0 { coords[axis] = 0 continue } coords[axis] = remainder % dim remainder /= dim } for axis := range 4 { newCoords[orderFull[axis]] = coords[axis] } newIndex := 0 for axis := range 4 { if newDims[axis] == 0 { continue } newIndex += newCoords[axis] * newStride[axis] } newData[newIndex] = t.data[idx] } } numDims := 4 for numDims > 1 && newShape4[numDims-1] <= 1 { numDims-- } newShape := make([]int, numDims) copy(newShape, newShape4[:numDims]) return &testTensor{ dtype: t.dtype, elementSize: t.elementSize, data: newData, shape: newShape, } } func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor { dst := t srcTensor := src.(*testTensor) idxTensor := idxs.(*testTensor) shapeTo4D := func(shape []int) [4]int { out := [4]int{1, 1, 1, 1} for i := 0; i < len(shape) && i < 4; i++ { out[i] = shape[i] } return out } computeStrides := func(shape [4]int) [4]int { out := [4]int{1, 1, 1, 1} for i := 1; i < 4; i++ { out[i] = out[i-1] * shape[i-1] } return out } dstShape4D := shapeTo4D(dst.shape) srcShape4D := shapeTo4D(srcTensor.shape) idxShape4D := shapeTo4D(idxTensor.shape) if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] { panic("SetRows requires matching tensor shapes") } if srcShape4D[1] != idxShape4D[0] { panic("SetRows rows/index mismatch") } if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 { panic("SetRows cannot broadcast indices") } if idxShape4D[3] != 1 { panic("SetRows expects 1D or 2D index tensors") } dstStride := computeStrides(dstShape4D) srcStride := computeStrides(srcShape4D) idxStride := computeStrides(idxShape4D) numColumns := srcShape4D[0] numRows := srcShape4D[1] for dim3Index := range dstShape4D[3] { for dim2Index := range dstShape4D[2] { idxDim2 := 0 idxDim3 := 0 if idxShape4D[1] > 0 { idxDim2 = dim2Index % idxShape4D[1] } if idxShape4D[2] > 0 { idxDim3 = dim3Index % idxShape4D[2] } idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1] srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2] dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2] for row := range numRows { idx := int(idxTensor.data[idxBase+row*idxStride[0]]) if idx < 0 || idx >= dstShape4D[1] { panic("SetRows index out of range") } srcOffset := srcBase + row*srcStride[1] dstOffset := dstBase + idx*dstStride[1] copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns]) } } } return dst } func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { copy(t2.(*testTensor).data, t.data) return nil }