diff --git a/integration/embed_test.go b/integration/embed_test.go index e155498db..f01903ee5 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -4,7 +4,9 @@ package integration import ( "context" + "errors" "math" + "strings" "testing" "time" @@ -204,8 +206,8 @@ func TestAllMiniLMEmbed(t *testing.T) { t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim) } - if res.PromptEvalCount != 6 { - t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount) + if res.PromptEvalCount != 8 { + t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount) } } @@ -251,8 +253,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim) } - if res.PromptEvalCount != 12 { - t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount) + if res.PromptEvalCount != 16 { + t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount) } } @@ -275,7 +277,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { cases := []struct { name string request api.EmbedRequest - check func(*api.EmbedResponse, error) + check func(*testing.T, *api.EmbedResponse, error) }{ { name: "target truncation", @@ -283,7 +285,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Model: "all-minilm", Input: "why", }, - check: func(got *api.EmbedResponse, err error) { + check: func(t *testing.T, got *api.EmbedResponse, err error) { if err != nil { t.Fatal(err) } @@ -300,10 +302,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Input: "why is the sky blue?", Options: map[string]any{"num_ctx": 3}, }, - check: func(got *api.EmbedResponse, err error) { + check: func(t *testing.T, got *api.EmbedResponse, err error) { if err != nil { t.Fatal(err) } + t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount) if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { t.Errorf("embedding mismatch (-want +got):\n%s", diff) } @@ -317,10 +320,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Truncate: &truncTrue, Options: map[string]any{"num_ctx": 3}, }, - check: func(got *api.EmbedResponse, err error) { + check: func(t *testing.T, got *api.EmbedResponse, err error) { if err != nil { t.Fatal(err) } + t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount) if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { t.Errorf("embedding mismatch (-want +got):\n%s", diff) } @@ -334,21 +338,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Truncate: &truncFalse, Options: map[string]any{"num_ctx": 3}, }, - check: func(res *api.EmbedResponse, err error) { - if err.Error() != "input exceeds maximum context length" { + check: func(t *testing.T, res *api.EmbedResponse, err error) { + if err.Error() != "the input length exceeds the context length" { t.Fatalf("expected truncation error, got: %v", err) } }, }, { - name: "input after truncate error", + name: "input after truncate error with context length of 1", request: api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", Truncate: &truncTrue, Options: map[string]any{"num_ctx": 1}, }, - check: func(res *api.EmbedResponse, err error) { + check: func(t *testing.T, res *api.EmbedResponse, err error) { if err.Error() != "input after truncation exceeds maximum context length" { t.Fatalf("expected truncation error, got: %v", err) } @@ -362,7 +366,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Truncate: &truncTrue, Options: map[string]any{"num_ctx": 0}, }, - check: func(res *api.EmbedResponse, err error) { + check: func(t *testing.T, res *api.EmbedResponse, err error) { if err.Error() != "input after truncation exceeds maximum context length" { t.Fatalf("expected truncation error, got: %v", err) } @@ -375,7 +379,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { Input: "why is the sky blue? Why is the sky blue? hi there my", Options: map[string]any{"num_ctx": 16}, }, - check: func(res *api.EmbedResponse, err error) { + check: func(t *testing.T, res *api.EmbedResponse, err error) { if err != nil { t.Fatal(err) } @@ -385,7 +389,8 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { for _, req := range cases { t.Run(req.name, func(t *testing.T) { - req.check(embedTestHelper(ctx, client, t, req.request)) + resp, err := embedTestHelper(ctx, client, t, req.request) + req.check(t, resp, err) }) } } @@ -409,3 +414,173 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req return client.Embed(ctx, &req) } + +func TestEmbedTruncation(t *testing.T) { + // Use test deadline if set, otherwise default to 2 minutes + timeout := 2 * time.Minute + if deadline, ok := t.Deadline(); ok { + timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + for _, model := range libraryEmbedModels { + model := model + t.Run(model, func(t *testing.T) { + // Check if we're running out of time (reserve 20s for current model) + if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second { + t.Skip("skipping remaining tests to avoid timeout") + } + + // Give each model its own budget to account for first-time pulls/loads + mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute) + defer mcancel() + + t.Run("truncation batch", func(t *testing.T) { + truncTrue := true + req := api.EmbedRequest{ + Model: model, + Input: []string{"short", strings.Repeat("long ", 100), "medium text"}, + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 30}, + } + + res, err := embedTestHelper(mctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if len(res.Embeddings) != 3 { + t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) + } + + if res.PromptEvalCount > 90 { + t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount) + } + }) + + t.Run("runner token count accuracy", func(t *testing.T) { + baseline := api.EmbedRequest{Model: model, Input: "test"} + baseRes, err := embedTestHelper(mctx, client, t, baseline) + if err != nil { + t.Fatal(err) + } + + batch := api.EmbedRequest{ + Model: model, + Input: []string{"test", "test", "test"}, + } + batchRes, err := embedTestHelper(mctx, client, t, batch) + if err != nil { + t.Fatal(err) + } + + expectedCount := baseRes.PromptEvalCount * 3 + if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 { + t.Fatalf("expected ~%d tokens (3 × %d), got %d", + expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount) + } + }) + }) + } +} + +// TestEmbedStatusCode tests that errors from the embedding endpoint +// properly preserve their HTTP status codes when returned to the client. +// This test specifically checks the error handling path in EmbedHandler +// where api.StatusError errors should maintain their original status code. +func TestEmbedStatusCode(t *testing.T) { + // Use test deadline if set, otherwise default to 2 minutes + timeout := 2 * time.Minute + if deadline, ok := t.Deadline(); ok { + timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + for _, model := range libraryEmbedModels { + model := model + t.Run(model, func(t *testing.T) { + // Check if we're running out of time (reserve 20s for current model) + if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second { + t.Skip("skipping remaining tests to avoid timeout") + } + + mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute) + defer mcancel() + + // Pull the model if needed + if err := PullIfMissing(mctx, client, model); err != nil { + t.Fatal(err) + } + + t.Run("truncation error status code", func(t *testing.T) { + truncFalse := false + longInput := strings.Repeat("word ", 100) + + req := api.EmbedRequest{ + Model: model, + Input: longInput, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(mctx, client, t, req) + if err == nil { + t.Fatal("expected error when truncate=false with long input") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error (likely 400 Bad Request) + // not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + + // Verify the error message is meaningful + if !strings.Contains(err.Error(), "context length") { + t.Errorf("expected error message to mention context length, got: %v", err) + } + }) + + t.Run("batch truncation error status code", func(t *testing.T) { + truncFalse := false + req := api.EmbedRequest{ + Model: model, + Input: []string{ + "short input", + strings.Repeat("very long input ", 100), + "another short input", + }, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(mctx, client, t, req) + if err == nil { + t.Fatal("expected error when one input exceeds context with truncate=false") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error, not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + }) + }) + } +} diff --git a/llm/server.go b/llm/server.go index e9d0a030f..1c47601f4 100644 --- a/llm/server.go +++ b/llm/server.go @@ -69,7 +69,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embedding(ctx context.Context, input string) ([]float32, error) + Embedding(ctx context.Context, input string) ([]float32, int, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -1629,10 +1629,11 @@ type EmbeddingRequest struct { } type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` + Embedding []float32 `json:"embedding"` + PromptEvalCount int `json:"prompt_eval_count"` } -func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { +func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, int, error) { logutil.Trace("embedding request", "input", input) if err := s.sem.Acquire(ctx, 1); err != nil { @@ -1641,51 +1642,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err } else { slog.Error("Failed to acquire semaphore", "error", err) } - return nil, err + return nil, 0, err } defer s.sem.Release(1) // Make sure the server is ready status, err := s.getServerStatusRetry(ctx) if err != nil { - return nil, err + return nil, 0, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %s", status) + return nil, 0, fmt.Errorf("unexpected server status: %s", status) } data, err := json.Marshal(EmbeddingRequest{Content: input}) if err != nil { - return nil, fmt.Errorf("error marshaling embed data: %w", err) + return nil, 0, fmt.Errorf("error marshaling embed data: %w", err) } r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) if err != nil { - return nil, fmt.Errorf("error creating embed request: %w", err) + return nil, 0, fmt.Errorf("error creating embed request: %w", err) } r.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(r) if err != nil { - return nil, fmt.Errorf("do embedding request: %w", err) + return nil, 0, fmt.Errorf("do embedding request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading embed response: %w", err) + return nil, 0, fmt.Errorf("error reading embed response: %w", err) } if resp.StatusCode >= 400 { log.Printf("llm embedding error: %s", body) - return nil, fmt.Errorf("%s", body) + return nil, 0, api.StatusError{ + StatusCode: resp.StatusCode, + ErrorMessage: string(body), + } } var e EmbeddingResponse if err := json.Unmarshal(body, &e); err != nil { - return nil, fmt.Errorf("unmarshal tokenize response: %w", err) + return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err) } - return e.Embedding, nil + return e.Embedding, e.PromptEvalCount, nil } func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index a23ddd61a..0f32fd2af 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -757,13 +757,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - - // TODO (jmorganca): this should be provided by the server via the - // request options and truncated here in the runner, instead of relying on - // the server's truncate logic - truncate: true, + truncate: false, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -806,7 +806,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: embedding, + Embedding: embedding, + PromptEvalCount: seq.numPromptInputs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 153390868..d0427662c 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -146,12 +146,12 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe params.numKeep = min(params.numKeep, s.cache.numCtx-1) if int32(len(inputs)) > s.cache.numCtx { - discard := int32(len(inputs)) - s.cache.numCtx - if !params.truncate { return nil, errorInputTooLong } + discard := int32(len(inputs)) - s.cache.numCtx + promptStart := params.numKeep + discard // If we need to truncate in the middle of a unbreakable batch, remove the entire batch @@ -996,13 +996,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ embedding: true, - - // TODO (jmorganca): this should be provided by the server via the - // request options and truncated here in the runner, instead of relying on - // the server's truncate logic - truncate: true, + truncate: false, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -1043,7 +1043,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { } if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ - Embedding: <-seq.embedding, + Embedding: <-seq.embedding, + PromptEvalCount: seq.numPromptInputs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } diff --git a/server/routes.go b/server/routes.go index e5e6dd5af..4dd870ed0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -22,6 +22,7 @@ import ( "os/signal" "slices" "strings" + "sync/atomic" "syscall" "time" @@ -649,11 +650,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - truncate := true - if req.Truncate != nil && !*req.Truncate { - truncate = false - } - var input []string switch i := req.Input.(type) { @@ -701,55 +697,57 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - var count int - for i, s := range input { - tokens, err := r.Tokenize(c.Request.Context(), s) + ctx := c.Request.Context() + + embedWithRetry := func(text string) ([]float32, int, error) { + emb, tokCount, err := r.Embedding(ctx, text) + if err == nil { + return emb, tokCount, nil + } + + var serr api.StatusError + if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest { + return nil, 0, err + } + if req.Truncate != nil && !*req.Truncate { + return nil, 0, err + } + + tokens, err := r.Tokenize(ctx, text) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return nil, 0, err } + // TODO @nicolepardal: avoid reaching into kvData here; pass required tokenizer metadata via model/options instead ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) - if len(tokens) > ctxLen { - if !truncate { - c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) - return - } - - if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { - ctxLen-- - } - - if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { - ctxLen-- - } - - slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens)) - if ctxLen <= 0 { - // return error if the truncated input would be empty or just special tokens - c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"}) - return - } - - tokens = tokens[:ctxLen] - - s, err = r.Detokenize(c.Request.Context(), tokens) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { + ctxLen-- + } + if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); len(tokens) > 0 && tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { + ctxLen-- } - count += len(tokens) + if len(tokens) <= ctxLen { + return nil, 0, fmt.Errorf("input exceeds maximum context length and cannot be truncated further") + } + if ctxLen <= 0 { + return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length") + } - input[i] = s + truncatedTokens := tokens[:ctxLen] + truncated, err := r.Detokenize(ctx, truncatedTokens) + if err != nil { + return nil, 0, err + } + return r.Embedding(ctx, truncated) } var g errgroup.Group embeddings := make([][]float32, len(input)) + var totalTokens uint64 for i, text := range input { g.Go(func() error { - embedding, err := r.Embedding(c.Request.Context(), text) + embedding, tokenCount, err := embedWithRetry(text) if err != nil { return err } @@ -759,12 +757,23 @@ func (s *Server) EmbedHandler(c *gin.Context) { embedding = normalize(embedding[:req.Dimensions]) } embeddings[i] = embedding + atomic.AddUint64(&totalTokens, uint64(tokenCount)) return nil }) } if err := g.Wait(); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) + var serr api.StatusError + if errors.As(err, &serr) { + c.AbortWithStatusJSON(serr.StatusCode, gin.H{ + "error": strings.TrimSpace(serr.ErrorMessage), + }) + return + } + + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ + "error": strings.TrimSpace(err.Error()), + }) return } @@ -773,7 +782,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { Embeddings: embeddings, TotalDuration: time.Since(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart), - PromptEvalCount: count, + PromptEvalCount: int(totalTokens), } c.JSON(http.StatusOK, resp) } @@ -819,7 +828,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := r.Embedding(c.Request.Context(), req.Prompt) + embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) return diff --git a/server/sched_test.go b/server/sched_test.go index 678be954f..480aafa4e 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn return s.completionResp } -func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) { - return s.embeddingResp, s.embeddingRespErr +func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, int, error) { + return s.embeddingResp, 0, s.embeddingRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {