mirror of https://github.com/ollama/ollama
model: default gemma 3 rope scale to 1.0, apply corrections based on layer counts (#13453)
This commit is contained in:
parent
1b308e1d2a
commit
4ff8a691bc
|
|
@ -2,7 +2,6 @@ package gemma3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
|
@ -13,19 +12,20 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type TextConfig struct {
|
type TextConfig struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, contextLength, numHeads, numKVHeads int
|
||||||
attnKeyLen, attnValLen int
|
attnKeyLen, attnValLen int
|
||||||
eps, ropeScale float32
|
eps, ropeScale float32
|
||||||
ropeLocalBase float32
|
ropeLocalBase float32
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
slidingWindowPattern []bool
|
slidingWindow uint32
|
||||||
ropeBase float32
|
slidingWindowPattern []bool
|
||||||
ropeType string
|
ropeBase float32
|
||||||
ropeOriginalContext int
|
ropeType string
|
||||||
ropeExtrapolation float32
|
ropeOriginalContext int
|
||||||
ropeBetaFast float32
|
ropeExtrapolation float32
|
||||||
ropeBetaSlow float32
|
ropeBetaFast float32
|
||||||
finalLogitSoftcap float32
|
ropeBetaSlow float32
|
||||||
|
finalLogitSoftcap float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
|
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
|
||||||
|
|
@ -55,6 +55,9 @@ type TextModel struct {
|
||||||
|
|
||||||
const (
|
const (
|
||||||
gemmaGlobalCacheCount = 6
|
gemmaGlobalCacheCount = 6
|
||||||
|
gemma1BLayerCount = 26
|
||||||
|
gemma4BLayerCount = 34
|
||||||
|
gemma12BLayerCount = 48
|
||||||
gemma27BLayerCount = 62
|
gemma27BLayerCount = 62
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -70,6 +73,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||||
Layers: make([]TextLayer, numBlocks),
|
Layers: make([]TextLayer, numBlocks),
|
||||||
TextConfig: &TextConfig{
|
TextConfig: &TextConfig{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
contextLength: int(c.Uint("context_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||||
|
|
@ -77,28 +81,32 @@ func newTextModel(c fs.Config) *TextModel {
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||||
ropeBase: c.Float("rope.freq_base", 1000000.0),
|
ropeBase: c.Float("rope.freq_base", 1000000.0),
|
||||||
|
slidingWindow: c.Uint("attention.sliding_window"),
|
||||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||||
ropeType: c.String("rope.scaling.type"),
|
ropeType: c.String("rope.scaling.type"),
|
||||||
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
|
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
|
||||||
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
||||||
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
|
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
|
||||||
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 8.0),
|
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Google's Gemma 3 release with sliding window attention does
|
// Apply corrections for older versions of the Gemma 3 models
|
||||||
// not use final logit softcapping, and so force it to 0.0
|
// by looking at whether they use sliding window attention and
|
||||||
// The QAT weights for Gemma 3 also included an incorrect
|
// based on their layer counts.
|
||||||
// value for the rope scale, so we need to set it to 1.0 here.
|
if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) {
|
||||||
// TODO (jmorganca): this should ideally be set to 0.0 in the
|
switch numBlocks {
|
||||||
// model configuration instead of here, as future versions of
|
case gemma1BLayerCount:
|
||||||
// models may include both sliding window attention and final
|
// The 1B model has final logit softcapping set to 30.0
|
||||||
// logit softcapping.
|
// but it should be 0.0
|
||||||
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
|
m.TextConfig.finalLogitSoftcap = 0.0
|
||||||
m.TextConfig.finalLogitSoftcap = 0.0
|
case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount:
|
||||||
m.TextConfig.ropeScale = 1.0
|
// The 4B, 12B, and 27B models have rope scale unset
|
||||||
|
// but it shuold be set to 8.0
|
||||||
|
m.TextConfig.ropeScale = 8.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if numBlocks == gemma27BLayerCount {
|
if numBlocks == gemma27BLayerCount {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue