embeddings: modified batch size (#13429)

This PR detects embedding models and sets batch_size = context_size so the full input fits in a single batch.
Previously, if batch size was smaller than the input, tokens could be split across batches and cause a SIGTRAP crash.
This change ensures all tokens stay in one batch and prevents crashes.
Fixes: #12938 #13054

Co-authored-by: Jesse Gross <jesse@ollama.com>
This commit is contained in:
nicole pardal 2025-12-11 15:36:31 -08:00 committed by GitHub
parent 48e78e9be1
commit 3475d915cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 78 additions and 7 deletions

View File

@ -487,6 +487,63 @@ func TestEmbedTruncation(t *testing.T) {
}
}
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
func TestEmbedLargeInput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
defer mcancel()
// Test with progressively larger inputs
testCases := []struct {
name string
inputWords int
}{
{"medium_input_256_words", 256},
{"large_input_512_words", 512},
{"very_large_input_800_words", 800},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
words := make([]string, tc.inputWords)
for i := range words {
words[i] = "word"
}
input := strings.Join(words, " ")
req := api.EmbedRequest{
Model: model,
Input: input,
KeepAlive: &api.Duration{Duration: 30 * time.Second},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) == 0 {
t.Fatal("expected non-empty embedding")
}
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
})
}
})
}
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler

View File

@ -121,7 +121,8 @@ type ContextParams struct {
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
params := C.llama_context_default_params()
params.n_ctx = C.uint(numCtx)
params.n_batch = C.uint(batchSize)
params.n_batch = C.uint(batchSize * numSeqMax)
params.n_ubatch = C.uint(batchSize)
params.n_seq_max = C.uint(numSeqMax)
params.n_threads = C.int(threads)
params.n_threads_batch = params.n_threads

View File

@ -474,6 +474,13 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
}
// Check if embedding model and adjust batch size accordingly
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
s.loadRequest.BatchSize = s.options.NumCtx
slog.Info("embedding model detected, setting batch size to context length", "batch_size", s.loadRequest.BatchSize)
}
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)

View File

@ -842,7 +842,7 @@ func (s *Server) loadModel(
panic(err)
}
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
ctxParams := llama.NewContextParams(kvSize, s.batchSize, s.parallel, threads, flashAttention, kvCacheType)
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
if err != nil {
panic(err)

View File

@ -1203,16 +1203,22 @@ func (s *Server) allocModel(
return errors.New("loras are not yet implemented")
}
if s.model.Config().Cache == nil {
if parallel > 1 {
parallel = 1
slog.Warn("model does not support caching, disabling parallel processing")
}
if s.batchSize < kvSize {
s.batchSize = kvSize
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize)
}
}
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil {
return err
}
if !s.cache.enabled && parallel > 1 {
parallel = 1
slog.Warn("model does not support caching, disabling parallel processing")
}
s.parallel = parallel
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))