diff --git a/convert/convert.go b/convert/convert.go index 3e98eee1a..f6afd8a32 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -206,6 +206,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &commandrModel{} case "GptOssForCausalLM": conv = &gptossModel{} + case "DeepseekOCRForCausalLM": + conv = &deepseekocr{} default: return fmt.Errorf("unsupported architecture %q", p.Architectures[0]) } diff --git a/convert/convert_deepseekocr.go b/convert/convert_deepseekocr.go new file mode 100644 index 000000000..cf1dfa0c4 --- /dev/null +++ b/convert/convert_deepseekocr.go @@ -0,0 +1,136 @@ +package convert + +import ( + "fmt" + + "github.com/ollama/ollama/fs/ggml" +) + +type deepseekocr struct { + ModelParameters + LanguageConfig struct { + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + HiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + NumRoutedExperts uint32 `json:"n_routed_experts"` + NumSharedExperts uint32 `json:"n_shared_experts"` + NumExpertsPerToken uint32 `json:"num_experts_per_tok"` + FirstKDenseReplace uint32 `json:"first_k_dense_replace"` + } `json:"language_config"` + + VisionConfig struct { + ImageSize uint32 `json:"image_size"` + Width struct { + Vision struct { + Heads uint32 `json:"heads"` + ImageSize uint32 `json:"image_size"` + Layers uint32 `json:"layers"` + PatchSize uint32 `json:"patch_size"` + Width uint32 `json:"width"` + } `json:"clip-l-14-224"` + Sam struct { + GlobalAttentionIndexes []int32 `json:"global_attn_indexes"` + Heads uint32 `json:"heads"` + Layers uint32 `json:"layers"` + Width uint32 `json:"width"` + } `json:"sam_vit_b"` + } + } `json:"vision_config"` +} + +func (m *deepseekocr) KV(t *Tokenizer) ggml.KV { + kv := m.ModelParameters.KV(t) + kv["general.architecture"] = "deepseekocr" + kv["block_count"] = m.LanguageConfig.HiddenLayers + kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings + kv["embedding_length"] = m.LanguageConfig.HiddenSize + kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize + kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads + kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads + kv["expert_count"] = m.LanguageConfig.NumRoutedExperts + kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken + kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace + + kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers + kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width + kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads + kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize + kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize + + kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers + kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width + kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads + kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes + return kv +} + +func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) { + merges := make([]merge, m.LanguageConfig.HiddenLayers*3) + for i := range m.LanguageConfig.HiddenLayers { + merges[i*3+0] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i), + } + merges[i*3+1] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_up_exps.weight", i), + } + merges[i*3+2] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), + } + } + + out, s = mergeTensors(s, merges...) + for _, t := range s { + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + return out +} + +func (m *deepseekocr) Replacements() []string { + return []string{ + "model.embed_tokens", "token_embd", + "model.layers", "blk", + "input_layernorm", "attn_norm", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "post_attention_layernorm", "ffn_norm", + "mlp.gate_proj", "ffn_gate", + "mlp.up_proj", "ffn_up", + "mlp.down_proj", "ffn_down", + "mlp.gate", "ffn_gate_inp", + "mlp.shared_experts.gate_proj", "ffn_gate_shexp", + "mlp.shared_experts.up_proj", "ffn_up_shexp", + "mlp.shared_experts.down_proj", "ffn_down_shexp", + "model.norm", "output_norm", + "lm_head", "output", + + "model.vision_model", "v", + "embeddings.patch_embedding", "patch_embd", + "embeddings.class_embedding", "class_embd", + "embeddings.position_embedding", "position_embd", + "transformer.layers", "blk", + + "model.projector", "mm", + "model.image_newline", "mm.image_newline", + //nolint:misspell // this misspelling is upstream. fixing it breaks the model + "model.view_seperator", "mm.view_seperator", + + "model.sam_model.patch_embed.proj", "s.patch_embd", + "model.sam_model.pos_embed", "s.position_embd", + "model.sam_model.blocks", "s.blk", + "model.sam_model.neck", "s.neck", + "model.sam_model.net_", "s.net_", + } +} diff --git a/convert/reader.go b/convert/reader.go index b3f7a8660..75764f018 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -44,7 +44,10 @@ func (t tensorBase) Kind() uint32 { t.name == "v.positional_embedding_vlm" || t.name == "v.tile_position_embd.weight" || t.name == "v.pre_tile_position_embd.weight" || - t.name == "v.post_tile_position_embd.weight" { + t.name == "v.post_tile_position_embd.weight" || + t.name == "s.position_embd" || + strings.HasSuffix(t.name, "rel_pos_h") || + strings.HasSuffix(t.name, "rel_pos_w") { // these tensors are always F32 return tensorKindFP32 } diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index eea0de2f5..f7d9754f0 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -96,7 +96,10 @@ type safetensor struct { func (st safetensor) Kind() uint32 { kind := st.tensorBase.Kind() - if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 { + if st.dtype == "BF16" && + !strings.HasPrefix(st.name, "v.") && + !strings.HasPrefix(st.name, "s.") && + kind != tensorKindFP32 { kind = tensorKindBF16 } diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 0b5d37a7b..205279c67 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -249,6 +249,7 @@ func (kv KV) OllamaEngineRequired() bool { "qwen25vl", "qwen3", "qwen3moe", "qwen3vl", "qwen3vlmoe", + "deepseekocr", }, kv.Architecture()) } diff --git a/ml/backend.go b/ml/backend.go index 99d6b146e..bf2c5851f 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -173,6 +173,7 @@ type Tensor interface { Cos(ctx Context) Tensor Tanh(ctx Context) Tensor GELU(ctx Context, up ...Tensor) Tensor + QuickGELU(ctx Context, up ...Tensor) Tensor SILU(ctx Context, up ...Tensor) Tensor RELU(ctx Context, up ...Tensor) Tensor Sigmoid(ctx Context) Tensor @@ -207,6 +208,8 @@ type Tensor interface { Stddev(ctx Context) Tensor Sqr(ctx Context) Tensor Sqrt(ctx Context) Tensor + + Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor } // ScaledDotProductAttention implements a fused attention @@ -372,3 +375,10 @@ const ( DTypeI32 DTypeMXFP4 ) + +type SamplingMode int + +const ( + SamplingModeNearest SamplingMode = iota + SamplingModeBilinear +) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 1413e044f..1d457fc4f 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -314,7 +314,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { "altup_proj", "altup_unembd_proj", "per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"): createTensor(tensor{source: t}, output.bts, blocks) - case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."): + case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm.") || strings.HasPrefix(t.Name, "s."): // TODO: assign vision tensors to the gpu if possible createTensor(tensor{source: t}, output.bts, blocks) case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"): @@ -1567,6 +1567,16 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { } } +func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + var tt *C.struct_ggml_tensor + if len(t2) > 0 { + tt = C.ggml_geglu_quick_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t) + } else { + tt = C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t) + } + return &Tensor{b: t.b, t: tt} +} + func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { if len(t2) > 0 { return &Tensor{ @@ -1724,6 +1734,23 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor { } } +func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor { + var mode C.uint32_t + switch samplingMode { + case ml.SamplingModeNearest: + mode = C.GGML_SCALE_MODE_NEAREST + case ml.SamplingModeBilinear: + mode = C.GGML_SCALE_MODE_BILINEAR + default: + panic("unsupported interpolate mode") + } + + return &Tensor{ + b: t.b, + t: C.ggml_interpolate(ctx.(*Context).ctx, t.t, C.int64_t(dims[0]), C.int64_t(dims[1]), C.int64_t(dims[2]), C.int64_t(dims[3]), mode), + } +} + // Slice returns a view of the tensor sliced along dim from low to high in step steps. // Slice panics if the dimension is invalid or the slice parameters are out of range. // If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape. diff --git a/model/imageproc/images.go b/model/imageproc/images.go index 7afe36701..2cfea762f 100644 --- a/model/imageproc/images.go +++ b/model/imageproc/images.go @@ -25,12 +25,15 @@ const ( // Composite returns an image with the alpha channel removed by drawing over a white background. func Composite(img image.Image) image.Image { - dst := image.NewRGBA(img.Bounds()) - white := color.RGBA{255, 255, 255, 255} - draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src) - draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over) + return CompositeColor(img, white) +} +// CompositeColor returns an image with the alpha channel removed by drawing over a white background. +func CompositeColor(img image.Image, color color.Color) image.Image { + dst := image.NewRGBA(img.Bounds()) + draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src) + draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over) return dst } @@ -55,6 +58,31 @@ func Resize(img image.Image, newSize image.Point, method int) image.Image { return dst } +// Pad returns an image which has been resized to fit within a new size, preserving aspect ratio, and padded with a color. +func Pad(img image.Image, newSize image.Point, color color.Color, kernel draw.Interpolator) image.Image { + dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y)) + draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src) + + var minPoint, maxPoint image.Point + if img.Bounds().Dx() > img.Bounds().Dy() { + // landscape + height := newSize.X * img.Bounds().Dy() / img.Bounds().Dx() + minPoint = image.Point{0, (newSize.Y - height) / 2} + maxPoint = image.Point{newSize.X, height + minPoint.Y} + } else { + // portrait + width := newSize.Y * img.Bounds().Dx() / img.Bounds().Dy() + minPoint = image.Point{(newSize.X - width) / 2, 0} + maxPoint = image.Point{minPoint.X + width, newSize.Y} + } + + kernel.Scale(dst, image.Rectangle{ + Min: minPoint, + Max: maxPoint, + }, img, img.Bounds(), draw.Over, nil) + return dst +} + // Normalize returns a slice of float32 containing each of the r, g, b values for an image normalized around a value. func Normalize(img image.Image, mean, std [3]float32, rescale bool, channelFirst bool) []float32 { var pixelVals []float32 diff --git a/model/models/deepseekocr/imageprocessor.go b/model/models/deepseekocr/imageprocessor.go new file mode 100644 index 000000000..76bcb279c --- /dev/null +++ b/model/models/deepseekocr/imageprocessor.go @@ -0,0 +1,83 @@ +package deepseekocr + +import ( + "bytes" + "image" + "image/color" + "math" + "slices" + + "golang.org/x/image/draw" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/imageproc" +) + +type ratio struct { + x, y int +} + +func ProcessImage(ctx ml.Context, bts []byte) (ml.Tensor, ml.Tensor, []int, error) { + img, _, err := image.Decode(bytes.NewReader(bts)) + if err != nil { + return nil, nil, nil, err + } + + minNum, maxNum, imageSize, baseSize := 2, 9, 640, 1024 + var targetRatios []ratio + for n := minNum; n <= maxNum; n++ { + for i := 1; i <= n; i++ { + for j := 1; j <= n; j++ { + if i*j <= maxNum && i*j >= minNum && !slices.Contains(targetRatios, ratio{i, j}) { + targetRatios = append(targetRatios, ratio{i, j}) + } + } + } + } + + targetRatio := findBestAspectRatio(targetRatios, img.Bounds().Dx(), img.Bounds().Dy(), imageSize) + targetWidth, targetHeight := imageSize*targetRatio.x, imageSize*targetRatio.y + blocks := targetRatio.x * targetRatio.y + + mean := imageproc.ImageNetStandardMean + std := imageproc.ImageNetStandardSTD + + var patches []float32 + resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear) + for i := range blocks { + patch := image.NewRGBA(image.Rect(0, 0, imageSize, imageSize)) + draw.Draw(patch, patch.Bounds(), resized, image.Point{ + X: i % (targetWidth / imageSize) * imageSize, + Y: i / (targetWidth / imageSize) * imageSize, + }, draw.Over) + + patches = append(patches, imageproc.Normalize(patch, mean, std, true, true)...) + } + + img = imageproc.CompositeColor(img, color.Gray{}) + img = imageproc.Pad(img, image.Point{X: baseSize, Y: baseSize}, color.Gray{127}, draw.BiLinear) + + return ctx.Input().FromFloats(patches, imageSize, imageSize, 3, blocks), + ctx.Input().FromFloats(imageproc.Normalize(img, mean, std, true, true), baseSize, baseSize, 3), + []int{targetRatio.x, targetRatio.y}, + nil +} + +func findBestAspectRatio(targetRatios []ratio, width, height, imageSize int) ratio { + bestDiff := math.MaxFloat64 + best := ratio{1, 1} + realRatio := float64(width) / float64(height) + for _, target := range targetRatios { + targetRatio := float64(target.x) / float64(target.y) + diff := math.Abs(realRatio - targetRatio) + if diff < bestDiff { + bestDiff = diff + best = target + } else if diff == bestDiff { + if float64(width*height) > 0.5*float64(imageSize*imageSize*best.x*best.y) { + best = target + } + } + } + return best +} diff --git a/model/models/deepseekocr/model.go b/model/models/deepseekocr/model.go new file mode 100644 index 000000000..4fc069b69 --- /dev/null +++ b/model/models/deepseekocr/model.go @@ -0,0 +1,192 @@ +package deepseekocr + +import ( + "math" + "slices" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.TextProcessor + + Sam *samModel `gguf:"s"` + Vision *visionModel `gguf:"v"` + Text *textModel + + ImageNewline ml.Tensor `gguf:"mm.image_newline"` + //nolint:misspell // this misspelling is upstream. fixing it breaks the model + ViewSeperator ml.Tensor `gguf:"mm.view_seperator"` + + Projector *nn.Linear `gguf:"mm.layers"` +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, bts []byte) ([]input.Multimodal, error) { + patches, original, crop, err := ProcessImage(ctx, bts) + if err != nil { + return nil, err + } + + var outputs []ml.Tensor + if true { // TODO: local features if sum(patches) != 0 + samOutputs := m.Sam.Forward(ctx, patches) + visionOutputs := m.Vision.Forward(ctx, patches, samOutputs) + + samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3) + visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1) + localOutputs := visionOutputs.Concat(ctx, samOutputs, 0) + localOutputs = m.Projector.Forward(ctx, localOutputs) + + hw := int(math.Sqrt(float64(localOutputs.Dim(1)))) + localOutputs = localOutputs.Reshape(ctx, -1, hw, crop[0], crop[1]) + localOutputs = localOutputs.Permute(ctx, 0, 2, 1, 3) + localOutputs = localOutputs.Contiguous(ctx, -1, crop[0]*hw, crop[1]*hw) + localOutputs = localOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, localOutputs.Dim(2)), 1) + localOutputs = localOutputs.Reshape(ctx, localOutputs.Dim(0), -1) + + outputs = append(outputs, localOutputs) + } + + samOutputs := m.Sam.Forward(ctx, original) + visionOutputs := m.Vision.Forward(ctx, original, samOutputs) + + samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3) + visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1) + globalOutputs := visionOutputs.Concat(ctx, samOutputs, 0) + globalOutputs = m.Projector.Forward(ctx, globalOutputs) + + hw := int(math.Sqrt(float64(globalOutputs.Dim(1)))) + globalOutputs = globalOutputs.Reshape(ctx, -1, hw, hw) + globalOutputs = globalOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, globalOutputs.Dim(2)), 1) + globalOutputs = globalOutputs.Reshape(ctx, globalOutputs.Dim(0), -1) + + outputs = append(outputs, globalOutputs, m.ViewSeperator) + return []input.Multimodal{ + {Tensor: outputs[0].Stack(ctx, 1, outputs[1:]...)}, + }, nil +} + +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + outputs := make([]*input.Input, 0, len(inputs)) + for i := range inputs { + if inputs[i].Multimodal == nil { + outputs = append(outputs, inputs[i]) + continue + } + + t := inputs[i].Multimodal[0].Tensor + outputs = append(outputs, &input.Input{ + Token: 128815, + Multimodal: inputs[i].Multimodal, + MultimodalHash: inputs[i].MultimodalHash, + SameBatch: t.Dim(1) - 1, + }) + + outputs = slices.Grow(outputs, t.Dim(1)-1) + outputs = append(outputs, slices.Repeat([]*input.Input{{Token: 128815}}, t.Dim(1)-1)...) + } + return outputs, nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + inputsEmbeds := m.Text.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + + for _, mm := range batch.Multimodal { + t := mm.Multimodal[0].Tensor + ctx.Forward(t.Copy(ctx, inputsEmbeds.View(ctx, mm.Index*inputsEmbeds.Stride(1), t.Dim(0)*t.Dim(1)))) + } + + hiddenStates := inputsEmbeds + for i, block := range m.Text.Blocks { + if m.Cache != nil { + m.Cache.SetLayer(i) + } + + var outputs ml.Tensor + if i == len(m.Text.Blocks)-1 { + outputs = batch.Outputs + } + + hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Text.Options) + } + + hiddenStates = m.Text.OutputNorm.Forward(ctx, hiddenStates, m.Text.Options.eps) + return m.Text.Output.Forward(ctx, hiddenStates), nil +} + +func init() { + model.Register("deepseekocr", func(c fs.Config) (model.Model, error) { + textBlocks := make([]textBlock, c.Uint("block_count")) + leadingDenseBlockCount := int(c.Uint("leading_dense_block_count", 1)) + for i := range textBlocks { + if i >= leadingDenseBlockCount { + textBlocks[i].FeedForward = &textMoe{} + } else { + textBlocks[i].FeedForward = &textMLP{} + } + } + + m := Model{ + TextProcessor: model.NewBytePairEncoding( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + // Split regex into multiple parts (according to DeepSeek3's regex) + "\\p{N}{1,3}", + `[一-龥぀-ゟ゠-ヿ]+`, + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + ), + Text: &textModel{ + Blocks: textBlocks, + Options: textOptions{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + ropeBase: c.Float("rope.freq_base", 10_000), + ropeScale: c.Float("rope.scaling.factor", 1.0), + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6), + }, + }, + Vision: &visionModel{ + Blocks: make([]visionBlock, c.Uint("vision.block_count")), + Options: visionOptions{ + hiddenSize: int(c.Uint("vision.embedding_length")), + numHeads: int(c.Uint("vision.head_count")), + imageSize: int(c.Uint("vision.image_size", 224)), + patchSize: int(c.Uint("vision.patch_size", 14)), + eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5), + }, + }, + Sam: &samModel{ + Blocks: make([]samBlock, c.Uint("sam.block_count")), + Options: samOptions{ + hiddenSize: int(c.Uint("sam.embedding_length")), + numHeads: int(c.Uint("sam.head_count")), + eps: c.Float("sam.attention.layer_norm_epsilon", 1e-6), + globalAttentionLayers: c.Ints("sam.global_attention_indexes"), + }, + }, + } + + m.Cache = kvcache.NewCausalCache(m.Text.Shift) + return &m, nil + }) +} diff --git a/model/models/deepseekocr/model_sam.go b/model/models/deepseekocr/model_sam.go new file mode 100644 index 000000000..8bf30f96c --- /dev/null +++ b/model/models/deepseekocr/model_sam.go @@ -0,0 +1,225 @@ +package deepseekocr + +import ( + "math" + "slices" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +type samModel struct { + PatchEmbedding *nn.Conv2D `gguf:"patch_embd"` + PositionEmbedding ml.Tensor `gguf:"position_embd"` + + Blocks []samBlock `gguf:"blk"` + + Neck *samNeck `gguf:"neck"` + Net2 *nn.Conv2D `gguf:"net_2"` + Net3 *nn.Conv2D `gguf:"net_3"` + + Options samOptions +} + +func (m *samModel) absolutePositionEmbedding(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + source := m.PositionEmbedding.Dim(1) + target := hiddenStates.Dim(2) + if source != target { + positionEmbed := m.PositionEmbedding.Permute(ctx, 2, 0, 1, 3) + positionEmbed = positionEmbed.Interpolate(ctx, [4]int{target, target, hiddenStates.Dim(0), 1}, ml.SamplingModeBilinear) + return positionEmbed.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + } + + return m.PositionEmbedding +} + +func (m *samModel) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor { + hiddenStates := m.PatchEmbedding.Forward(ctx, t, 16, 16, 0, 0, 1, 1) + hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + + if m.PositionEmbedding != nil { + hiddenStates = hiddenStates.Add(ctx, m.absolutePositionEmbedding(ctx, hiddenStates)) + } + + for i, block := range m.Blocks { + var windowSize int + if !slices.Contains(m.Options.globalAttentionLayers, int32(i)) { + windowSize = 14 + } + + hiddenStates = block.Forward(ctx, hiddenStates, windowSize, m.Options) + } + + hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) + hiddenStates = m.Neck.Forward(ctx, hiddenStates, m.Options) + hiddenStates = m.Net2.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1) + hiddenStates = m.Net3.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1) + return hiddenStates +} + +type samOptions struct { + hiddenSize, + numHeads int + eps float32 + globalAttentionLayers []int32 +} + +func (o samOptions) headDim() int { + return o.hiddenSize / o.numHeads +} + +type samBlock struct { + Norm1 *nn.LayerNorm `gguf:"norm1"` + Attention *samAttention `gguf:"attn"` + Norm2 *nn.LayerNorm `gguf:"norm2"` + FeedForward *samMLP `gguf:"mlp"` +} + +func (m *samBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, windowSize int, opts samOptions) ml.Tensor { + c, w, h := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2) + + residual := hiddenStates + hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps) + + var pw, ph int + if windowSize > 0 { + pw = (windowSize - hiddenStates.Dim(1)%windowSize) % windowSize + ph = (windowSize - hiddenStates.Dim(2)%windowSize) % windowSize + if pw > 0 || ph > 0 { + hiddenStates = hiddenStates.Pad(ctx, 0, pw, ph, 0) + } + + hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, (w+pw)/windowSize, windowSize, -1) + hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, c, windowSize, windowSize, -1) + } + + hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts) + + if windowSize > 0 { + hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, windowSize, (w+pw)/windowSize, -1) + hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3) + hiddenStates = hiddenStates.Contiguous(ctx, c, w+pw, h+ph, -1) + hiddenStates = hiddenStates.Pad(ctx, 0, -pw, -ph, 0) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +type samAttention struct { + QKV *nn.Linear `gguf:"qkv"` + Output *nn.Linear `gguf:"proj"` + + RelativePosition *struct { + Height ml.Tensor `gguf:"h"` + Width ml.Tensor `gguf:"w"` + } `gguf:",pre:rel_pos_"` +} + +func relativeCoordinates(ctx ml.Context, qn, kn int) ml.Tensor { + s := make([]int32, qn*kn) + for i := range qn { + for j := range kn { + q := i * max(kn/qn, 1) + k := j * max(qn/kn, 1) + s[i*kn+j] = int32(q - k + (kn-1)*max(qn/kn, 1)) + } + } + return ctx.Input().FromInts(s, qn*kn) +} + +func relativePositions(ctx ml.Context, positions ml.Tensor, qn, kn int) ml.Tensor { + maxRelativeDistance := 2*max(qn, kn) - 1 + if positions.Dim(1) != maxRelativeDistance { + // linear interpolation kernel not available so approx. with bilinear interpolation + positions = positions.Interpolate(ctx, [4]int{positions.Dim(0), maxRelativeDistance, 1, 1}, ml.SamplingModeBilinear) + } + + rc := relativeCoordinates(ctx, qn, kn) + return positions.Rows(ctx, rc).Reshape(ctx, positions.Dim(0), kn, qn) +} + +func (m *samAttention) decomposedRelativePositions(ctx ml.Context, query ml.Tensor, qn, kn []int) (ml.Tensor, ml.Tensor) { + qh, qw := qn[0], qn[1] + kh, kw := kn[0], kn[1] + + rh := relativePositions(ctx, m.RelativePosition.Height, qh, kh) + rw := relativePositions(ctx, m.RelativePosition.Width, qw, kw) + + query = query.Contiguous(ctx, query.Dim(0), qw, qh, -1) + rh = rh.Mulmat(ctx, query).Reshape(ctx, 1, kh, qh*qw, -1) + rw = rw.Mulmat(ctx, query.Permute(ctx, 0, 2, 1, 3)).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, kw, 1, qh*qw, -1) + return rh, rw +} + +func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor { + w, h, b := hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3) + + qkv := m.QKV.Forward(ctx, hiddenStates) + qkv = qkv.Reshape(ctx, opts.headDim(), -1, w*h, b) + chunks := qkv.Chunk(ctx, 1, opts.numHeads) + query, key, value := chunks[0], chunks[1], chunks[2] + + ctx.Forward(query, key, value) + + query = query.Permute(ctx, 0, 2, 1, 3) + rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w}) + mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw) + mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b) + + key = key.Permute(ctx, 0, 2, 1, 3) + scores := key.MulmatFullPrec(ctx, query) + scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim()))) + + scores = scores.Add(ctx, mask) + scores = scores.Softmax(ctx) + + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + attention := value.Mulmat(ctx, scores) + attention = attention.Permute(ctx, 0, 2, 1, 3) + attention = attention.Contiguous(ctx, -1, w, h, b) + return m.Output.Forward(ctx, attention) +} + +type samMLP struct { + Lin1 *nn.Linear `gguf:"lin1"` + Lin2 *nn.Linear `gguf:"lin2"` +} + +func (m *samMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor { + return m.Lin2.Forward(ctx, m.Lin1.Forward(ctx, hiddenStates).GELU(ctx)) +} + +type LayerNorm2D struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` +} + +func (ln *LayerNorm2D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor { + x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + u := x.Mean(ctx) + d := x.Sub(ctx, u) + s := d.Sqr(ctx).Mean(ctx) + x = d.Div(ctx, s.Add(ctx, ctx.Input().FromFloats([]float32{eps}, 1)).Sqrt(ctx)) + x = x.Mul(ctx, ln.Weight).Add(ctx, ln.Bias) + return x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) +} + +type samNeck struct { + C1 *nn.Conv2D `gguf:"0"` + LN1 *LayerNorm2D `gguf:"1"` + C2 *nn.Conv2D `gguf:"2"` + LN2 *LayerNorm2D `gguf:"3"` +} + +func (m *samNeck) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor { + hiddenStates = m.C1.Forward(ctx, hiddenStates, 1, 1, 0, 0, 1, 1) + hiddenStates = m.LN1.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = m.C2.Forward(ctx, hiddenStates, 1, 1, 1, 1, 1, 1) + hiddenStates = m.LN2.Forward(ctx, hiddenStates, opts.eps) + return hiddenStates +} diff --git a/model/models/deepseekocr/model_text.go b/model/models/deepseekocr/model_text.go new file mode 100644 index 000000000..1513b1388 --- /dev/null +++ b/model/models/deepseekocr/model_text.go @@ -0,0 +1,140 @@ +package deepseekocr + +import ( + "math" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" +) + +type textModel struct { + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Blocks []textBlock `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output"` + + Options textOptions +} + +func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil +} + +type textOptions struct { + hiddenSize, + numHeads, + numKVHeads, + numExperts, + numExpertsUsed int + ropeBase, + ropeScale, + eps float32 +} + +func (o textOptions) headDim() int { + return o.hiddenSize / o.numHeads +} + +func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor { + return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX()) +} + +type textBlock struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Attention *textAttention + MLPNNorm *nn.RMSNorm `gguf:"ffn_norm"` + FeedForward textFeedForward +} + +func (m *textBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor { + residual := hiddenStates + hiddenStates = m.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = m.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = m.MLPNNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +type textAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor { + query := m.Query.Forward(ctx, hiddenStates) + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, -1) + + key := m.Key.Forward(ctx, hiddenStates) + key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1) + + value := m.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1) + + query = opts.applyRotaryPositionalEmbedding(ctx, query, positions) + key = opts.applyRotaryPositionalEmbedding(ctx, key, positions) + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention = attention.Reshape(ctx, -1, attention.Dim(2)) + return m.Output.Forward(ctx, attention) +} + +type textFeedForward interface { + Forward(ml.Context, ml.Tensor, textOptions) ml.Tensor +} + +type textMoe struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` + SharedExperts *textMLP `gguf:",suf:_shexp"` +} + +func (m *textMoe) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts textOptions) ml.Tensor { + scores := m.Router.Forward(ctx, hiddenStates).Softmax(ctx) + indices := scores.TopK(ctx, opts.numExpertsUsed) + weights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, indices) + + experts := hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + experts = m.Gate.Forward(ctx, experts, indices).SILU(ctx, m.Up.Forward(ctx, experts, indices)) + experts = m.Down.Forward(ctx, experts, indices) + experts = experts.Mul(ctx, weights) + + expert := func(i int) ml.Tensor { + return experts.View( + ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2), + ) + } + + routedStates := expert(0) + for i := 1; i < opts.numExpertsUsed; i++ { + routedStates = routedStates.Add(ctx, expert(i)) + } + + sharedStates := m.SharedExperts.Forward(ctx, hiddenStates, opts) + return routedStates.Add(ctx, sharedStates) +} + +type textMLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (m *textMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ textOptions) ml.Tensor { + hiddenStates = m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates)) + return m.Down.Forward(ctx, hiddenStates) +} diff --git a/model/models/deepseekocr/model_vision.go b/model/models/deepseekocr/model_vision.go new file mode 100644 index 000000000..61121ebfd --- /dev/null +++ b/model/models/deepseekocr/model_vision.go @@ -0,0 +1,117 @@ +package deepseekocr + +import ( + "math" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +type visionModel struct { + PatchEmbedding *nn.Conv2D `gguf:"patch_embd"` + ClassEmbedding ml.Tensor `gguf:"class_embd"` + PositionEmbedding *nn.Embedding `gguf:"position_embd"` + + PreLayerNorm *nn.LayerNorm `gguf:"pre_layrnorm"` + Blocks []visionBlock `gguf:"blk"` + + Options visionOptions +} + +func (m *visionModel) absolutePositionEmbedding(ctx ml.Context, embeds ml.Tensor) ml.Tensor { + numPatches := m.Options.imageSize / m.Options.patchSize * m.Options.imageSize / m.Options.patchSize + positions := ctx.Arange(0, float32(numPatches+1), 1, ml.DTypeI32) + positionEmbeds := m.PositionEmbedding.Forward(ctx, positions) + + source := int(math.Sqrt(float64(positionEmbeds.Dim(1) - 1))) + target := int(math.Sqrt(float64(embeds.Dim(1) - 1))) + if source != target { + newPositionEmbeds := positionEmbeds.Slice(ctx, 1, 1, positionEmbeds.Dim(1), 1) + newPositionEmbeds = newPositionEmbeds.Reshape(ctx, -1, source, source) + newPositionEmbeds = newPositionEmbeds.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) + newPositionEmbeds = newPositionEmbeds.Interpolate(ctx, [4]int{target, target, embeds.Dim(0), 1}, ml.SamplingModeBilinear) + newPositionEmbeds = newPositionEmbeds.Permute(ctx, 1, 2, 0, 3) + newPositionEmbeds = newPositionEmbeds.Contiguous(ctx, -1, target*target) + + positionEmbeds = positionEmbeds.Slice(ctx, 1, 0, 1, 1).Concat(ctx, newPositionEmbeds, 1) + } + + return positionEmbeds +} + +func (m *visionModel) Forward(ctx ml.Context, pixelValues, patchEmbeds ml.Tensor) ml.Tensor { + if patchEmbeds == nil { + patchEmbeds = m.PatchEmbedding.Forward(ctx, pixelValues, m.Options.patchSize, m.Options.patchSize, 0, 0, 1, 1) + } + + patchEmbeds = patchEmbeds.Reshape(ctx, -1, patchEmbeds.Dim(2), patchEmbeds.Dim(3)) + patchEmbeds = patchEmbeds.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + + classEmbeds := m.ClassEmbedding.Repeat(ctx, 2, patchEmbeds.Dim(2)) + embeds := classEmbeds.Concat(ctx, patchEmbeds, 1) + embeds = embeds.Add(ctx, m.absolutePositionEmbedding(ctx, embeds)) + + hiddenStates := m.PreLayerNorm.Forward(ctx, embeds, m.Options.eps) + for _, block := range m.Blocks { + hiddenStates = block.Forward(ctx, hiddenStates, m.Options) + } + + return hiddenStates +} + +type visionOptions struct { + hiddenSize, + numHeads int + eps float32 + + imageSize, patchSize int +} + +func (o visionOptions) headDim() int { + return o.hiddenSize / o.numHeads +} + +type visionBlock struct { + Norm1 *nn.LayerNorm `gguf:"layer_norm1"` + Attention *visionAttention `gguf:"self_attn"` + Norm2 *nn.LayerNorm `gguf:"layer_norm2"` + FeedForward *visionMLP `gguf:"mlp"` +} + +func (m *visionBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts visionOptions) ml.Tensor { + residual := hiddenStates + hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = m.FeedForward.Forward(ctx, hiddenStates) + hiddenStates = hiddenStates.Add(ctx, residual) + return hiddenStates +} + +type visionAttention struct { + QKV *nn.Linear `gguf:"qkv_proj"` + Output *nn.Linear `gguf:"out_proj"` +} + +func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOptions) ml.Tensor { + qkv := m.QKV.Forward(ctx, t) + qkv = qkv.Reshape(ctx, opts.headDim(), -1, qkv.Dim(1), qkv.Dim(2)) + chunks := qkv.Chunk(ctx, 1, opts.numHeads) + query, key, value := chunks[0], chunks[1], chunks[2] + + attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3)) + return m.Output.Forward(ctx, attention) +} + +type visionMLP struct { + FC1 *nn.Linear `gguf:"fc1"` + FC2 *nn.Linear `gguf:"fc2"` +} + +func (m *visionMLP) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor { + return m.FC2.Forward(ctx, m.FC1.Forward(ctx, t).QuickGELU(ctx)) +} diff --git a/model/models/models.go b/model/models/models.go index deefeb58f..cb09633ef 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -3,6 +3,7 @@ package models import ( _ "github.com/ollama/ollama/model/models/bert" _ "github.com/ollama/ollama/model/models/deepseek2" + _ "github.com/ollama/ollama/model/models/deepseekocr" _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n"