mirror of https://github.com/ollama/ollama
model: default gemma 3 rope scale to 1.0, apply corrections based on layer counts (#13453)
This commit is contained in:
parent
1b308e1d2a
commit
4ff8a691bc
|
|
@ -2,7 +2,6 @@ package gemma3
|
|||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
|
|
@ -13,19 +12,20 @@ import (
|
|||
)
|
||||
|
||||
type TextConfig struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase float32
|
||||
largeModelScaling bool
|
||||
slidingWindowPattern []bool
|
||||
ropeBase float32
|
||||
ropeType string
|
||||
ropeOriginalContext int
|
||||
ropeExtrapolation float32
|
||||
ropeBetaFast float32
|
||||
ropeBetaSlow float32
|
||||
finalLogitSoftcap float32
|
||||
hiddenSize, contextLength, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase float32
|
||||
largeModelScaling bool
|
||||
slidingWindow uint32
|
||||
slidingWindowPattern []bool
|
||||
ropeBase float32
|
||||
ropeType string
|
||||
ropeOriginalContext int
|
||||
ropeExtrapolation float32
|
||||
ropeBetaFast float32
|
||||
ropeBetaSlow float32
|
||||
finalLogitSoftcap float32
|
||||
}
|
||||
|
||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
|
||||
|
|
@ -55,6 +55,9 @@ type TextModel struct {
|
|||
|
||||
const (
|
||||
gemmaGlobalCacheCount = 6
|
||||
gemma1BLayerCount = 26
|
||||
gemma4BLayerCount = 34
|
||||
gemma12BLayerCount = 48
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
|
|
@ -70,6 +73,7 @@ func newTextModel(c fs.Config) *TextModel {
|
|||
Layers: make([]TextLayer, numBlocks),
|
||||
TextConfig: &TextConfig{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
contextLength: int(c.Uint("context_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
|
|
@ -77,28 +81,32 @@ func newTextModel(c fs.Config) *TextModel {
|
|||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeBase: c.Float("rope.freq_base", 1000000.0),
|
||||
slidingWindow: c.Uint("attention.sliding_window"),
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
|
||||
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
||||
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
|
||||
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 8.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||
},
|
||||
}
|
||||
|
||||
// Google's Gemma 3 release with sliding window attention does
|
||||
// not use final logit softcapping, and so force it to 0.0
|
||||
// The QAT weights for Gemma 3 also included an incorrect
|
||||
// value for the rope scale, so we need to set it to 1.0 here.
|
||||
// TODO (jmorganca): this should ideally be set to 0.0 in the
|
||||
// model configuration instead of here, as future versions of
|
||||
// models may include both sliding window attention and final
|
||||
// logit softcapping.
|
||||
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
|
||||
m.TextConfig.finalLogitSoftcap = 0.0
|
||||
m.TextConfig.ropeScale = 1.0
|
||||
// Apply corrections for older versions of the Gemma 3 models
|
||||
// by looking at whether they use sliding window attention and
|
||||
// based on their layer counts.
|
||||
if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) {
|
||||
switch numBlocks {
|
||||
case gemma1BLayerCount:
|
||||
// The 1B model has final logit softcapping set to 30.0
|
||||
// but it should be 0.0
|
||||
m.TextConfig.finalLogitSoftcap = 0.0
|
||||
case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount:
|
||||
// The 4B, 12B, and 27B models have rope scale unset
|
||||
// but it shuold be set to 8.0
|
||||
m.TextConfig.ropeScale = 8.0
|
||||
}
|
||||
}
|
||||
|
||||
if numBlocks == gemma27BLayerCount {
|
||||
|
|
|
|||
Loading…
Reference in New Issue