mirror of https://github.com/ollama/ollama
openai: add v1/responses support (#13351)
Only supporting the stateless part of the API. Doc updates to come once this is shipped. Closes: #9659
This commit is contained in:
parent
3475d915cb
commit
1eb5e75972
|
|
@ -433,3 +433,111 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ResponsesWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
converter *openai.ResponsesStreamConverter
|
||||||
|
model string
|
||||||
|
stream bool
|
||||||
|
responseID string
|
||||||
|
itemID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||||
|
d, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var chatResponse api.ChatResponse
|
||||||
|
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.stream {
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
|
events := w.converter.Process(chatResponse)
|
||||||
|
for _, event := range events {
|
||||||
|
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-streaming response
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||||
|
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(data)
|
||||||
|
}
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResponsesMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req openai.ResponsesRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
chatReq, err := openai.FromResponsesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if client requested streaming (defaults to false)
|
||||||
|
streamRequested := req.Stream != nil && *req.Stream
|
||||||
|
|
||||||
|
// Pass streaming preference to the underlying chat request
|
||||||
|
chatReq.Stream = &streamRequested
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
responseID := fmt.Sprintf("resp_%d", rand.Intn(999999))
|
||||||
|
itemID := fmt.Sprintf("msg_%d", rand.Intn(999999))
|
||||||
|
|
||||||
|
w := &ResponsesWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||||
|
model: req.Model,
|
||||||
|
stream: streamRequested,
|
||||||
|
responseID: responseID,
|
||||||
|
itemID: itemID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set headers based on streaming mode
|
||||||
|
if streamRequested {
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -487,29 +487,9 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
img, err := decodeImageURL(url)
|
||||||
valid := false
|
|
||||||
// support blank mime type to match api/chat taking just unadorned base64
|
|
||||||
if strings.HasPrefix(url, "data:;base64,") {
|
|
||||||
url = strings.TrimPrefix(url, "data:;base64,")
|
|
||||||
valid = true
|
|
||||||
}
|
|
||||||
for _, t := range types {
|
|
||||||
prefix := "data:image/" + t + ";base64,"
|
|
||||||
if strings.HasPrefix(url, prefix) {
|
|
||||||
url = strings.TrimPrefix(url, prefix)
|
|
||||||
valid = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !valid {
|
|
||||||
return nil, errors.New("invalid image input")
|
|
||||||
}
|
|
||||||
|
|
||||||
img, err := base64.StdEncoding.DecodeString(url)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("invalid message format")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||||
|
|
@ -648,6 +628,35 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||||
|
func decodeImageURL(url string) (api.ImageData, error) {
|
||||||
|
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||||
|
|
||||||
|
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||||
|
if strings.HasPrefix(url, "data:;base64,") {
|
||||||
|
url = strings.TrimPrefix(url, "data:;base64,")
|
||||||
|
} else {
|
||||||
|
valid := false
|
||||||
|
for _, t := range types {
|
||||||
|
prefix := "data:image/" + t + ";base64,"
|
||||||
|
if strings.HasPrefix(url, prefix) {
|
||||||
|
url = strings.TrimPrefix(url, prefix)
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return nil, errors.New("invalid image input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := base64.StdEncoding.DecodeString(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("invalid image input")
|
||||||
|
}
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
|
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
|
||||||
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
||||||
apiToolCalls := make([]api.ToolCall, len(toolCalls))
|
apiToolCalls := make([]api.ToolCall, len(toolCalls))
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1532,6 +1532,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||||
|
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||||
|
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
// wrap old with new
|
// wrap old with new
|
||||||
|
|
@ -2393,3 +2394,4 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||||
}
|
}
|
||||||
return msgs
|
return msgs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue