mirror of https://github.com/ollama/ollama
Merge 36280de427 into a013693f80
This commit is contained in:
commit
e1aac9195c
|
|
@ -1525,6 +1525,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||||
r.POST("/api/chat", s.ChatHandler)
|
r.POST("/api/chat", s.ChatHandler)
|
||||||
r.POST("/api/embed", s.EmbedHandler)
|
r.POST("/api/embed", s.EmbedHandler)
|
||||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||||
|
r.POST("/api/web_search", s.WebSearchHandler)
|
||||||
|
r.POST("/api/web_fetch", s.WebFetchHandler)
|
||||||
|
|
||||||
// Inference (OpenAI compatibility)
|
// Inference (OpenAI compatibility)
|
||||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// signFunc is a variable to allow tests to override signing behavior.
|
||||||
|
var signFunc = auth.Sign
|
||||||
|
|
||||||
|
// httpClient is injectable for tests to capture outbound requests.
|
||||||
|
var httpClient = &http.Client{Timeout: 30 * time.Second}
|
||||||
|
|
||||||
|
// proxyToMain forwards the incoming request body to the main ollama server
|
||||||
|
// and sets an Authorization token if signing is available locally.
|
||||||
|
func (s *Server) proxyToMain(c *gin.Context, path string) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||||
|
chal := fmt.Sprintf("%s,%s?ts=%s", http.MethodPost, path, now)
|
||||||
|
|
||||||
|
token, err := signFunc(ctx, []byte(chal))
|
||||||
|
if err != nil {
|
||||||
|
// If signing fails, return an error so callers know the proxy couldn't
|
||||||
|
// obtain a token. Clients may fallback to asking the user for a key.
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to sign request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteURL := &url.URL{Scheme: "https", Host: "ollama.com", Path: path}
|
||||||
|
q := remoteURL.Query()
|
||||||
|
q.Set("ts", now)
|
||||||
|
remoteURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, remoteURL.String(), bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create outbound request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve content type if provided
|
||||||
|
if ct := c.GetHeader("Content-Type"); ct != "" {
|
||||||
|
req.Header.Set("Content-Type", ct)
|
||||||
|
} else {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
if token != "" {
|
||||||
|
req.Header.Set("Authorization", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to contact main server"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// proxy status and content-type back to caller
|
||||||
|
c.Status(resp.StatusCode)
|
||||||
|
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||||
|
c.Header("Content-Type", ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stream body
|
||||||
|
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
|
// nothing much to do here, connection likely closed by client
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) WebSearchHandler(c *gin.Context) {
|
||||||
|
s.proxyToMain(c, "/api/web_search")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) WebFetchHandler(c *gin.Context) {
|
||||||
|
s.proxyToMain(c, "/api/web_fetch")
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rtStub struct {
|
||||||
|
fn func(req *http.Request) *http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rtStub) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if r.fn == nil {
|
||||||
|
return nil, errors.New("no stub")
|
||||||
|
}
|
||||||
|
return r.fn(req), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyToMain_WithToken_SetsAuthorizationAndForwards(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// backup globals
|
||||||
|
origSign := signFunc
|
||||||
|
origClient := httpClient
|
||||||
|
defer func() { signFunc = origSign; httpClient = origClient }()
|
||||||
|
|
||||||
|
var captured *http.Request
|
||||||
|
|
||||||
|
// stub signing
|
||||||
|
signFunc = func(_ any, _ []byte) (string, error) { return "Bearer testtoken", nil }
|
||||||
|
|
||||||
|
// stub transport
|
||||||
|
stub := &rtStub{fn: func(req *http.Request) *http.Response {
|
||||||
|
captured = req
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Header: http.Header{"Content-Type": {"application/json"}},
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
httpClient = &http.Client{Transport: stub}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/web_search", strings.NewReader(`{"q":"hi"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
s := &Server{}
|
||||||
|
s.WebSearchHandler(c)
|
||||||
|
|
||||||
|
if w.Code != 200 {
|
||||||
|
t.Fatalf("expected 200 status, got %d", w.Code)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(w.Body.String()) != "ok" {
|
||||||
|
t.Fatalf("unexpected body: %q", w.Body.String())
|
||||||
|
}
|
||||||
|
if captured == nil {
|
||||||
|
t.Fatal("no outbound request captured")
|
||||||
|
}
|
||||||
|
if got := captured.Header.Get("Authorization"); got != "Bearer testtoken" {
|
||||||
|
t.Fatalf("expected Authorization header set, got %q", got)
|
||||||
|
}
|
||||||
|
if captured.URL.Path != "/api/web_search" {
|
||||||
|
t.Fatalf("expected path /api/web_search, got %s", captured.URL.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyToMain_SignFails_Returns500(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
origSign := signFunc
|
||||||
|
defer func() { signFunc = origSign }()
|
||||||
|
|
||||||
|
signFunc = func(_ any, _ []byte) (string, error) { return "", errors.New("no key") }
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/web_fetch", strings.NewReader(`{"url":"https://example.com"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
s := &Server{}
|
||||||
|
s.WebFetchHandler(c)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500 status, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue