model: conversion and hyperparameter fixes for ministral and devstral (#13424)

This commit is contained in:
Jeffrey Morgan 2025-12-11 13:04:00 -08:00 committed by GitHub
parent 1c4e85b4df
commit a838421ea3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 250 additions and 12 deletions

View File

@ -182,6 +182,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "Ministral3ForCausalLM":
conv = &mistral3CausalModel{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":

View File

@ -30,13 +30,15 @@ type mistral3Model struct {
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
ScalingBeta float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
@ -50,6 +52,9 @@ type mistral3Model struct {
HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
@ -72,10 +77,22 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
if p.TextModel.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
}
if p.TextModel.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
}
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
}
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
}
// Vision configuration
@ -88,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
// Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex

View File

@ -0,0 +1,181 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3CausalModel struct {
ModelParameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.NumHiddenLayers
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.HiddenSize
kv["mistral3.feed_forward_length"] = p.IntermediateSize
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["mistral3.attention.key_length"] = p.HeadDim
kv["mistral3.attention.value_length"] = p.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
if p.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
}
if p.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
}
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
if p.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
return kv
}
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3CausalModel) Replacements() []string {
return []string{
"model.norm", "output_norm",
"model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@ -17,10 +18,41 @@ type TextOptions struct {
eps, ropeBase, ropeScale float32
ropeOrigPosEmbeddings int
ropeScalingBeta float32
ropeType string
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
ropeMscale float32
ropeMscaleAllDim float32
}
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale)
var ropeOpts []func(*rope.Options)
if o.ropeType == "yarn" {
getMscale := func(scale, mscale float64) float64 {
if scale <= 1.0 {
return 1.0
}
return 0.1*mscale*math.Log(scale) + 1.0
}
var attnFactor float32
if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 {
attnFactor = float32(getMscale(float64(o.ropeScale), float64(o.ropeMscale)) / getMscale(float64(o.ropeScale), float64(o.ropeMscaleAllDim)))
} else {
attnFactor = float32(getMscale(float64(o.ropeScale), 1))
}
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
rope.WithAttentionFactor(attnFactor),
rope.WithBetaFast(o.ropeBetaFast),
rope.WithBetaSlow(o.ropeBetaSlow),
)
}
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, ropeOpts...)
}
type TextModel struct {
@ -150,9 +182,15 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
ropeScale: c.Float("rope.scaling.factor", 1.0),
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
ropeScalingBeta: c.Float("rope.scaling_beta"),
ropeScalingBeta: c.Float("rope.scaling_beta", 0.1),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 32.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeType: c.String("rope.scaling.type"),
ropeMscale: c.Float("rope.scaling.mscale"),
ropeMscaleAllDim: c.Float("rope.scaling.mscale_all_dim"),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1),
},
}
}