model: default gemma 3 rope scale to 1.0, apply corrections based on layer counts (#13453)

This commit is contained in:
Jeffrey Morgan 2025-12-12 17:51:56 -08:00 committed by GitHub
parent 1b308e1d2a
commit 4ff8a691bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 34 additions and 26 deletions

View File

@ -2,7 +2,6 @@ package gemma3
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@ -13,19 +12,20 @@ import (
)
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase float32
largeModelScaling bool
slidingWindowPattern []bool
ropeBase float32
ropeType string
ropeOriginalContext int
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
finalLogitSoftcap float32
hiddenSize, contextLength, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase float32
largeModelScaling bool
slidingWindow uint32
slidingWindowPattern []bool
ropeBase float32
ropeType string
ropeOriginalContext int
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
finalLogitSoftcap float32
}
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
@ -55,6 +55,9 @@ type TextModel struct {
const (
gemmaGlobalCacheCount = 6
gemma1BLayerCount = 26
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62
)
@ -70,6 +73,7 @@ func newTextModel(c fs.Config) *TextModel {
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
contextLength: int(c.Uint("context_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
@ -77,28 +81,32 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeBase: c.Float("rope.freq_base", 1000000.0),
slidingWindow: c.Uint("attention.sliding_window"),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
ropeType: c.String("rope.scaling.type"),
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeScale: c.Float("rope.scaling.factor", 8.0),
ropeScale: c.Float("rope.scaling.factor", 1.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
},
}
// Google's Gemma 3 release with sliding window attention does
// not use final logit softcapping, and so force it to 0.0
// The QAT weights for Gemma 3 also included an incorrect
// value for the rope scale, so we need to set it to 1.0 here.
// TODO (jmorganca): this should ideally be set to 0.0 in the
// model configuration instead of here, as future versions of
// models may include both sliding window attention and final
// logit softcapping.
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
m.TextConfig.finalLogitSoftcap = 0.0
m.TextConfig.ropeScale = 1.0
// Apply corrections for older versions of the Gemma 3 models
// by looking at whether they use sliding window attention and
// based on their layer counts.
if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) {
switch numBlocks {
case gemma1BLayerCount:
// The 1B model has final logit softcapping set to 30.0
// but it should be 0.0
m.TextConfig.finalLogitSoftcap = 0.0
case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount:
// The 4B, 12B, and 27B models have rope scale unset
// but it shuold be set to 8.0
m.TextConfig.ropeScale = 8.0
}
}
if numBlocks == gemma27BLayerCount {