This commit is contained in:
Tiago Terenas Almeida 2025-12-12 19:07:47 +00:00
parent 2a310f5fc6
commit 36280de427
3 changed files with 107 additions and 6 deletions

View File

@ -1521,13 +1521,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// Inference
r.GET("/api/ps", s.PsHandler)
// Web proxy endpoints: forward web search/fetch to main server (ollama.com)
r.POST("/api/web_search", s.WebSearchHandler)
r.POST("/api/web_fetch", s.WebFetchHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/web_search", s.WebSearchHandler)
r.POST("/api/web_fetch", s.WebFetchHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)

View File

@ -13,6 +13,12 @@ import (
"github.com/ollama/ollama/auth"
)
// signFunc is a variable to allow tests to override signing behavior.
var signFunc = auth.Sign
// httpClient is injectable for tests to capture outbound requests.
var httpClient = &http.Client{Timeout: 30 * time.Second}
// proxyToMain forwards the incoming request body to the main ollama server
// and sets an Authorization token if signing is available locally.
func (s *Server) proxyToMain(c *gin.Context, path string) {
@ -27,7 +33,7 @@ func (s *Server) proxyToMain(c *gin.Context, path string) {
now := strconv.FormatInt(time.Now().Unix(), 10)
chal := fmt.Sprintf("%s,%s?ts=%s", http.MethodPost, path, now)
token, err := auth.Sign(ctx, []byte(chal))
token, err := signFunc(ctx, []byte(chal))
if err != nil {
// If signing fails, return an error so callers know the proxy couldn't
// obtain a token. Clients may fallback to asking the user for a key.
@ -57,8 +63,7 @@ func (s *Server) proxyToMain(c *gin.Context, path string) {
req.Header.Set("Authorization", token)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
resp, err := httpClient.Do(req)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to contact main server"})
return

97
server/webproxy_test.go Normal file
View File

@ -0,0 +1,97 @@
package server
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
type rtStub struct {
fn func(req *http.Request) *http.Response
}
func (r *rtStub) RoundTrip(req *http.Request) (*http.Response, error) {
if r.fn == nil {
return nil, errors.New("no stub")
}
return r.fn(req), nil
}
func TestProxyToMain_WithToken_SetsAuthorizationAndForwards(t *testing.T) {
t.Parallel()
// backup globals
origSign := signFunc
origClient := httpClient
defer func() { signFunc = origSign; httpClient = origClient }()
var captured *http.Request
// stub signing
signFunc = func(_ any, _ []byte) (string, error) { return "Bearer testtoken", nil }
// stub transport
stub := &rtStub{fn: func(req *http.Request) *http.Response {
captured = req
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("ok")),
Header: http.Header{"Content-Type": {"application/json"}},
}
}}
httpClient = &http.Client{Transport: stub}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodPost, "/api/web_search", strings.NewReader(`{"q":"hi"}`))
req.Header.Set("Content-Type", "application/json")
c.Request = req
s := &Server{}
s.WebSearchHandler(c)
if w.Code != 200 {
t.Fatalf("expected 200 status, got %d", w.Code)
}
if strings.TrimSpace(w.Body.String()) != "ok" {
t.Fatalf("unexpected body: %q", w.Body.String())
}
if captured == nil {
t.Fatal("no outbound request captured")
}
if got := captured.Header.Get("Authorization"); got != "Bearer testtoken" {
t.Fatalf("expected Authorization header set, got %q", got)
}
if captured.URL.Path != "/api/web_search" {
t.Fatalf("expected path /api/web_search, got %s", captured.URL.Path)
}
}
func TestProxyToMain_SignFails_Returns500(t *testing.T) {
t.Parallel()
origSign := signFunc
defer func() { signFunc = origSign }()
signFunc = func(_ any, _ []byte) (string, error) { return "", errors.New("no key") }
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodPost, "/api/web_fetch", strings.NewReader(`{"url":"https://example.com"}`))
req.Header.Set("Content-Type", "application/json")
c.Request = req
s := &Server{}
s.WebFetchHandler(c)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500 status, got %d", w.Code)
}
}