diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 759cc6b37..e1c0004d9 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -2,7 +2,6 @@ package gemma3 import ( "math" - "slices" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -13,19 +12,20 @@ import ( ) type TextConfig struct { - hiddenSize, numHeads, numKVHeads int - attnKeyLen, attnValLen int - eps, ropeScale float32 - ropeLocalBase float32 - largeModelScaling bool - slidingWindowPattern []bool - ropeBase float32 - ropeType string - ropeOriginalContext int - ropeExtrapolation float32 - ropeBetaFast float32 - ropeBetaSlow float32 - finalLogitSoftcap float32 + hiddenSize, contextLength, numHeads, numKVHeads int + attnKeyLen, attnValLen int + eps, ropeScale float32 + ropeLocalBase float32 + largeModelScaling bool + slidingWindow uint32 + slidingWindowPattern []bool + ropeBase float32 + ropeType string + ropeOriginalContext int + ropeExtrapolation float32 + ropeBetaFast float32 + ropeBetaSlow float32 + finalLogitSoftcap float32 } func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor { @@ -55,6 +55,9 @@ type TextModel struct { const ( gemmaGlobalCacheCount = 6 + gemma1BLayerCount = 26 + gemma4BLayerCount = 34 + gemma12BLayerCount = 48 gemma27BLayerCount = 62 ) @@ -70,6 +73,7 @@ func newTextModel(c fs.Config) *TextModel { Layers: make([]TextLayer, numBlocks), TextConfig: &TextConfig{ hiddenSize: int(c.Uint("embedding_length")), + contextLength: int(c.Uint("context_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), attnKeyLen: int(c.Uint("attention.key_length", 256)), @@ -77,28 +81,32 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeBase: c.Float("rope.freq_base", 1000000.0), + slidingWindow: c.Uint("attention.sliding_window"), slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), ropeType: c.String("rope.scaling.type"), ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")), ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0), ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), - ropeScale: c.Float("rope.scaling.factor", 8.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0), }, } - // Google's Gemma 3 release with sliding window attention does - // not use final logit softcapping, and so force it to 0.0 - // The QAT weights for Gemma 3 also included an incorrect - // value for the rope scale, so we need to set it to 1.0 here. - // TODO (jmorganca): this should ideally be set to 0.0 in the - // model configuration instead of here, as future versions of - // models may include both sliding window attention and final - // logit softcapping. - if slices.Contains(m.TextConfig.slidingWindowPattern, true) { - m.TextConfig.finalLogitSoftcap = 0.0 - m.TextConfig.ropeScale = 1.0 + // Apply corrections for older versions of the Gemma 3 models + // by looking at whether they use sliding window attention and + // based on their layer counts. + if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) { + switch numBlocks { + case gemma1BLayerCount: + // The 1B model has final logit softcapping set to 30.0 + // but it should be 0.0 + m.TextConfig.finalLogitSoftcap = 0.0 + case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount: + // The 4B, 12B, and 27B models have rope scale unset + // but it shuold be set to 8.0 + m.TextConfig.ropeScale = 8.0 + } } if numBlocks == gemma27BLayerCount {