diff --git a/convert/convert.go b/convert/convert.go index bc110c6fa..f08467952 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -182,6 +182,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &llama4Model{} case "Mistral3ForConditionalGeneration": conv = &mistral3Model{} + case "Ministral3ForCausalLM": + conv = &mistral3CausalModel{} case "MixtralForCausalLM": conv = &mixtralModel{} case "GemmaForCausalLM": diff --git a/convert/convert_mistral.go b/convert/convert_mistral.go index 81774853b..f11bd9644 100644 --- a/convert/convert_mistral.go +++ b/convert/convert_mistral.go @@ -30,13 +30,15 @@ type mistral3Model struct { HiddenAct string `json:"hidden_act"` VocabSize uint32 `json:"vocab_size"` RopeParameters struct { - BetaFast float32 `json:"beta_fast"` - BetaSlow float32 `json:"beta_slow"` - Factor float32 `json:"factor"` - ScalingBeta float32 `json:"llama_4_scaling_beta"` - OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` - RopeType string `json:"rope_type"` - RopeTheta float32 `json:"rope_theta"` + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + Factor float32 `json:"factor"` + Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"` + OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + RopeType string `json:"rope_type"` + RopeTheta float32 `json:"rope_theta"` + Mscale *float32 `json:"mscale"` + MscaleAllDim *float32 `json:"mscale_all_dim"` } `json:"rope_parameters"` } `json:"text_config"` VisionModel struct { @@ -50,6 +52,9 @@ type mistral3Model struct { HeadDim uint32 `json:"head_dim"` HiddenAct string `json:"hidden_act"` RopeTheta float32 `json:"rope_theta"` + RopeParameters struct { + RopeTheta float32 `json:"rope_theta"` + } `json:"rope_parameters"` } `json:"vision_config"` MultiModalProjectorBias bool `json:"multimodal_projector_bias"` ProjectorHiddenAct string `json:"projector_hidden_act"` @@ -72,10 +77,22 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { kv["mistral3.attention.value_length"] = p.TextModel.HeadDim kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads) kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta) + kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor + kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType + kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast + kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow + if p.TextModel.RopeParameters.Mscale != nil { + kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale + } + if p.TextModel.RopeParameters.MscaleAllDim != nil { + kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim + } if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 { kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings - kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta + } + if p.TextModel.RopeParameters.Llama4ScalingBeta != nil { + kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta } // Vision configuration @@ -88,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels // kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value - kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta + kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta) // Multimodal configuration kv["mistral3.image_token_index"] = p.ImageTokenIndex diff --git a/convert/convert_mistral_causal.go b/convert/convert_mistral_causal.go new file mode 100644 index 000000000..99a483736 --- /dev/null +++ b/convert/convert_mistral_causal.go @@ -0,0 +1,181 @@ +package convert + +import ( + "cmp" + "fmt" + "strings" + + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" + + "github.com/ollama/ollama/fs/ggml" +) + +type mistral3CausalModel struct { + ModelParameters + + NumHiddenLayers uint32 `json:"num_hidden_layers"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RopeTheta float32 `json:"rope_theta"` + RMSNormEPS float32 `json:"rms_norm_eps"` + HeadDim uint32 `json:"head_dim"` + SlidingWindow *uint32 `json:"sliding_window"` + HiddenAct string `json:"hidden_act"` + VocabSize uint32 `json:"vocab_size"` + RopeParameters struct { + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + Factor float32 `json:"factor"` + Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"` + OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + RopeType string `json:"rope_type"` + RopeTheta float32 `json:"rope_theta"` + Mscale *float32 `json:"mscale"` + MscaleAllDim *float32 `json:"mscale_all_dim"` + } `json:"rope_parameters"` +} + +func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "mistral3" + kv["mistral3.vocab_size"] = p.VocabSize + + // Text configuration + kv["mistral3.block_count"] = p.NumHiddenLayers + kv["mistral3.context_length"] = p.MaxPositionEmbeddings + kv["mistral3.embedding_length"] = p.HiddenSize + kv["mistral3.feed_forward_length"] = p.IntermediateSize + kv["mistral3.attention.head_count"] = p.NumAttentionHeads + kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads + kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS + kv["mistral3.attention.key_length"] = p.HeadDim + kv["mistral3.attention.value_length"] = p.HeadDim + kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads) + kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta) + kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor + kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType + kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast + kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow + + if p.RopeParameters.Mscale != nil { + kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale + } + + if p.RopeParameters.MscaleAllDim != nil { + kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim + } + + if p.RopeParameters.OrigMaxPositionEmbeddings > 0 { + kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings + kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta + } + + if p.RopeParameters.Llama4ScalingBeta != nil { + kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta + } + + return kv +} + +func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + + for _, t := range ts { + if !strings.HasPrefix(t.Name(), "v.") { + if strings.HasSuffix(t.Name(), ".attn_q.weight") || + strings.HasSuffix(t.Name(), ".attn_k.weight") { + t.SetRepacker(p.repack) + } + } + + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + + return out +} + +func (p *mistral3CausalModel) Replacements() []string { + return []string{ + "model.norm", "output_norm", + "model.", "", + "layers", "blk", + "transformer.layers", "blk", + "vision_tower", "v", + "ln_pre", "encoder_norm", + "input_layernorm", "attn_norm", + "post_attention_layernorm", "ffn_norm", + "embed_tokens", "token_embd", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "mlp.down_proj", "ffn_down", + "mlp.gate_proj", "ffn_gate", + "mlp.up_proj", "ffn_up", + "attention.q_proj", "attn_q", + "attention.k_proj", "attn_k", + "attention.v_proj", "attn_v", + "attention.o_proj", "attn_output", + "attention_norm", "attn_norm", + "feed_forward.gate_proj", "ffn_gate", + "feed_forward.down_proj", "ffn_down", + "feed_forward.up_proj", "ffn_up", + "multi_modal_projector", "mm", + "ffn_norm", "ffn_norm", + "lm_head", "output", + } +} + +func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) { + var dims []int + for _, dim := range shape { + dims = append(dims, int(dim)) + } + + var heads uint32 + if strings.HasSuffix(name, ".attn_q.weight") { + heads = p.NumAttentionHeads + } else if strings.HasSuffix(name, ".attn_k.weight") { + heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) + } else { + return nil, fmt.Errorf("unknown tensor for repack: %s", name) + } + + n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil { + return nil, err + } + + if err := n.T(0, 2, 1, 3); err != nil { + return nil, err + } + + if err := n.Reshape(dims...); err != nil { + return nil, err + } + + if err := n.Transpose(); err != nil { + return nil, err + } + + ts, err := native.SelectF32(n, 1) + if err != nil { + return nil, err + } + + var f32s []float32 + for _, t := range ts { + f32s = append(f32s, t...) + } + + return f32s, nil +} diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index ebb7b3aa1..01eca1c50 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -8,6 +8,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -17,10 +18,41 @@ type TextOptions struct { eps, ropeBase, ropeScale float32 ropeOrigPosEmbeddings int ropeScalingBeta float32 + ropeType string + ropeExtrapolation float32 + ropeBetaFast float32 + ropeBetaSlow float32 + ropeMscale float32 + ropeMscaleAllDim float32 } func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { - return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale) + var ropeOpts []func(*rope.Options) + if o.ropeType == "yarn" { + getMscale := func(scale, mscale float64) float64 { + if scale <= 1.0 { + return 1.0 + } + return 0.1*mscale*math.Log(scale) + 1.0 + } + + var attnFactor float32 + if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 { + attnFactor = float32(getMscale(float64(o.ropeScale), float64(o.ropeMscale)) / getMscale(float64(o.ropeScale), float64(o.ropeMscaleAllDim))) + } else { + attnFactor = float32(getMscale(float64(o.ropeScale), 1)) + } + + ropeOpts = append(ropeOpts, + rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings), + rope.WithExtrapolationFactor(o.ropeExtrapolation), + rope.WithAttentionFactor(attnFactor), + rope.WithBetaFast(o.ropeBetaFast), + rope.WithBetaSlow(o.ropeBetaSlow), + ) + } + + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, ropeOpts...) } type TextModel struct { @@ -150,9 +182,15 @@ func newTextModel(c fs.Config) *TextModel { ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.scaling.factor", 1), + ropeScale: c.Float("rope.scaling.factor", 1.0), ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")), - ropeScalingBeta: c.Float("rope.scaling_beta"), + ropeScalingBeta: c.Float("rope.scaling_beta", 0.1), + ropeBetaFast: c.Float("rope.scaling.beta_fast", 32.0), + ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0), + ropeType: c.String("rope.scaling.type"), + ropeMscale: c.Float("rope.scaling.mscale"), + ropeMscaleAllDim: c.Float("rope.scaling.mscale_all_dim"), + ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1), }, } }