truncation: fixed runner truncation logic + removed server truncation (#12839)

This PR consolidates all embedding prompt-length checking, truncation, and prompt token counting into the runner to ensure a single source of truth.
This commit is contained in:
nicole pardal 2025-12-08 11:20:28 -08:00 committed by GitHub
parent 5dae738067
commit e082d60a24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 278 additions and 88 deletions

View File

@ -4,7 +4,9 @@ package integration
import ( import (
"context" "context"
"errors"
"math" "math"
"strings"
"testing" "testing"
"time" "time"
@ -204,8 +206,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim) t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
} }
if res.PromptEvalCount != 6 { if res.PromptEvalCount != 8 {
t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount) t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
} }
} }
@ -251,8 +253,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim) t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
} }
if res.PromptEvalCount != 12 { if res.PromptEvalCount != 16 {
t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount) t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
} }
} }
@ -275,7 +277,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
cases := []struct { cases := []struct {
name string name string
request api.EmbedRequest request api.EmbedRequest
check func(*api.EmbedResponse, error) check func(*testing.T, *api.EmbedResponse, error)
}{ }{
{ {
name: "target truncation", name: "target truncation",
@ -283,7 +285,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Model: "all-minilm", Model: "all-minilm",
Input: "why", Input: "why",
}, },
check: func(got *api.EmbedResponse, err error) { check: func(t *testing.T, got *api.EmbedResponse, err error) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -300,10 +302,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Input: "why is the sky blue?", Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 3}, Options: map[string]any{"num_ctx": 3},
}, },
check: func(got *api.EmbedResponse, err error) { check: func(t *testing.T, got *api.EmbedResponse, err error) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff) t.Errorf("embedding mismatch (-want +got):\n%s", diff)
} }
@ -317,10 +320,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncTrue, Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 3}, Options: map[string]any{"num_ctx": 3},
}, },
check: func(got *api.EmbedResponse, err error) { check: func(t *testing.T, got *api.EmbedResponse, err error) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff) t.Errorf("embedding mismatch (-want +got):\n%s", diff)
} }
@ -334,21 +338,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncFalse, Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 3}, Options: map[string]any{"num_ctx": 3},
}, },
check: func(res *api.EmbedResponse, err error) { check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err.Error() != "input exceeds maximum context length" { if err.Error() != "the input length exceeds the context length" {
t.Fatalf("expected truncation error, got: %v", err) t.Fatalf("expected truncation error, got: %v", err)
} }
}, },
}, },
{ {
name: "input after truncate error", name: "input after truncate error with context length of 1",
request: api.EmbedRequest{ request: api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
Truncate: &truncTrue, Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1}, Options: map[string]any{"num_ctx": 1},
}, },
check: func(res *api.EmbedResponse, err error) { check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" { if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err) t.Fatalf("expected truncation error, got: %v", err)
} }
@ -362,7 +366,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncTrue, Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 0}, Options: map[string]any{"num_ctx": 0},
}, },
check: func(res *api.EmbedResponse, err error) { check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" { if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err) t.Fatalf("expected truncation error, got: %v", err)
} }
@ -375,7 +379,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Input: "why is the sky blue? Why is the sky blue? hi there my", Input: "why is the sky blue? Why is the sky blue? hi there my",
Options: map[string]any{"num_ctx": 16}, Options: map[string]any{"num_ctx": 16},
}, },
check: func(res *api.EmbedResponse, err error) { check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -385,7 +389,8 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
for _, req := range cases { for _, req := range cases {
t.Run(req.name, func(t *testing.T) { t.Run(req.name, func(t *testing.T) {
req.check(embedTestHelper(ctx, client, t, req.request)) resp, err := embedTestHelper(ctx, client, t, req.request)
req.check(t, resp, err)
}) })
} }
} }
@ -409,3 +414,173 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
return client.Embed(ctx, &req) return client.Embed(ctx, &req)
} }
func TestEmbedTruncation(t *testing.T) {
// Use test deadline if set, otherwise default to 2 minutes
timeout := 2 * time.Minute
if deadline, ok := t.Deadline(); ok {
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
// Check if we're running out of time (reserve 20s for current model)
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
t.Skip("skipping remaining tests to avoid timeout")
}
// Give each model its own budget to account for first-time pulls/loads
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: model,
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: model, Input: "test"}
baseRes, err := embedTestHelper(mctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: model,
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(mctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
})
}
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
// Use test deadline if set, otherwise default to 2 minutes
timeout := 2 * time.Minute
if deadline, ok := t.Deadline(); ok {
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
// Check if we're running out of time (reserve 20s for current model)
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
t.Skip("skipping remaining tests to avoid timeout")
}
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
// Pull the model if needed
if err := PullIfMissing(mctx, client, model); err != nil {
t.Fatal(err)
}
t.Run("truncation error status code", func(t *testing.T) {
truncFalse := false
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: model,
Input: longInput,
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(mctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error (likely 400 Bad Request)
// not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
// Verify the error message is meaningful
if !strings.Contains(err.Error(), "context length") {
t.Errorf("expected error message to mention context length, got: %v", err)
}
})
t.Run("batch truncation error status code", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: model,
Input: []string{
"short input",
strings.Repeat("very long input ", 100),
"another short input",
},
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(mctx, client, t, req)
if err == nil {
t.Fatal("expected error when one input exceeds context with truncate=false")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error, not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
})
})
}
}

View File

@ -69,7 +69,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string) ([]float32, error) Embedding(ctx context.Context, input string) ([]float32, int, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@ -1629,10 +1629,11 @@ type EmbeddingRequest struct {
} }
type EmbeddingResponse struct { type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"` Embedding []float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_eval_count"`
} }
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, int, error) {
logutil.Trace("embedding request", "input", input) logutil.Trace("embedding request", "input", input)
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
@ -1641,51 +1642,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
} else { } else {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
} }
return nil, err return nil, 0, err
} }
defer s.sem.Release(1) defer s.sem.Release(1)
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatusRetry(ctx) status, err := s.getServerStatusRetry(ctx)
if err != nil { if err != nil {
return nil, err return nil, 0, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %s", status) return nil, 0, fmt.Errorf("unexpected server status: %s", status)
} }
data, err := json.Marshal(EmbeddingRequest{Content: input}) data, err := json.Marshal(EmbeddingRequest{Content: input})
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err) return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
} }
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err) return nil, 0, fmt.Errorf("error creating embed request: %w", err)
} }
r.Header.Set("Content-Type", "application/json") r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r) resp, err := http.DefaultClient.Do(r)
if err != nil { if err != nil {
return nil, fmt.Errorf("do embedding request: %w", err) return nil, 0, fmt.Errorf("do embedding request: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("error reading embed response: %w", err) return nil, 0, fmt.Errorf("error reading embed response: %w", err)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
log.Printf("llm embedding error: %s", body) log.Printf("llm embedding error: %s", body)
return nil, fmt.Errorf("%s", body) return nil, 0, api.StatusError{
StatusCode: resp.StatusCode,
ErrorMessage: string(body),
}
} }
var e EmbeddingResponse var e EmbeddingResponse
if err := json.Unmarshal(body, &e); err != nil { if err := json.Unmarshal(body, &e); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err) return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
} }
return e.Embedding, nil return e.Embedding, e.PromptEvalCount, nil
} }
func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {

View File

@ -757,13 +757,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true, embedding: true,
truncate: false,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
}) })
if err != nil { if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return return
} }
@ -806,7 +806,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embedding := <-seq.embedding embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: embedding, Embedding: embedding,
PromptEvalCount: seq.numPromptInputs,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
} }

View File

@ -146,12 +146,12 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
params.numKeep = min(params.numKeep, s.cache.numCtx-1) params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if int32(len(inputs)) > s.cache.numCtx { if int32(len(inputs)) > s.cache.numCtx {
discard := int32(len(inputs)) - s.cache.numCtx
if !params.truncate { if !params.truncate {
return nil, errorInputTooLong return nil, errorInputTooLong
} }
discard := int32(len(inputs)) - s.cache.numCtx
promptStart := params.numKeep + discard promptStart := params.numKeep + discard
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch // If we need to truncate in the middle of a unbreakable batch, remove the entire batch
@ -996,13 +996,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{ seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true, embedding: true,
truncate: false,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
}) })
if err != nil { if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return return
} }
@ -1043,7 +1043,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
} }
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding, Embedding: <-seq.embedding,
PromptEvalCount: seq.numPromptInputs,
}); err != nil { }); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
} }

View File

@ -22,6 +22,7 @@ import (
"os/signal" "os/signal"
"slices" "slices"
"strings" "strings"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@ -649,11 +650,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
var input []string var input []string
switch i := req.Input.(type) { switch i := req.Input.(type) {
@ -701,55 +697,57 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
var count int ctx := c.Request.Context()
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s) embedWithRetry := func(text string) ([]float32, int, error) {
emb, tokCount, err := r.Embedding(ctx, text)
if err == nil {
return emb, tokCount, nil
}
var serr api.StatusError
if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest {
return nil, 0, err
}
if req.Truncate != nil && !*req.Truncate {
return nil, 0, err
}
tokens, err := r.Tokenize(ctx, text)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return nil, 0, err
return
} }
// TODO @nicolepardal: avoid reaching into kvData here; pass required tokenizer metadata via model/options instead
ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen { if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
if !truncate { ctxLen--
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) }
return if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); len(tokens) > 0 && tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
} ctxLen--
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
if ctxLen <= 0 {
// return error if the truncated input would be empty or just special tokens
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} }
count += len(tokens) if len(tokens) <= ctxLen {
return nil, 0, fmt.Errorf("input exceeds maximum context length and cannot be truncated further")
}
if ctxLen <= 0 {
return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length")
}
input[i] = s truncatedTokens := tokens[:ctxLen]
truncated, err := r.Detokenize(ctx, truncatedTokens)
if err != nil {
return nil, 0, err
}
return r.Embedding(ctx, truncated)
} }
var g errgroup.Group var g errgroup.Group
embeddings := make([][]float32, len(input)) embeddings := make([][]float32, len(input))
var totalTokens uint64
for i, text := range input { for i, text := range input {
g.Go(func() error { g.Go(func() error {
embedding, err := r.Embedding(c.Request.Context(), text) embedding, tokenCount, err := embedWithRetry(text)
if err != nil { if err != nil {
return err return err
} }
@ -759,12 +757,23 @@ func (s *Server) EmbedHandler(c *gin.Context) {
embedding = normalize(embedding[:req.Dimensions]) embedding = normalize(embedding[:req.Dimensions])
} }
embeddings[i] = embedding embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount))
return nil return nil
}) })
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) var serr api.StatusError
if errors.As(err, &serr) {
c.AbortWithStatusJSON(serr.StatusCode, gin.H{
"error": strings.TrimSpace(serr.ErrorMessage),
})
return
}
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": strings.TrimSpace(err.Error()),
})
return return
} }
@ -773,7 +782,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
Embeddings: embeddings, Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart), TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: count, PromptEvalCount: int(totalTokens),
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
@ -819,7 +828,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := r.Embedding(c.Request.Context(), req.Prompt) embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return return

View File

@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
return s.completionResp return s.completionResp
} }
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) { func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return s.embeddingResp, s.embeddingRespErr return s.embeddingResp, 0, s.embeddingRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {