diff --git a/api/client.go b/api/client.go index 3dffce600..9f0dba8dc 100644 --- a/api/client.go +++ b/api/client.go @@ -24,7 +24,10 @@ import ( "net/http" "net/url" "runtime" + "strconv" + "time" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/version" @@ -76,6 +79,14 @@ func NewClient(base *url.URL, http *http.Client) *Client { } } +func getAuthorizationToken(ctx context.Context, challenge string) (string, error) { + token, err := auth.Sign(ctx, []byte(challenge)) + if err != nil { + return "", err + } + return token, nil +} + func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { var reqBody io.Reader var data []byte @@ -97,6 +108,21 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData } requestURL := c.base.JoinPath(path) + + var token string + if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" { + now := strconv.FormatInt(time.Now().Unix(), 10) + chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now) + token, err = getAuthorizationToken(ctx, chal) + if err != nil { + return err + } + + q := requestURL.Query() + q.Set("ts", now) + requestURL.RawQuery = q.Encode() + } + request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody) if err != nil { return err @@ -106,6 +132,10 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData request.Header.Set("Accept", "application/json") request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) + if token != "" { + request.Header.Set("Authorization", token) + } + respObj, err := c.http.Do(request) if err != nil { return err @@ -143,6 +173,22 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f } requestURL := c.base.JoinPath(path) + + var token string + if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" { + var err error + now := strconv.FormatInt(time.Now().Unix(), 10) + chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now) + token, err = getAuthorizationToken(ctx, chal) + if err != nil { + return err + } + + q := requestURL.Query() + q.Set("ts", now) + requestURL.RawQuery = q.Encode() + } + request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf) if err != nil { return err @@ -152,6 +198,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f request.Header.Set("Accept", "application/x-ndjson") request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) + if token != "" { + request.Header.Set("Authorization", token) + } + response, err := c.http.Do(request) if err != nil { return err diff --git a/envconfig/config.go b/envconfig/config.go index 9d7c2e218..763f04646 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -183,6 +183,8 @@ var ( NewEngine = Bool("OLLAMA_NEW_ENGINE") // ContextLength sets the default context length ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) + // Auth enables authentication between the Ollama client and server + UseAuth = Bool("OLLAMA_AUTH") ) func String(s string) func() string {