mirror of https://github.com/ollama/ollama
embeddings: modified batch size (#13429)
This PR detects embedding models and sets batch_size = context_size so the full input fits in a single batch. Previously, if batch size was smaller than the input, tokens could be split across batches and cause a SIGTRAP crash. This change ensures all tokens stay in one batch and prevents crashes. Fixes: #12938 #13054 Co-authored-by: Jesse Gross <jesse@ollama.com>
This commit is contained in:
parent
48e78e9be1
commit
3475d915cb
|
|
@ -487,6 +487,63 @@ func TestEmbedTruncation(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
|
||||||
|
func TestEmbedLargeInput(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, model := range libraryEmbedModels {
|
||||||
|
model := model
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||||
|
defer mcancel()
|
||||||
|
|
||||||
|
// Test with progressively larger inputs
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputWords int
|
||||||
|
}{
|
||||||
|
{"medium_input_256_words", 256},
|
||||||
|
{"large_input_512_words", 512},
|
||||||
|
{"very_large_input_800_words", 800},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
words := make([]string, tc.inputWords)
|
||||||
|
for i := range words {
|
||||||
|
words[i] = "word"
|
||||||
|
}
|
||||||
|
input := strings.Join(words, " ")
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: input,
|
||||||
|
KeepAlive: &api.Duration{Duration: 30 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(mctx, client, t, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 1 {
|
||||||
|
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) == 0 {
|
||||||
|
t.Fatal("expected non-empty embedding")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||||
// properly preserve their HTTP status codes when returned to the client.
|
// properly preserve their HTTP status codes when returned to the client.
|
||||||
// This test specifically checks the error handling path in EmbedHandler
|
// This test specifically checks the error handling path in EmbedHandler
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,8 @@ type ContextParams struct {
|
||||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
|
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
|
||||||
params := C.llama_context_default_params()
|
params := C.llama_context_default_params()
|
||||||
params.n_ctx = C.uint(numCtx)
|
params.n_ctx = C.uint(numCtx)
|
||||||
params.n_batch = C.uint(batchSize)
|
params.n_batch = C.uint(batchSize * numSeqMax)
|
||||||
|
params.n_ubatch = C.uint(batchSize)
|
||||||
params.n_seq_max = C.uint(numSeqMax)
|
params.n_seq_max = C.uint(numSeqMax)
|
||||||
params.n_threads = C.int(threads)
|
params.n_threads = C.int(threads)
|
||||||
params.n_threads_batch = params.n_threads
|
params.n_threads_batch = params.n_threads
|
||||||
|
|
|
||||||
|
|
@ -474,6 +474,13 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||||
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if embedding model and adjust batch size accordingly
|
||||||
|
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
|
||||||
|
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
|
||||||
|
s.loadRequest.BatchSize = s.options.NumCtx
|
||||||
|
slog.Info("embedding model detected, setting batch size to context length", "batch_size", s.loadRequest.BatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
|
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
|
||||||
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)
|
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -842,7 +842,7 @@ func (s *Server) loadModel(
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
|
ctxParams := llama.NewContextParams(kvSize, s.batchSize, s.parallel, threads, flashAttention, kvCacheType)
|
||||||
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
||||||
|
|
@ -1203,16 +1203,22 @@ func (s *Server) allocModel(
|
||||||
return errors.New("loras are not yet implemented")
|
return errors.New("loras are not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.model.Config().Cache == nil {
|
||||||
|
if parallel > 1 {
|
||||||
|
parallel = 1
|
||||||
|
slog.Warn("model does not support caching, disabling parallel processing")
|
||||||
|
}
|
||||||
|
if s.batchSize < kvSize {
|
||||||
|
s.batchSize = kvSize
|
||||||
|
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.cache.enabled && parallel > 1 {
|
|
||||||
parallel = 1
|
|
||||||
slog.Warn("model does not support caching, disabling parallel processing")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.parallel = parallel
|
s.parallel = parallel
|
||||||
s.seqs = make([]*Sequence, s.parallel)
|
s.seqs = make([]*Sequence, s.parallel)
|
||||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue