diff --git a/docs/gpu.mdx b/docs/gpu.mdx index 36bfd3da7..da04ce240 100644 --- a/docs/gpu.mdx +++ b/docs/gpu.mdx @@ -163,4 +163,34 @@ To select specific Vulkan GPU(s), you can set the environment variable `GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs -by setting `GGML_VK_VISIBLE_DEVICES=-1` \ No newline at end of file +by setting `GGML_VK_VISIBLE_DEVICES=-1` + + +## Advanced: Manually overriding multi-GPU layer split + +By default, Ollama decides how many layers to offload to each GPU based on free +memory and other heuristics. For some large models and specific hardware mixes you may +prefer a `manual split`. + +Set the environment variable `OLLAMA_OVERRIDE_CONFIG` to an INI file (or place one at +`~/.ollama.ini`) and add a section for the model’s `short name` with a `tensor-split` +value: + +```ini +[llama3:70b] +tensor-split=18,21,21,21 +``` + +The list represents proportions for each visible GPU, in order. The sum of +the values determines how many of the last layers will be offloaded to GPU VRAM +(`n_gpu_layers = sum(tensor-split)`). The assignment is proportional across devices: +larger numbers get more of those last layers. + +**Constraints** + +- The number of values must be ≤ the number of visible GPUs. +- All values must be non-negative integers. +- If invalid or not present for the model, Ollama falls back to its heuristics. + +This feature is intended for expert tuning when the automatic split under-utilizes +your GPUs for a given model/context configuration. \ No newline at end of file diff --git a/envconfig/config.go b/envconfig/config.go index 238e5e6e1..3eec78f43 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -206,6 +206,9 @@ var ( UseAuth = Bool("OLLAMA_AUTH") // Enable Vulkan backend EnableVulkan = Bool("OLLAMA_VULKAN") + // Optional path to per-model override file (INI). + // If unset, defaults to ~/.ollama.ini + OverrideConfigPath = String("OLLAMA_OVERRIDE_CONFIG") ) func String(s string) func() string { @@ -293,6 +296,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"}, + "OLLAMA_OVERRIDE_CONFIG": {"OLLAMA_OVERRIDE_CONFIG", OverrideConfigPath(), "Path to model override config (default: ~/.ollama.ini)"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/envconfig/override.go b/envconfig/override.go new file mode 100644 index 000000000..71f10168a --- /dev/null +++ b/envconfig/override.go @@ -0,0 +1,116 @@ +package envconfig + +import ( + "bufio" + "os" + "path/filepath" + "strings" +) + +type Override struct { + ModelName string + NumGPULayers int // -1 means unset + TensorSplit []int // nil means unset +} + +// LoadOverride loads overrides for the given model section name (e.g. "llama3.2-vision:90b"). +// The INI format is: +// [model-name:params] +// tensor-split= +// Note: n-gpu-layers is not read from the file; it is always derived as the sum of tensor-split. +// Returns nil if no file or no matching section. +func LoadOverride(model string) *Override { + // Resolve config path + path := OverrideConfigPath() + if path == "" { + home, _ := os.UserHomeDir() + if home == "" { + return nil + } + path = filepath.Join(home, ".ollama.ini") + } + f, err := os.Open(path) + if err != nil { + return nil + } + defer f.Close() + + sectionHdr := "[" + model + "]" + var inSection bool + ovr := &Override{ModelName: model, NumGPULayers: -1} + + sc := bufio.NewScanner(f) + for sc.Scan() { + line := strings.TrimSpace(sc.Text()) + if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") { + continue + } + if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { + inSection = (line == sectionHdr) + continue + } + if !inSection { + continue + } + kv := strings.SplitN(line, "=", 2) + if len(kv) != 2 { + continue + } + k := strings.TrimSpace(strings.ToLower(kv[0])) + v := strings.TrimSpace(kv[1]) + + switch k { + case "tensor-split": + if arr := parseUintList(v); len(arr) > 0 { + ovr.TensorSplit = arr + } + } + } + + // If a tensor-split is provided, NumGPULayers is always the sum of entries. + if len(ovr.TensorSplit) > 0 { + total := 0 + for _, n := range ovr.TensorSplit { + total += n + } + ovr.NumGPULayers = total + } + + // If nothing set, return nil + if ovr.NumGPULayers < 0 && len(ovr.TensorSplit) == 0 { + return nil + } + return ovr +} + +func parseUint(s string) int { + s = strings.TrimSpace(s) + if s == "" { + return -1 + } + var n int + for _, r := range s { + if r < '0' || r > '9' { + return -1 + } + n = n*10 + int(r-'0') + } + return n +} + +func parseUintList(s string) []int { + s = strings.TrimSpace(s) + if s == "" { + return nil + } + parts := strings.Split(s, ",") + out := make([]int, 0, len(parts)) + for _, p := range parts { + n := parseUint(strings.TrimSpace(p)) + if n < 0 { + return nil + } + out = append(out, n) + } + return out +} diff --git a/envconfig/override_test.go b/envconfig/override_test.go new file mode 100644 index 000000000..a8f124020 --- /dev/null +++ b/envconfig/override_test.go @@ -0,0 +1,140 @@ +package envconfig + +import ( + "os" + "path/filepath" + "testing" +) + +func TestParseUint(t *testing.T) { + tests := []struct { + in string + want int + }{ + {"0", 0}, + {"1", 1}, + {"123", 123}, + {"", -1}, + {" 42 ", 42}, + {"-1", -1}, + {"abc", -1}, + {"12x", -1}, + } + for _, tt := range tests { + if got := parseUint(tt.in); got != tt.want { + t.Fatalf("parseUint(%q) = %d; want %d", tt.in, got, tt.want) + } + } +} + +func TestParseUintList(t *testing.T) { + tests := []struct { + in string + want []int + }{ + {"", nil}, + {"1", []int{1}}, + {"1,2,3", []int{1, 2, 3}}, + {" 4 , 5 , 6 ", []int{4, 5, 6}}, + {"1, -2", nil}, // invalid -> whole list rejected + {"a,b", nil}, + } + for _, tt := range tests { + got := parseUintList(tt.in) + if (got == nil) != (tt.want == nil) { + t.Fatalf("parseUintList(%q) = %v; want %v", tt.in, got, tt.want) + } + if got == nil { + continue + } + if len(got) != len(tt.want) { + t.Fatalf("parseUintList(%q) len=%d; want %d", tt.in, len(got), len(tt.want)) + } + for i := range got { + if got[i] != tt.want[i] { + t.Fatalf("parseUintList(%q)[%d]=%d; want %d", tt.in, i, got[i], tt.want[i]) + } + } + } +} + +func TestLoadOverride_Basic(t *testing.T) { + dir := t.TempDir() + cfg := filepath.Join(dir, "over.ini") + content := ` +; comment +[llama3.2-vision:90b] +n-gpu-layers=33 +tensor-split=10,20,30 + +[other] +n-gpu-layers=1 +` + if err := os.WriteFile(cfg, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("OLLAMA_OVERRIDE_CONFIG", cfg) + + ovr := LoadOverride("llama3.2-vision:90b") + if ovr == nil { + t.Fatalf("LoadOverride returned nil") + } + if ovr.ModelName != "llama3.2-vision:90b" { + t.Fatalf("ModelName=%q; want %q", ovr.ModelName, "llama3.2-vision:90b") + } + // n-gpu-layers must be the sum of tensor-split entries (10+20+30=60). + if ovr.NumGPULayers != 60 { + t.Fatalf("NumGPULayers=%d; want %d", ovr.NumGPULayers, 60) + } + wantSplit := []int{10, 20, 30} + if len(ovr.TensorSplit) != len(wantSplit) { + t.Fatalf("TensorSplit len=%d; want %d", len(ovr.TensorSplit), len(wantSplit)) + } + for i := range wantSplit { + if ovr.TensorSplit[i] != wantSplit[i] { + t.Fatalf("TensorSplit[%d]=%d; want %d", i, ovr.TensorSplit[i], wantSplit[i]) + } + } +} + +func TestLoadOverride_NoMatchOrEmpty(t *testing.T) { + dir := t.TempDir() + cfg := filepath.Join(dir, "over.ini") + content := ` +[some-model] +n-gpu-layers=7 +` + if err := os.WriteFile(cfg, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("OLLAMA_OVERRIDE_CONFIG", cfg) + + // Section exists but different model -> nil + if got := LoadOverride("another-model"); got != nil { + t.Fatalf("expected nil for unmatched section, got %#v", got) + } + + // File missing -> nil + t.Setenv("OLLAMA_OVERRIDE_CONFIG", filepath.Join(dir, "missing.ini")) + if got := LoadOverride("some-model"); got != nil { + t.Fatalf("expected nil for missing file, got %#v", got) + } +} + +func TestLoadOverride_BadValuesIgnored(t *testing.T) { + dir := t.TempDir() + cfg := filepath.Join(dir, "over.ini") + content := ` +[m] +n-gpu-layers=abc +tensor-split=1,2,x +` + if err := os.WriteFile(cfg, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("OLLAMA_OVERRIDE_CONFIG", cfg) + + if got := LoadOverride("m"); got != nil { + t.Fatalf("expected nil when no valid keys parsed, got %#v", got) + } +} diff --git a/llm/server.go b/llm/server.go index a89027b06..e72f55b0b 100644 --- a/llm/server.go +++ b/llm/server.go @@ -103,7 +103,8 @@ type llmServer struct { loadStart time.Time // Record how long it took the model to load loadProgress float32 - sem *semaphore.Weighted + sem *semaphore.Weighted + override *envconfig.Override } type llamaServer struct { @@ -118,6 +119,80 @@ type ollamaServer struct { textProcessor model.TextProcessor // textProcessor handles text encoding/decoding } +// buildGPULayersFromOverride constructs an explicit ml.GPULayersList from a +// per-model override configuration. +// +// It takes: +// - totalLayers: the total number of layers in the model (i.e. block_count+1, +// including the output layer). +// - gpus: the visible GPUs, whose order is used to map tensor-split entries to +// device IDs (tensor-split index i -> gpus[i]). +// - override: the parsed override which provides: +// * NumGPULayers: how many of the last model layers to offload, and +// * TensorSplit: integer weights describing how to distribute those layers +// across the GPUs (proportional split). +// +// The function assigns the last NumGPULayers layers in the range +// [blocks-NumGPULayers, blocks] to GPUs according to the cumulative proportions +// derived from TensorSplit. If TensorSplit has more entries than visible GPUs, +// any required value is non-positive, the proportional total is zero, or the +// computed span is empty, the function returns nil to signal "no override". +// +// On success it returns a non-empty GPULayersList; otherwise it returns nil. +// +func buildGPULayersFromOverride(totalLayers int, gpus []ml.DeviceInfo, override *envconfig.Override) ml.GPULayersList { + if totalLayers <= 0 || len(gpus) == 0 || override == nil { + return nil + } + if len(override.TensorSplit) > len(gpus) { + return nil + } + // cumulative proportions + var total int + for _, v := range override.TensorSplit { + total += v + } + if total <= 0 { + return nil + } + cum := make([]float32, len(override.TensorSplit)) + var run float32 + for i, v := range override.TensorSplit { + run += float32(v) / float32(total) + cum[i] = run + } + + // totalLayers = blocks + 1 + blocks := totalLayers - 1 + start := max(0, blocks-override.NumGPULayers) + stop := min(start+override.NumGPULayers, blocks+1) + + gl := make(ml.GPULayersList, len(gpus)) + for i := range gpus { + gl[i].DeviceID = gpus[i].DeviceID + } + + span := float32(stop - start) + if span <= 0 { + return nil + } + for layer := start; layer < stop; layer++ { + ratio := float32(layer-start) / span + idx := 0 + for i := range cum { + if ratio < cum[i] { + idx = i + break + } + } + gl[idx].Layers = append(gl[idx].Layers, layer) + } + if gl.Sum() == 0 { + return nil + } + return gl +} + // LoadModel will load a model from disk. The model must be in the GGML format. // // It collects array values for arrays with a size less than or equal to @@ -139,7 +214,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) { } // NewLlamaServer will run a server for the given GPUs -func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) { +func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int, override *envconfig.Override) (LlamaServer, error) { var llamaModel *llama.Model var textProcessor model.TextProcessor var err error @@ -280,6 +355,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st totalLayers: f.KV().BlockCount() + 1, loadStart: time.Now(), done: make(chan error, 1), + override: override, } if err != nil { @@ -492,6 +568,75 @@ type LoadResponse struct { var ErrLoadRequiredFull = errors.New("unable to load full model on GPU") +// maybeApplyOverride attempts to replace a heuristic GPU layer layout with a +// per-model override read from OLLAMA_OVERRIDE_CONFIG. +// +// Inputs: +// - gpus: the set of visible GPUs (their order is used to map tensor-split +// entries onto devices: tensor-split index i -> gpus[i]). +// - gpuLayers: the current, heuristic ml.GPULayersList that would be used if +// no override is applied. +// +// Behavior: +// * If no override is configured, or it is incomplete (missing NumGPULayers +// or TensorSplit), the function returns the original gpuLayers and false. +// * If TensorSplit has more entries than visible GPUs, or NumGPULayers +// exceeds the model's total layers (block_count+1), the override is ignored +// and the function returns the original gpuLayers and false (with a log +// warning). +// * Otherwise, it builds a replacement assignment via buildGPULayersFromOverride. +// On success, it logs the application, updates s.options.NumGPU to match +// the override's NumGPULayers (so downstream logging/heuristics see a +// consistent value), and returns (override, true). If the mapping produces +// no layers, the heuristic layout is kept and false is returned. +// +func (s *llmServer) maybeApplyOverride(gpus []ml.DeviceInfo, gpuLayers ml.GPULayersList) (ml.GPULayersList, bool) { + // If no override loaded, or incomplete, bail out + if s.override == nil || s.override.NumGPULayers <= 0 || len(s.override.TensorSplit) == 0 { + return gpuLayers, false + } + + // Too many split entries for visible GPUs? Warn and fallback. + if len(s.override.TensorSplit) > len(gpus) { + slog.Warn( + "Override ignored: tensor-split override has more entries than visible GPUs; using heuristic split instead", + "model", s.override.ModelName, + "tensor_split_entries", len(s.override.TensorSplit), + "visible_gpus", len(gpus), + ) + return gpuLayers, false + } + + // Clamp to model size (totalLayers == block_count + 1) + maxLayers := int(s.totalLayers) + if s.override.NumGPULayers > maxLayers { + slog.Warn( + "Override ignored: n_gpu_layers is larger than the maximum supported; using heuristic split instead", + "model", s.override.ModelName, + "max_layers", maxLayers, + "n_gpu_layers", s.override.NumGPULayers, + ) + return gpuLayers, false + } + + override := buildGPULayersFromOverride(int(s.totalLayers), gpus, s.override) + if override == nil || override.Sum() == 0 { + slog.Warn("Override ignored: override mapping produced no layers; using heuristic layout instead") + return gpuLayers, false + } + + slog.Info( + "Applying override from OLLAMA_OVERRIDE_CONFIG", + "model", s.override.ModelName, + "n_gpu_layers", s.override.NumGPULayers, + "tensor_split", s.override.TensorSplit, + "layers_offloaded", override.Sum(), + ) + // Align NumGPU with override for downstream logging / heuristics that read it + s.options.NumGPU = s.override.NumGPULayers + return override, true +} + func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) { slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU) @@ -626,6 +771,11 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system } } + // Apply per-model override + if newLayers, ok := s.maybeApplyOverride(gpus, gpuLayers); ok { + gpuLayers = newLayers + } + // This maintains the historical assignment of graph sizes, though it isn't fully accurate graphSize := graphFullOffload if gpuLayers.Sum() < int(s.totalLayers) { @@ -761,6 +911,11 @@ nextOperation: for operation := LoadOperationFit; operation < LoadOperationCommit; operation++ { nextLoad: for { + // Apply per-model override if present + if newLayers, ok := s.maybeApplyOverride(gpus, gpuLayers); ok { + gpuLayers = newLayers + } + s.loadRequest.GPULayers = gpuLayers resp, err := s.initModel(ctx, s.loadRequest, operation) if err != nil { diff --git a/llm/server_test.go b/llm/server_test.go index 5dc0aa9bc..b121a2e69 100644 --- a/llm/server_test.go +++ b/llm/server_test.go @@ -9,6 +9,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/ml" "golang.org/x/sync/semaphore" ) @@ -279,3 +280,125 @@ func TestLLMServerCompletionFormat(t *testing.T) { }, nil) checkValid(err) } + +func TestBuildGPULayersFromOverride_Basic(t *testing.T) { + // totalLayers = blocks + 1. With totalLayers=5 -> blocks in [0..4]. + totalLayers := 5 + gpus := []ml.DeviceInfo{ + {DeviceID: ml.DeviceID{ID: "gpu0"}}, + {DeviceID: ml.DeviceID{ID: "gpu1"}}, + } + ov := &envconfig.Override{ + ModelName: "dummy", + NumGPULayers: 4, // assign last 4 layers: indices 0..3 with our simplified test + TensorSplit: []int{1,1}, // even split across 2 GPUs + } + + gl := buildGPULayersFromOverride(totalLayers, gpus, ov) + if gl == nil || gl.Sum() == 0 { + t.Fatalf("expected non-empty GPULayersList, got %#v", gl) + } + + // Expect gpu0 to get first half (layers 0,1) and gpu1 to get (2,3) + want := ml.GPULayersList{ + {DeviceID: gpus[0].DeviceID, Layers: []int{0, 1}}, + {DeviceID: gpus[1].DeviceID, Layers: []int{2, 3}}, + } + if gl.Hash() != want.Hash() { + t.Errorf("override mapping = %v, want %v", gl, want) + } +} + +func TestBuildGPULayersFromOverride_TooManySplits(t *testing.T) { + totalLayers := 5 + gpus := []ml.DeviceInfo{ + {DeviceID: ml.DeviceID{ID: "gpu0"}}, + {DeviceID: ml.DeviceID{ID: "gpu1"}}, + } + ov := &envconfig.Override{ + ModelName: "dummy", + NumGPULayers: 4, + TensorSplit: []int{1, 1, 1}, // 3 entries, only 2 GPUs + } + gl := buildGPULayersFromOverride(totalLayers, gpus, ov) + if gl != nil { + t.Fatalf("expected nil due to too many tensor-split entries, got %v", gl) + } +} + +func TestBuildGPULayersFromOverride_ZeroTotalSplit(t *testing.T) { + totalLayers := 5 + gpus := []ml.DeviceInfo{ + {DeviceID: ml.DeviceID{ID: "gpu0"}}, + {DeviceID: ml.DeviceID{ID: "gpu1"}}, + } + ov := &envconfig.Override{ + ModelName: "dummy", + NumGPULayers: 4, + TensorSplit: []int{0, 0}, // totals to zero + } + gl := buildGPULayersFromOverride(totalLayers, gpus, ov) + if gl != nil { + t.Fatalf("expected nil due to zero/invalid tensor-split total, got %v", gl) + } +} + +func TestMaybeApplyOverride_Applies(t *testing.T) { + // Model with 5 total layers (blocks 0..4). + s := &llmServer{ + totalLayers: 5, + options: api.Options{}, + override: &envconfig.Override{ + ModelName: "dummy", + NumGPULayers: 4, + TensorSplit: []int{1, 1}, + }, + } + gpus := []ml.DeviceInfo{ + {DeviceID: ml.DeviceID{ID: "gpu0"}}, + {DeviceID: ml.DeviceID{ID: "gpu1"}}, + } + // Heuristic layout (will be replaced) + heuristic := ml.GPULayersList{ + {DeviceID: gpus[1].DeviceID, Layers: []int{0, 1}}, + } + got, ok := s.maybeApplyOverride(gpus, heuristic) + if !ok { + t.Fatalf("expected override to be applied") + } + // Expect override mapping (even split) + want := ml.GPULayersList{ + {DeviceID: gpus[0].DeviceID, Layers: []int{0, 1}}, + {DeviceID: gpus[1].DeviceID, Layers: []int{2, 3}}, + } + if got.Hash() != want.Hash() { + t.Errorf("maybeApplyOverride = %v, want %v", got, want) + } + // options.NumGPU should align with override.NumGPULayers + if s.options.NumGPU != s.override.NumGPULayers { + t.Errorf("options.NumGPU = %d, want %d", s.options.NumGPU, s.override.NumGPULayers) + } +} + +func TestMaybeApplyOverride_RejectsTooManySplits(t *testing.T) { + s := &llmServer{ + totalLayers: 5, + options: api.Options{}, + override: &envconfig.Override{ + ModelName: "dummy", + NumGPULayers: 4, + TensorSplit: []int{1, 1, 1}, // 3 entries, 2 GPUs -> reject + }, + } + gpus := []ml.DeviceInfo{ + {DeviceID: ml.DeviceID{ID: "gpu0"}}, + {DeviceID: ml.DeviceID{ID: "gpu1"}}, + } + heuristic := ml.GPULayersList{ + {DeviceID: gpus[1].DeviceID, Layers: []int{0, 1}}, + } + got, ok := s.maybeApplyOverride(gpus, heuristic) + if ok || got.Hash() != heuristic.Hash() { + t.Fatalf("expected override to be ignored and heuristic preserved; got=%v ok=%v", got, ok) + } +} \ No newline at end of file diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 13befff2a..95da5bf38 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -17,6 +17,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" @@ -48,8 +49,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error return } -func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) { - return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) { +func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int, *envconfig.Override) (llm.LlamaServer, error) { + return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int, _ *envconfig.Override) (llm.LlamaServer, error) { return mock, nil } } diff --git a/server/sched.go b/server/sched.go index c5bc6692d..3ffb8384d 100644 --- a/server/sched.go +++ b/server/sched.go @@ -50,7 +50,7 @@ type Scheduler struct { loaded map[string]*runnerRef loadFn func(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool - newServerFn func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) + newServerFn func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int, override *envconfig.Override) (llm.LlamaServer, error) getGpuFn func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo getSystemInfoFn func() ml.SystemInfo waitForRecovery time.Duration @@ -414,7 +414,9 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo if llama == nil { var err error - llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) + // Load per-model override (by short name) + override := envconfig.LoadOverride(req.model.ShortName) + llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel, override) if err != nil { // some older models are not compatible with newer versions of llama.cpp // show a generalized compatibility error until there is a better way to diff --git a/server/sched_test.go b/server/sched_test.go index 480aafa4e..832ca47af 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/ollama/ollama/envconfig" "github.com/stretchr/testify/require" "github.com/ollama/ollama/api" @@ -49,7 +50,7 @@ func TestSchedLoad(t *testing.T) { sessionDuration: &api.Duration{Duration: 2 * time.Second}, } // Fail to load model first - s.newServerFn = func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { + s.newServerFn = func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int, override *envconfig.Override) (llm.LlamaServer, error) { return nil, errors.New("something failed to load model blah") } gpus := []ml.DeviceInfo{} @@ -64,7 +65,7 @@ func TestSchedLoad(t *testing.T) { require.Contains(t, err.Error(), "this model may be incompatible") server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}} - s.newServerFn = func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { + s.newServerFn = func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int, override *envconfig.Override) (llm.LlamaServer, error) { server.modelPath = model return server, nil } @@ -106,7 +107,7 @@ type reqBundle struct { f *ggml.GGML } -func (scenario *reqBundle) newServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { +func (scenario *reqBundle) newServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int, override *envconfig.Override) (llm.LlamaServer, error) { scenario.srv.modelPath = model return scenario.srv, nil } @@ -466,7 +467,7 @@ func TestSchedExpireRunner(t *testing.T) { gpus := []ml.DeviceInfo{} systemInfo := ml.SystemInfo{} server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}} - s.newServerFn = func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { + s.newServerFn = func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int, override *envconfig.Override) (llm.LlamaServer, error) { server.modelPath = model return server, nil }