mirror of https://github.com/ollama/ollama
openai: add v1/responses support (#13351)
Only supporting the stateless part of the API. Doc updates to come once this is shipped. Closes: #9659
This commit is contained in:
parent
3475d915cb
commit
1eb5e75972
|
|
@ -433,3 +433,111 @@ func ChatMiddleware() gin.HandlerFunc {
|
|||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type ResponsesWriter struct {
|
||||
BaseWriter
|
||||
converter *openai.ResponsesStreamConverter
|
||||
model string
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
chatReq, err := openai.FromResponsesRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if client requested streaming (defaults to false)
|
||||
streamRequested := req.Stream != nil && *req.Stream
|
||||
|
||||
// Pass streaming preference to the underlying chat request
|
||||
chatReq.Stream = &streamRequested
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
responseID := fmt.Sprintf("resp_%d", rand.Intn(999999))
|
||||
itemID := fmt.Sprintf("msg_%d", rand.Intn(999999))
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
if streamRequested {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -487,29 +487,9 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||
}
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
valid := false
|
||||
// support blank mime type to match api/chat taking just unadorned base64
|
||||
if strings.HasPrefix(url, "data:;base64,") {
|
||||
url = strings.TrimPrefix(url, "data:;base64,")
|
||||
valid = true
|
||||
}
|
||||
for _, t := range types {
|
||||
prefix := "data:image/" + t + ";base64,"
|
||||
if strings.HasPrefix(url, prefix) {
|
||||
url = strings.TrimPrefix(url, prefix)
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
return nil, errors.New("invalid image input")
|
||||
}
|
||||
|
||||
img, err := base64.StdEncoding.DecodeString(url)
|
||||
img, err := decodeImageURL(url)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid message format")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||
|
|
@ -648,6 +628,35 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
if strings.HasPrefix(url, "data:;base64,") {
|
||||
url = strings.TrimPrefix(url, "data:;base64,")
|
||||
} else {
|
||||
valid := false
|
||||
for _, t := range types {
|
||||
prefix := "data:image/" + t + ";base64,"
|
||||
if strings.HasPrefix(url, prefix) {
|
||||
url = strings.TrimPrefix(url, prefix)
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("invalid image input")
|
||||
}
|
||||
}
|
||||
|
||||
img, err := base64.StdEncoding.DecodeString(url)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid image input")
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
|
||||
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
||||
apiToolCalls := make([]api.ToolCall, len(toolCalls))
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1532,6 +1532,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
|
|
@ -2393,3 +2394,4 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
|||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue