diff --git a/kvcache/causal.go b/kvcache/causal.go index d804f3bf0..15a4cdea4 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -3,7 +3,6 @@ package kvcache import ( "errors" "fmt" - "log/slog" "math" "slices" @@ -40,18 +39,18 @@ type Causal struct { // ** current forward pass ** - // the active layer for Get and Put - curLayer int - - // starting location for data storage for this batch - curLoc int - // size of the current batch curBatchSize int + // locations for data storage for this batch + curLoc ml.Tensor + // mask of the cache as used by this batch curMask ml.Tensor + // the active layer for Get and Put + curLayer int + // locations in the cache that are needed for this batch curCellRange cellRange @@ -206,45 +205,47 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e c.curPositions = batch.Positions c.opts.Except = nil + var locs []int32 if !reserve { c.updateSlidingWindow() var err error - c.curLoc, err = c.findStartLoc() - if errors.Is(err, ErrKvCacheFull) { - c.defrag() - c.curLoc, err = c.findStartLoc() - } + locs, err = c.findLocs() if err != nil { return err } for i, pos := range batch.Positions { seq := batch.Sequences[i] + loc := int(locs[i]) - c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} + c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}} seqRange, ok := c.cellRanges[seq] if !ok { seqRange = newRange() } - seqRange.min = min(seqRange.min, c.curLoc+i) - c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i) + seqRange.min = min(seqRange.min, loc) + c.curCellRange.min = min(c.curCellRange.min, loc) - seqRange.max = max(seqRange.max, c.curLoc+i) - c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i) + seqRange.max = max(seqRange.max, loc) + c.curCellRange.max = max(c.curCellRange.max, loc) c.cellRanges[seq] = seqRange } } else { // If we are reserving memory, don't update any of the cache metadata but set the size // to the worst case. - c.curLoc = 0 + locs = make([]int32, c.curBatchSize) + for i := range locs { + locs[i] = int32(i) + } c.curCellRange.min = 0 c.curCellRange.max = len(c.cells) - 1 } + c.curLoc = ctx.Input().FromInts(locs, len(locs)) c.curMask = c.buildMask(ctx) return nil @@ -257,22 +258,20 @@ func newRange() cellRange { } } -// Find the first contiguous block of at least curBatchSize -func (c *Causal) findStartLoc() (int, error) { - var start, count int +// Returns a slice of locations where each token in the batch should be stored +func (c *Causal) findLocs() ([]int32, error) { + loc := make([]int32, 0, c.curBatchSize) + for i := range c.cells { if len(c.cells[i].sequences) == 0 { - count++ - if count >= c.curBatchSize { - return start, nil + loc = append(loc, int32(i)) + if len(loc) >= c.curBatchSize { + return loc, nil } - } else { - start = i + 1 - count = 0 } } - return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize) + return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize) } func (c *Causal) updateSlidingWindow() { @@ -402,145 +401,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { return maskTensor } -func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { - for i, key := range c.keys { - if key == nil { - continue - } - - kHeadDim := key.Dim(0) - numKVHeads := key.Dim(1) - rowSize := key.Stride(2) - - kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length) - kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length) - - value := c.values[i] - var vSrcView, vDstView ml.Tensor - if c.config.PermutedV { - vHeadDim := value.Dim(1) - elemSize := value.Stride(0) - - vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads) - vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads) - } else { - vHeadDim := value.Dim(0) - rowSize := value.Stride(2) - - vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length) - vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length) - } - - ctx.Forward( - kSrcView.Copy(ctx, kDstView), - vSrcView.Copy(ctx, vDstView), - ) - } -} - -func (c *Causal) defrag() { - slog.Debug("defragmenting kv cache") - - // Defrag strategy: - // - Search for empty holes at the beginning of the cache, - // filling them with active data starting at the end - // - If there are contiguous elements that need to be moved, - // combine them into a single operation by holding new moves - // until we see that the next one is non-contiguous - // - Fill up the context with the maximum number of operations it - // can hold then compute that and continue with a new context - // - // We could try to optimize placement by grouping blocks from - // the same sequences together but most likely the next forward - // pass will disrupt this anyways, so the real world benefit - // seems limited as this time. - - ctx := c.backend.NewContext() - - // For every move, 6 tensors are required per layer (2 views and a - // copy for each of k and v). We also need to refer to the original - // k and v cache tensors - once per layer, not per move. - layers := 0 - for _, key := range c.keys { - if key == nil { - continue - } - layers++ - } - - maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers) - moves := 0 - - var pendingSrc, pendingDst, pendingLen int - src := len(c.cells) - 1 - - for dst := 0; dst < src; dst++ { - if len(c.cells[dst].sequences) == 0 { - for ; src > dst; src-- { - if len(c.cells[src].sequences) != 0 { - c.cells[dst] = c.cells[src] - c.cells[src] = cacheCell{} - - if pendingLen > 0 { - if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen { - pendingSrc = src - pendingLen++ - break - } else { - c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) - moves++ - } - } - - pendingSrc = src - pendingDst = dst - pendingLen = 1 - - break - } - } - } - - if moves >= maxMoves { - ctx.Compute() - ctx.Close() - ctx = c.backend.NewContext() - - moves = 0 - } - } - - if pendingLen > 0 { - c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) - moves++ - } - - if moves > 0 { - ctx.Compute() - } - ctx.Close() - - // Reset range metadata - for seq := range c.cellRanges { - seqRange := newRange() - - for i, cell := range c.cells { - if slices.Contains(cell.sequences, seq) { - if i < seqRange.min { - seqRange.min = i - } - if i > seqRange.max { - seqRange.max = i - } - } - } - - c.cellRanges[seq] = seqRange - } - - c.updateSlidingWindow() -} - func (c *Causal) SetLayer(layer int) { c.curLayer = layer } @@ -625,18 +485,25 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { } } - rowSize := c.keys[c.curLayer].Stride(2) - ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize))) + key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize) + keyCache := c.keys[c.curLayer] + keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells)) + ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc)) if c.config.PermutedV { - elemSize := c.values[c.curLayer].Stride(0) + value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize) + value = value.Permute(ctx, 2, 0, 1, 3) - value = value.Permute(ctx, 1, 2, 0, 3) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads))) + valueCache := c.values[c.curLayer] + valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads) + + ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc)) } else { - rowSize := c.values[c.curLayer].Stride(2) + value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize) + valueCache := c.values[c.curLayer] + valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells)) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize))) + ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc)) } } diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index dd0c04427..4f441d0fd 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -207,11 +207,11 @@ func TestSWAMem(t *testing.T) { inShape: []int{1, 1, 2}, seqs: []int{0, 0}, pos: []int32{4, 5}, - expected: []float32{4, 5, 6}, - expectedShape: []int{1, 1, 3}, + expected: []float32{5, 2, 3, 4, 6}, + expectedShape: []int{1, 1, 5}, expectedMask: []float32{ - 0, 0, x, - x, 0, 0, + 0, x, x, 0, x, + 0, x, x, x, 0, }, }, } @@ -319,6 +319,8 @@ func TestRemove(t *testing.T) { cache.Init(backend, ml.DTypeF16, 1, 16, 16) + x := float32(math.Inf(-1)) + tests := []testCase{ { name: "FirstBatch", @@ -328,7 +330,12 @@ func TestRemove(t *testing.T) { pos: []int32{0, 1, 0, 1}, expected: []float32{1, 2, 3, 4}, expectedShape: []int{1, 1, 4}, - expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + expectedMask: []float32{ + 0, x, x, x, + 0, 0, x, x, + x, x, 0, x, + x, x, 0, 0, + }, }, } @@ -346,9 +353,12 @@ func TestRemove(t *testing.T) { inShape: []int{1, 1, 2}, seqs: []int{0, 1}, pos: []int32{1, 2}, - expected: []float32{1, 2, 3, 4, 5, 6}, - expectedShape: []int{1, 1, 6}, - expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, + expected: []float32{1, 5, 3, 4, 6}, + expectedShape: []int{1, 1, 5}, + expectedMask: []float32{ + 0, 0, x, x, x, + x, x, 0, 0, 0, + }, }, } @@ -366,59 +376,12 @@ func TestRemove(t *testing.T) { inShape: []int{1, 1, 2}, seqs: []int{0, 0}, pos: []int32{1, 2}, - expected: []float32{7, 8, 3, 4, 4}, - expectedShape: []int{1, 1, 5}, - expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0}, - }, - } - - testCache(t, backend, cache, tests) -} - -func TestDefrag(t *testing.T) { - backend := &testBackend{} - cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.Add(ctx, shift), nil - }) - defer cache.Close() - - cache.Init(backend, ml.DTypeF16, 1, 16, 16) - - tests := []testCase{ - { - name: "FirstBatch", - in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - inShape: []int{1, 1, 16}, - seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - expectedShape: []int{1, 1, 16}, - expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - }, - } - - testCache(t, backend, cache, tests) - - err := cache.Remove(0, 2, 4) - if err != nil { - panic(err) - } - - err = cache.Remove(0, 13, math.MaxInt32) - if err != nil { - panic(err) - } - - tests = []testCase{ - { - name: "Defrag", - in: []float32{17, 18, 19}, - inShape: []int{1, 1, 3}, - seqs: []int{0, 0, 0}, - pos: []int32{16, 17, 18}, - expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19}, - expectedShape: []int{1, 1, 16}, - expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + expected: []float32{7, 4, 3, 4, 6, 8}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{ + 0, 0, x, x, x, x, + 0, 0, x, x, x, 0, + }, }, } @@ -770,6 +733,15 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return out } +func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { + return &testTensor{ + dtype: t.dtype, + elementSize: t.elementSize, + data: t.data, + shape: shape, + } +} + func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { offset /= t.elementSize @@ -778,6 +750,8 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { switch len(shape) { case 1: s = []int{shape[0]} + case 3: + s = []int{shape[0], shape[2]} case 5: s = []int{shape[0], shape[2], shape[4]} default: @@ -792,6 +766,86 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { return view } +func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor { + dst := t + srcTensor := src.(*testTensor) + idxTensor := idxs.(*testTensor) + + shapeTo4D := func(shape []int) [4]int { + out := [4]int{1, 1, 1, 1} + for i := 0; i < len(shape) && i < 4; i++ { + out[i] = shape[i] + } + return out + } + + computeStrides := func(shape [4]int) [4]int { + out := [4]int{1, 1, 1, 1} + for i := 1; i < 4; i++ { + out[i] = out[i-1] * shape[i-1] + } + return out + } + + dstShape4D := shapeTo4D(dst.shape) + srcShape4D := shapeTo4D(srcTensor.shape) + idxShape4D := shapeTo4D(idxTensor.shape) + + if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] { + panic("SetRows requires matching tensor shapes") + } + + if srcShape4D[1] != idxShape4D[0] { + panic("SetRows rows/index mismatch") + } + + if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 { + panic("SetRows cannot broadcast indices") + } + + if idxShape4D[3] != 1 { + panic("SetRows expects 1D or 2D index tensors") + } + + dstStride := computeStrides(dstShape4D) + srcStride := computeStrides(srcShape4D) + idxStride := computeStrides(idxShape4D) + + numColumns := srcShape4D[0] + numRows := srcShape4D[1] + + for dim3Index := range dstShape4D[3] { + for dim2Index := range dstShape4D[2] { + idxDim2 := 0 + idxDim3 := 0 + if idxShape4D[1] > 0 { + idxDim2 = dim2Index % idxShape4D[1] + } + if idxShape4D[2] > 0 { + idxDim3 = dim3Index % idxShape4D[2] + } + + idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1] + srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2] + dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2] + + for row := range numRows { + idx := int(idxTensor.data[idxBase+row*idxStride[0]]) + if idx < 0 || idx >= dstShape4D[1] { + panic("SetRows index out of range") + } + + srcOffset := srcBase + row*srcStride[1] + dstOffset := dstBase + idx*dstStride[1] + + copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns]) + } + } + } + + return dst +} + func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { copy(t2.(*testTensor).data, t.data) return nil diff --git a/ml/backend.go b/ml/backend.go index bf2c5851f..4d930fe43 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -194,6 +194,7 @@ type Tensor interface { Repeat(ctx Context, dim, n int) Tensor Concat(ctx Context, t2 Tensor, dim int) Tensor Rows(ctx Context, t2 Tensor) Tensor + SetRows(ctx Context, src Tensor, idxs Tensor) Tensor Copy(ctx Context, t2 Tensor) Tensor Duplicate(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 8d0415cfa..520d95cb0 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1338,6 +1338,13 @@ func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } +func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_set_rows(ctx.(*Context).ctx, t.t, src.(*Tensor).t, idxs.(*Tensor).t), + } +} + func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ b: t.b,