kvcache: Use SetRows to store cache data

We currently copy data into the KV cache in contiguous buffers using
ggml_cpy(). ggml_set_rows() was introduced to allow scatter operation
so that contiguous buffers are no longer required. The direct primary
benefit of this is that we no longer need to perform defragmentation.

However, GGML recently removed an optimization for ggml_cpy() and
we picked it up in 544b673 "ggml update to b6840 (#12791)". This
caused a roughly 40% drop in token generation performance on CUDA
due to CUDA graphs no longer being used. By switching to
ggml_set_rows(), the original optimization is no longer necessary
and CUDA performance is restored.

Fixes #13112
This commit is contained in:
Jesse Gross 2025-08-18 10:45:58 -07:00 committed by Jesse Gross
parent b6e02cbbd2
commit 53985b3c4d
4 changed files with 164 additions and 235 deletions

View File

@ -3,7 +3,6 @@ package kvcache
import ( import (
"errors" "errors"
"fmt" "fmt"
"log/slog"
"math" "math"
"slices" "slices"
@ -40,18 +39,18 @@ type Causal struct {
// ** current forward pass ** // ** current forward pass **
// the active layer for Get and Put
curLayer int
// starting location for data storage for this batch
curLoc int
// size of the current batch // size of the current batch
curBatchSize int curBatchSize int
// locations for data storage for this batch
curLoc ml.Tensor
// mask of the cache as used by this batch // mask of the cache as used by this batch
curMask ml.Tensor curMask ml.Tensor
// the active layer for Get and Put
curLayer int
// locations in the cache that are needed for this batch // locations in the cache that are needed for this batch
curCellRange cellRange curCellRange cellRange
@ -206,45 +205,47 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curPositions = batch.Positions c.curPositions = batch.Positions
c.opts.Except = nil c.opts.Except = nil
var locs []int32
if !reserve { if !reserve {
c.updateSlidingWindow() c.updateSlidingWindow()
var err error var err error
c.curLoc, err = c.findStartLoc() locs, err = c.findLocs()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil { if err != nil {
return err return err
} }
for i, pos := range batch.Positions { for i, pos := range batch.Positions {
seq := batch.Sequences[i] seq := batch.Sequences[i]
loc := int(locs[i])
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq] seqRange, ok := c.cellRanges[seq]
if !ok { if !ok {
seqRange = newRange() seqRange = newRange()
} }
seqRange.min = min(seqRange.min, c.curLoc+i) seqRange.min = min(seqRange.min, loc)
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i) c.curCellRange.min = min(c.curCellRange.min, loc)
seqRange.max = max(seqRange.max, c.curLoc+i) seqRange.max = max(seqRange.max, loc)
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i) c.curCellRange.max = max(c.curCellRange.max, loc)
c.cellRanges[seq] = seqRange c.cellRanges[seq] = seqRange
} }
} else { } else {
// If we are reserving memory, don't update any of the cache metadata but set the size // If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case. // to the worst case.
c.curLoc = 0 locs = make([]int32, c.curBatchSize)
for i := range locs {
locs[i] = int32(i)
}
c.curCellRange.min = 0 c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1 c.curCellRange.max = len(c.cells) - 1
} }
c.curLoc = ctx.Input().FromInts(locs, len(locs))
c.curMask = c.buildMask(ctx) c.curMask = c.buildMask(ctx)
return nil return nil
@ -257,22 +258,20 @@ func newRange() cellRange {
} }
} }
// Find the first contiguous block of at least curBatchSize // Returns a slice of locations where each token in the batch should be stored
func (c *Causal) findStartLoc() (int, error) { func (c *Causal) findLocs() ([]int32, error) {
var start, count int loc := make([]int32, 0, c.curBatchSize)
for i := range c.cells { for i := range c.cells {
if len(c.cells[i].sequences) == 0 { if len(c.cells[i].sequences) == 0 {
count++ loc = append(loc, int32(i))
if count >= c.curBatchSize { if len(loc) >= c.curBatchSize {
return start, nil return loc, nil
} }
} else {
start = i + 1
count = 0
} }
} }
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize) return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
} }
func (c *Causal) updateSlidingWindow() { func (c *Causal) updateSlidingWindow() {
@ -402,145 +401,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
return maskTensor return maskTensor
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys {
if key == nil {
continue
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
func (c *Causal) defrag() {
slog.Debug("defragmenting kv cache")
// Defrag strategy:
// - Search for empty holes at the beginning of the cache,
// filling them with active data starting at the end
// - If there are contiguous elements that need to be moved,
// combine them into a single operation by holding new moves
// until we see that the next one is non-contiguous
// - Fill up the context with the maximum number of operations it
// can hold then compute that and continue with a new context
//
// We could try to optimize placement by grouping blocks from
// the same sequences together but most likely the next forward
// pass will disrupt this anyways, so the real world benefit
// seems limited as this time.
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0
for _, key := range c.keys {
if key == nil {
continue
}
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
src := len(c.cells) - 1
for dst := 0; dst < src; dst++ {
if len(c.cells[dst].sequences) == 0 {
for ; src > dst; src-- {
if len(c.cells[src].sequences) != 0 {
c.cells[dst] = c.cells[src]
c.cells[src] = cacheCell{}
if pendingLen > 0 {
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
pendingSrc = src
pendingLen++
break
} else {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
pendingSrc = src
pendingDst = dst
pendingLen = 1
break
}
}
}
if moves >= maxMoves {
ctx.Compute()
ctx.Close()
ctx = c.backend.NewContext()
moves = 0
}
}
if pendingLen > 0 {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
if moves > 0 {
ctx.Compute()
}
ctx.Close()
// Reset range metadata
for seq := range c.cellRanges {
seqRange := newRange()
for i, cell := range c.cells {
if slices.Contains(cell.sequences, seq) {
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[seq] = seqRange
}
c.updateSlidingWindow()
}
func (c *Causal) SetLayer(layer int) { func (c *Causal) SetLayer(layer int) {
c.curLayer = layer c.curLayer = layer
} }
@ -625,18 +485,25 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
} }
} }
rowSize := c.keys[c.curLayer].Stride(2) key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize))) keyCache := c.keys[c.curLayer]
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
if c.config.PermutedV { if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0) value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
value = value.Permute(ctx, 2, 0, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3) valueCache := c.values[c.curLayer]
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads))) valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
} else { } else {
rowSize := c.values[c.curLayer].Stride(2) value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize))) ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
} }
} }

