diff --git a/.golangci.yaml b/.golangci.yaml index b211e5de7..bb33a056c 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -11,7 +11,6 @@ linters: - errorlint - exptostd - gocheckcompilerdirectives - - gocritic - govet - ineffassign - intrange diff --git a/cmd/cmd.go b/cmd/cmd.go index a67299402..d77bb2c58 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1430,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { latest.Summary() } - return &api.Message{Role: role, Content: fullResponse.String()}, nil + return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil } func generate(cmd *cobra.Command, opts runOptions) error { diff --git a/convert/convert_mistral.go b/convert/convert_mistral.go index a6fd4c41a..81774853b 100644 --- a/convert/convert_mistral.go +++ b/convert/convert_mistral.go @@ -29,6 +29,15 @@ type mistral3Model struct { SlidingWindow *uint32 `json:"sliding_window"` HiddenAct string `json:"hidden_act"` VocabSize uint32 `json:"vocab_size"` + RopeParameters struct { + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + Factor float32 `json:"factor"` + ScalingBeta float32 `json:"llama_4_scaling_beta"` + OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + RopeType string `json:"rope_type"` + RopeTheta float32 `json:"rope_theta"` + } `json:"rope_parameters"` } `json:"text_config"` VisionModel struct { NumAttentionHeads uint32 `json:"num_attention_heads"` @@ -61,8 +70,13 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS kv["mistral3.attention.key_length"] = p.TextModel.HeadDim kv["mistral3.attention.value_length"] = p.TextModel.HeadDim - kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers - kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta + kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads) + kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta) + + if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 { + kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings + kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta + } // Vision configuration kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index e071d71a8..8230dde39 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -159,8 +159,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + positionsScale := m.getScale(ctx, batch.Positions) - return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, positionsScale, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index d2e2eac6c..624d31510 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -16,6 +16,8 @@ type TextOptions struct { hiddenSize, numHeads, numKVHeads int headDim, ropeDim int eps, ropeBase, ropeScale float32 + ropeOrigPosEmbeddings int + ropeScalingBeta float32 } type TextModel struct { @@ -34,7 +36,7 @@ type SelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) @@ -49,6 +51,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + if opts.ropeOrigPosEmbeddings > 0 { + q = q.Mul(ctx, positionsScale) + } + kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache) kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize) return sa.Output.Forward(ctx, kqv) @@ -76,11 +82,11 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) - hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts) // In the final layer (outputs != nil), optimize by pruning to just the token positions // we need logits for. @@ -97,7 +103,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten return hiddenState.Add(ctx, residual) } -func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) // image embeddings @@ -114,25 +120,36 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor lastLayerOutputs = outputs } - hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions) + hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) return m.Output.Forward(ctx, hiddenState) } +func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor { + posScale := make([]float32, len(positions)) + for n, pos := range positions { + interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings)) + posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval)) + } + return ctx.Input().FromFloats(posScale, 1, 1, len(posScale)) +} + func newTextModel(c fs.Config) *TextModel { return &TextModel{ Layers: make([]Layer, c.Uint("block_count")), TextOptions: &TextOptions{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - headDim: int(c.Uint("attention.key_length")), - ropeDim: int(c.Uint("rope.dimension_count")), - eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.scaling.factor", 1), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1), + ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")), + ropeScalingBeta: c.Float("rope.scaling_beta"), }, } } diff --git a/model/parsers/ministral.go b/model/parsers/ministral.go new file mode 100644 index 000000000..fbb54ad2d --- /dev/null +++ b/model/parsers/ministral.go @@ -0,0 +1,136 @@ +package parsers + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/ollama/ollama/api" +) + +type ministralParserState int + +const ( + ministralCollectingContent = iota + ministralCollectingThinkingContent + ministralCollectingToolName + ministralCollectingToolArgs +) + +type MinistralParser struct { + state ministralParserState + buffer strings.Builder + tools []api.Tool + hasThinkingSupport bool + currentTool *api.Tool +} + +func (p *MinistralParser) HasToolSupport() bool { + return true +} + +func (p *MinistralParser) HasThinkingSupport() bool { + return p.hasThinkingSupport +} + +func (p *MinistralParser) setInitialState(lastMessage *api.Message) { + prefill := lastMessage != nil && lastMessage.Role == "assistant" + if !p.HasThinkingSupport() { + p.state = ministralCollectingContent + return + } + + if prefill && lastMessage.Content != "" { + p.state = ministralCollectingContent + return + } + + p.state = ministralCollectingThinkingContent +} + +func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + p.setInitialState(lastMessage) + return tools +} + +func toolByName(tools []api.Tool, n string) (*api.Tool, error) { + for i := range tools { + if tools[i].Function.Name == n { + return &tools[i], nil + } + } + return nil, fmt.Errorf("tool '%s' not found", n) +} + +func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + + switch p.state { + case ministralCollectingContent: + if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") { + before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false) + if before != "" { + return before, "", calls, nil + } + p.state = ministralCollectingToolName + } else if strings.Contains(p.buffer.String(), "[THINK]") { + p.state = ministralCollectingThinkingContent + return "", "", calls, nil + } else { + p.buffer.Reset() + return s, "", calls, nil + } + case ministralCollectingThinkingContent: + if strings.Contains(p.buffer.String(), "[/THINK]") { + thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true) + p.state = ministralCollectingContent + if after != "" { + p.buffer.Reset() + return after, thinkingContent, calls, nil + } + return "", thinkingContent, calls, nil + } else { + p.buffer.Reset() + return "", s, calls, nil + } + case ministralCollectingToolName: + if strings.Contains(p.buffer.String(), "[ARGS]") { + name, _ := splitAtTag(&p.buffer, "[ARGS]", false) + + t, err := toolByName(p.tools, name) + if err != nil { + return "", "", calls, err + } + p.currentTool = t + p.state = ministralCollectingToolArgs + return "", "", calls, nil + } + return "", "", calls, nil + case ministralCollectingToolArgs: + if strings.Contains(p.buffer.String(), "}") { + before, _ := splitAtTag(&p.buffer, "}", false) + before += "}" + + var data map[string]any + if err := json.Unmarshal([]byte(before), &data); err != nil { + // todo - throw a better error + return "", "", calls, err + } + + p.state = ministralCollectingContent + + call := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: p.currentTool.Function.Name, + Arguments: api.ToolCallFunctionArguments(data), + }, + } + calls = append(calls, call) + return "", "", calls, nil + } + return "", "", calls, nil + } + + return p.buffer.String(), thinking, calls, nil +} diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 9bfd5c1ec..24ab07fb2 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -1,6 +1,9 @@ package parsers import ( + "strings" + "unicode" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/harmony" ) @@ -38,16 +41,17 @@ func ParserForName(name string) Parser { if parser, ok := registry.constructors[name]; ok { return parser() } + var p Parser + switch name { case "qwen3-coder": - parser := &Qwen3CoderParser{} - return parser + p = &Qwen3CoderParser{} case "qwen3-vl-instruct": - parser := &Qwen3VLParser{hasThinkingSupport: false} - return parser + p = &Qwen3VLParser{hasThinkingSupport: false} case "qwen3-vl-thinking": - parser := &Qwen3VLParser{hasThinkingSupport: true} - return parser + p = &Qwen3VLParser{hasThinkingSupport: true} + case "ministral": + p = &MinistralParser{hasThinkingSupport: false} case "passthrough": return &PassthroughParser{} case "harmony": @@ -57,6 +61,7 @@ func ParserForName(name string) Parser { default: return nil } + return p } type PassthroughParser struct{} @@ -76,3 +81,20 @@ func (p *PassthroughParser) HasToolSupport() bool { func (p *PassthroughParser) HasThinkingSupport() bool { return false } + +func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string) { + split := strings.SplitN(sb.String(), tag, 2) + if len(split) == 1 { + sb.Reset() + return split[0], "" + } + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + after := split[1] + if trimAfter { + after = strings.TrimLeftFunc(after, unicode.IsSpace) + } + sb.Reset() + sb.WriteString(after) + return before, after // return events +} diff --git a/model/parsers/parsers_test.go b/model/parsers/parsers_test.go index e918a62f6..4f8566de3 100644 --- a/model/parsers/parsers_test.go +++ b/model/parsers/parsers_test.go @@ -1,6 +1,7 @@ package parsers import ( + "strings" "testing" "github.com/ollama/ollama/api" @@ -95,3 +96,164 @@ func TestUnknownParserReturnsNil(t *testing.T) { t.Error("expected nil for unknown parser") } } + +func TestSplitAtTag(t *testing.T) { + tests := []struct { + name string + input string + tag string + trimAfter bool + wantBefore string + wantAfter string + wantSB string // expected content of strings.Builder after operation + }{ + { + name: "basic split with trimAfter true", + input: "hello world", + tag: "", + trimAfter: true, + wantBefore: "hello", + wantAfter: "world", + wantSB: "world", + }, + { + name: "basic split with trimAfter false", + input: "hello world", + tag: "", + trimAfter: false, + wantBefore: "hello", + wantAfter: " world", + wantSB: " world", + }, + { + name: "tag at beginning with trimAfter true", + input: "world", + tag: "", + trimAfter: true, + wantBefore: "", + wantAfter: "world", + wantSB: "world", + }, + { + name: "tag at beginning with trimAfter false", + input: " world", + tag: "", + trimAfter: false, + wantBefore: "", + wantAfter: " world", + wantSB: " world", + }, + { + name: "tag at end with trimAfter true", + input: "hello ", + tag: "", + trimAfter: true, + wantBefore: "hello", + wantAfter: "", + wantSB: "", + }, + { + name: "tag at end with trimAfter false", + input: "hello ", + tag: "", + trimAfter: false, + wantBefore: "hello", + wantAfter: "", + wantSB: "", + }, + { + name: "multiple tags splits at first occurrence", + input: "hello world end", + tag: "", + trimAfter: true, + wantBefore: "hello", + wantAfter: "world end", + wantSB: "world end", + }, + { + name: "tag not present", + input: "hello world", + tag: "", + trimAfter: true, + wantBefore: "hello world", + wantAfter: "", + wantSB: "", + }, + { + name: "empty input", + input: "", + tag: "", + trimAfter: true, + wantBefore: "", + wantAfter: "", + wantSB: "", + }, + { + name: "only whitespace before tag", + input: " \t\nworld", + tag: "", + trimAfter: true, + wantBefore: "", + wantAfter: "world", + wantSB: "world", + }, + { + name: "only whitespace after tag with trimAfter true", + input: "hello \t\n", + tag: "", + trimAfter: true, + wantBefore: "hello", + wantAfter: "", + wantSB: "", + }, + { + name: "only whitespace after tag with trimAfter false", + input: "hello \t\n", + tag: "", + trimAfter: false, + wantBefore: "hello", + wantAfter: " \t\n", + wantSB: " \t\n", + }, + { + name: "complex whitespace trimming", + input: " hello \t\n \n\t world ", + tag: "", + trimAfter: true, + wantBefore: " hello", + wantAfter: "world ", + wantSB: "world ", + }, + { + name: "tag with special characters", + input: "text more text", + tag: "", + trimAfter: true, + wantBefore: "text", + wantAfter: "more text", + wantSB: "more text", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sb := &strings.Builder{} + sb.WriteString(tt.input) + + before, after := splitAtTag(sb, tt.tag, tt.trimAfter) + + // Check return values + if before != tt.wantBefore { + t.Errorf("splitAtTag() before = %q, want %q", before, tt.wantBefore) + } + if after != tt.wantAfter { + t.Errorf("splitAtTag() after = %q, want %q", after, tt.wantAfter) + } + + // Check strings.Builder state + if sb.String() != tt.wantSB { + t.Errorf("strings.Builder after split = %q, want %q", sb.String(), tt.wantSB) + } + }) + } +} diff --git a/model/parsers/qwen3vl.go b/model/parsers/qwen3vl.go index 979f1668c..cb7627638 100644 --- a/model/parsers/qwen3vl.go +++ b/model/parsers/qwen3vl.go @@ -70,7 +70,6 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin p.buffer.WriteString(s) events := p.parseEvents() - var toolCalls []api.ToolCall var contentSb strings.Builder var thinkingSb strings.Builder for _, event := range events { @@ -81,7 +80,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin slog.Warn("qwen tool call parsing failed", "error", err) return "", "", nil, err } - toolCalls = append(toolCalls, toolCall) + calls = append(calls, toolCall) case qwenEventThinkingContent: thinkingSb.WriteString(event.content) case qwenEventContent: @@ -91,7 +90,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin } } - return contentSb.String(), thinkingSb.String(), toolCalls, nil + return contentSb.String(), thinkingSb.String(), calls, nil } func (p *Qwen3VLParser) parseEvents() []qwenEvent { @@ -113,19 +112,6 @@ func (p *Qwen3VLParser) parseEvents() []qwenEvent { return all } -func splitAtTag(p *Qwen3VLParser, tag string, trimAfter bool) (string, string) { - split := strings.SplitN(p.buffer.String(), tag, 2) - before := split[0] - before = strings.TrimRightFunc(before, unicode.IsSpace) - after := split[1] - if trimAfter { - after = strings.TrimLeftFunc(after, unicode.IsSpace) - } - p.buffer.Reset() - p.buffer.WriteString(after) - return before, after // return events -} - func (p *Qwen3VLParser) eatLeadingWhitespaceAndTransitionTo(nextState qwenParserState) ([]qwenEvent, bool) { trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) p.buffer.Reset() @@ -144,7 +130,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) { case CollectingContent: if strings.Contains(p.buffer.String(), toolOpenTag) { // events = emitContentBeforeTag(p, events, toolOpenTag) - before, _ := splitAtTag(p, toolOpenTag, false) + before, _ := splitAtTag(&p.buffer, toolOpenTag, false) if len(before) > 0 { events = append(events, qwenEventContent{content: before}) } @@ -195,7 +181,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) { } case CollectingThinkingContent: if strings.Contains(p.buffer.String(), thinkingCloseTag) { - thinking, remaining := splitAtTag(p, thinkingCloseTag, true) + thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true) if len(thinking) > 0 { events = append(events, qwenEventThinkingContent{content: thinking}) }