mirror of https://github.com/ollama/ollama
truncation: fixed runner truncation logic + removed server truncation (#12839)
This PR consolidates all embedding prompt-length checking, truncation, and prompt token counting into the runner to ensure a single source of truth.
This commit is contained in:
parent
5dae738067
commit
e082d60a24
|
|
@ -4,7 +4,9 @@ package integration
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -204,8 +206,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount != 6 {
|
||||
t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount)
|
||||
if res.PromptEvalCount != 8 {
|
||||
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -251,8 +253,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount != 12 {
|
||||
t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount)
|
||||
if res.PromptEvalCount != 16 {
|
||||
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -275,7 +277,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
cases := []struct {
|
||||
name string
|
||||
request api.EmbedRequest
|
||||
check func(*api.EmbedResponse, error)
|
||||
check func(*testing.T, *api.EmbedResponse, error)
|
||||
}{
|
||||
{
|
||||
name: "target truncation",
|
||||
|
|
@ -283,7 +285,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -300,10 +302,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -317,10 +320,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -334,21 +338,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input exceeds maximum context length" {
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "the input length exceeds the context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "input after truncate error",
|
||||
name: "input after truncate error with context length of 1",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
|
|
@ -362,7 +366,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 0},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
|
|
@ -375,7 +379,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
Input: "why is the sky blue? Why is the sky blue? hi there my",
|
||||
Options: map[string]any{"num_ctx": 16},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -385,7 +389,8 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||
|
||||
for _, req := range cases {
|
||||
t.Run(req.name, func(t *testing.T) {
|
||||
req.check(embedTestHelper(ctx, client, t, req.request))
|
||||
resp, err := embedTestHelper(ctx, client, t, req.request)
|
||||
req.check(t, resp, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -409,3 +414,173 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
|||
|
||||
return client.Embed(ctx, &req)
|
||||
}
|
||||
|
||||
func TestEmbedTruncation(t *testing.T) {
|
||||
// Use test deadline if set, otherwise default to 2 minutes
|
||||
timeout := 2 * time.Minute
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
// Check if we're running out of time (reserve 20s for current model)
|
||||
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||
t.Skip("skipping remaining tests to avoid timeout")
|
||||
}
|
||||
|
||||
// Give each model its own budget to account for first-time pulls/loads
|
||||
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
t.Run("truncation batch", func(t *testing.T) {
|
||||
truncTrue := true
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 30},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(mctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 3 {
|
||||
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if res.PromptEvalCount > 90 {
|
||||
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||
baseline := api.EmbedRequest{Model: model, Input: "test"}
|
||||
baseRes, err := embedTestHelper(mctx, client, t, baseline)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
batch := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: []string{"test", "test", "test"},
|
||||
}
|
||||
batchRes, err := embedTestHelper(mctx, client, t, batch)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedCount := baseRes.PromptEvalCount * 3
|
||||
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
||||
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
||||
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||
// properly preserve their HTTP status codes when returned to the client.
|
||||
// This test specifically checks the error handling path in EmbedHandler
|
||||
// where api.StatusError errors should maintain their original status code.
|
||||
func TestEmbedStatusCode(t *testing.T) {
|
||||
// Use test deadline if set, otherwise default to 2 minutes
|
||||
timeout := 2 * time.Minute
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
// Check if we're running out of time (reserve 20s for current model)
|
||||
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||
t.Skip("skipping remaining tests to avoid timeout")
|
||||
}
|
||||
|
||||
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
// Pull the model if needed
|
||||
if err := PullIfMissing(mctx, client, model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
longInput := strings.Repeat("word ", 100)
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: longInput,
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(mctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when truncate=false with long input")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error (likely 400 Bad Request)
|
||||
// not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
|
||||
// Verify the error message is meaningful
|
||||
if !strings.Contains(err.Error(), "context length") {
|
||||
t.Errorf("expected error message to mention context length, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("batch truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: []string{
|
||||
"short input",
|
||||
strings.Repeat("very long input ", 100),
|
||||
"another short input",
|
||||
},
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(mctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error, not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ type LlamaServer interface {
|
|||
Ping(ctx context.Context) error
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embedding(ctx context.Context, input string) ([]float32, error)
|
||||
Embedding(ctx context.Context, input string) ([]float32, int, error)
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
|
|
@ -1630,9 +1630,10 @@ type EmbeddingRequest struct {
|
|||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
logutil.Trace("embedding request", "input", input)
|
||||
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
|
|
@ -1641,51 +1642,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
|||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer s.sem.Release(1)
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status)
|
||||
return nil, 0, fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||
return nil, 0, fmt.Errorf("error creating embed request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||
return nil, 0, fmt.Errorf("do embedding request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading embed response: %w", err)
|
||||
return nil, 0, fmt.Errorf("error reading embed response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm embedding error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
return nil, 0, api.StatusError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ErrorMessage: string(body),
|
||||
}
|
||||
}
|
||||
|
||||
var e EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &e); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return e.Embedding, nil
|
||||
return e.Embedding, e.PromptEvalCount, nil
|
||||
}
|
||||
|
||||
func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
|
|
|||
|
|
@ -757,13 +757,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
||||
embedding: true,
|
||||
|
||||
// TODO (jmorganca): this should be provided by the server via the
|
||||
// request options and truncated here in the runner, instead of relying on
|
||||
// the server's truncate logic
|
||||
truncate: true,
|
||||
truncate: false,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, errorInputTooLong) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -807,6 +807,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -146,12 +146,12 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if int32(len(inputs)) > s.cache.numCtx {
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
|
||||
if !params.truncate {
|
||||
return nil, errorInputTooLong
|
||||
}
|
||||
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
|
||||
promptStart := params.numKeep + discard
|
||||
|
||||
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
|
||||
|
|
@ -996,13 +996,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Content-Type", "application/json")
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
||||
embedding: true,
|
||||
|
||||
// TODO (jmorganca): this should be provided by the server via the
|
||||
// request options and truncated here in the runner, instead of relying on
|
||||
// the server's truncate logic
|
||||
truncate: true,
|
||||
truncate: false,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, errorInputTooLong) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -1044,6 +1044,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: <-seq.embedding,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
|
|
@ -649,11 +650,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
truncate := true
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
}
|
||||
|
||||
var input []string
|
||||
|
||||
switch i := req.Input.(type) {
|
||||
|
|
@ -701,55 +697,57 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
var count int
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
ctx := c.Request.Context()
|
||||
|
||||
embedWithRetry := func(text string) ([]float32, int, error) {
|
||||
emb, tokCount, err := r.Embedding(ctx, text)
|
||||
if err == nil {
|
||||
return emb, tokCount, nil
|
||||
}
|
||||
|
||||
var serr api.StatusError
|
||||
if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest {
|
||||
return nil, 0, err
|
||||
}
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
tokens, err := r.Tokenize(ctx, text)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// TODO @nicolepardal: avoid reaching into kvData here; pass required tokenizer metadata via model/options instead
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if !truncate {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
|
||||
return
|
||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); len(tokens) > 0 && tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||
ctxLen--
|
||||
if len(tokens) <= ctxLen {
|
||||
return nil, 0, fmt.Errorf("input exceeds maximum context length and cannot be truncated further")
|
||||
}
|
||||
|
||||
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
|
||||
if ctxLen <= 0 {
|
||||
// return error if the truncated input would be empty or just special tokens
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
|
||||
return
|
||||
return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length")
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
truncatedTokens := tokens[:ctxLen]
|
||||
truncated, err := r.Detokenize(ctx, truncatedTokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
count += len(tokens)
|
||||
|
||||
input[i] = s
|
||||
return r.Embedding(ctx, truncated)
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
embeddings := make([][]float32, len(input))
|
||||
var totalTokens uint64
|
||||
for i, text := range input {
|
||||
g.Go(func() error {
|
||||
embedding, err := r.Embedding(c.Request.Context(), text)
|
||||
embedding, tokenCount, err := embedWithRetry(text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -759,12 +757,23 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
embedding = normalize(embedding[:req.Dimensions])
|
||||
}
|
||||
embeddings[i] = embedding
|
||||
atomic.AddUint64(&totalTokens, uint64(tokenCount))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
var serr api.StatusError
|
||||
if errors.As(err, &serr) {
|
||||
c.AbortWithStatusJSON(serr.StatusCode, gin.H{
|
||||
"error": strings.TrimSpace(serr.ErrorMessage),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
|
||||
"error": strings.TrimSpace(err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -773,7 +782,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
Embeddings: embeddings,
|
||||
TotalDuration: time.Since(checkpointStart),
|
||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||
PromptEvalCount: count,
|
||||
PromptEvalCount: int(totalTokens),
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
|
@ -819,7 +828,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
return
|
||||
|
|
|
|||
|
|
@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
|||
return s.completionResp
|
||||
}
|
||||
|
||||
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
return s.embeddingResp, s.embeddingRespErr
|
||||
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return s.embeddingResp, 0, s.embeddingRespErr
|
||||
}
|
||||
|
||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue