perf: build graph for next batch async to keep GPU busy (#11863)

* perf: build graph for next batch in parallel to keep GPU busy

This refactors the main run loop of the ollama runner to perform the main GPU
intensive tasks (Compute+Floats) in a go routine so we can prepare the next
batch in parallel to reduce the amount of time the GPU stalls waiting for the
next batch of work.

* tests: tune integration tests for ollama engine

This tunes the integration tests to focus more on models supported
by the new engine.
This commit is contained in:
Daniel Hiltgen 2025-08-29 14:20:28 -07:00 committed by GitHub
parent ead4a9a1d0
commit 517807cdf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 591 additions and 235 deletions

View File

@ -2,10 +2,13 @@
This directory contains integration tests to exercise Ollama end-to-end to verify behavior This directory contains integration tests to exercise Ollama end-to-end to verify behavior
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
The integration tests have 2 modes of operating. The integration tests have 2 modes of operating.
1. By default, they will start the server on a random port, run the tests, and then shutdown the server. 1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote 2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
> [!IMPORTANT]
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.

View File

@ -390,7 +390,7 @@ func TestAPIEmbeddings(t *testing.T) {
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
req := api.EmbeddingRequest{ req := api.EmbeddingRequest{
Model: "orca-mini", Model: libraryEmbedModels[0],
Prompt: "why is the sky blue?", Prompt: "why is the sky blue?",
Options: map[string]interface{}{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,

View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
) )
func TestBlueSky(t *testing.T) { func TestBlueSky(t *testing.T) {
@ -37,8 +36,8 @@ func TestUnicode(t *testing.T) {
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test // DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
Prompt: "天空为什么是蓝色的?", Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
"temperature": 0, "temperature": 0,
@ -50,8 +49,20 @@ func TestUnicode(t *testing.T) {
} }
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second) t.Fatal(err)
}
slog.Info("loading", "model", req.Model)
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
DoGenerate(ctx, t, client, req, []string{
"散射", // scattering
"频率", // frequency
}, 120*time.Second, 120*time.Second)
} }
func TestExtendedUnicodeOutput(t *testing.T) { func TestExtendedUnicodeOutput(t *testing.T) {
@ -69,7 +80,9 @@ func TestExtendedUnicodeOutput(t *testing.T) {
} }
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second) DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
} }
@ -84,7 +97,9 @@ func TestUnicodeModelDir(t *testing.T) {
} }
modelDir, err := os.MkdirTemp("", "ollama_埃") modelDir, err := os.MkdirTemp("", "ollama_埃")
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(modelDir) defer os.RemoveAll(modelDir)
slog.Info("unicode", "OLLAMA_MODELS", modelDir) slog.Info("unicode", "OLLAMA_MODELS", modelDir)

View File

@ -14,8 +14,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
@ -79,21 +77,21 @@ func TestMultiModelStress(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// All models compatible with ollama-engine
smallModels := []string{ smallModels := []string{
"llama3.2:1b", "llama3.2:1b",
"qwen3:0.6b", "qwen3:0.6b",
"gemma:2b", "gemma2:2b",
"deepseek-r1:1.5b", "deepseek-r1:1.5b", // qwen2 arch
"starcoder2:3b", "gemma3:270m",
} }
mediumModels := []string{ mediumModels := []string{
"qwen3:8b", "llama3.2:3b", // ~3.4G
"llama2", "qwen3:8b", // ~6.6G
"deepseek-r1:7b", "gpt-oss:20b", // ~15G
"mistral", "deepseek-r1:7b", // ~5.6G
"dolphin-mistral", "gemma3:4b", // ~5.8G
"gemma:7b", "gemma2:9b", // ~8.1G
"codellama:7b",
} }
var chosenModels []string var chosenModels []string
@ -114,7 +112,9 @@ func TestMultiModelStress(t *testing.T) {
// Make sure all the models are pulled before we get started // Make sure all the models are pulled before we get started
for _, model := range chosenModels { for _, model := range chosenModels {
require.NoError(t, PullIfMissing(ctx, client, model)) if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatal(err)
}
} }
// Determine how many models we can load in parallel before we exceed VRAM // Determine how many models we can load in parallel before we exceed VRAM

View File

@ -22,7 +22,7 @@ func TestLongInputContext(t *testing.T) {
defer cancel() defer cancel()
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: "llama2", Model: smol,
Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?", Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err) t.Fatalf("PullIfMissing failed: %v", err)
} }
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second) DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
} }
func TestContextExhaustion(t *testing.T) { func TestContextExhaustion(t *testing.T) {
@ -49,7 +49,7 @@ func TestContextExhaustion(t *testing.T) {
defer cancel() defer cancel()
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: "llama2", Model: smol,
Prompt: "Write me a story with a ton of emojis?", Prompt: "Write me a story with a ton of emojis?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
@ -63,10 +63,10 @@ func TestContextExhaustion(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err) t.Fatalf("PullIfMissing failed: %v", err)
} }
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second) DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
} }
// Send multiple requests with prior context and ensure the response is coherant and expected // Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) { func TestGenerateWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := GenerateRequests() req, resp := GenerateRequests()
@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) {
}(i) }(i)
} }
wg.Wait() wg.Wait()
}
// Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestChatWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := ChatRequests()
numParallel := 2
iterLimit := 2
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial empty request
slog.Info("loading", "model", modelOverride)
err := client.Generate(ctx,
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", modelOverride, err)
}
var wg sync.WaitGroup
wg.Add(numParallel)
for i := range numParallel {
go func(i int) {
defer wg.Done()
k := i % len(req)
req[k].Model = modelOverride
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
if assistant == nil {
t.Fatalf("didn't get an assistant response for context")
}
req[k].Messages = append(req[k].Messages,
*assistant,
api.Message{Role: "user", Content: "tell me more!"},
)
}
}(i)
}
wg.Wait()
} }

View File

@ -9,7 +9,6 @@ import (
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
) )
func TestVisionModels(t *testing.T) { func TestVisionModels(t *testing.T) {
@ -32,7 +31,9 @@ func TestVisionModels(t *testing.T) {
for _, v := range testCases { for _, v := range testCases {
t.Run(v.model, func(t *testing.T) { t.Run(v.model, func(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding) image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: v.model, Model: v.model,
Prompt: "what does the text in this image say?", Prompt: "what does the text in this image say?",
@ -52,7 +53,9 @@ func TestVisionModels(t *testing.T) {
// Note: sometimes it returns "the ollamas" sometimes "the ollams" // Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam" resp := "the ollam"
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
// llava models on CPU can be quite slow to start // llava models on CPU can be quite slow to start
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
}) })
@ -62,7 +65,9 @@ func TestVisionModels(t *testing.T) {
func TestIntegrationSplitBatch(t *testing.T) { func TestIntegrationSplitBatch(t *testing.T) {
skipUnderMinVRAM(t, 6) skipUnderMinVRAM(t, 6)
image, err := base64.StdEncoding.DecodeString(imageEncoding) image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: "gemma3:4b", Model: "gemma3:4b",
// Fill up a chunk of the batch so the image will partially spill over into the next one // Fill up a chunk of the batch so the image will partially spill over into the next one
@ -84,7 +89,9 @@ func TestIntegrationSplitBatch(t *testing.T) {
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
// llava models on CPU can be quite slow to start, // llava models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second) DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
} }

View File

@ -1,47 +0,0 @@
//go:build integration
package integration
import (
"context"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
// package to avoid circular dependencies
var (
stream = false
req = [2]api.GenerateRequest{
{
Model: smol,
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: smol,
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
{"sunlight", "scattering", "interact"},
{"england", "english", "massachusetts", "pilgrims"},
}
)
func TestIntegrationSimple(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
GenerateTestHelper(ctx, t, req[0], resp[0])
}

View File

@ -13,12 +13,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
func TestMaxQueue(t *testing.T) { func TestMaxQueue(t *testing.T) {
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
if os.Getenv("OLLAMA_TEST_EXISTING") != "" { if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size") t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
return return
@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) {
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
// Context for the worker threads so we can shut them down // Context for the worker threads so we can shut them down
// embedCtx, embedCancel := context.WithCancel(ctx) // embedCtx, embedCancel := context.WithCancel(ctx)
@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) {
switch { switch {
case genErr == nil: case genErr == nil:
successCount++ successCount++
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
}
case errors.Is(genErr, context.Canceled): case errors.Is(genErr, context.Canceled):
canceledCount++ canceledCount++
case strings.Contains(genErr.Error(), "busy"): case strings.Contains(genErr.Error(), "busy"):
@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) {
case strings.Contains(genErr.Error(), "connection reset by peer"): case strings.Contains(genErr.Error(), "connection reset by peer"):
resetByPeerCount++ resetByPeerCount++
default: default:
require.NoError(t, genErr, "%d request failed", i) if genErr != nil {
t.Fatalf("%d request failed", i)
}
} }
slog.Info("embed finished", "id", i) slog.Info("embed finished", "id", i)
@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) {
embedwg.Wait() embedwg.Wait()
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount) slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?") if resetByPeerCount != 0 {
require.True(t, busyCount > 0, "no requests hit busy error but some should have") t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout") }
if busyCount == 0 {
t.Fatalf("no requests hit busy error but some should have")
}
if canceledCount > 0 {
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
}
} }

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
@ -25,11 +26,11 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/stretchr/testify/require"
) )
var ( var (
smol = "llama3.2:1b" smol = "llama3.2:1b"
stream = false
) )
var ( var (
@ -435,7 +436,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
} }
lifecycle.ServerLogFile = fp.Name() lifecycle.ServerLogFile = fp.Name()
fp.Close() fp.Close()
require.NoError(t, startServer(t, ctx, testEndpoint)) if err := startServer(t, ctx, testEndpoint); err != nil {
t.Fatal(err)
}
} }
return client, testEndpoint, func() { return client, testEndpoint, func() {
@ -468,7 +471,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) { func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, genReq.Model)) if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
t.Fatal(err)
}
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second) DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
} }
@ -509,7 +514,9 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr) slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
return context return context
} }
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt) if genErr != nil {
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
}
// Verify the response contains the expected data // Verify the response contains the expected data
response := buf.String() response := buf.String()
atLeastOne := false atLeastOne := false
@ -519,7 +526,9 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
break break
} }
} }
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response) if !atLeastOne {
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
}
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response) slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done(): case <-ctx.Done():
t.Error("outer test context done while waiting for generate") t.Error("outer test context done while waiting for generate")
@ -561,17 +570,97 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[][]string{ [][]string{
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states"}, {"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
{"fourth", "july", "declaration", "independence"}, {"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide"}, {"nitrogen", "oxygen", "carbon", "dioxide"},
} }
} }
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
role := "assistant"
fn := func(response api.ChatResponse) error {
// fmt.Print(".")
role = response.Message.Role
buf.Write([]byte(response.Message.Content))
if !stallTimer.Reset(streamTimeout) {
return errors.New("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
return nil
}
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
}
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
return &api.Message{Role: role, Content: buf.String()}
}
func ChatRequests() ([]api.ChatRequest, [][]string) {
genReqs, results := GenerateRequests()
reqs := make([]api.ChatRequest, len(genReqs))
// think := api.ThinkValue{Value: "low"}
for i := range reqs {
reqs[i].Model = genReqs[i].Model
reqs[i].Stream = genReqs[i].Stream
reqs[i].KeepAlive = genReqs[i].KeepAlive
// reqs[i].Think = &think
reqs[i].Messages = []api.Message{
{
Role: "user",
Content: genReqs[i].Prompt,
},
}
}
return reqs, results
}
func skipUnderMinVRAM(t *testing.T, gb uint64) { func skipUnderMinVRAM(t *testing.T, gb uint64) {
// TODO use info API in the future // TODO use info API in the future
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" { if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err := strconv.ParseUint(s, 10, 64) maxVram, err := strconv.ParseUint(s, 10, 64)
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
// Don't hammer on small VRAM cards... // Don't hammer on small VRAM cards...
if maxVram < gb*format.GibiByte { if maxVram < gb*format.GibiByte {
t.Skip("skipping with small VRAM to avoid timeouts") t.Skip("skipping with small VRAM to avoid timeouts")
@ -579,6 +668,39 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
} }
} }
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
models, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("failed to list running models: %s", err)
}
loaded := []string{}
for _, m := range models.Models {
loaded = append(loaded, m.Name)
if m.Name != model {
continue
}
gpuPercent := 0
switch {
case m.SizeVRAM == 0:
gpuPercent = 0
case m.SizeVRAM == m.Size:
gpuPercent = 100
case m.SizeVRAM > m.Size || m.Size == 0:
t.Logf("unexpected size detected: %d", m.SizeVRAM)
default:
sizeCPU := m.Size - m.SizeVRAM
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
gpuPercent = int(100 - cpuPercent)
}
if gpuPercent < minPercent {
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
}
return
}
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
}
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) { func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
deadline, hasDeadline := t.Deadline() deadline, hasDeadline := t.Deadline()
if !hasDeadline { if !hasDeadline {

View File

@ -372,6 +372,7 @@ type Context interface {
Forward(...Tensor) Context Forward(...Tensor) Context
Compute(...Tensor) Compute(...Tensor)
ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
// Reserve is analogous to Compute but rather than executing a // Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a // graph, simply preallocates memory. Typically called with a
@ -401,6 +402,8 @@ type Tensor interface {
Bytes() []byte Bytes() []byte
Floats() []float32 Floats() []float32
SetValueFromIntSlice(s []int32)
Neg(ctx Context) Tensor Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor Sub(ctx Context, t2 Tensor) Tensor

View File

@ -82,6 +82,7 @@ type Backend struct {
// to the name that is used by the model definition // to the name that is used by the model definition
tensorLoadTargets map[string][]string tensorLoadTargets map[string][]string
schedMu sync.Mutex // Only one Compute can run at a time
sched C.ggml_backend_sched_t sched C.ggml_backend_sched_t
schedBackends []C.ggml_backend_t schedBackends []C.ggml_backend_t
schedBufts []C.ggml_backend_buffer_type_t schedBufts []C.ggml_backend_buffer_type_t
@ -758,6 +759,15 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
} }
func (c *Context) Compute(tensors ...ml.Tensor) { func (c *Context) Compute(tensors ...ml.Tensor) {
c.ComputeWithNotify(nil, tensors...)
}
func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
c.b.schedMu.Lock()
defer c.b.schedMu.Unlock()
if cb != nil {
go cb()
}
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS { if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
panic(fmt.Errorf("error computing ggml graph: %v", status)) panic(fmt.Errorf("error computing ggml graph: %v", status))
} }
@ -1010,6 +1020,12 @@ func (t *Tensor) Floats() (data []float32) {
return return
} }
func (t *Tensor) SetValueFromIntSlice(s []int32) {
if len(s) > 0 {
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
}
}
func (t *Tensor) DType() ml.DType { func (t *Tensor) DType() ml.DType {
switch t.t._type { switch t.t._type {
case C.GGML_TYPE_F32: case C.GGML_TYPE_F32:

View File

@ -64,7 +64,7 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal // This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately // that is modified to ensure that there is a unique hash value that accurately
// represents the contents. // represents the contents.
PostTokenize([]input.Input) ([]input.Input, error) PostTokenize([]*input.Input) ([]*input.Input, error)
} }
// Base implements the common fields and methods for all models // Base implements the common fields and methods for all models
@ -278,7 +278,7 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice t.Kind() == reflect.Slice
} }
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) { func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) { if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
} }
@ -287,8 +287,6 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
return nil, errors.New("batch size cannot be less than 1") return nil, errors.New("batch size cannot be less than 1")
} }
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
cache := m.Config().Cache cache := m.Config().Cache
if cache != nil { if cache != nil {
err := cache.StartForward(ctx, batch, false) err := cache.StartForward(ctx, batch, false)
@ -302,7 +300,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
return nil, err return nil, err
} }
ctx.Forward(t).Compute(t) ctx.Forward(t)
return t, nil return t, nil
} }

View File

@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return []input.Multimodal{{Tensor: visionOutputs}}, nil return []input.Multimodal{{Tensor: visionOutputs}}, nil
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []input.Input var result []*input.Input
for _, inp := range inputs { for _, inp := range inputs {
if len(inp.Multimodal) == 0 { if len(inp.Multimodal) == 0 {
@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
inputMultimodal := inp.Multimodal[0].Tensor inputMultimodal := inp.Multimodal[0].Tensor
result = append(result, result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" &input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>"" &input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder &input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
) )
// add image token placeholders // add image token placeholders
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result, result = append(result,
input.Input{Token: 256000}, // <end_of_image> &input.Input{Token: 256000}, // <end_of_image>
input.Input{Token: 108}, // "\n\n" &input.Input{Token: 108}, // "\n\n"
) )
} }
} }

View File

@ -134,16 +134,16 @@ type separator struct {
y bool y bool
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []input.Input var result []*input.Input
for _, inp := range inputs { for _, inp := range inputs {
if len(inp.Multimodal) == 0 { if len(inp.Multimodal) == 0 {
result = append(result, inp) result = append(result, inp)
continue continue
} }
var imageInputs []input.Input var imageInputs []*input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|> imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
for i, mm := range inp.Multimodal { for i, mm := range inp.Multimodal {
patchesPerChunk := mm.Tensor.Dim(1) patchesPerChunk := mm.Tensor.Dim(1)
@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
if i < len(inp.Multimodal)-1 { if i < len(inp.Multimodal)-1 {
separator := mm.Data.(*separator) separator := mm.Data.(*separator)
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if separator.x { if separator.x {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|> imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
} }
if separator.y { if separator.y {
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
} }
} else { } else {
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|> imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|> imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
} }
} }

View File

@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] // [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings // Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
// that can be processed together. // that can be processed together.
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []input.Input var result []*input.Input
for _, inp := range inputs { for _, inp := range inputs {
if len(inp.Multimodal) == 0 { if len(inp.Multimodal) == 0 {
result = append(result, inp) result = append(result, inp)
} else { } else {
for i, row := range inp.Multimodal { for i, row := range inp.Multimodal {
// [IMG] // [IMG]
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)}) result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...) result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
if i == len(inp.Multimodal)-1 { if i == len(inp.Multimodal)-1 {
// [IMG_END] // [IMG_END]
result = append(result, input.Input{Token: 13}) result = append(result, &input.Input{Token: 13})
} else { } else {
// [IMG_BREAK] // [IMG_BREAK]
result = append(result, input.Input{Token: 12}) result = append(result, &input.Input{Token: 12})
} }
} }
} }

View File

@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return []input.Multimodal{{Tensor: projectedOutputs}}, nil return []input.Multimodal{{Tensor: projectedOutputs}}, nil
} }
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
for i := range inputs { for i := range inputs {
if inputs[i].Multimodal != nil { if inputs[i].Multimodal != nil {
inputs[i].Token = 128256 // <|image|> inputs[i].Token = 128256 // <|image|>

View File

@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
} }
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []input.Input var result []*input.Input
var ( var (
imageToken int32 = 151655 imageToken int32 = 151655
@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return nil, fmt.Errorf("failed to encode image prompt: %w", err) return nil, fmt.Errorf("failed to encode image prompt: %w", err)
} }
for i := range pre { for i := range pre {
result = append(result, input.Input{Token: pre[i]}) result = append(result, &input.Input{Token: pre[i]})
} }
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1) patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
// First add the vision start token // First add the vision start token
result = append(result, input.Input{Token: visionStartToken}) result = append(result, &input.Input{Token: visionStartToken})
// Add the image token with the multimodal tensor data at the first position // Add the image token with the multimodal tensor data at the first position
result = append(result, input.Input{ result = append(result, &input.Input{
Token: imageToken, Token: imageToken,
Multimodal: inp.Multimodal, Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash, MultimodalHash: inp.MultimodalHash,
@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
}) })
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...) result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
result = append(result, input.Input{Token: visionEndToken}) result = append(result, &input.Input{Token: visionEndToken})
} }
} }

View File

@ -86,7 +86,7 @@ type InputCacheSlot struct {
Id int Id int
// Inputs that are stored in the KV cache // Inputs that are stored in the KV cache
Inputs []input.Input Inputs []*input.Input
// is this cache actively being processed as part of a sequence? // is this cache actively being processed as part of a sequence?
InUse bool InUse bool
@ -95,7 +95,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) { func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int32 var numPast int32
var err error var err error
@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
return slot, prompt, nil return slot, prompt, nil
} }
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1) longest := int32(-1)
var longestSlot *InputCacheSlot var longestSlot *InputCacheSlot
@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
return longestSlot, longest, nil return longestSlot, longest, nil
} }
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now() oldest := time.Now()
var oldestSlot *InputCacheSlot var oldestSlot *InputCacheSlot
@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
if longest > 0 && longestSlot != oldestSlot { if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs)) len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input.Input, longest) oldestSlot.Inputs = make([]*input.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil { if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
return oldestSlot, longest, nil return oldestSlot, longest, nil
} }
func countCommonPrefix(a []input.Input, b []input.Input) int32 { func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
var count int32 var count int32
for i := range a { for i := range a {
@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
} }
type ErrReprocessInputs struct { type ErrReprocessInputs struct {
Inputs []input.Input Inputs []*input.Input
} }
func (e *ErrReprocessInputs) Error() string { func (e *ErrReprocessInputs) Error() string {
@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
"id", slot.Id, "error", err) "id", slot.Id, "error", err)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard) // Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard)) newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep]) copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the cache // Reset the cache
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32) _ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
slot.Inputs = []input.Input{} slot.Inputs = []*input.Input{}
// Return error with inputs that need to be reprocessed // Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs} return &ErrReprocessInputs{Inputs: newInputs}

View File

