openai compatibility

This commit is contained in:
Roy Han 2024-06-28 21:03:55 -07:00
parent b910fa9010
commit d3f98a811e
4 changed files with 11 additions and 2 deletions

View File

@ -94,11 +94,14 @@ type ChatRequest struct {
Format string `json:"format"` Format string `json:"format"`
// KeepAlive controls how long the model will stay loaded into memory // KeepAlive controls how long the model will stay loaded into memory
// followin the request. // following the request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
// OpenAI indicates redirection from the compatibility endpoint.
OpenAI bool `json:"openai,omitempty"`
} }
// Message is a single message in a chat sequence. The message contains the // Message is a single message in a chat sequence. The message contains the

View File

@ -656,6 +656,7 @@ type completion struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Stop bool `json:"stop"` Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"` StoppedLimit bool `json:"stopped_limit"`
TokensEval int `json:"tokens_evaluated"`
Timings struct { Timings struct {
PredictedN int `json:"predicted_n"` PredictedN int `json:"predicted_n"`
@ -680,6 +681,7 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration PromptEvalDuration time.Duration
EvalCount int EvalCount int
EvalDuration time.Duration EvalDuration time.Duration
TokensEval int
} }
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@ -828,6 +830,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN, EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS), EvalDuration: parseDurationMs(c.Timings.PredictedMS),
TokensEval: c.TokensEval,
}) })
return nil return nil
} }

View File

@ -117,7 +117,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}(r.DoneReason), }(r.DoneReason),
}}, }},
Usage: Usage{ Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount, PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount, CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount,
@ -205,6 +204,7 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
Format: format, Format: format,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
OpenAI: true,
} }
} }

View File

@ -1372,6 +1372,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch) defer close(ch)
fn := func(r llm.CompletionResponse) { fn := func(r llm.CompletionResponse) {
if req.OpenAI {
r.PromptEvalCount = r.TokensEval
}
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),