mirror of https://github.com/ollama/ollama
review comments
This commit is contained in:
parent
6a43120754
commit
d7cda4f5ac
|
|
@ -360,15 +360,12 @@ func roundUp(length, pad int) int {
|
||||||
// token in the history should apply. This is based on both the sequence and causality (the
|
// token in the history should apply. This is based on both the sequence and causality (the
|
||||||
// position of the history is not ahead of the token in the batch).
|
// position of the history is not ahead of the token in the batch).
|
||||||
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||||
// Align and pad the two dimensions as required by the backend
|
|
||||||
batchSize := c.curBatchSize
|
|
||||||
|
|
||||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||||
|
|
||||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||||
|
|
||||||
mask := make([]float32, batchSize*length)
|
mask := make([]float32, c.curBatchSize*length)
|
||||||
|
|
||||||
for i := range c.curBatchSize {
|
for i := range c.curBatchSize {
|
||||||
enabled := !slices.Contains(c.opts.Except, i)
|
enabled := !slices.Contains(c.opts.Except, i)
|
||||||
|
|
@ -382,13 +379,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize)
|
||||||
// has already been masked out because the sequence doesn't match.
|
|
||||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
|
||||||
mask[i] = float32(math.Inf(-1))
|
|
||||||
}
|
|
||||||
|
|
||||||
maskTensor := ctx.Input().FromFloats(mask, length, batchSize)
|
|
||||||
|
|
||||||
if c.config.MaskDType != ml.DTypeF32 {
|
if c.config.MaskDType != ml.DTypeF32 {
|
||||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue