mirror of https://github.com/ollama/ollama
routes: add logprobs in tool calls (#13238)
This commit is contained in:
parent
dac4f17fea
commit
1c4e85b4df
|
|
@ -2195,7 +2195,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
} else {
|
||||
|
|
@ -2235,8 +2235,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else if res.Message.Thinking != "" {
|
||||
// don't return
|
||||
// don't return, fall through to send
|
||||
} else {
|
||||
// Send logprobs while content is being buffered by the parser for tool calls
|
||||
if len(res.Logprobs) > 0 && !r.Done {
|
||||
logprobRes := res
|
||||
logprobRes.Message.Content = ""
|
||||
logprobRes.Message.ToolCalls = nil
|
||||
ch <- logprobRes
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
ch <- res
|
||||
|
|
|
|||
|
|
@ -708,6 +708,95 @@ func TestGenerateChat(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("messages with tools and logprobs (streaming)", func(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
// Simulate a response where logprobs are sent while the tool call is being buffered
|
||||
responses := []llm.CompletionResponse{
|
||||
{
|
||||
Content: `{ "name": "get_weather"`,
|
||||
Done: false,
|
||||
Logprobs: []llm.Logprob{{}},
|
||||
},
|
||||
{
|
||||
Content: `,"arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
|
||||
Done: false,
|
||||
Logprobs: []llm.Logprob{{}},
|
||||
},
|
||||
{
|
||||
Content: ``,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
Logprobs: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
fn(resp)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
var totalLogprobs int
|
||||
|
||||
for {
|
||||
var resp api.ChatResponse
|
||||
if err := decoder.Decode(&resp); err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
totalLogprobs += len(resp.Logprobs)
|
||||
}
|
||||
|
||||
expectedLogprobs := 2
|
||||
if totalLogprobs != expectedLogprobs {
|
||||
t.Errorf("expected %d logprobs, got %d", expectedLogprobs, totalLogprobs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status error non-streaming", func(t *testing.T) {
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
return api.StatusError{
|
||||
|
|
|
|||
Loading…
Reference in New Issue