mirror of https://github.com/ollama/ollama
perf: build graph for next batch async to keep GPU busy (#11863)
* perf: build graph for next batch in parallel to keep GPU busy This refactors the main run loop of the ollama runner to perform the main GPU intensive tasks (Compute+Floats) in a go routine so we can prepare the next batch in parallel to reduce the amount of time the GPU stalls waiting for the next batch of work. * tests: tune integration tests for ollama engine This tunes the integration tests to focus more on models supported by the new engine.
This commit is contained in:
parent
ead4a9a1d0
commit
517807cdf2
|
|
@ -2,10 +2,13 @@
|
||||||
|
|
||||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||||
|
|
||||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||||
|
|
||||||
|
|
||||||
The integration tests have 2 modes of operating.
|
The integration tests have 2 modes of operating.
|
||||||
|
|
||||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||||
|
|
|
||||||
|
|
@ -390,7 +390,7 @@ func TestAPIEmbeddings(t *testing.T) {
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
req := api.EmbeddingRequest{
|
req := api.EmbeddingRequest{
|
||||||
Model: "orca-mini",
|
Model: libraryEmbedModels[0],
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: "why is the sky blue?",
|
||||||
Options: map[string]interface{}{
|
Options: map[string]interface{}{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBlueSky(t *testing.T) {
|
func TestBlueSky(t *testing.T) {
|
||||||
|
|
@ -37,8 +36,8 @@ func TestUnicode(t *testing.T) {
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
||||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||||
Prompt: "天空为什么是蓝色的?",
|
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
|
|
@ -50,8 +49,20 @@ func TestUnicode(t *testing.T) {
|
||||||
}
|
}
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second)
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
slog.Info("loading", "model", req.Model)
|
||||||
|
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||||
|
}
|
||||||
|
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
|
||||||
|
|
||||||
|
DoGenerate(ctx, t, client, req, []string{
|
||||||
|
"散射", // scattering
|
||||||
|
"频率", // frequency
|
||||||
|
}, 120*time.Second, 120*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtendedUnicodeOutput(t *testing.T) {
|
func TestExtendedUnicodeOutput(t *testing.T) {
|
||||||
|
|
@ -69,7 +80,9 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||||
}
|
}
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -84,7 +97,9 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
defer os.RemoveAll(modelDir)
|
defer os.RemoveAll(modelDir)
|
||||||
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
|
@ -79,21 +77,21 @@ func TestMultiModelStress(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// All models compatible with ollama-engine
|
||||||
smallModels := []string{
|
smallModels := []string{
|
||||||
"llama3.2:1b",
|
"llama3.2:1b",
|
||||||
"qwen3:0.6b",
|
"qwen3:0.6b",
|
||||||
"gemma:2b",
|
"gemma2:2b",
|
||||||
"deepseek-r1:1.5b",
|
"deepseek-r1:1.5b", // qwen2 arch
|
||||||
"starcoder2:3b",
|
"gemma3:270m",
|
||||||
}
|
}
|
||||||
mediumModels := []string{
|
mediumModels := []string{
|
||||||
"qwen3:8b",
|
"llama3.2:3b", // ~3.4G
|
||||||
"llama2",
|
"qwen3:8b", // ~6.6G
|
||||||
"deepseek-r1:7b",
|
"gpt-oss:20b", // ~15G
|
||||||
"mistral",
|
"deepseek-r1:7b", // ~5.6G
|
||||||
"dolphin-mistral",
|
"gemma3:4b", // ~5.8G
|
||||||
"gemma:7b",
|
"gemma2:9b", // ~8.1G
|
||||||
"codellama:7b",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var chosenModels []string
|
var chosenModels []string
|
||||||
|
|
@ -114,7 +112,9 @@ func TestMultiModelStress(t *testing.T) {
|
||||||
|
|
||||||
// Make sure all the models are pulled before we get started
|
// Make sure all the models are pulled before we get started
|
||||||
for _, model := range chosenModels {
|
for _, model := range chosenModels {
|
||||||
require.NoError(t, PullIfMissing(ctx, client, model))
|
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine how many models we can load in parallel before we exceed VRAM
|
// Determine how many models we can load in parallel before we exceed VRAM
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ func TestLongInputContext(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "llama2",
|
Model: smol,
|
||||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
|
|
@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("PullIfMissing failed: %v", err)
|
t.Fatalf("PullIfMissing failed: %v", err)
|
||||||
}
|
}
|
||||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second)
|
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextExhaustion(t *testing.T) {
|
func TestContextExhaustion(t *testing.T) {
|
||||||
|
|
@ -49,7 +49,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "llama2",
|
Model: smol,
|
||||||
Prompt: "Write me a story with a ton of emojis?",
|
Prompt: "Write me a story with a ton of emojis?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
|
|
@ -63,10 +63,10 @@ func TestContextExhaustion(t *testing.T) {
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("PullIfMissing failed: %v", err)
|
t.Fatalf("PullIfMissing failed: %v", err)
|
||||||
}
|
}
|
||||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send multiple requests with prior context and ensure the response is coherant and expected
|
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||||
func TestGenerateWithHistory(t *testing.T) {
|
func TestGenerateWithHistory(t *testing.T) {
|
||||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||||
req, resp := GenerateRequests()
|
req, resp := GenerateRequests()
|
||||||
|
|
@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||||
|
func TestChatWithHistory(t *testing.T) {
|
||||||
|
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||||
|
req, resp := ChatRequests()
|
||||||
|
numParallel := 2
|
||||||
|
iterLimit := 2
|
||||||
|
|
||||||
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||||
|
slog.Info("loading", "model", modelOverride)
|
||||||
|
err := client.Generate(ctx,
|
||||||
|
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||||
|
func(response api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numParallel)
|
||||||
|
for i := range numParallel {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
k := i % len(req)
|
||||||
|
req[k].Model = modelOverride
|
||||||
|
for j := 0; j < iterLimit; j++ {
|
||||||
|
if time.Now().Sub(started) > softTimeout {
|
||||||
|
slog.Info("exceeded soft timeout, winding down test")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Info("Starting", "thread", i, "iter", j)
|
||||||
|
// On slower GPUs it can take a while to process the concurrent requests
|
||||||
|
// so we allow a much longer initial timeout
|
||||||
|
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||||
|
if assistant == nil {
|
||||||
|
t.Fatalf("didn't get an assistant response for context")
|
||||||
|
}
|
||||||
|
req[k].Messages = append(req[k].Messages,
|
||||||
|
*assistant,
|
||||||
|
api.Message{Role: "user", Content: "tell me more!"},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVisionModels(t *testing.T) {
|
func TestVisionModels(t *testing.T) {
|
||||||
|
|
@ -32,7 +31,9 @@ func TestVisionModels(t *testing.T) {
|
||||||
for _, v := range testCases {
|
for _, v := range testCases {
|
||||||
t.Run(v.model, func(t *testing.T) {
|
t.Run(v.model, func(t *testing.T) {
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: v.model,
|
Model: v.model,
|
||||||
Prompt: "what does the text in this image say?",
|
Prompt: "what does the text in this image say?",
|
||||||
|
|
@ -52,7 +53,9 @@ func TestVisionModels(t *testing.T) {
|
||||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||||
resp := "the ollam"
|
resp := "the ollam"
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
// llava models on CPU can be quite slow to start
|
// llava models on CPU can be quite slow to start
|
||||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||||
})
|
})
|
||||||
|
|
@ -62,7 +65,9 @@ func TestVisionModels(t *testing.T) {
|
||||||
func TestIntegrationSplitBatch(t *testing.T) {
|
func TestIntegrationSplitBatch(t *testing.T) {
|
||||||
skipUnderMinVRAM(t, 6)
|
skipUnderMinVRAM(t, 6)
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "gemma3:4b",
|
Model: "gemma3:4b",
|
||||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||||
|
|
@ -84,7 +89,9 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
// llava models on CPU can be quite slow to start,
|
// llava models on CPU can be quite slow to start,
|
||||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,47 +0,0 @@
|
||||||
//go:build integration
|
|
||||||
|
|
||||||
package integration
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
|
||||||
// package to avoid circular dependencies
|
|
||||||
|
|
||||||
var (
|
|
||||||
stream = false
|
|
||||||
req = [2]api.GenerateRequest{
|
|
||||||
{
|
|
||||||
Model: smol,
|
|
||||||
Prompt: "why is the ocean blue?",
|
|
||||||
Stream: &stream,
|
|
||||||
Options: map[string]any{
|
|
||||||
"seed": 42,
|
|
||||||
"temperature": 0.0,
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Model: smol,
|
|
||||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
|
||||||
Stream: &stream,
|
|
||||||
Options: map[string]any{
|
|
||||||
"seed": 42,
|
|
||||||
"temperature": 0.0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
resp = [2][]string{
|
|
||||||
{"sunlight", "scattering", "interact"},
|
|
||||||
{"england", "english", "massachusetts", "pilgrims"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIntegrationSimple(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
|
||||||
defer cancel()
|
|
||||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
|
||||||
}
|
|
||||||
|
|
@ -13,12 +13,12 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMaxQueue(t *testing.T) {
|
func TestMaxQueue(t *testing.T) {
|
||||||
|
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||||
|
|
||||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||||
return
|
return
|
||||||
|
|
@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) {
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Context for the worker threads so we can shut them down
|
// Context for the worker threads so we can shut them down
|
||||||
// embedCtx, embedCancel := context.WithCancel(ctx)
|
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||||
|
|
@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) {
|
||||||
switch {
|
switch {
|
||||||
case genErr == nil:
|
case genErr == nil:
|
||||||
successCount++
|
successCount++
|
||||||
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
|
if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
|
||||||
|
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
|
||||||
|
}
|
||||||
case errors.Is(genErr, context.Canceled):
|
case errors.Is(genErr, context.Canceled):
|
||||||
canceledCount++
|
canceledCount++
|
||||||
case strings.Contains(genErr.Error(), "busy"):
|
case strings.Contains(genErr.Error(), "busy"):
|
||||||
|
|
@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) {
|
||||||
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
||||||
resetByPeerCount++
|
resetByPeerCount++
|
||||||
default:
|
default:
|
||||||
require.NoError(t, genErr, "%d request failed", i)
|
if genErr != nil {
|
||||||
|
t.Fatalf("%d request failed", i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("embed finished", "id", i)
|
slog.Info("embed finished", "id", i)
|
||||||
|
|
@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) {
|
||||||
embedwg.Wait()
|
embedwg.Wait()
|
||||||
|
|
||||||
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
||||||
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
|
if resetByPeerCount != 0 {
|
||||||
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
|
t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
|
||||||
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
|
}
|
||||||
|
if busyCount == 0 {
|
||||||
|
t.Fatalf("no requests hit busy error but some should have")
|
||||||
|
}
|
||||||
|
if canceledCount > 0 {
|
||||||
|
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
@ -25,11 +26,11 @@ import (
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/lifecycle"
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
smol = "llama3.2:1b"
|
smol = "llama3.2:1b"
|
||||||
|
stream = false
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
@ -435,7 +436,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||||
}
|
}
|
||||||
lifecycle.ServerLogFile = fp.Name()
|
lifecycle.ServerLogFile = fp.Name()
|
||||||
fp.Close()
|
fp.Close()
|
||||||
require.NoError(t, startServer(t, ctx, testEndpoint))
|
if err := startServer(t, ctx, testEndpoint); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return client, testEndpoint, func() {
|
return client, testEndpoint, func() {
|
||||||
|
|
@ -468,7 +471,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -509,7 +514,9 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||||
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
||||||
return context
|
return context
|
||||||
}
|
}
|
||||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
if genErr != nil {
|
||||||
|
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||||
|
}
|
||||||
// Verify the response contains the expected data
|
// Verify the response contains the expected data
|
||||||
response := buf.String()
|
response := buf.String()
|
||||||
atLeastOne := false
|
atLeastOne := false
|
||||||
|
|
@ -519,7 +526,9 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
|
if !atLeastOne {
|
||||||
|
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||||
|
}
|
||||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("outer test context done while waiting for generate")
|
t.Error("outer test context done while waiting for generate")
|
||||||
|
|
@ -561,17 +570,97 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||||
[][]string{
|
[][]string{
|
||||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states"},
|
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
|
||||||
{"fourth", "july", "declaration", "independence"},
|
{"fourth", "july", "declaration", "independence"},
|
||||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
role := "assistant"
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
// fmt.Print(".")
|
||||||
|
role = response.Message.Role
|
||||||
|
buf.Write([]byte(response.Message.Content))
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return errors.New("stall was detected while streaming response, aborting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := true
|
||||||
|
req.Stream = &stream
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
genErr = client.Chat(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||||
|
} else {
|
||||||
|
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||||
|
}
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||||
|
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if genErr != nil {
|
||||||
|
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response := buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("outer test context done while waiting for generate")
|
||||||
|
}
|
||||||
|
return &api.Message{Role: role, Content: buf.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||||
|
genReqs, results := GenerateRequests()
|
||||||
|
reqs := make([]api.ChatRequest, len(genReqs))
|
||||||
|
// think := api.ThinkValue{Value: "low"}
|
||||||
|
for i := range reqs {
|
||||||
|
reqs[i].Model = genReqs[i].Model
|
||||||
|
reqs[i].Stream = genReqs[i].Stream
|
||||||
|
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||||
|
// reqs[i].Think = &think
|
||||||
|
reqs[i].Messages = []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: genReqs[i].Prompt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reqs, results
|
||||||
|
}
|
||||||
|
|
||||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||||
// TODO use info API in the future
|
// TODO use info API in the future
|
||||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
// Don't hammer on small VRAM cards...
|
// Don't hammer on small VRAM cards...
|
||||||
if maxVram < gb*format.GibiByte {
|
if maxVram < gb*format.GibiByte {
|
||||||
t.Skip("skipping with small VRAM to avoid timeouts")
|
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||||
|
|
@ -579,6 +668,39 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
|
||||||
|
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
|
||||||
|
models, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to list running models: %s", err)
|
||||||
|
}
|
||||||
|
loaded := []string{}
|
||||||
|
for _, m := range models.Models {
|
||||||
|
loaded = append(loaded, m.Name)
|
||||||
|
if m.Name != model {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
gpuPercent := 0
|
||||||
|
switch {
|
||||||
|
case m.SizeVRAM == 0:
|
||||||
|
gpuPercent = 0
|
||||||
|
case m.SizeVRAM == m.Size:
|
||||||
|
gpuPercent = 100
|
||||||
|
case m.SizeVRAM > m.Size || m.Size == 0:
|
||||||
|
t.Logf("unexpected size detected: %d", m.SizeVRAM)
|
||||||
|
default:
|
||||||
|
sizeCPU := m.Size - m.SizeVRAM
|
||||||
|
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
|
||||||
|
gpuPercent = int(100 - cpuPercent)
|
||||||
|
}
|
||||||
|
if gpuPercent < minPercent {
|
||||||
|
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
|
||||||
|
}
|
||||||
|
|
||||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||||
deadline, hasDeadline := t.Deadline()
|
deadline, hasDeadline := t.Deadline()
|
||||||
if !hasDeadline {
|
if !hasDeadline {
|
||||||
|
|
|
||||||
|
|
@ -372,6 +372,7 @@ type Context interface {
|
||||||
|
|
||||||
Forward(...Tensor) Context
|
Forward(...Tensor) Context
|
||||||
Compute(...Tensor)
|
Compute(...Tensor)
|
||||||
|
ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||||
|
|
||||||
// Reserve is analogous to Compute but rather than executing a
|
// Reserve is analogous to Compute but rather than executing a
|
||||||
// graph, simply preallocates memory. Typically called with a
|
// graph, simply preallocates memory. Typically called with a
|
||||||
|
|
@ -401,6 +402,8 @@ type Tensor interface {
|
||||||
Bytes() []byte
|
Bytes() []byte
|
||||||
Floats() []float32
|
Floats() []float32
|
||||||
|
|
||||||
|
SetValueFromIntSlice(s []int32)
|
||||||
|
|
||||||
Neg(ctx Context) Tensor
|
Neg(ctx Context) Tensor
|
||||||
Add(ctx Context, t2 Tensor) Tensor
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
Sub(ctx Context, t2 Tensor) Tensor
|
Sub(ctx Context, t2 Tensor) Tensor
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,7 @@ type Backend struct {
|
||||||
// to the name that is used by the model definition
|
// to the name that is used by the model definition
|
||||||
tensorLoadTargets map[string][]string
|
tensorLoadTargets map[string][]string
|
||||||
|
|
||||||
|
schedMu sync.Mutex // Only one Compute can run at a time
|
||||||
sched C.ggml_backend_sched_t
|
sched C.ggml_backend_sched_t
|
||||||
schedBackends []C.ggml_backend_t
|
schedBackends []C.ggml_backend_t
|
||||||
schedBufts []C.ggml_backend_buffer_type_t
|
schedBufts []C.ggml_backend_buffer_type_t
|
||||||
|
|
@ -758,6 +759,15 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||||
|
c.ComputeWithNotify(nil, tensors...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
|
||||||
|
c.b.schedMu.Lock()
|
||||||
|
defer c.b.schedMu.Unlock()
|
||||||
|
if cb != nil {
|
||||||
|
go cb()
|
||||||
|
}
|
||||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||||
}
|
}
|
||||||
|
|
@ -1010,6 +1020,12 @@ func (t *Tensor) Floats() (data []float32) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) SetValueFromIntSlice(s []int32) {
|
||||||
|
if len(s) > 0 {
|
||||||
|
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) DType() ml.DType {
|
func (t *Tensor) DType() ml.DType {
|
||||||
switch t.t._type {
|
switch t.t._type {
|
||||||
case C.GGML_TYPE_F32:
|
case C.GGML_TYPE_F32:
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ type MultimodalProcessor interface {
|
||||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||||
// that is modified to ensure that there is a unique hash value that accurately
|
// that is modified to ensure that there is a unique hash value that accurately
|
||||||
// represents the contents.
|
// represents the contents.
|
||||||
PostTokenize([]input.Input) ([]input.Input, error)
|
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Base implements the common fields and methods for all models
|
// Base implements the common fields and methods for all models
|
||||||
|
|
@ -278,7 +278,7 @@ func canNil(t reflect.Type) bool {
|
||||||
t.Kind() == reflect.Slice
|
t.Kind() == reflect.Slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||||
if len(batch.Positions) != len(batch.Sequences) {
|
if len(batch.Positions) != len(batch.Sequences) {
|
||||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||||
}
|
}
|
||||||
|
|
@ -287,8 +287,6 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
||||||
return nil, errors.New("batch size cannot be less than 1")
|
return nil, errors.New("batch size cannot be less than 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
|
||||||
|
|
||||||
cache := m.Config().Cache
|
cache := m.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
err := cache.StartForward(ctx, batch, false)
|
err := cache.StartForward(ctx, batch, false)
|
||||||
|
|
@ -302,7 +300,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(t).Compute(t)
|
ctx.Forward(t)
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
|
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if len(inp.Multimodal) == 0 {
|
if len(inp.Multimodal) == 0 {
|
||||||
|
|
@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
inputMultimodal := inp.Multimodal[0].Tensor
|
inputMultimodal := inp.Multimodal[0].Tensor
|
||||||
|
|
||||||
result = append(result,
|
result = append(result,
|
||||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||||
input.Input{Token: 255999}, // "<start_of_image>""
|
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||||
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||||
)
|
)
|
||||||
|
|
||||||
// add image token placeholders
|
// add image token placeholders
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||||
|
|
||||||
result = append(result,
|
result = append(result,
|
||||||
input.Input{Token: 256000}, // <end_of_image>
|
&input.Input{Token: 256000}, // <end_of_image>
|
||||||
input.Input{Token: 108}, // "\n\n"
|
&input.Input{Token: 108}, // "\n\n"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -134,16 +134,16 @@ type separator struct {
|
||||||
y bool
|
y bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if len(inp.Multimodal) == 0 {
|
if len(inp.Multimodal) == 0 {
|
||||||
result = append(result, inp)
|
result = append(result, inp)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var imageInputs []input.Input
|
var imageInputs []*input.Input
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
|
||||||
|
|
||||||
for i, mm := range inp.Multimodal {
|
for i, mm := range inp.Multimodal {
|
||||||
patchesPerChunk := mm.Tensor.Dim(1)
|
patchesPerChunk := mm.Tensor.Dim(1)
|
||||||
|
|
@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
if i < len(inp.Multimodal)-1 {
|
if i < len(inp.Multimodal)-1 {
|
||||||
separator := mm.Data.(*separator)
|
separator := mm.Data.(*separator)
|
||||||
|
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||||
|
|
||||||
if separator.x {
|
if separator.x {
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||||
}
|
}
|
||||||
if separator.y {
|
if separator.y {
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||||
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
||||||
// that can be processed together.
|
// that can be processed together.
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if len(inp.Multimodal) == 0 {
|
if len(inp.Multimodal) == 0 {
|
||||||
result = append(result, inp)
|
result = append(result, inp)
|
||||||
} else {
|
} else {
|
||||||
for i, row := range inp.Multimodal {
|
for i, row := range inp.Multimodal {
|
||||||
// [IMG]
|
// [IMG]
|
||||||
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||||
if i == len(inp.Multimodal)-1 {
|
if i == len(inp.Multimodal)-1 {
|
||||||
// [IMG_END]
|
// [IMG_END]
|
||||||
result = append(result, input.Input{Token: 13})
|
result = append(result, &input.Input{Token: 13})
|
||||||
} else {
|
} else {
|
||||||
// [IMG_BREAK]
|
// [IMG_BREAK]
|
||||||
result = append(result, input.Input{Token: 12})
|
result = append(result, &input.Input{Token: 12})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
for i := range inputs {
|
for i := range inputs {
|
||||||
if inputs[i].Multimodal != nil {
|
if inputs[i].Multimodal != nil {
|
||||||
inputs[i].Token = 128256 // <|image|>
|
inputs[i].Token = 128256 // <|image|>
|
||||||
|
|
|
||||||
|
|
@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
}
|
}
|
||||||
|
|
||||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
|
|
||||||
var (
|
var (
|
||||||
imageToken int32 = 151655
|
imageToken int32 = 151655
|
||||||
|
|
@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||||
}
|
}
|
||||||
for i := range pre {
|
for i := range pre {
|
||||||
result = append(result, input.Input{Token: pre[i]})
|
result = append(result, &input.Input{Token: pre[i]})
|
||||||
}
|
}
|
||||||
|
|
||||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||||
|
|
||||||
// First add the vision start token
|
// First add the vision start token
|
||||||
result = append(result, input.Input{Token: visionStartToken})
|
result = append(result, &input.Input{Token: visionStartToken})
|
||||||
|
|
||||||
// Add the image token with the multimodal tensor data at the first position
|
// Add the image token with the multimodal tensor data at the first position
|
||||||
result = append(result, input.Input{
|
result = append(result, &input.Input{
|
||||||
Token: imageToken,
|
Token: imageToken,
|
||||||
Multimodal: inp.Multimodal,
|
Multimodal: inp.Multimodal,
|
||||||
MultimodalHash: inp.MultimodalHash,
|
MultimodalHash: inp.MultimodalHash,
|
||||||
|
|
@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||||
|
|
||||||
result = append(result, input.Input{Token: visionEndToken})
|
result = append(result, &input.Input{Token: visionEndToken})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ type InputCacheSlot struct {
|
||||||
Id int
|
Id int
|
||||||
|
|
||||||
// Inputs that are stored in the KV cache
|
// Inputs that are stored in the KV cache
|
||||||
Inputs []input.Input
|
Inputs []*input.Input
|
||||||
|
|
||||||
// is this cache actively being processed as part of a sequence?
|
// is this cache actively being processed as part of a sequence?
|
||||||
InUse bool
|
InUse bool
|
||||||
|
|
@ -95,7 +95,7 @@ type InputCacheSlot struct {
|
||||||
lastUsed time.Time
|
lastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
|
||||||
var slot *InputCacheSlot
|
var slot *InputCacheSlot
|
||||||
var numPast int32
|
var numPast int32
|
||||||
var err error
|
var err error
|
||||||
|
|
@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
||||||
return slot, prompt, nil
|
return slot, prompt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||||
longest := int32(-1)
|
longest := int32(-1)
|
||||||
var longestSlot *InputCacheSlot
|
var longestSlot *InputCacheSlot
|
||||||
|
|
||||||
|
|
@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
|
||||||
return longestSlot, longest, nil
|
return longestSlot, longest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||||
oldest := time.Now()
|
oldest := time.Now()
|
||||||
var oldestSlot *InputCacheSlot
|
var oldestSlot *InputCacheSlot
|
||||||
|
|
||||||
|
|
@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
||||||
if longest > 0 && longestSlot != oldestSlot {
|
if longest > 0 && longestSlot != oldestSlot {
|
||||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||||
len(longestSlot.Inputs))
|
len(longestSlot.Inputs))
|
||||||
oldestSlot.Inputs = make([]input.Input, longest)
|
oldestSlot.Inputs = make([]*input.Input, longest)
|
||||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||||
if c.cache != nil {
|
if c.cache != nil {
|
||||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||||
|
|
@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
||||||
return oldestSlot, longest, nil
|
return oldestSlot, longest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
|
||||||
var count int32
|
var count int32
|
||||||
|
|
||||||
for i := range a {
|
for i := range a {
|
||||||
|
|
@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ErrReprocessInputs struct {
|
type ErrReprocessInputs struct {
|
||||||
Inputs []input.Input
|
Inputs []*input.Input
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ErrReprocessInputs) Error() string {
|
func (e *ErrReprocessInputs) Error() string {
|
||||||
|
|
@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
||||||
"id", slot.Id, "error", err)
|
"id", slot.Id, "error", err)
|
||||||
|
|
||||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
|
||||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||||
|
|
||||||
// Reset the cache
|
// Reset the cache
|
||||||
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
||||||
slot.Inputs = []input.Input{}
|
slot.Inputs = []*input.Input{}
|
||||||
|
|
||||||
// Return error with inputs that need to be reprocessed
|
// Return error with inputs that need to be reprocessed
|
||||||
return &ErrReprocessInputs{Inputs: newInputs}
|
return &ErrReprocessInputs{Inputs: newInputs}
|
||||||
|
|
|
||||||
|
|
@ -13,50 +13,50 @@ import (
|
||||||
func TestCountCommon(t *testing.T) {
|
func TestCountCommon(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
t1 []input.Input
|
t1 []*input.Input
|
||||||
t2 []input.Input
|
t2 []*input.Input
|
||||||
expected int32
|
expected int32
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Equal",
|
name: "Equal",
|
||||||
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
expected: 3,
|
expected: 3,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Prefix",
|
name: "Prefix",
|
||||||
t1: []input.Input{{Token: 1}},
|
t1: []*input.Input{{Token: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Image Prefix",
|
name: "Image Prefix",
|
||||||
t1: []input.Input{{MultimodalHash: 1}},
|
t1: []*input.Input{{MultimodalHash: 1}},
|
||||||
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Mixed",
|
name: "Mixed",
|
||||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||||
expected: 2,
|
expected: 2,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Mixed, Same Length",
|
name: "Mixed, Same Length",
|
||||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
|
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty",
|
name: "Empty",
|
||||||
t1: []input.Input{},
|
t1: []*input.Input{},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
expected: 0,
|
expected: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Both Empty",
|
name: "Both Empty",
|
||||||
t1: []input.Input{},
|
t1: []*input.Input{},
|
||||||
t2: []input.Input{},
|
t2: []*input.Input{},
|
||||||
expected: 0,
|
expected: 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cache InputCache
|
cache InputCache
|
||||||
prompt []input.Input
|
prompt []*input.Input
|
||||||
longest expected
|
longest expected
|
||||||
best expected
|
best expected
|
||||||
}{
|
}{
|
||||||
|
|
@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}},
|
prompt: []*input.Input{{Token: 1}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 0, len: 0},
|
best: expected{result: 0, len: 0},
|
||||||
},
|
},
|
||||||
|
|
@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []*input.Input{{Token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
longest: expected{result: 1, len: 2},
|
longest: expected{result: 1, len: 2},
|
||||||
best: expected{result: 1, len: 2},
|
best: expected{result: 1, len: 2},
|
||||||
},
|
},
|
||||||
|
|
@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 2}},
|
prompt: []*input.Input{{Token: 2}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 1, len: 0},
|
best: expected{result: 1, len: 0},
|
||||||
},
|
},
|
||||||
|
|
@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
|
||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}},
|
prompt: []*input.Input{{Token: 1}},
|
||||||
longest: expected{result: 0, len: 1},
|
longest: expected{result: 0, len: 1},
|
||||||
best: expected{result: 1, len: 1},
|
best: expected{result: 1, len: 1},
|
||||||
},
|
},
|
||||||
|
|
@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []*input.Input{{Token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 2}, {Token: 3}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 1, len: 0},
|
best: expected{result: 1, len: 0},
|
||||||
},
|
},
|
||||||
|
|
@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: true,
|
InUse: true,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []*input.Input{{Token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
longest: expected{result: 1, len: 1},
|
longest: expected{result: 1, len: 1},
|
||||||
best: expected{result: 1, len: 2},
|
best: expected{result: 1, len: 2},
|
||||||
},
|
},
|
||||||
|
|
@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cache InputCache
|
cache InputCache
|
||||||
prompt []input.Input
|
prompt []*input.Input
|
||||||
wantErr bool
|
wantErr bool
|
||||||
expectedSlotId int
|
expectedSlotId int
|
||||||
expectedPrompt int // expected length of remaining prompt
|
expectedPrompt int // expected length of remaining prompt
|
||||||
|
|
@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
expectedSlotId: 0,
|
expectedSlotId: 0,
|
||||||
expectedPrompt: 1, // Only token 3 remains
|
expectedPrompt: 1, // Only token 3 remains
|
||||||
|
|
@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
expectedSlotId: 0,
|
expectedSlotId: 0,
|
||||||
expectedPrompt: 1, // Only token 3 remains
|
expectedPrompt: 1, // Only token 3 remains
|
||||||
|
|
@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
expectedSlotId: 0,
|
expectedSlotId: 0,
|
||||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||||
|
|
@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: true,
|
InUse: true,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
expectedSlotId: -1,
|
expectedSlotId: -1,
|
||||||
expectedPrompt: -1,
|
expectedPrompt: -1,
|
||||||
|
|
@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
numCtx int32
|
numCtx int32
|
||||||
inputs []input.Input
|
inputs []*input.Input
|
||||||
numKeep int32
|
numKeep int32
|
||||||
cacheErr bool
|
cacheErr bool
|
||||||
wantErr any
|
wantErr any
|
||||||
|
|
@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "Normal shift",
|
name: "Normal shift",
|
||||||
numCtx: 10,
|
numCtx: 10,
|
||||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
numKeep: 2,
|
numKeep: 2,
|
||||||
cacheErr: false, // No error
|
cacheErr: false, // No error
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
|
|
@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "Cache removal fails",
|
name: "Cache removal fails",
|
||||||
numCtx: 10,
|
numCtx: 10,
|
||||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
numKeep: 2,
|
numKeep: 2,
|
||||||
cacheErr: true,
|
cacheErr: true,
|
||||||
wantErr: &ErrReprocessInputs{},
|
wantErr: &ErrReprocessInputs{},
|
||||||
|
|
@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||||
}
|
}
|
||||||
slot := &InputCacheSlot{
|
slot := &InputCacheSlot{
|
||||||
Id: 123,
|
Id: 123,
|
||||||
Inputs: make([]input.Input, len(tt.inputs)),
|
Inputs: make([]*input.Input, len(tt.inputs)),
|
||||||
}
|
}
|
||||||
copy(slot.Inputs, tt.inputs)
|
copy(slot.Inputs, tt.inputs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"runtime/debug"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
@ -51,10 +52,10 @@ type Sequence struct {
|
||||||
iBatch int
|
iBatch int
|
||||||
|
|
||||||
// prompt inputs left to evaluate
|
// prompt inputs left to evaluate
|
||||||
inputs []input.Input
|
inputs []*input.Input
|
||||||
|
|
||||||
// inputs that have been added to a batch but not yet submitted to Forward
|
// inputs that have been added to a batch but not yet submitted to Forward
|
||||||
pendingInputs []input.Input
|
pendingInputs []*input.Input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []string
|
||||||
|
|
@ -182,8 +183,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||||
// inputs processes the prompt and images into a list of inputs
|
// inputs processes the prompt and images into a list of inputs
|
||||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
// decoding images
|
// decoding images
|
||||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
|
||||||
var inputs []input.Input
|
var inputs []*input.Input
|
||||||
var ctxs []ml.Context
|
var ctxs []ml.Context
|
||||||
var mmStore multimodalStore
|
var mmStore multimodalStore
|
||||||
|
|
||||||
|
|
@ -210,7 +211,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tokens {
|
for _, t := range tokens {
|
||||||
inputs = append(inputs, input.Input{Token: t})
|
inputs = append(inputs, &input.Input{Token: t})
|
||||||
}
|
}
|
||||||
|
|
||||||
// image - decode and store
|
// image - decode and store
|
||||||
|
|
@ -243,7 +244,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||||
|
|
||||||
mmStore.addMultimodal(imageEmbeddings)
|
mmStore.addMultimodal(imageEmbeddings)
|
||||||
|
|
||||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||||
postTokenize = true
|
postTokenize = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -259,6 +260,37 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||||
return inputs, ctxs, mmStore, nil
|
return inputs, ctxs, mmStore, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type batchState struct {
|
||||||
|
// id provides a counter for trace logging batches
|
||||||
|
id int
|
||||||
|
|
||||||
|
// ctx holds the backend context used for this batch
|
||||||
|
ctx ml.Context
|
||||||
|
|
||||||
|
// modelOutput holds the outputs from this batch
|
||||||
|
modelOutput ml.Tensor
|
||||||
|
|
||||||
|
// batchInputs holds the input token pointers which may start as
|
||||||
|
// placeholders later filled in before calling ctx.Compute
|
||||||
|
batchInputs []*input.Input
|
||||||
|
|
||||||
|
// batch contains the inputs for a model forward pass
|
||||||
|
batch input.Batch
|
||||||
|
|
||||||
|
// full set of seqs at the time this batch was initiated
|
||||||
|
seqs []*Sequence
|
||||||
|
|
||||||
|
// Signaled when this batches inputs are ready and compute can proceed
|
||||||
|
inputsReadyCh chan struct{}
|
||||||
|
|
||||||
|
// Signaling when Compute is about to begin on this batch, and
|
||||||
|
// seqs have been updated to prepare for the next batch
|
||||||
|
computeStartedCh chan struct{}
|
||||||
|
|
||||||
|
// Signaled when this batches outputs are complete and the next batch can proceed
|
||||||
|
outputsReadyCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
// modelPath is the location of the model to be loaded
|
// modelPath is the location of the model to be loaded
|
||||||
modelPath string
|
modelPath string
|
||||||
|
|
@ -290,6 +322,12 @@ type Server struct {
|
||||||
// TODO (jmorganca): make this n_batch
|
// TODO (jmorganca): make this n_batch
|
||||||
batchSize int
|
batchSize int
|
||||||
|
|
||||||
|
// Used to signal a hard failure during async processing which will panic the runner
|
||||||
|
hardErrCh chan error
|
||||||
|
|
||||||
|
// Simple counter used only for trace logging batches
|
||||||
|
batchID int
|
||||||
|
|
||||||
// protects access to everything below this line
|
// protects access to everything below this line
|
||||||
// this is context state needed for decoding
|
// this is context state needed for decoding
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
@ -362,33 +400,66 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||||
s.seqsSem.Release(1)
|
s.seqsSem.Release(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// track batch state between forwardBatch, computeBatch and predictForwardBatch
|
||||||
|
|
||||||
func (s *Server) run(ctx context.Context) {
|
func (s *Server) run(ctx context.Context) {
|
||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
|
var activeBatch batchState
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
|
case err := <-s.hardErrCh:
|
||||||
|
panic(err)
|
||||||
default:
|
default:
|
||||||
err := s.processBatch()
|
var err error
|
||||||
|
activeBatch, err = s.forwardBatch(activeBatch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
go s.computeBatch(activeBatch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) processBatch() error {
|
// forwardBatch will calculate a batch.
|
||||||
|
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
|
||||||
|
// If we have a pending batch still processing, wait until Compute has started
|
||||||
|
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||||
|
// token values and we get the correct input pointers for the batchInputs
|
||||||
|
if pendingBatch.ctx != nil {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
|
||||||
|
<-pendingBatch.computeStartedCh
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
|
||||||
|
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||||
|
} else {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||||
|
// No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||||
|
nextBatch.inputsReadyCh = make(chan struct{}, 1)
|
||||||
|
nextBatch.inputsReadyCh <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
for s.allNil() {
|
for s.allNil() {
|
||||||
s.cond.Wait() // Wait until an item is added
|
s.cond.Wait() // Wait until an item is added
|
||||||
}
|
}
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
ctx := s.model.Backend().NewContext()
|
nextBatch.ctx = s.model.Backend().NewContext()
|
||||||
defer ctx.Close()
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
nextBatch.ctx.Close()
|
||||||
|
nextBatch.ctx = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
nextBatch.id = s.batchID
|
||||||
|
nextBatch.seqs = append([]*Sequence{}, s.seqs...)
|
||||||
|
nextBatch.computeStartedCh = make(chan struct{}, 1)
|
||||||
|
nextBatch.outputsReadyCh = make(chan struct{}, 1)
|
||||||
|
|
||||||
var batchInputs []int32
|
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||||
|
var batchInputs []*input.Input
|
||||||
var batch input.Batch
|
var batch input.Batch
|
||||||
|
|
||||||
resumeSeq := -1
|
resumeSeq := -1
|
||||||
|
|
@ -396,7 +467,6 @@ func (s *Server) processBatch() error {
|
||||||
for range s.seqs {
|
for range s.seqs {
|
||||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||||
seq := s.seqs[seqIdx]
|
seq := s.seqs[seqIdx]
|
||||||
|
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -404,12 +474,13 @@ func (s *Server) processBatch() error {
|
||||||
// if past the num predict limit
|
// if past the num predict limit
|
||||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||||
|
nextBatch.seqs[seqIdx] = nil
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.cache.enabled {
|
if !s.cache.enabled {
|
||||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||||
seq.cache.Inputs = []input.Input{}
|
seq.cache.Inputs = []*input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
batchSize := s.batchSize
|
batchSize := s.batchSize
|
||||||
|
|
@ -442,25 +513,28 @@ func (s *Server) processBatch() error {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var reprocess *ErrReprocessInputs
|
var reprocess *ErrReprocessInputs
|
||||||
if errors.As(err, &reprocess) {
|
if errors.As(err, &reprocess) {
|
||||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
// Skip this sequence but continue processing the rest
|
// Skip this sequence but continue processing the rest
|
||||||
|
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||||
|
err = nil
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
batchInputs = append(batchInputs, inp.Token)
|
batchInputs = append(batchInputs, seq.inputs[i])
|
||||||
if inp.Multimodal != nil {
|
if inp.Multimodal != nil {
|
||||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
var mm []input.Multimodal
|
||||||
|
mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||||
}
|
}
|
||||||
|
|
@ -472,6 +546,7 @@ func (s *Server) processBatch() error {
|
||||||
if i+1 == len(seq.inputs) {
|
if i+1 == len(seq.inputs) {
|
||||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -485,36 +560,129 @@ func (s *Server) processBatch() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(batchInputs) == 0 {
|
if len(batchInputs) == 0 {
|
||||||
return nil
|
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||||
|
nextBatch.ctx.Close()
|
||||||
|
nextBatch.ctx = nil
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
s.batchID++
|
||||||
|
|
||||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||||
|
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||||
|
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to decode batch: %w", err)
|
err = fmt.Errorf("failed to build graph: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nextBatch.batchInputs = batchInputs
|
||||||
|
nextBatch.batch = batch
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async processing of the next batch
|
||||||
|
func (s *Server) computeBatch(activeBatch batchState) {
|
||||||
|
if activeBatch.ctx == nil {
|
||||||
|
// Nothing to compute
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer activeBatch.ctx.Close()
|
||||||
|
|
||||||
|
// Wait until inputs are ready
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
|
||||||
|
<-activeBatch.inputsReadyCh
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", activeBatch.id)
|
||||||
|
|
||||||
|
// Once we complete, signal the next batch of inputs are ready
|
||||||
|
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||||
|
defer func() {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", activeBatch.id)
|
||||||
|
activeBatch.outputsReadyCh <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
|
||||||
|
// Gather the actual input token values now that they're ready
|
||||||
|
batchInputs := make([]int32, len(activeBatch.batchInputs))
|
||||||
|
for i := range batchInputs {
|
||||||
|
batchInputs[i] = activeBatch.batchInputs[i].Token
|
||||||
}
|
}
|
||||||
|
|
||||||
logits := modelOutput.Floats()
|
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
|
||||||
|
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
|
||||||
|
// decoded tokens.
|
||||||
|
nextBatchTokens := make([]*input.Input, len(s.seqs))
|
||||||
|
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
|
iBatches[i] = -1
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Skip over any newly added or skipped sequences
|
||||||
|
if activeBatch.seqs[i] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// After calling Forward, pending inputs are now in the cache
|
// Detect if the sequence we're processing has already been completed and replaced
|
||||||
|
// with a new sequence
|
||||||
|
if seq != activeBatch.seqs[i] {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pending inputs will actually be in the cache after we call Compute.
|
||||||
|
// However, we have already resolved any placeholder tokens.
|
||||||
|
//
|
||||||
|
// It's possible for incoming sequences to look at the values that we've
|
||||||
|
// added to the cache here and start relying on them before we've done
|
||||||
|
// the computation. This is OK as long as we ensure that this batch's
|
||||||
|
// computation happens before any future batch's and we never fail
|
||||||
|
// (unless we take down the whole runner).
|
||||||
if len(seq.pendingInputs) > 0 {
|
if len(seq.pendingInputs) > 0 {
|
||||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||||
seq.pendingInputs = []input.Input{}
|
seq.pendingInputs = []*input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't sample prompt processing
|
// don't sample prompt processing
|
||||||
if len(seq.inputs) != 0 {
|
if len(seq.inputs) != 0 {
|
||||||
if !s.cache.enabled {
|
if !s.cache.enabled {
|
||||||
return errors.New("caching disabled but unable to fit entire input in a batch")
|
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||||
|
s.mu.Unlock()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.numPredicted++
|
seq.numPredicted++
|
||||||
|
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||||
|
seq.inputs = []*input.Input{nextToken}
|
||||||
|
nextBatchTokens[i] = nextToken
|
||||||
|
iBatches[i] = seq.iBatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point the seqs are ready for forwardBatch to move forward so unblock
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
|
||||||
|
activeBatch.ctx.ComputeWithNotify(
|
||||||
|
func() {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
|
||||||
|
activeBatch.computeStartedCh <- struct{}{}
|
||||||
|
},
|
||||||
|
activeBatch.modelOutput)
|
||||||
|
logits := activeBatch.modelOutput.Floats()
|
||||||
|
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", activeBatch.id)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", activeBatch.id)
|
||||||
|
for i, seq := range s.seqs {
|
||||||
|
if seq == nil || nextBatchTokens[i] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if seq.numPredicted == 1 {
|
if seq.numPredicted == 1 {
|
||||||
seq.startGenerationTime = time.Now()
|
seq.startGenerationTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
@ -522,36 +690,38 @@ func (s *Server) processBatch() error {
|
||||||
// if done processing the prompt, generate an embedding and return
|
// if done processing the prompt, generate an embedding and return
|
||||||
if seq.embeddingOnly {
|
if seq.embeddingOnly {
|
||||||
// TODO(jessegross): Embedding support
|
// TODO(jessegross): Embedding support
|
||||||
slog.Warn("generation of embedding outputs not yet supported")
|
slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
|
||||||
s.removeSequence(i, llm.DoneReasonStop)
|
s.removeSequence(i, llm.DoneReasonStop)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(logits) / len(batch.Outputs)
|
vocabSize := len(logits) / len(activeBatch.batch.Outputs)
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to sample token: %w", err)
|
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nextBatchTokens[i].Token = token
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||||
// TODO (jmorganca): we should send this back
|
// TODO (jmorganca): we should send this back
|
||||||
// as it's important for the /api/generate context
|
// as it's important for the /api/generate context
|
||||||
// seq.responses <- piece
|
// seq.responses <- piece
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
|
||||||
s.removeSequence(i, llm.DoneReasonStop)
|
s.removeSequence(i, llm.DoneReasonStop)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.inputs = []input.Input{{Token: token}}
|
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
sequence := strings.Join(seq.pendingResponses, "")
|
||||||
|
|
||||||
|
|
@ -575,6 +745,7 @@ func (s *Server) processBatch() error {
|
||||||
if tokenTruncated || origLen == newLen {
|
if tokenTruncated || origLen == newLen {
|
||||||
tokenLen--
|
tokenLen--
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||||
|
|
||||||
s.removeSequence(i, llm.DoneReasonStop)
|
s.removeSequence(i, llm.DoneReasonStop)
|
||||||
|
|
@ -593,8 +764,6 @@ func (s *Server) processBatch() error {
|
||||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
@ -736,7 +905,10 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||||
defer ctx.Close()
|
defer ctx.Close()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
inputs := make([]input.Input, s.batchSize)
|
inputs := make([]*input.Input, s.batchSize)
|
||||||
|
for i := range inputs {
|
||||||
|
inputs[i] = &input.Input{}
|
||||||
|
}
|
||||||
mmStore := newMultimodalStore()
|
mmStore := newMultimodalStore()
|
||||||
|
|
||||||
// Multimodal strategy:
|
// Multimodal strategy:
|
||||||
|
|
@ -778,8 +950,11 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(inputs) < s.batchSize {
|
if len(inputs) < s.batchSize {
|
||||||
newInputs := make([]input.Input, s.batchSize)
|
newInputs := make([]*input.Input, s.batchSize)
|
||||||
copy(newInputs, inputs)
|
copy(newInputs, inputs)
|
||||||
|
for i := len(inputs); i < s.batchSize; i++ {
|
||||||
|
newInputs[i] = &input.Input{}
|
||||||
|
}
|
||||||
inputs = newInputs
|
inputs = newInputs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -842,6 +1017,7 @@ func (s *Server) allocModel(
|
||||||
// Convert memory allocation panics to errors
|
// Convert memory allocation panics to errors
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
|
debug.PrintStack()
|
||||||
if err, ok := r.(error); ok {
|
if err, ok := r.(error); ok {
|
||||||
panicErr = err
|
panicErr = err
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -1011,6 +1187,7 @@ func Execute(args []string) error {
|
||||||
server := &Server{
|
server := &Server{
|
||||||
modelPath: *mpath,
|
modelPath: *mpath,
|
||||||
status: llm.ServerStatusLaunched,
|
status: llm.ServerStatusLaunched,
|
||||||
|
hardErrCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue