diff --git a/server/routes.go b/server/routes.go index 4dd870ed0..bbf6b9b90 100644 --- a/server/routes.go +++ b/server/routes.go @@ -2195,7 +2195,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 { slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) ch <- res } else { @@ -2235,8 +2235,16 @@ func (s *Server) ChatHandler(c *gin.Context) { res.Message.ToolCalls = toolCalls res.Message.Content = "" } else if res.Message.Thinking != "" { - // don't return + // don't return, fall through to send } else { + // Send logprobs while content is being buffered by the parser for tool calls + if len(res.Logprobs) > 0 && !r.Done { + logprobRes := res + logprobRes.Message.Content = "" + logprobRes.Message.ToolCalls = nil + ch <- logprobRes + } + if r.Done { res.Message.Content = toolParser.Content() ch <- res diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index a9931ea24..13befff2a 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -708,6 +708,95 @@ func TestGenerateChat(t *testing.T) { } }) + t.Run("messages with tools and logprobs (streaming)", func(t *testing.T) { + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + defer wg.Done() + + // Simulate a response where logprobs are sent while the tool call is being buffered + responses := []llm.CompletionResponse{ + { + Content: `{ "name": "get_weather"`, + Done: false, + Logprobs: []llm.Logprob{{}}, + }, + { + Content: `,"arguments":{"location":"Seattle, WA","unit":"celsius"}}`, + Done: false, + Logprobs: []llm.Logprob{{}}, + }, + { + Content: ``, + Done: true, + DoneReason: llm.DoneReasonStop, + Logprobs: nil, + }, + } + + for _, resp := range responses { + select { + case <-ctx.Done(): + return ctx.Err() + default: + fn(resp) + time.Sleep(10 * time.Millisecond) + } + } + return nil + } + + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + }, + Tools: tools, + Stream: &stream, + }) + + wg.Wait() + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + decoder := json.NewDecoder(w.Body) + var totalLogprobs int + + for { + var resp api.ChatResponse + if err := decoder.Decode(&resp); err == io.EOF { + break + } else if err != nil { + t.Fatal(err) + } + + totalLogprobs += len(resp.Logprobs) + } + + expectedLogprobs := 2 + if totalLogprobs != expectedLogprobs { + t.Errorf("expected %d logprobs, got %d", expectedLogprobs, totalLogprobs) + } + }) + t.Run("status error non-streaming", func(t *testing.T) { mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { return api.StatusError{