model: default gemma 3 rope scale to 1.0, apply corrections based on layer counts (#13453)

This commit is contained in:
Jeffrey Morgan 2025-12-12 17:51:56 -08:00 committed by GitHub
parent 1b308e1d2a
commit 4ff8a691bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 34 additions and 26 deletions

View File

@ -2,7 +2,6 @@ package gemma3
import ( import (
"math" "math"
"slices"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
@ -13,19 +12,20 @@ import (
) )
type TextConfig struct { type TextConfig struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, contextLength, numHeads, numKVHeads int
attnKeyLen, attnValLen int attnKeyLen, attnValLen int
eps, ropeScale float32 eps, ropeScale float32
ropeLocalBase float32 ropeLocalBase float32
largeModelScaling bool largeModelScaling bool
slidingWindowPattern []bool slidingWindow uint32
ropeBase float32 slidingWindowPattern []bool
ropeType string ropeBase float32
ropeOriginalContext int ropeType string
ropeExtrapolation float32 ropeOriginalContext int
ropeBetaFast float32 ropeExtrapolation float32
ropeBetaSlow float32 ropeBetaFast float32
finalLogitSoftcap float32 ropeBetaSlow float32
finalLogitSoftcap float32
} }
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor { func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
@ -55,6 +55,9 @@ type TextModel struct {
const ( const (
gemmaGlobalCacheCount = 6 gemmaGlobalCacheCount = 6
gemma1BLayerCount = 26
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62 gemma27BLayerCount = 62
) )
@ -70,6 +73,7 @@ func newTextModel(c fs.Config) *TextModel {
Layers: make([]TextLayer, numBlocks), Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{ TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
contextLength: int(c.Uint("context_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)), attnKeyLen: int(c.Uint("attention.key_length", 256)),
@ -77,28 +81,32 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeBase: c.Float("rope.freq_base", 1000000.0), ropeBase: c.Float("rope.freq_base", 1000000.0),
slidingWindow: c.Uint("attention.sliding_window"),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
ropeType: c.String("rope.scaling.type"), ropeType: c.String("rope.scaling.type"),
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")), ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0), ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeScale: c.Float("rope.scaling.factor", 8.0), ropeScale: c.Float("rope.scaling.factor", 1.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0), finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
}, },
} }
// Google's Gemma 3 release with sliding window attention does // Apply corrections for older versions of the Gemma 3 models
// not use final logit softcapping, and so force it to 0.0 // by looking at whether they use sliding window attention and
// The QAT weights for Gemma 3 also included an incorrect // based on their layer counts.
// value for the rope scale, so we need to set it to 1.0 here. if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) {
// TODO (jmorganca): this should ideally be set to 0.0 in the switch numBlocks {
// model configuration instead of here, as future versions of case gemma1BLayerCount:
// models may include both sliding window attention and final // The 1B model has final logit softcapping set to 30.0
// logit softcapping. // but it should be 0.0
if slices.Contains(m.TextConfig.slidingWindowPattern, true) { m.TextConfig.finalLogitSoftcap = 0.0
m.TextConfig.finalLogitSoftcap = 0.0 case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount:
m.TextConfig.ropeScale = 1.0 // The 4B, 12B, and 27B models have rope scale unset
// but it shuold be set to 8.0
m.TextConfig.ropeScale = 8.0
}
} }
if numBlocks == gemma27BLayerCount { if numBlocks == gemma27BLayerCount {