From 3475d915cb0882042041d7746e6baf888469c3e0 Mon Sep 17 00:00:00 2001 From: nicole pardal <109545900+npardal@users.noreply.github.com> Date: Thu, 11 Dec 2025 15:36:31 -0800 Subject: [PATCH] embeddings: modified batch size (#13429) This PR detects embedding models and sets batch_size = context_size so the full input fits in a single batch. Previously, if batch size was smaller than the input, tokens could be split across batches and cause a SIGTRAP crash. This change ensures all tokens stay in one batch and prevents crashes. Fixes: #12938 #13054 Co-authored-by: Jesse Gross --- integration/embed_test.go | 57 +++++++++++++++++++++++++++++++++++ llama/llama.go | 3 +- llm/server.go | 7 +++++ runner/llamarunner/runner.go | 2 +- runner/ollamarunner/runner.go | 16 +++++++--- 5 files changed, 78 insertions(+), 7 deletions(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index f01903ee5..e45066739 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -487,6 +487,63 @@ func TestEmbedTruncation(t *testing.T) { } } +// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes. +func TestEmbedLargeInput(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + for _, model := range libraryEmbedModels { + model := model + t.Run(model, func(t *testing.T) { + mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute) + defer mcancel() + + // Test with progressively larger inputs + testCases := []struct { + name string + inputWords int + }{ + {"medium_input_256_words", 256}, + {"large_input_512_words", 512}, + {"very_large_input_800_words", 800}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + words := make([]string, tc.inputWords) + for i := range words { + words[i] = "word" + } + input := strings.Join(words, " ") + + req := api.EmbedRequest{ + Model: model, + Input: input, + KeepAlive: &api.Duration{Duration: 30 * time.Second}, + } + + res, err := embedTestHelper(mctx, client, t, req) + if err != nil { + t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err) + } + + if len(res.Embeddings) != 1 { + t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings)) + } + + if len(res.Embeddings[0]) == 0 { + t.Fatal("expected non-empty embedding") + } + + t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount) + }) + } + }) + } +} + // TestEmbedStatusCode tests that errors from the embedding endpoint // properly preserve their HTTP status codes when returned to the client. // This test specifically checks the error handling path in EmbedHandler diff --git a/llama/llama.go b/llama/llama.go index 582d4128c..70bf3b9c3 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -121,7 +121,8 @@ type ContextParams struct { func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams { params := C.llama_context_default_params() params.n_ctx = C.uint(numCtx) - params.n_batch = C.uint(batchSize) + params.n_batch = C.uint(batchSize * numSeqMax) + params.n_ubatch = C.uint(batchSize) params.n_seq_max = C.uint(numSeqMax) params.n_threads = C.int(threads) params.n_threads_batch = params.n_threads diff --git a/llm/server.go b/llm/server.go index 1c47601f4..5c232f0fa 100644 --- a/llm/server.go +++ b/llm/server.go @@ -474,6 +474,13 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers) } + // Check if embedding model and adjust batch size accordingly + _, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())] + if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx { + s.loadRequest.BatchSize = s.options.NumCtx + slog.Info("embedding model detected, setting batch size to context length", "batch_size", s.loadRequest.BatchSize) + } + kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize), s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 0f32fd2af..cb4bbe505 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -842,7 +842,7 @@ func (s *Server) loadModel( panic(err) } - ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType) + ctxParams := llama.NewContextParams(kvSize, s.batchSize, s.parallel, threads, flashAttention, kvCacheType) s.lc, err = llama.NewContextWithModel(s.model, ctxParams) if err != nil { panic(err) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d0427662c..a756cba23 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -1203,16 +1203,22 @@ func (s *Server) allocModel( return errors.New("loras are not yet implemented") } + if s.model.Config().Cache == nil { + if parallel > 1 { + parallel = 1 + slog.Warn("model does not support caching, disabling parallel processing") + } + if s.batchSize < kvSize { + s.batchSize = kvSize + slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize) + } + } + s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) if err != nil { return err } - if !s.cache.enabled && parallel > 1 { - parallel = 1 - slog.Warn("model does not support caching, disabling parallel processing") - } - s.parallel = parallel s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel))