model: default gemma 3 rope scale to 1.0, apply corrections based on layer counts (#13453)

This commit is contained in:
Jeffrey Morgan 2025-12-12 17:51:56 -08:00 committed by GitHub
parent 1b308e1d2a
commit 4ff8a691bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 34 additions and 26 deletions

View File

@ -2,7 +2,6 @@ package gemma3
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@ -13,11 +12,12 @@ import (
)
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
hiddenSize, contextLength, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase float32
largeModelScaling bool
slidingWindow uint32
slidingWindowPattern []bool
ropeBase float32
ropeType string
@ -55,6 +55,9 @@ type TextModel struct {
const (
gemmaGlobalCacheCount = 6
gemma1BLayerCount = 26
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62
)
@ -70,6 +73,7 @@ func newTextModel(c fs.Config) *TextModel {
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
contextLength: int(c.Uint("context_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
@ -77,28 +81,32 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeBase: c.Float("rope.freq_base", 1000000.0),
slidingWindow: c.Uint("attention.sliding_window"),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
ropeType: c.String("rope.scaling.type"),
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeScale: c.Float("rope.scaling.factor", 8.0),
ropeScale: c.Float("rope.scaling.factor", 1.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
},
}
// Google's Gemma 3 release with sliding window attention does
// not use final logit softcapping, and so force it to 0.0
// The QAT weights for Gemma 3 also included an incorrect
// value for the rope scale, so we need to set it to 1.0 here.
// TODO (jmorganca): this should ideally be set to 0.0 in the
// model configuration instead of here, as future versions of
// models may include both sliding window attention and final
// logit softcapping.
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
// Apply corrections for older versions of the Gemma 3 models
// by looking at whether they use sliding window attention and
// based on their layer counts.
if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) {
switch numBlocks {
case gemma1BLayerCount:
// The 1B model has final logit softcapping set to 30.0
// but it should be 0.0
m.TextConfig.finalLogitSoftcap = 0.0
m.TextConfig.ropeScale = 1.0
case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount:
// The 4B, 12B, and 27B models have rope scale unset
// but it shuold be set to 8.0
m.TextConfig.ropeScale = 8.0
}
}
if numBlocks == gemma27BLayerCount {