diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 01eca1c50..36106107b 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -29,24 +29,13 @@ type TextOptions struct { func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { var ropeOpts []func(*rope.Options) if o.ropeType == "yarn" { - getMscale := func(scale, mscale float64) float64 { - if scale <= 1.0 { - return 1.0 - } - return 0.1*mscale*math.Log(scale) + 1.0 - } - - var attnFactor float32 if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 { - attnFactor = float32(getMscale(float64(o.ropeScale), float64(o.ropeMscale)) / getMscale(float64(o.ropeScale), float64(o.ropeMscaleAllDim))) - } else { - attnFactor = float32(getMscale(float64(o.ropeScale), 1)) + ropeOpts = append(ropeOpts, rope.WithAttentionFactor(1.0/float32(0.1*math.Log(float64(o.ropeScale))+1.0))) } ropeOpts = append(ropeOpts, rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings), rope.WithExtrapolationFactor(o.ropeExtrapolation), - rope.WithAttentionFactor(attnFactor), rope.WithBetaFast(o.ropeBetaFast), rope.WithBetaSlow(o.ropeBetaSlow), )