diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index da9fe3e93..4f405dc35 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -388,9 +388,9 @@ func NewFunctionNameMap() *FunctionNameMap { } } -// Init initializes the handler with tools and optional last message +// Init initializes the handler with tools, optional last message, and think value // Implements the Parser interface -func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { +func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { // Initialize the harmony parser if h.HarmonyParser == nil { h.HarmonyParser = &HarmonyParser{ diff --git a/model/parsers/cogito.go b/model/parsers/cogito.go new file mode 100644 index 000000000..2415dd31b --- /dev/null +++ b/model/parsers/cogito.go @@ -0,0 +1,319 @@ +package parsers + +import ( + "encoding/json" + "errors" + "log/slog" + "strings" + "unicode" + + "github.com/ollama/ollama/api" +) + +type CogitoParserState int + +const ( + CogitoCollectingThinking CogitoParserState = iota + CogitoCollectingContent + CogitoCollectingToolCalls + CogitoCollectingToolOutput +) + +const ( + cogitoThinkingCloseTag = "" + cogitoToolCallsBeginTag = "<|tool▁calls▁begin|>" + cogitoToolCallsEndTag = "<|tool▁calls▁end|>" + cogitoToolCallBeginTag = "<|tool▁call▁begin|>" + cogitoToolCallEndTag = "<|tool▁call▁end|>" + cogitoToolSepTag = "<|tool▁sep|>" + cogitoToolOutputBeginTag = "<|tool▁output▁begin|>" + cogitoToolOutputEndTag = "<|tool▁output▁end|>" + cogitoToolOutputsBeginTag = "<|tool▁outputs▁begin|>" + cogitoToolOutputsEndTag = "<|tool▁outputs▁end|>" +) + +type CogitoParser struct { + state CogitoParserState + buffer strings.Builder +} + +func (p *CogitoParser) HasToolSupport() bool { + return true +} + +func (p *CogitoParser) HasThinkingSupport() bool { + return true +} + +func (p *CogitoParser) setInitialState(lastMessage *api.Message, tools []api.Tool, thinkValue *api.ThinkValue) { + prefill := lastMessage != nil && lastMessage.Role == "assistant" + + // Check both model capability AND request preference + thinkingEnabled := thinkValue != nil && thinkValue.Bool() + // thinkingEnabled should be set to false for tools + + if !thinkingEnabled { + p.state = CogitoCollectingContent + return + } + + if prefill && lastMessage.Content != "" { + p.state = CogitoCollectingContent + return + } + + // Note: for cogito, if there are tools, then we don't want to be thinking + if len(tools) > 0 { + p.state = CogitoCollectingContent + return + } + + p.state = CogitoCollectingThinking +} + +func (p *CogitoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.setInitialState(lastMessage, tools, thinkValue) + return tools +} + +type cogitoEvent interface { + isCogitoEvent() +} + +type cogitoEventThinkingContent struct { + content string +} + +type cogitoEventContent struct { + content string +} + +type cogitoEventToolCall struct { + toolCall api.ToolCall +} + +func (cogitoEventThinkingContent) isCogitoEvent() {} +func (cogitoEventContent) isCogitoEvent() {} +func (cogitoEventToolCall) isCogitoEvent() {} + +func (p *CogitoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + var thinkingSb strings.Builder + for _, event := range events { + switch event := event.(type) { + case cogitoEventToolCall: + toolCalls = append(toolCalls, event.toolCall) + case cogitoEventThinkingContent: + thinkingSb.WriteString(event.content) + case cogitoEventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), toolCalls, nil +} + +func (p *CogitoParser) parseEvents() []cogitoEvent { + var all []cogitoEvent + + keepLooping := true + for keepLooping { + var events []cogitoEvent + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + return all +} + +func (p *CogitoParser) eat() ([]cogitoEvent, bool) { + var events []cogitoEvent + bufStr := p.buffer.String() + if bufStr == "" { + return events, false + } + + switch p.state { + case CogitoCollectingThinking: + if strings.Contains(bufStr, cogitoThinkingCloseTag) { // thinking[] -> content + split := strings.SplitN(bufStr, cogitoThinkingCloseTag, 2) + thinking := split[0] + thinking = strings.TrimRightFunc(thinking, unicode.IsSpace) + + remaining := split[1] + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = CogitoCollectingContent + + if len(thinking) > 0 { + events = append(events, cogitoEventThinkingContent{content: thinking}) + } + return events, true + } else if overlapLen := overlap(bufStr, cogitoThinkingCloseTag); overlapLen > 0 { // partial + beforePartialTag := bufStr[:len(bufStr)-overlapLen] + trailingLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingLen + + unambiguous := bufStr[:ambiguousStart] + ambiguous := bufStr[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, cogitoEventThinkingContent{content: unambiguous}) + } + return events, false + } else { // otherwise its thinking content + whitespaceLen := trailingWhitespaceLen(bufStr) + ambiguousStart := len(bufStr) - whitespaceLen + + unambiguous := bufStr[:ambiguousStart] + ambiguous := bufStr[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, cogitoEventThinkingContent{content: unambiguous}) + } + return events, false + } + + case CogitoCollectingContent: + switch { + case strings.Contains(bufStr, cogitoToolCallsBeginTag): // content[<|tool▁calls▁begin|>] -> tool calls + split := strings.SplitN(bufStr, cogitoToolCallsBeginTag, 2) + contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace) + remaining := split[1] + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = CogitoCollectingToolCalls + + if len(contentBefore) > 0 { + events = append(events, cogitoEventContent{content: contentBefore}) + } + return events, true + case strings.Contains(bufStr, cogitoToolOutputsBeginTag): // content[<|tool▁outputs▁begin|>] -> tool outputs + split := strings.SplitN(bufStr, cogitoToolOutputsBeginTag, 2) + contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace) + remaining := split[1] + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = CogitoCollectingToolOutput + + if len(contentBefore) > 0 { + events = append(events, cogitoEventContent{content: contentBefore}) + } + return events, true + default: // otherwise its content + p.buffer.Reset() + if len(bufStr) > 0 { + events = append(events, cogitoEventContent{content: bufStr}) + } + return events, false + } + case CogitoCollectingToolCalls: + if idx := strings.Index(bufStr, cogitoToolCallBeginTag); idx != -1 { + startIdx := idx + len(cogitoToolCallBeginTag) + if endIdx := strings.Index(bufStr[startIdx:], cogitoToolCallEndTag); endIdx != -1 { + toolCallContent := bufStr[startIdx : startIdx+endIdx] + + if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil { + remaining := bufStr[startIdx+endIdx+len(cogitoToolCallEndTag):] + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + + events = append(events, cogitoEventToolCall{toolCall: toolCall}) + return events, true + } else { + slog.Warn("cogito tool call parsing failed", "error", err) + } + } + } + + if idx := strings.Index(bufStr, cogitoToolCallsEndTag); idx != -1 { + remaining := bufStr[idx+len(cogitoToolCallsEndTag):] + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = CogitoCollectingContent + + return events, true + } + + return events, false + + case CogitoCollectingToolOutput: + if idx := strings.Index(bufStr, cogitoToolOutputBeginTag); idx != -1 { + startIdx := idx + len(cogitoToolOutputBeginTag) + if endIdx := strings.Index(bufStr[startIdx:], cogitoToolOutputEndTag); endIdx != -1 { + remaining := bufStr[startIdx+endIdx+len(cogitoToolOutputEndTag):] + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + + return events, true + } + } + + if idx := strings.Index(bufStr, cogitoToolOutputsEndTag); idx != -1 { + remaining := bufStr[idx+len(cogitoToolOutputsEndTag):] + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = CogitoCollectingContent + + return events, true + } + + return events, false + } + + return events, false +} + +func (p *CogitoParser) parseToolCallContent(content string) (api.ToolCall, error) { + // Expected format: function<|tool▁sep|>tool_name\n```json\n{args}\n``` + parts := strings.SplitN(content, cogitoToolSepTag, 2) + if len(parts) < 2 { + return api.ToolCall{}, errors.New("invalid format") + } + nameAndArgs := parts[1] + + jsonStart := strings.Index(nameAndArgs, "\n```json\n") + if jsonStart == -1 { + return api.ToolCall{}, errors.New("invalid format") + } + toolName := strings.TrimSpace(nameAndArgs[:jsonStart]) + jsonContent := nameAndArgs[jsonStart+len("\n```json\n"):] + + jsonEnd := strings.Index(jsonContent, "\n```") + if jsonEnd == -1 { + return api.ToolCall{}, errors.New("invalid format") + } + argsJSON := jsonContent[:jsonEnd] + + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return api.ToolCall{}, err + } + + return api.ToolCall{ + Function: api.ToolCallFunction{ + Name: toolName, + Arguments: args, + }, + }, nil +} diff --git a/model/parsers/cogito_test.go b/model/parsers/cogito_test.go new file mode 100644 index 000000000..7eaa1c2e2 --- /dev/null +++ b/model/parsers/cogito_test.go @@ -0,0 +1,565 @@ +package parsers + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestCogitoParser(t *testing.T) { + tests := []struct { + name string + input string + expectedContent string + expectedThinking string + expectedToolCalls []api.ToolCall + tools []api.Tool + lastMessage *api.Message + }{ + { + name: "simple_content", + input: "This is a simple response.", + expectedContent: "This is a simple response.", + expectedThinking: "", + }, + { + name: "thinking_only", + input: "This is thinking content.This is response content.", + expectedContent: "This is response content.", + expectedThinking: "This is thinking content.", + }, + { + name: "tool_call_simple", + input: `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +` + "```json\n" + `{"location":"Paris"} +` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|>`, + expectedToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + }, + { + name: "thinking_with_tool_call", + input: `I need to check the weather.<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +` + "```json\n" + `{"location":"Paris"} +` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|>`, + expectedContent: "I need to check the weather.", + expectedThinking: "", // No thinking when tools are present (Cogito-specific behavior) + expectedToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + }, + { + name: "multiple_tool_calls", + input: `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +` + "```json\n" + `{"location":"Paris"} +` + "```" + `<|tool▁call▁end|> +<|tool▁call▁begin|>function<|tool▁sep|>get_weather +` + "```json\n" + `{"location":"London"} +` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|>`, + expectedToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "London", + }, + }, + }, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + }, + { + name: "complex_tool_arguments", + input: `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>process_data +` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true,"threshold":0.95},"count":42} +` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|>`, + expectedToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "process_data", + Arguments: api.ToolCallFunctionArguments{ + "items": []any{"item1", "item2"}, + "config": map[string]any{"enabled": true, "threshold": 0.95}, + "count": 42.0, + }, + }, + }, + }, + }, + { + name: "tool_output_parsing", + input: `<|tool▁outputs▁begin|><|tool▁output▁begin|>{"temperature": 22, "condition": "sunny"}<|tool▁output▁end|><|tool▁outputs▁end|>`, + expectedContent: "", + expectedThinking: "", + }, + { + name: "thinking_with_multiline_content", + input: `This is line 1 +This is line 2 +This is line 3Final response here.`, + expectedContent: "Final response here.", + expectedThinking: "This is line 1\nThis is line 2\nThis is line 3", + }, + { + name: "no_thinking_simple", + input: "This is content.", + expectedContent: "This is content.", + expectedThinking: "", + }, + { + name: "prefill_content_only", + input: "Continuing from previous content.", + expectedContent: "Continuing from previous content.", + lastMessage: &api.Message{ + Role: "assistant", + Content: "Previous content", + }, + }, + { + name: "prefill_with_thinking", + input: "Continuing thinkingContinuing content.", + expectedContent: "Continuing content.", + expectedThinking: "Continuing thinking", + lastMessage: &api.Message{ + Role: "assistant", + }, + }, + // Edge cases + { + name: "nested_think_tags_in_thinking", + input: "I'm thinking nested more thinkingFinal content.", + expectedContent: "more thinkingFinal content.", + expectedThinking: "I'm thinking nested", + }, + { + name: "multiple_think_close_tags", + input: "First thinkingContentMore content.", + expectedContent: "ContentMore content.", + expectedThinking: "First thinking", + }, + { + name: "empty_thinking_content", + input: "Just content here.", + expectedContent: "Just content here.", + expectedThinking: "", + }, + { + name: "thinking_disabled_with_think_tags", + input: "Content with tags should be treated as content.", + expectedContent: "Content with tags should be treated as content.", + expectedThinking: "", + lastMessage: &api.Message{ + Role: "assistant", + Content: "existing", // Forces non-thinking mode + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Use thinking-enabled parser for tests that expect thinking + hasThinking := tt.expectedThinking != "" + parser := &CogitoParser{} // it has thinking support + parser.Init(tt.tools, tt.lastMessage, &api.ThinkValue{Value: hasThinking}) // but we should set it with the request that the user wants + + content, thinking, toolCalls, err := parser.Add(tt.input, true) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + if diff := cmp.Diff(tt.expectedContent, content); diff != "" { + t.Errorf("content mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" { + t.Errorf("thinking mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" { + t.Errorf("tool calls mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestCogitoParser_Streaming(t *testing.T) { + parser := &CogitoParser{} + parser.Init(nil, nil, &api.ThinkValue{Value: true}) + + chunks := []string{ + "This is ", + "thinking content", + ".This is ", + "content.<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test_tool\n```json\n{\"arg\":\"value\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>", + } + + var finalContent, finalThinking strings.Builder + var finalToolCalls []api.ToolCall + + for i, chunk := range chunks { + done := i == len(chunks)-1 + content, thinking, toolCalls, err := parser.Add(chunk, done) + if err != nil { + t.Fatalf("Add() error on chunk %d: %v", i, err) + } + + finalContent.WriteString(content) + finalThinking.WriteString(thinking) + finalToolCalls = append(finalToolCalls, toolCalls...) + } + + expectedContent := "This is content." + expectedThinking := "This is thinking content." + expectedToolCalls := []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test_tool", + Arguments: api.ToolCallFunctionArguments{ + "arg": "value", + }, + }, + }, + } + + if finalContent.String() != expectedContent { + t.Errorf("expected content %q, got %q", expectedContent, finalContent.String()) + } + + if finalThinking.String() != expectedThinking { + t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String()) + } + + if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" { + t.Errorf("tool calls mismatch (-want +got):\n%s", diff) + } +} + +func TestCogitoParser_StreamingEdgeCases(t *testing.T) { + tests := []struct { + name string + chunks []string + expectedContent string + expectedThinking string + expectedToolCalls []api.ToolCall + hasThinkingSupport bool + }{ + { + name: "split_thinking_tag", + chunks: []string{ + "This is thinking contentThis is content.", + }, + expectedContent: "This is content.", + expectedThinking: "This is thinking content", + hasThinkingSupport: true, + }, + { + name: "split_tool_calls_begin_tag_conservative_parsing", + chunks: []string{ + "Content before<|tool▁calls▁beg", + "in|><|tool▁call▁begin|>function<|tool▁sep|>test\n```json\n{}\n```<|tool▁call▁end|><|tool▁calls▁end|>", + }, + // Parser is conservative - treats incomplete tags as content + expectedContent: "Content before<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test\n```json\n{}\n```<|tool▁call▁end|><|tool▁calls▁end|>", + expectedToolCalls: nil, + hasThinkingSupport: false, + }, + { + name: "thinking_disabled_with_split_tags", + chunks: []string{ + "Content with should be treated as content.", + }, + expectedContent: "Content with should be treated as content.", + expectedThinking: "", + hasThinkingSupport: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &CogitoParser{} + parser.Init(nil, nil, &api.ThinkValue{Value: tt.hasThinkingSupport}) + + var finalContent, finalThinking strings.Builder + var finalToolCalls []api.ToolCall + + for i, chunk := range tt.chunks { + done := i == len(tt.chunks)-1 + content, thinking, toolCalls, err := parser.Add(chunk, done) + if err != nil { + t.Fatalf("Add() error on chunk %d: %v", i, err) + } + + finalContent.WriteString(content) + finalThinking.WriteString(thinking) + finalToolCalls = append(finalToolCalls, toolCalls...) + } + + if finalContent.String() != tt.expectedContent { + t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String()) + } + + if finalThinking.String() != tt.expectedThinking { + t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String()) + } + + if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" { + t.Errorf("tool calls mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestCogitoParser_HasToolSupport(t *testing.T) { + parser := &CogitoParser{} + if !parser.HasToolSupport() { + t.Error("CogitoParser should support tools") + } +} + +func TestCogitoParser_Init(t *testing.T) { + parser := &CogitoParser{} + + tools := []api.Tool{ + {Function: api.ToolFunction{Name: "test_tool"}}, + } + + lastMessage := &api.Message{Role: "assistant", Content: "previous"} + + returnedTools := parser.Init(tools, lastMessage, nil) + + if len(returnedTools) != len(tools) { + t.Errorf("expected %d tools returned, got %d", len(tools), len(returnedTools)) + } +} + +func TestCogitoParser_parseToolCallContent(t *testing.T) { + tests := []struct { + name string + content string + expected api.ToolCall + expectError bool + }{ + { + name: "valid_tool_call_standard_format", + content: `function<|tool▁sep|>get_weather +` + "```json\n" + `{"location":"Paris"} +` + "```", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + expectError: false, + }, + { + name: "valid_tool_call_complex_args", + content: `function<|tool▁sep|>process_data +` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true},"count":42} +` + "```", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "process_data", + Arguments: api.ToolCallFunctionArguments{ + "items": []any{"item1", "item2"}, + "config": map[string]any{"enabled": true}, + "count": 42.0, + }, + }, + }, + expectError: false, + }, + { + name: "valid_tool_call_empty_args", + content: `function<|tool▁sep|>no_args_tool +` + "```json\n" + `{} +` + "```", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "no_args_tool", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, + expectError: false, + }, + { + name: "missing_separator", + content: `functionget_weather` + "```json\n" + `{"location":"Paris"}` + "\n```", + expected: api.ToolCall{}, + expectError: true, + }, + { + name: "invalid_function_type", + content: `not_function<|tool▁sep|>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```", + expected: api.ToolCall{}, + expectError: true, + }, + { + name: "missing_json_block_start", + content: `function<|tool▁sep|>get_weather{"location":"Paris"}` + "```", + expected: api.ToolCall{}, + expectError: true, + }, + { + name: "missing_json_block_end", + content: `function<|tool▁sep|>get_weather` + "```json\n" + `{"location":"Paris"}`, + expected: api.ToolCall{}, + expectError: true, + }, + { + name: "invalid_json", + content: `function<|tool▁sep|>get_weather` + "```json\n" + `{location:Paris}` + "\n```", + expected: api.ToolCall{}, + expectError: true, + }, + { + name: "empty_function_type", + content: `<|tool▁sep|>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```", + expected: api.ToolCall{}, + expectError: true, + }, + { + name: "tool_with_spaces_in_name", + content: `function<|tool▁sep|> get_weather +` + "```json\n" + `{"location":"Paris"} +` + "```", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + expectError: false, + }, + { + name: "tool_with_multiline_json", + content: `function<|tool▁sep|>get_weather +` + "```json\n" + `{ + "location": "Paris", + "units": "metric" +} +` + "```", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + "units": "metric", + }, + }, + }, + expectError: false, + }, + { + name: "tool_with_nested_objects", + content: `function<|tool▁sep|>complex_tool +` + "```json\n" + `{"nested":{"deep":{"value":123}}} +` + "```", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "complex_tool", + Arguments: api.ToolCallFunctionArguments{ + "nested": map[string]any{ + "deep": map[string]any{ + "value": 123.0, + }, + }, + }, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &CogitoParser{} + + result, err := parser.parseToolCallContent(tt.content) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.expected, result); diff != "" { + t.Errorf("tool call mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 4374f3e28..9bfd5c1ec 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -6,9 +6,9 @@ import ( ) type Parser interface { - // Init initializes the parser with tools and optional last message for chat prefill + // Init initializes the parser with tools, optional last message for chat prefill, and think value // Returns processed tools if the parser needs to modify them (e.g., harmony renames them) - Init(tools []api.Tool, lastMessage *api.Message) []api.Tool + Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool // Add processes streamed content and returns parsed content, thinking, and tool calls // The done flag indicates if this is the last chunk (used for draining accumulators) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) @@ -52,6 +52,8 @@ func ParserForName(name string) Parser { return &PassthroughParser{} case "harmony": return harmony.NewHarmonyMessageHandler() + case "cogito": + return &CogitoParser{} default: return nil } @@ -59,7 +61,7 @@ func ParserForName(name string) Parser { type PassthroughParser struct{} -func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { +func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { return tools // passthrough doesn't modify tools } diff --git a/model/parsers/parsers_test.go b/model/parsers/parsers_test.go index 8a64a2358..e918a62f6 100644 --- a/model/parsers/parsers_test.go +++ b/model/parsers/parsers_test.go @@ -10,7 +10,7 @@ type mockParser struct { name string } -func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { +func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { return tools } diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go index bfa9762c1..9a073b1c4 100644 --- a/model/parsers/qwen3coder.go +++ b/model/parsers/qwen3coder.go @@ -43,7 +43,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool { return false } -func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { +func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { p.tools = tools return tools // Qwen doesn't modify tools } @@ -432,7 +432,7 @@ func transformToXML(raw string) string { groups := qwenTagRegex.FindStringSubmatch(match) tag := groups[1] var escapedValue strings.Builder - xml.EscapeText(&escapedValue, []byte(groups[2])) + _ = xml.EscapeText(&escapedValue, []byte(groups[2])) // error is always nil for strings.Builder return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String()) }) diff --git a/model/parsers/qwen3vl.go b/model/parsers/qwen3vl.go index 87f49e892..979f1668c 100644 --- a/model/parsers/qwen3vl.go +++ b/model/parsers/qwen3vl.go @@ -54,7 +54,7 @@ func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message) { p.state = CollectingThinkingContent } -func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { +func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { p.tools = tools p.setInitialState(lastMessage) return tools diff --git a/model/parsers/qwen3vl_nonthinking_test.go b/model/parsers/qwen3vl_nonthinking_test.go index e0b9a02b3..803824a68 100644 --- a/model/parsers/qwen3vl_nonthinking_test.go +++ b/model/parsers/qwen3vl_nonthinking_test.go @@ -198,7 +198,7 @@ func TestQwen3VLNonThinkingParserStreaming(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { parser := Qwen3VLParser{hasThinkingSupport: false} - parser.Init([]api.Tool{}, nil) + parser.Init([]api.Tool{}, nil, nil) for i, step := range tc.steps { parser.buffer.WriteString(step.input) @@ -515,7 +515,7 @@ func TestQwenOldParserStreaming(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { parser := Qwen3VLParser{hasThinkingSupport: false} - parser.Init([]api.Tool{}, nil) + parser.Init([]api.Tool{}, nil, nil) for i, step := range tc.steps { parser.buffer.WriteString(step.input) @@ -822,7 +822,7 @@ func TestQwen3VLNonThinkingToolCallWhitespaceHandling(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { parser := Qwen3VLParser{hasThinkingSupport: false} - parser.Init([]api.Tool{}, nil) + parser.Init([]api.Tool{}, nil, nil) for i, step := range tc.steps { parser.buffer.WriteString(step.input) diff --git a/model/parsers/qwen3vl_thinking_test.go b/model/parsers/qwen3vl_thinking_test.go index 04b2a7db3..2d2424d20 100644 --- a/model/parsers/qwen3vl_thinking_test.go +++ b/model/parsers/qwen3vl_thinking_test.go @@ -205,7 +205,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { parser := Qwen3VLParser{hasThinkingSupport: true} - parser.Init([]api.Tool{}, nil) + parser.Init([]api.Tool{}, nil, nil) // parser.state = CollectingThinkingContent for i, step := range tc.steps { @@ -386,7 +386,7 @@ func TestQwen3VLParserState(t *testing.T) { for _, tc := range cases { parser := Qwen3VLParser{hasThinkingSupport: tc.hasThinking} - parser.Init(nil, tc.last) + parser.Init(nil, tc.last, nil) if parser.state != tc.wantState { t.Errorf("%s: got state %v, want %v", tc.desc, parser.state, tc.wantState) } @@ -437,7 +437,7 @@ func TestQwen3VLThinkingParserWithThinkingPrefill(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { parser := Qwen3VLParser{hasThinkingSupport: true} - parser.Init([]api.Tool{}, last) + parser.Init([]api.Tool{}, last, nil) for i, step := range tc.steps { parser.buffer.WriteString(step.input) @@ -500,7 +500,7 @@ func TestQwen3VLThinkingParserWithNonThinkingPrefill(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { parser := Qwen3VLParser{hasThinkingSupport: true} - parser.Init([]api.Tool{}, last) + parser.Init([]api.Tool{}, last, nil) for i, step := range tc.steps { parser.buffer.WriteString(step.input) @@ -523,7 +523,7 @@ func TestQwen3VLThinkingParserStreamingAssistantPrefillContent(t *testing.T) { // last message is assistant with content ⇒ start in CollectingContent last := &api.Message{Role: "assistant", Content: "has content"} parser := Qwen3VLParser{hasThinkingSupport: true} - parser.Init([]api.Tool{}, last) + parser.Init([]api.Tool{}, last, nil) type step struct { input string @@ -750,7 +750,7 @@ func TestQwen3VLThinkingWhitespaceHandling(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { parser := Qwen3VLParser{hasThinkingSupport: true} - parser.Init([]api.Tool{}, nil) + parser.Init([]api.Tool{}, nil, nil) for i, step := range tc.steps { parser.buffer.WriteString(step.input) @@ -859,7 +859,7 @@ func TestQwen3VLToolCallWhitespaceHandling(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { parser := Qwen3VLParser{hasThinkingSupport: true} - parser.Init([]api.Tool{}, tc.prefillMsg) + parser.Init([]api.Tool{}, tc.prefillMsg, nil) for i, step := range tc.steps { parser.buffer.WriteString(step.input) diff --git a/server/routes.go b/server/routes.go index 38ce4e4d4..16df3f4fc 100644 --- a/server/routes.go +++ b/server/routes.go @@ -340,7 +340,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { builtinParser = parsers.ParserForName(m.Config.Parser) if builtinParser != nil { // no tools or last message for generate endpoint - builtinParser.Init(nil, nil) + builtinParser.Init(nil, nil, req.Think) } } @@ -2051,7 +2051,7 @@ func (s *Server) ChatHandler(c *gin.Context) { lastMessage = &msgs[len(msgs)-1] } // Initialize parser and get processed tools - processedTools = builtinParser.Init(req.Tools, lastMessage) + processedTools = builtinParser.Init(req.Tools, lastMessage, req.Think) } }