From 1ca608bcd155c771d0fed683a75d8367fe9c7144 Mon Sep 17 00:00:00 2001 From: nicole pardal <109545900+npardal@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:58:03 -0800 Subject: [PATCH] embeddings: added embedding command for cl (#12795) Co-authored-by: A-Akhil This PR introduces a new ollama embed command that allows users to generate embeddings directly from the command line. Added ollama embed MODEL [TEXT...] command for generating text embeddings Supports both direct text arguments and stdin piping for scripted workflows Outputs embeddings as JSON arrays (one per line) --- README.md | 12 ++ cmd/cmd.go | 69 ++++++++++- cmd/cmd_test.go | 324 ++++++++++++++++++++++++++++++++++++++++++++++++ docs/cli.mdx | 12 ++ 4 files changed, 416 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 58df3d55f..be1b88cc9 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,18 @@ ollama ps ollama stop llama3.2 ``` +### Generate embeddings from the CLI + +```shell +ollama run embeddinggemma "Your text to embed" +``` + +You can also pipe text for scripted workflows: + +```shell +echo "Your text to embed" | ollama run embeddinggemma +``` + ### Start Ollama `ollama serve` is used when you want to start ollama without running the desktop application. diff --git a/cmd/cmd.go b/cmd/cmd.go index fbacda7c3..a67299402 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -322,6 +322,44 @@ func StopHandler(cmd *cobra.Command, args []string) error { return nil } +func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *api.Duration, truncate *bool, dimensions int) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + req := &api.EmbedRequest{ + Model: modelName, + Input: input, + } + if keepAlive != nil { + req.KeepAlive = keepAlive + } + if truncate != nil { + req.Truncate = truncate + } + if dimensions > 0 { + req.Dimensions = dimensions + } + + resp, err := client.Embed(cmd.Context(), req) + if err != nil { + return err + } + + if len(resp.Embeddings) == 0 { + return errors.New("no embeddings returned") + } + + output, err := json.Marshal(resp.Embeddings[0]) + if err != nil { + return err + } + fmt.Println(string(output)) + + return nil +} + func RunHandler(cmd *cobra.Command, args []string) error { interactive := true @@ -386,7 +424,11 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - prompts = append([]string{string(in)}, prompts...) + // Only prepend stdin content if it's not empty + stdinContent := string(in) + if len(stdinContent) > 0 { + prompts = append([]string{stdinContent}, prompts...) + } opts.ShowConnect = false opts.WordWrap = false interactive = false @@ -452,6 +494,29 @@ func RunHandler(cmd *cobra.Command, args []string) error { opts.ParentModel = info.Details.ParentModel + // Check if this is an embedding model + isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding) + + // If it's an embedding model, handle embedding generation + if isEmbeddingModel { + if opts.Prompt == "" { + return errors.New("embedding models require input text. Usage: ollama run " + name + " \"your text here\"") + } + + // Get embedding-specific flags + var truncate *bool + if truncateFlag, err := cmd.Flags().GetBool("truncate"); err == nil && cmd.Flags().Changed("truncate") { + truncate = &truncateFlag + } + + dimensions, err := cmd.Flags().GetInt("dimensions") + if err != nil { + return err + } + + return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions) + } + if interactive { if err := loadOrUnloadModel(cmd, &opts); err != nil { var sErr api.AuthorizationError @@ -1684,6 +1749,8 @@ func NewCLI() *cobra.Command { runCmd.Flags().String("think", "", "Enable thinking mode: true/false or high/medium/low for supported models") runCmd.Flags().Lookup("think").NoOptDefVal = "true" runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)") + runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead") + runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)") stopCmd := &cobra.Command{ Use: "stop MODEL", diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index a84272c8e..1c9d19942 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -355,6 +355,330 @@ func TestDeleteHandler(t *testing.T) { } } +func TestRunEmbeddingModel(t *testing.T) { + reqCh := make(chan api.EmbedRequest, 1) + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" && r.Method == http.MethodPost { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + Capabilities: []model.Capability{model.CapabilityEmbedding}, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + if r.URL.Path == "/api/embed" && r.Method == http.MethodPost { + var req api.EmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + reqCh <- req + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(api.EmbedResponse{ + Model: "test-embedding-model", + Embeddings: [][]float32{{0.1, 0.2, 0.3}}, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + http.NotFound(w, r) + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("truncate", false, "") + cmd.Flags().Int("dimensions", 0, "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + errCh := make(chan error, 1) + go func() { + errCh <- RunHandler(cmd, []string{"test-embedding-model", "hello", "world"}) + }() + + err := <-errCh + w.Close() + os.Stdout = oldStdout + + if err != nil { + t.Fatalf("RunHandler returned error: %v", err) + } + + var out bytes.Buffer + io.Copy(&out, r) + + select { + case req := <-reqCh: + inputText, _ := req.Input.(string) + if diff := cmp.Diff("hello world", inputText); diff != "" { + t.Errorf("unexpected input (-want +got):\n%s", diff) + } + if req.Truncate != nil { + t.Errorf("expected truncate to be nil, got %v", *req.Truncate) + } + if req.KeepAlive != nil { + t.Errorf("expected keepalive to be nil, got %v", req.KeepAlive) + } + if req.Dimensions != 0 { + t.Errorf("expected dimensions to be 0, got %d", req.Dimensions) + } + default: + t.Fatal("server did not receive embed request") + } + + expectOutput := "[0.1,0.2,0.3]\n" + if diff := cmp.Diff(expectOutput, out.String()); diff != "" { + t.Errorf("unexpected output (-want +got):\n%s", diff) + } +} + +func TestRunEmbeddingModelWithFlags(t *testing.T) { + reqCh := make(chan api.EmbedRequest, 1) + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" && r.Method == http.MethodPost { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + Capabilities: []model.Capability{model.CapabilityEmbedding}, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + if r.URL.Path == "/api/embed" && r.Method == http.MethodPost { + var req api.EmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + reqCh <- req + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(api.EmbedResponse{ + Model: "test-embedding-model", + Embeddings: [][]float32{{0.4, 0.5}}, + LoadDuration: 5 * time.Millisecond, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + http.NotFound(w, r) + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("truncate", false, "") + cmd.Flags().Int("dimensions", 0, "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + + if err := cmd.Flags().Set("truncate", "true"); err != nil { + t.Fatalf("failed to set truncate flag: %v", err) + } + if err := cmd.Flags().Set("dimensions", "2"); err != nil { + t.Fatalf("failed to set dimensions flag: %v", err) + } + if err := cmd.Flags().Set("keepalive", "5m"); err != nil { + t.Fatalf("failed to set keepalive flag: %v", err) + } + + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + errCh := make(chan error, 1) + go func() { + errCh <- RunHandler(cmd, []string{"test-embedding-model", "test", "input"}) + }() + + err := <-errCh + w.Close() + os.Stdout = oldStdout + + if err != nil { + t.Fatalf("RunHandler returned error: %v", err) + } + + var out bytes.Buffer + io.Copy(&out, r) + + select { + case req := <-reqCh: + inputText, _ := req.Input.(string) + if diff := cmp.Diff("test input", inputText); diff != "" { + t.Errorf("unexpected input (-want +got):\n%s", diff) + } + if req.Truncate == nil || !*req.Truncate { + t.Errorf("expected truncate pointer true, got %v", req.Truncate) + } + if req.Dimensions != 2 { + t.Errorf("expected dimensions 2, got %d", req.Dimensions) + } + if req.KeepAlive == nil || req.KeepAlive.Duration != 5*time.Minute { + t.Errorf("unexpected keepalive duration: %v", req.KeepAlive) + } + default: + t.Fatal("server did not receive embed request") + } + + expectOutput := "[0.4,0.5]\n" + if diff := cmp.Diff(expectOutput, out.String()); diff != "" { + t.Errorf("unexpected output (-want +got):\n%s", diff) + } +} + +func TestRunEmbeddingModelPipedInput(t *testing.T) { + reqCh := make(chan api.EmbedRequest, 1) + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" && r.Method == http.MethodPost { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + Capabilities: []model.Capability{model.CapabilityEmbedding}, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + if r.URL.Path == "/api/embed" && r.Method == http.MethodPost { + var req api.EmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + reqCh <- req + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(api.EmbedResponse{ + Model: "test-embedding-model", + Embeddings: [][]float32{{0.6, 0.7}}, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + http.NotFound(w, r) + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("truncate", false, "") + cmd.Flags().Int("dimensions", 0, "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + + // Capture stdin + oldStdin := os.Stdin + stdinR, stdinW, _ := os.Pipe() + os.Stdin = stdinR + stdinW.Write([]byte("piped text")) + stdinW.Close() + + // Capture stdout + oldStdout := os.Stdout + stdoutR, stdoutW, _ := os.Pipe() + os.Stdout = stdoutW + + errCh := make(chan error, 1) + go func() { + errCh <- RunHandler(cmd, []string{"test-embedding-model", "additional", "args"}) + }() + + err := <-errCh + stdoutW.Close() + os.Stdout = oldStdout + os.Stdin = oldStdin + + if err != nil { + t.Fatalf("RunHandler returned error: %v", err) + } + + var out bytes.Buffer + io.Copy(&out, stdoutR) + + select { + case req := <-reqCh: + inputText, _ := req.Input.(string) + // Should combine piped input with command line args + if diff := cmp.Diff("piped text additional args", inputText); diff != "" { + t.Errorf("unexpected input (-want +got):\n%s", diff) + } + default: + t.Fatal("server did not receive embed request") + } + + expectOutput := "[0.6,0.7]\n" + if diff := cmp.Diff(expectOutput, out.String()); diff != "" { + t.Errorf("unexpected output (-want +got):\n%s", diff) + } +} + +func TestRunEmbeddingModelNoInput(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" && r.Method == http.MethodPost { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + Capabilities: []model.Capability{model.CapabilityEmbedding}, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + http.NotFound(w, r) + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("truncate", false, "") + cmd.Flags().Int("dimensions", 0, "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + + // Test with no input arguments (only model name) + err := RunHandler(cmd, []string{"test-embedding-model"}) + if err == nil || !strings.Contains(err.Error(), "embedding models require input text") { + t.Fatalf("expected error about missing input, got %v", err) + } +} + func TestGetModelfileName(t *testing.T) { tests := []struct { name string diff --git a/docs/cli.mdx b/docs/cli.mdx index 3081838f9..97810e64a 100644 --- a/docs/cli.mdx +++ b/docs/cli.mdx @@ -25,6 +25,18 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol ollama run gemma3 "What's in this image? /Users/jmorgan/Desktop/smile.png" ``` +### Generate embeddings + +``` +ollama run embeddinggemma "Hello world" +``` + +Output is a JSON array: + +``` +echo "Hello world" | ollama run nomic-embed-text +``` + ### Download a model ```