@ -13,50 +13,50 @@ import (
func TestCountCommon(t *testing.T) { func TestCountCommon(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
t1 []input.Input t1 []*input.Input
t2 []input.Input t2 []*input.Input
expected int32 expected int32
}{ }{
{ {
name: "Equal", name: "Equal",
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3, expected: 3,
}, },
{ {
name: "Prefix", name: "Prefix",
t1: []input.Input{{Token: 1}}, t1: []*input.Input{{Token: 1}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Image Prefix", name: "Image Prefix",
t1: []input.Input{{MultimodalHash: 1}}, t1: []*input.Input{{MultimodalHash: 1}},
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}}, t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Mixed", name: "Mixed",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}}, t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
expected: 2, expected: 2,
}, },
{ {
name: "Mixed, Same Length", name: "Mixed, Same Length",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}}, t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
expected: 1, expected: 1,
}, },
{ {
name: "Empty", name: "Empty",
t1: []input.Input{}, t1: []*input.Input{},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0, expected: 0,
}, },
{ {
name: "Both Empty", name: "Both Empty",
t1: []input.Input{}, t1: []*input.Input{},
t2: []input.Input{}, t2: []*input.Input{},
expected: 0, expected: 0,
}, },
} }
@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cache InputCache cache InputCache
prompt []input.Input prompt []*input.Input
longest expected longest expected
best expected best expected
}{ }{
@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input.Input{{Token: 1}}, prompt: []*input.Input{{Token: 1}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0}, best: expected{result: 0, len: 0},
}, },
@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}}, Inputs: []*input.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input.Input{{Token: 1}, {Token: 2}}, prompt: []*input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2}, longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },
@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input.Input{{Token: 2}}, prompt: []*input.Input{{Token: 2}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}}, prompt: []*input.Input{{Token: 1}},
longest: expected{result: 0, len: 1}, longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1}, best: expected{result: 1, len: 1},
}, },
@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}}, Inputs: []*input.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input.Input{{Token: 2}, {Token: 3}}, prompt: []*input.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: true, InUse: true,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{{Token: 1}}, Inputs: []*input.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input.Input{{Token: 1}, {Token: 2}}, prompt: []*input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1}, longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },
@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cache InputCache cache InputCache
prompt []input.Input prompt []*input.Input
wantErr bool wantErr bool
expectedSlotId int expectedSlotId int
expectedPrompt int // expected length of remaining prompt expectedPrompt int // expected length of remaining prompt
@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false, wantErr: false,
expectedSlotId: 0, expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains expectedPrompt: 1, // Only token 3 remains
@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []*input.Input{},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false, wantErr: false,
expectedSlotId: 0, expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains expectedPrompt: 1, // Only token 3 remains
@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}, {Token: 2}}, prompt: []*input.Input{{Token: 1}, {Token: 2}},
wantErr: false, wantErr: false,
expectedSlotId: 0, expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling expectedPrompt: 1, // Should leave 1 token for sampling
@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: true, InUse: true,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true, wantErr: true,
expectedSlotId: -1, expectedSlotId: -1,
expectedPrompt: -1, expectedPrompt: -1,
@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
numCtx int32 numCtx int32
inputs []input.Input inputs []*input.Input
numKeep int32 numKeep int32
cacheErr bool cacheErr bool
wantErr any wantErr any
@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
{ {
name: "Normal shift", name: "Normal shift",
numCtx: 10, numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2, numKeep: 2,
cacheErr: false, // No error cacheErr: false, // No error
wantErr: nil, wantErr: nil,
@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
{ {
name: "Cache removal fails", name: "Cache removal fails",
numCtx: 10, numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2, numKeep: 2,
cacheErr: true, cacheErr: true,
wantErr: &ErrReprocessInputs{}, wantErr: &ErrReprocessInputs{},
@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
} }
slot := &InputCacheSlot{ slot := &InputCacheSlot{
Id: 123, Id: 123,
Inputs: make([]input.Input, len(tt.inputs)), Inputs: make([]*input.Input, len(tt.inputs)),
} }
copy(slot.Inputs, tt.inputs) copy(slot.Inputs, tt.inputs)

View File

@ -17,6 +17,7 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"runtime" "runtime"
"runtime/debug"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -51,10 +52,10 @@ type Sequence struct {
iBatch int iBatch int
// prompt inputs left to evaluate // prompt inputs left to evaluate
inputs []input.Input inputs []*input.Input
// inputs that have been added to a batch but not yet submitted to Forward // inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []input.Input pendingInputs []*input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []string
@ -182,8 +183,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) { func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
var inputs []input.Input var inputs []*input.Input
var ctxs []ml.Context var ctxs []ml.Context
var mmStore multimodalStore var mmStore multimodalStore
@ -210,7 +211,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
} }
for _, t := range tokens { for _, t := range tokens {
inputs = append(inputs, input.Input{Token: t}) inputs = append(inputs, &input.Input{Token: t})
} }
// image - decode and store // image - decode and store
@ -243,7 +244,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
mmStore.addMultimodal(imageEmbeddings) mmStore.addMultimodal(imageEmbeddings)
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true postTokenize = true
} }
} }
@ -259,6 +260,37 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
return inputs, ctxs, mmStore, nil return inputs, ctxs, mmStore, nil
} }
type batchState struct {
// id provides a counter for trace logging batches
id int
// ctx holds the backend context used for this batch
ctx ml.Context
// modelOutput holds the outputs from this batch
modelOutput ml.Tensor
// batchInputs holds the input token pointers which may start as
// placeholders later filled in before calling ctx.Compute
batchInputs []*input.Input
// batch contains the inputs for a model forward pass
batch input.Batch
// full set of seqs at the time this batch was initiated
seqs []*Sequence
// Signaled when this batches inputs are ready and compute can proceed
inputsReadyCh chan struct{}
// Signaling when Compute is about to begin on this batch, and
// seqs have been updated to prepare for the next batch
computeStartedCh chan struct{}
// Signaled when this batches outputs are complete and the next batch can proceed
outputsReadyCh chan struct{}
}
type Server struct { type Server struct {
// modelPath is the location of the model to be loaded // modelPath is the location of the model to be loaded
modelPath string modelPath string
@ -290,6 +322,12 @@ type Server struct {
// TODO (jmorganca): make this n_batch // TODO (jmorganca): make this n_batch
batchSize int batchSize int
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// Simple counter used only for trace logging batches
batchID int
// protects access to everything below this line // protects access to everything below this line
// this is context state needed for decoding // this is context state needed for decoding
mu sync.Mutex mu sync.Mutex
@ -362,33 +400,66 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
s.seqsSem.Release(1) s.seqsSem.Release(1)
} }
// track batch state between forwardBatch, computeBatch and predictForwardBatch
func (s *Server) run(ctx context.Context) { func (s *Server) run(ctx context.Context) {
s.ready.Wait() s.ready.Wait()
var activeBatch batchState
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case err := <-s.hardErrCh:
panic(err)
default: default:
err := s.processBatch() var err error
activeBatch, err = s.forwardBatch(activeBatch)
if err != nil { if err != nil {
panic(err) panic(err)
} }
go s.computeBatch(activeBatch)
} }
} }
} }
func (s *Server) processBatch() error { // forwardBatch will calculate a batch.
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
// If we have a pending batch still processing, wait until Compute has started
// before setting up the next batch so the seqs inputs are ready to receive their
// token values and we get the correct input pointers for the batchInputs
if pendingBatch.ctx != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
<-pendingBatch.computeStartedCh
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
// No pendingBatch, so the inputs will be ready in the seqs immediately
nextBatch.inputsReadyCh = make(chan struct{}, 1)
nextBatch.inputsReadyCh <- struct{}{}
}
s.mu.Lock() s.mu.Lock()
for s.allNil() { for s.allNil() {
s.cond.Wait() // Wait until an item is added s.cond.Wait() // Wait until an item is added
} }
defer s.mu.Unlock() defer s.mu.Unlock()
ctx := s.model.Backend().NewContext() nextBatch.ctx = s.model.Backend().NewContext()
defer ctx.Close() defer func() {
if err != nil {
nextBatch.ctx.Close()
nextBatch.ctx = nil
}
}()
nextBatch.id = s.batchID
nextBatch.seqs = append([]*Sequence{}, s.seqs...)
nextBatch.computeStartedCh = make(chan struct{}, 1)
nextBatch.outputsReadyCh = make(chan struct{}, 1)
var batchInputs []int32 // Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
var batch input.Batch var batch input.Batch
resumeSeq := -1 resumeSeq := -1
@ -396,7 +467,6 @@ func (s *Server) processBatch() error {
for range s.seqs { for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs) seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx] seq := s.seqs[seqIdx]
if seq == nil { if seq == nil {
continue continue
} }
@ -404,12 +474,13 @@ func (s *Server) processBatch() error {
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength) s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue continue
} }
if !s.cache.enabled { if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...) seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []input.Input{} seq.cache.Inputs = []*input.Input{}
} }
batchSize := s.batchSize batchSize := s.batchSize
@ -442,25 +513,28 @@ func (s *Server) processBatch() error {
break break
} }
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil { if err != nil {
var reprocess *ErrReprocessInputs var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) { if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing // Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...) seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest // Skip this sequence but continue processing the rest
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
err = nil
continue continue
} else { } else {
return err return
} }
} }
} }
batchInputs = append(batchInputs, inp.Token) batchInputs = append(batchInputs, seq.inputs[i])
if inp.Multimodal != nil { if inp.Multimodal != nil {
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false) var mm []input.Multimodal
mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
if err != nil { if err != nil {
return err return
} }
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm}) batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
} }
@ -472,6 +546,7 @@ func (s *Server) processBatch() error {
if i+1 == len(seq.inputs) { if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
} }
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp) seq.pendingInputs = append(seq.pendingInputs, inp)
} }
@ -485,36 +560,129 @@ func (s *Server) processBatch() error {
} }
if len(batchInputs) == 0 { if len(batchInputs) == 0 {
return nil slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
nextBatch.ctx.Close()
nextBatch.ctx = nil
return
} }
s.batchID++
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode batch: %w", err) err = fmt.Errorf("failed to build graph: %w", err)
return
}
nextBatch.batchInputs = batchInputs
nextBatch.batch = batch
return
}
// Async processing of the next batch
func (s *Server) computeBatch(activeBatch batchState) {
if activeBatch.ctx == nil {
// Nothing to compute
return
}
defer activeBatch.ctx.Close()
// Wait until inputs are ready
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
<-activeBatch.inputsReadyCh
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", activeBatch.id)
// Once we complete, signal the next batch of inputs are ready
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
defer func() {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", activeBatch.id)
activeBatch.outputsReadyCh <- struct{}{}
}()
s.mu.Lock()
// Gather the actual input token values now that they're ready
batchInputs := make([]int32, len(activeBatch.batchInputs))
for i := range batchInputs {
batchInputs[i] = activeBatch.batchInputs[i].Token
} }
logits := modelOutput.Floats() // Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
// decoded tokens.
nextBatchTokens := make([]*input.Input, len(s.seqs))
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
for i, seq := range s.seqs { for i, seq := range s.seqs {
iBatches[i] = -1
if seq == nil { if seq == nil {
continue continue
} }
// Skip over any newly added or skipped sequences
if activeBatch.seqs[i] == nil {
continue
}
// After calling Forward, pending inputs are now in the cache // Detect if the sequence we're processing has already been completed and replaced
// with a new sequence
if seq != activeBatch.seqs[i] {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
// Pending inputs will actually be in the cache after we call Compute.
// However, we have already resolved any placeholder tokens.
//
// It's possible for incoming sequences to look at the values that we've
// added to the cache here and start relying on them before we've done
// the computation. This is OK as long as we ensure that this batch's
// computation happens before any future batch's and we never fail
// (unless we take down the whole runner).
if len(seq.pendingInputs) > 0 { if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input.Input{} seq.pendingInputs = []*input.Input{}
} }
// don't sample prompt processing // don't sample prompt processing
if len(seq.inputs) != 0 { if len(seq.inputs) != 0 {
if !s.cache.enabled { if !s.cache.enabled {
return errors.New("caching disabled but unable to fit entire input in a batch") s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
s.mu.Unlock()
return
} }
continue continue
} }
seq.numPredicted++ seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
iBatches[i] = seq.iBatch
}
// At this point the seqs are ready for forwardBatch to move forward so unblock
s.mu.Unlock()
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
activeBatch.ctx.ComputeWithNotify(
func() {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
activeBatch.computeStartedCh <- struct{}{}
},
activeBatch.modelOutput)
logits := activeBatch.modelOutput.Floats()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", activeBatch.id)
s.mu.Lock()
defer s.mu.Unlock()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", activeBatch.id)
for i, seq := range s.seqs {
if seq == nil || nextBatchTokens[i] == nil {
continue
}
if seq.numPredicted == 1 { if seq.numPredicted == 1 {
seq.startGenerationTime = time.Now() seq.startGenerationTime = time.Now()
} }
@ -522,36 +690,38 @@ func (s *Server) processBatch() error {
// if done processing the prompt, generate an embedding and return // if done processing the prompt, generate an embedding and return
if seq.embeddingOnly { if seq.embeddingOnly {
// TODO(jessegross): Embedding support // TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported") slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
s.removeSequence(i, llm.DoneReasonStop) s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
// sample a token // sample a token
vocabSize := len(logits) / len(batch.Outputs) vocabSize := len(logits) / len(activeBatch.batch.Outputs)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil { if err != nil {
return fmt.Errorf("failed to sample token: %w", err) s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
} }
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break // if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) { if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back // TODO (jmorganca): we should send this back
// as it's important for the /api/generate context // as it's important for the /api/generate context
// seq.responses <- piece // seq.responses <- piece
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
s.removeSequence(i, llm.DoneReasonStop) s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
piece, err := s.model.(model.TextProcessor).Decode([]int32{token}) piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil { if err != nil {
return err s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
return
} }
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "") sequence := strings.Join(seq.pendingResponses, "")
@ -575,6 +745,7 @@ func (s *Server) processBatch() error {
if tokenTruncated || origLen == newLen { if tokenTruncated || origLen == newLen {
tokenLen-- tokenLen--
} }
seq.cache.Inputs = seq.cache.Inputs[:tokenLen] seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, llm.DoneReasonStop) s.removeSequence(i, llm.DoneReasonStop)
@ -593,8 +764,6 @@ func (s *Server) processBatch() error {
s.removeSequence(i, llm.DoneReasonConnectionClosed) s.removeSequence(i, llm.DoneReasonConnectionClosed)
} }
} }
return nil
} }
func (s *Server) completion(w http.ResponseWriter, r *http.Request) { func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
@ -736,7 +905,10 @@ func (s *Server) reserveWorstCaseGraph() error {
defer ctx.Close() defer ctx.Close()
var err error var err error
inputs := make([]input.Input, s.batchSize) inputs := make([]*input.Input, s.batchSize)
for i := range inputs {
inputs[i] = &input.Input{}
}
mmStore := newMultimodalStore() mmStore := newMultimodalStore()
// Multimodal strategy: // Multimodal strategy:
@ -778,8 +950,11 @@ func (s *Server) reserveWorstCaseGraph() error {
} }
if len(inputs) < s.batchSize { if len(inputs) < s.batchSize {
newInputs := make([]input.Input, s.batchSize) newInputs := make([]*input.Input, s.batchSize)
copy(newInputs, inputs) copy(newInputs, inputs)
for i := len(inputs); i < s.batchSize; i++ {
newInputs[i] = &input.Input{}
}
inputs = newInputs inputs = newInputs
} }
} }
@ -842,6 +1017,7 @@ func (s *Server) allocModel(
// Convert memory allocation panics to errors // Convert memory allocation panics to errors
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
debug.PrintStack()
if err, ok := r.(error); ok { if err, ok := r.(error); ok {
panicErr = err panicErr = err
} else { } else {
@ -1011,6 +1187,7 @@ func Execute(args []string) error {
server := &Server{ server := &Server{
modelPath: *mpath, modelPath: *mpath,
status: llm.ServerStatusLaunched, status: llm.ServerStatusLaunched,
hardErrCh: make(chan error, 1),
} }
server.cond = sync.NewCond(&server.mu) server.cond = sync.NewCond(&server.mu)