diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index a50d8ec9c..6a044260a 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1534,7 +1534,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase unsafe.SliceData(mropeSections), C.int(opts.Type), cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10), - C.float(ropeBase), C.float(ropeScale), + C.float(ropeBase), + C.float(ropeScale), C.float(opts.YaRN.ExtrapolationFactor), cmp.Or(C.float(opts.YaRN.AttentionFactor), 1), cmp.Or(C.float(opts.YaRN.BetaFast), 32), @@ -1546,9 +1547,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase dequant, positions.(*Tensor).t, opts.Factors.(*Tensor).t, - C.int(ropeDim), C.int(opts.Type), + C.int(ropeDim), + C.int(opts.Type), cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10), - C.float(ropeBase), C.float(ropeScale), + C.float(ropeBase), + C.float(ropeScale), C.float(opts.YaRN.ExtrapolationFactor), cmp.Or(C.float(opts.YaRN.AttentionFactor), 1), cmp.Or(C.float(opts.YaRN.BetaFast), 32), diff --git a/ml/nn/rope/options.go b/ml/nn/rope/options.go index 84b926773..1724128a4 100644 --- a/ml/nn/rope/options.go +++ b/ml/nn/rope/options.go @@ -77,6 +77,13 @@ func WithMRoPE(sections []int) func(*Options) { } } +func WithVision(sections []int) func(*Options) { + return func(opts *Options) { + opts.Type |= 1<<3 | 1<<4 + opts.MRoPE.Sections = sections + } +} + func WithInterleaveMRoPE(sections []int) func(*Options) { return func(opts *Options) { opts.Type |= 1<<3 | 1<<5 diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 13fa3fee1..81296a81b 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -2,7 +2,6 @@ package qwen25vl import ( "bytes" - "fmt" "image" "slices" @@ -33,7 +32,7 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), EOS: append( @@ -54,19 +53,18 @@ func New(c fs.Config) (model.Model, error) { } func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) { - image, _, err := image.Decode(bytes.NewReader(multimodalData)) + img, _, err := image.Decode(bytes.NewReader(multimodalData)) if err != nil { return nil, nil, err } - f32s, grid, err := m.ImageProcessor.ProcessImage(image) + f32s, grid, err := m.ImageProcessor.ProcessImage(img) if err != nil { return nil, nil, err } // Calculate tensor dimensions - patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize * - m.ImageProcessor.patchSize * m.ImageProcessor.patchSize + patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize numPatches := grid.Temporal * grid.Height * grid.Width pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches) @@ -85,11 +83,13 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input } visionOutputs := m.VisionModel.Forward(ctx, pixels, grid) - return []input.Multimodal{{Tensor: visionOutputs}}, nil + return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil } // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + // Reset position cache + m.positionCache = m.positionCache[:0] var result []*input.Input var ( @@ -98,40 +98,37 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { visionEndToken int32 = 151653 ) - nImg := 0 + appendInput := func(i *input.Input, p int) int { + result = append(result, i) + m.positionCache = append(m.positionCache, int32(p)) + return p + 1 + } + + var p int for _, inp := range inputs { if inp.Multimodal == nil { // If not a multimodal input, add it to the result unchanged - result = append(result, inp) + p = appendInput(inp, p) } else { - // Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix - // the image tokens with a prompt, so we add a prefix here - nImg++ - pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true) - if err != nil { - return nil, fmt.Errorf("failed to encode image prompt: %w", err) - } - for i := range pre { - result = append(result, &input.Input{Token: pre[i]}) - } - - patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1) - // First add the vision start token - result = append(result, &input.Input{Token: visionStartToken}) + p = appendInput(&input.Input{Token: visionStartToken}, p) // Add the image token with the multimodal tensor data at the first position - result = append(result, &input.Input{ + tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1) + appendInput(&input.Input{ Token: imageToken, Multimodal: inp.Multimodal, MultimodalHash: inp.MultimodalHash, - SameBatch: patchesPerChunk, - }) + SameBatch: tokensPerGrid, + }, p) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) - result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...) + for range tokensPerGrid - 1 { + appendInput(&input.Input{Token: imageToken}, p) + } - result = append(result, &input.Input{Token: visionEndToken}) + grid := inp.Multimodal[0].Data.(*Grid) + p = appendInput(&input.Input{Token: visionEndToken}, p+max(grid.Width/m.spatialMergeSize, grid.Height/m.spatialMergeSize)) } } @@ -139,9 +136,58 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + // Initial token embedding + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx) - return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache) + positionSlice := func() [][]int32 { + s := [][]int32{ + make([]int32, len(batch.Positions)), + make([]int32, len(batch.Positions)), + make([]int32, len(batch.Positions)), + make([]int32, len(batch.Positions)), + } + for i, position := range batch.Positions { + if position < int32(len(m.positionCache)) { + position = m.positionCache[position] + } else if len(m.positionCache) > 0 { + position = position - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1 + } + + s[0][i] = position + s[1][i] = position + s[2][i] = position + } + return s + }() + + for _, mi := range batch.Multimodal { + img := mi.Multimodal[0].Tensor + ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) + if grid, ok := mi.Multimodal[0].Data.(*Grid); ok { + for i := range img.Dim(1) { + w := grid.Width / m.spatialMergeSize + positionSlice[1][mi.Index+i] += int32(i / w) + positionSlice[2][mi.Index+i] += int32(i % w) + } + } + } + + positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice)) + + // Process through transformer layers + for i, layer := range m.TextModel.Layers { + m.Cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.TextModel.Layers)-1 { + lastLayerOutputs = batch.Outputs + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextOptions) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps) + return m.Output.Forward(ctx, hiddenStates), nil } func init() { diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index b4db6043e..61b072d67 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -8,20 +8,17 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn/rope" - "github.com/ollama/ollama/model/input" ) type TextOptions struct { hiddenSize, numHeads, numKVHeads int ropeDim, originalContextLength int eps, ropeBase, ropeScale float32 + mropeSections []int } func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { - return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, - rope.WithOriginalContextLength(o.originalContextLength), - rope.WithTypeNeoX(), - ) + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithMRoPE(o.mropeSections)) } type TextModel struct { @@ -31,6 +28,7 @@ type TextModel struct { Output *nn.Linear `gguf:"output,alt:token_embd"` *TextOptions + positionCache []int32 } func NewTextModel(c fs.Config) *TextModel { @@ -45,6 +43,14 @@ func NewTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.scaling.factor", 1), + mropeSections: func() []int { + sections := c.Ints("rope.mrope_section") + s := make([]int, len(sections)) + for i, section := range sections { + s[i] = int(section) + } + return s + }(), }, } @@ -84,6 +90,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + m.positionCache = nil return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } @@ -130,28 +137,3 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten hiddenState = l.MLP.Forward(ctx, hiddenState, opts) return hiddenState.Add(ctx, residual) } - -func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) { - // Initial token embedding - hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) - - for _, mi := range batch.Multimodal { - img := mi.Multimodal[0].Tensor - ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) - } - - // Process through transformer layers - for i, layer := range m.Layers { - cache.SetLayer(i) - - var lastLayerOutputs ml.Tensor - if i == len(m.Layers)-1 { - lastLayerOutputs = outputs - } - - hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions) - } - - hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) - return m.Output.Forward(ctx, hiddenStates), nil -} diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index bfdafabe4..f1275437f 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -7,48 +7,28 @@ import ( "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/rope" ) -// We only support batch size of 1 -var batchSize int = 1 - -func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { - x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1) - x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx) - return x2.Scale(ctx, -1).Concat(ctx, x1, 0) -} - -func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor { - return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin)) -} - -func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor { - // Create a flat slice for the mask (all -inf initially to block all attention) - flat := make([]float32, seqLength*seqLength) - for i := range flat { - flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention +func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor { + // Initialize a 2D mask with -Inf + s := make([][]float32, seqLength) + for i := range s { + s[i] = slices.Repeat([]float32{float32(math.Inf(-1))}, seqLength) } // Fill in the mask with zeros for tokens that CAN attend to each other for i := 1; i < len(bounds); i++ { - start := bounds[i-1] - end := bounds[i] - - // Enable attention within this sequence block by setting values to 0 + start, end := bounds[i-1], bounds[i] + // Enable attention within this sequence block for row := start; row < end; row++ { for col := start; col < end; col++ { - idx := row*seqLength + col - flat[idx] = 0.0 // 0 allows attention, -inf blocks it + s[row][col] = 0.0 } } } - mask := ctx.Input().FromFloats(flat, seqLength, seqLength) - - // Reshape to match [seqLength, seqLength, 1] for broadcasting - mask = mask.Reshape(ctx, seqLength, seqLength, 1) - - return mask + return ctx.Input().FromFloats(slices.Concat(s...), seqLength, seqLength) } type VisionSelfAttention struct { @@ -58,17 +38,17 @@ type VisionSelfAttention struct { Output *nn.Linear `gguf:"attn_out"` } -func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { query := sa.Query.Forward(ctx, hiddenStates) key := sa.Key.Forward(ctx, hiddenStates) value := sa.Value.Forward(ctx, hiddenStates) - query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize) - key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) - value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) + query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1)) + key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1)) + value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1)) - query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) - key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) // Scale factor for scaled dot-product attention scale := 1.0 / math.Sqrt(float64(opts.headDim)) @@ -77,6 +57,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m query = query.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + kq := key.MulmatFullPrec(ctx, query) kq = kq.Scale(ctx, scale) if mask != nil { @@ -85,7 +66,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m kq = kq.Softmax(ctx) kqv := value.Mulmat(ctx, kq) attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2)) return sa.Output.Forward(ctx, attention) } @@ -98,10 +79,7 @@ type VisionMLP struct { } func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { - // Using activation as specified in config (likely GELU or SiLU/Swish) - gateOutput := mlp.Gate.Forward(ctx, hiddenStates) - hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) - + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } @@ -112,10 +90,10 @@ type VisionEncoderLayer struct { MLP *VisionMLP } -func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { residual := hiddenStates hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) - hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts) + hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positions, mask, opts) hiddenStates = hiddenStates.Add(ctx, residual) residual = hiddenStates @@ -139,6 +117,17 @@ type VisionModelOptions struct { temporalPatchSize int } +func (o VisionModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim/2, o.ropeTheta, 1, + rope.WithVision([]int{ + o.headDim / 4, + o.headDim / 4, + o.headDim / 4, + o.headDim / 4, + }), + ) +} + type PatchEmbedding struct { PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"` PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"` @@ -186,7 +175,7 @@ func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, op hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize) // Reshape the normalized output to view the hidden size dimension - reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize) + reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize)) hidden := pm.MLP0.Forward(ctx, reshaped) activated := hidden.GELU(ctx) @@ -209,36 +198,53 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) // Extract patch embeddings hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions) - positionEmbedding := m.PositionalEmbedding(ctx, grid) - - windowIndex, bounds := m.WindowIndex(ctx, grid) - + index, bounds := m.windowIndex(grid) spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize + windowIndex := ctx.Input().FromInts(index, len(index)) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit) - hiddenStates = hiddenStates.Rows(ctx, windowIndex) + hiddenStates = hiddenStates.Rows(ctx, windowIndex.Argsort(ctx)) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit) - positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit) - positionEmbedding = positionEmbedding.Rows(ctx, windowIndex) - positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit) - positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0) + positions := ctx.Input().FromInts(func() []int32 { + s := [][]int32{ + make([]int32, grid.Height*grid.Width), + make([]int32, grid.Height*grid.Width), + make([]int32, grid.Height*grid.Width), + make([]int32, grid.Height*grid.Width), + } - cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) - cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1)) - sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1)) + var cur int + for y := 0; y < grid.Height; y += m.spatialMergeSize { + for x := 0; x < grid.Width; x += m.spatialMergeSize { + for dy := range 2 { + for dx := range 2 { + i := int(index[cur/spatialMergeUnit]) * spatialMergeUnit + i += cur % spatialMergeUnit + s[0][i] = int32(y + dy) + s[1][i] = int32(x + dx) + s[2][i] = int32(y + dy) + s[3][i] = int32(x + dx) + cur++ + } + } + } + } + + return slices.Concat(s...) + }(), grid.Height*grid.Width*4) + + mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds) - mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads) // Apply encoder layers for i, layer := range m.Layers { if slices.Contains(m.fullAttnBlocks, int32(i)) { - hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions) + hiddenStates = layer.Forward(ctx, hiddenStates, positions, nil, m.VisionModelOptions) } else { hiddenStates = layer.Forward( ctx, hiddenStates, - cos, - sin, + positions, mask, m.VisionModelOptions, ) @@ -246,102 +252,43 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) } hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions) - reverseWindowIndex := windowIndex.Argsort(ctx) - return hiddenStates.Rows(ctx, reverseWindowIndex) + return hiddenStates.Rows(ctx, windowIndex) } -// WindowIndex divides the grid into windows and returns: -// 1. A tensor containing flattened indices of all grid points organized by windows +// windowIndex divides the grid into windows and returns: +// 1. A slice of grid point indices organized by windows // 2. A slice of boundaries that mark where each window's data begins and ends // in the flattened representation, scaled by spatialMergeSize squared // // The boundaries slice always starts with 0 and contains cumulative ending // positions for each window, allowing downstream processing to identify // window boundaries in the tensor data. -func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) { - vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize +func (m *VisionModel) windowIndex(grid *Grid) (index []int32, bounds []int) { + height := grid.Height / m.spatialMergeSize + width := grid.Width / m.spatialMergeSize + window := m.windowSize / m.patchSize / m.spatialMergeSize - llmGridH := grid.Height / m.spatialMergeSize - llmGridW := grid.Width / m.spatialMergeSize + index = make([]int32, height*width) - // Calculate window parameters - numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize))) - numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize))) + bounds = make([]int, 0, ((height+window-1)/window)*((width+window-1)/window)+1) + bounds = append(bounds, 0) - // Initialize index_new slice - var index []int32 - - // Initialize bounds with the first element as 0 - bounds := []int{0} - totalSeqLen := 0 - - // Process each window without padding - for wh := range numWindowsH { - for ww := range numWindowsW { - // Calculate window boundaries - hStart := wh * vitMergerWindowSize - wStart := ww * vitMergerWindowSize - hEnd := min(hStart+vitMergerWindowSize, llmGridH) - wEnd := min(wStart+vitMergerWindowSize, llmGridW) - - // Calculate sequence length for this window - seqLen := (hEnd - hStart) * (wEnd - wStart) - - // Collect indices for this window - for h := hStart; h < hEnd; h++ { - for w := wStart; w < wEnd; w++ { - index = append(index, int32(h*llmGridW+w)) + var cur int32 + for y := 0; y < height; y += window { + for x := 0; x < width; x += window { + h1 := min(window, height-y) + w1 := min(window, width-x) + for dy := range h1 { + for dx := range w1 { + win := (y+dy)*width + (x + dx) + index[win] = cur + cur++ } } - - totalSeqLen += seqLen - bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0]) + bounds = append(bounds, int(cur)*window) } } - - t := ctx.Input().FromInts(index, len(index)) - - return t, bounds -} - -// PositionalEmbedding generates rotary position embeddings for attention mechanisms -func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor { - dim := m.headDim / 2 - freq := dim / 2 - theta := float64(m.ropeTheta) - merge := m.spatialMergeSize - - // Create frequency patterns for position encoding - maxGridSize := max(grid.Height, grid.Width) - freqVals := make([]float32, freq*maxGridSize) - for i := range maxGridSize { - for j := range freq { - freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) - } - } - freqs := ctx.Input().FromFloats(freqVals, freq, maxGridSize) - - // Create position coordinates (y,x pairs) for the grid - // In PyTorch: Equivalent to generating position ids with torch.arange() - coords := make([]int32, 0, grid.Height*grid.Width*2) - for y := range grid.Height { - for x := range grid.Width { - coords = append(coords, int32(y), int32(x)) - } - } - pos := ctx.Input().FromInts(coords, 2, grid.Width, grid.Height) - - // Reshape and permute positions to match spatial merging pattern - pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) - pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge) - pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge) - - // Use position indices to look up corresponding frequency values - positionalEmbedding := freqs.Rows(ctx, pos) - positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2) - return positionalEmbedding + return index, bounds } // newVisionModel creates a new instance of the Qwen vision model diff --git a/model/models/qwen25vl/process_image.go b/model/models/qwen25vl/process_image.go index ce5ded295..66803d4a0 100644 --- a/model/models/qwen25vl/process_image.go +++ b/model/models/qwen25vl/process_image.go @@ -19,8 +19,8 @@ type ImageProcessor struct { maxPixels int factor int rescaleFactor float32 - imageMean []float32 - imageStd []float32 + imageMean [3]float32 + imageStd [3]float32 } // newImageProcessor creates a new image processor with default values @@ -34,11 +34,11 @@ func newImageProcessor(c fs.Config) ImageProcessor { temporalPatchSize: 2, mergeSize: mergeSize, minPixels: 56 * 56, - maxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit + maxPixels: int(c.Uint("vision.max_pixels", 2<<20)), // 2M limit factor: patchSize * mergeSize, rescaleFactor: 1.0 / 255.0, - imageMean: imageproc.ClipDefaultMean[:], - imageStd: imageproc.ClipDefaultSTD[:], + imageMean: imageproc.ClipDefaultMean, + imageStd: imageproc.ClipDefaultSTD, } } @@ -90,13 +90,7 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) // Resize image using existing functions resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear) - normalizedPixels := imageproc.Normalize( - resizedImg, - [3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]}, - [3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]}, - true, // rescale - true, // channelFirst - ) + normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true) // Calculate grid dimensions grid := &Grid{