View File

@ -207,11 +207,11 @@ func TestSWAMem(t *testing.T) {
inShape: []int{1, 1, 2}, inShape: []int{1, 1, 2},
seqs: []int{0, 0}, seqs: []int{0, 0},
pos: []int32{4, 5}, pos: []int32{4, 5},
expected: []float32{4, 5, 6}, expected: []float32{5, 2, 3, 4, 6},
expectedShape: []int{1, 1, 3}, expectedShape: []int{1, 1, 5},
expectedMask: []float32{ expectedMask: []float32{
0, 0, x, 0, x, x, 0, x,
x, 0, 0, 0, x, x, x, 0,
}, },
}, },
} }
@ -319,6 +319,8 @@ func TestRemove(t *testing.T) {
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{ tests := []testCase{
{ {
name: "FirstBatch", name: "FirstBatch",
@ -328,7 +330,12 @@ func TestRemove(t *testing.T) {
pos: []int32{0, 1, 0, 1}, pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4}, expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4}, expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
}, },
} }
@ -346,9 +353,12 @@ func TestRemove(t *testing.T) {
inShape: []int{1, 1, 2}, inShape: []int{1, 1, 2},
seqs: []int{0, 1}, seqs: []int{0, 1},
pos: []int32{1, 2}, pos: []int32{1, 2},
expected: []float32{1, 2, 3, 4, 5, 6}, expected: []float32{1, 5, 3, 4, 6},
expectedShape: []int{1, 1, 6}, expectedShape: []int{1, 1, 5},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, expectedMask: []float32{
0, 0, x, x, x,
x, x, 0, 0, 0,
},
}, },
} }
@ -366,59 +376,12 @@ func TestRemove(t *testing.T) {
inShape: []int{1, 1, 2}, inShape: []int{1, 1, 2},
seqs: []int{0, 0}, seqs: []int{0, 0},
pos: []int32{1, 2}, pos: []int32{1, 2},
expected: []float32{7, 8, 3, 4, 4}, expected: []float32{7, 4, 3, 4, 6, 8},
expectedShape: []int{1, 1, 5}, expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0}, expectedMask: []float32{
}, 0, 0, x, x, x, x,
} 0, 0, x, x, x, 0,
},
testCache(t, backend, cache, tests)
}
func TestDefrag(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.Add(ctx, shift), nil
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
inShape: []int{1, 1, 16},
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
expectedShape: []int{1, 1, 16},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
err := cache.Remove(0, 2, 4)
if err != nil {
panic(err)
}
err = cache.Remove(0, 13, math.MaxInt32)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "Defrag",
in: []float32{17, 18, 19},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{16, 17, 18},
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
expectedShape: []int{1, 1, 16},
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
}, },
} }
@ -770,6 +733,15 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out return out
} }
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
return &testTensor{
dtype: t.dtype,
elementSize: t.elementSize,
data: t.data,
shape: shape,
}
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize offset /= t.elementSize
@ -778,6 +750,8 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) { switch len(shape) {
case 1: case 1:
s = []int{shape[0]} s = []int{shape[0]}
case 3:
s = []int{shape[0], shape[2]}
case 5: case 5:
s = []int{shape[0], shape[2], shape[4]} s = []int{shape[0], shape[2], shape[4]}
default: default:
@ -792,6 +766,86 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
return view return view
} }
func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
dst := t
srcTensor := src.(*testTensor)
idxTensor := idxs.(*testTensor)
shapeTo4D := func(shape []int) [4]int {
out := [4]int{1, 1, 1, 1}
for i := 0; i < len(shape) && i < 4; i++ {
out[i] = shape[i]
}
return out
}
computeStrides := func(shape [4]int) [4]int {
out := [4]int{1, 1, 1, 1}
for i := 1; i < 4; i++ {
out[i] = out[i-1] * shape[i-1]
}
return out
}
dstShape4D := shapeTo4D(dst.shape)
srcShape4D := shapeTo4D(srcTensor.shape)
idxShape4D := shapeTo4D(idxTensor.shape)
if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
panic("SetRows requires matching tensor shapes")
}
if srcShape4D[1] != idxShape4D[0] {
panic("SetRows rows/index mismatch")
}
if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
panic("SetRows cannot broadcast indices")
}
if idxShape4D[3] != 1 {
panic("SetRows expects 1D or 2D index tensors")
}
dstStride := computeStrides(dstShape4D)
srcStride := computeStrides(srcShape4D)
idxStride := computeStrides(idxShape4D)
numColumns := srcShape4D[0]
numRows := srcShape4D[1]
for dim3Index := range dstShape4D[3] {
for dim2Index := range dstShape4D[2] {
idxDim2 := 0
idxDim3 := 0
if idxShape4D[1] > 0 {
idxDim2 = dim2Index % idxShape4D[1]
}
if idxShape4D[2] > 0 {
idxDim3 = dim3Index % idxShape4D[2]
}
idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
for row := range numRows {
idx := int(idxTensor.data[idxBase+row*idxStride[0]])
if idx < 0 || idx >= dstShape4D[1] {
panic("SetRows index out of range")
}
srcOffset := srcBase + row*srcStride[1]
dstOffset := dstBase + idx*dstStride[1]
copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
}
}
}
return dst
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data) copy(t2.(*testTensor).data, t.data)
return nil return nil

View File

@ -194,6 +194,7 @@ type Tensor interface {
Repeat(ctx Context, dim, n int) Tensor Repeat(ctx Context, dim, n int) Tensor
Concat(ctx Context, t2 Tensor, dim int) Tensor Concat(ctx Context, t2 Tensor, dim int) Tensor
Rows(ctx Context, t2 Tensor) Tensor Rows(ctx Context, t2 Tensor) Tensor
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
Copy(ctx Context, t2 Tensor) Tensor Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor Duplicate(ctx Context) Tensor

View File

@ -1338,6 +1338,13 @@ func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
} }
} }
func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_set_rows(ctx.(*Context).ctx, t.t, src.(*Tensor).t, idxs.(*Tensor).t),
}
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,