fix: qwen2.5 vl rope (#13486)

* qwen25vl: bump max pixels

* qwen25vl: mrope

fix qwen2.5vl window

* qwen25vl: vision rope
This commit is contained in:
Michael Yang 2025-12-15 17:30:33 -08:00 committed by GitHub
parent ffbe8e076d
commit 971d62595a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 195 additions and 216 deletions

View File

@ -1534,7 +1534,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
unsafe.SliceData(mropeSections), unsafe.SliceData(mropeSections),
C.int(opts.Type), C.int(opts.Type),
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10), cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
C.float(ropeBase), C.float(ropeScale), C.float(ropeBase),
C.float(ropeScale),
C.float(opts.YaRN.ExtrapolationFactor), C.float(opts.YaRN.ExtrapolationFactor),
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1), cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
cmp.Or(C.float(opts.YaRN.BetaFast), 32), cmp.Or(C.float(opts.YaRN.BetaFast), 32),
@ -1546,9 +1547,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
dequant, dequant,
positions.(*Tensor).t, positions.(*Tensor).t,
opts.Factors.(*Tensor).t, opts.Factors.(*Tensor).t,
C.int(ropeDim), C.int(opts.Type), C.int(ropeDim),
C.int(opts.Type),
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10), cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
C.float(ropeBase), C.float(ropeScale), C.float(ropeBase),
C.float(ropeScale),
C.float(opts.YaRN.ExtrapolationFactor), C.float(opts.YaRN.ExtrapolationFactor),
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1), cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
cmp.Or(C.float(opts.YaRN.BetaFast), 32), cmp.Or(C.float(opts.YaRN.BetaFast), 32),

View File

@ -77,6 +77,13 @@ func WithMRoPE(sections []int) func(*Options) {
} }
} }
func WithVision(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1<<3 | 1<<4
opts.MRoPE.Sections = sections
}
}
func WithInterleaveMRoPE(sections []int) func(*Options) { func WithInterleaveMRoPE(sections []int) func(*Options) {
return func(opts *Options) { return func(opts *Options) {
opts.Type |= 1<<3 | 1<<5 opts.Type |= 1<<3 | 1<<5

View File

@ -2,7 +2,6 @@ package qwen25vl
import ( import (
"bytes" "bytes"
"fmt"
"image" "image"
"slices" "slices"
@ -33,7 +32,7 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append( EOS: append(
@ -54,19 +53,18 @@ func New(c fs.Config) (model.Model, error) {
} }
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) { func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
image, _, err := image.Decode(bytes.NewReader(multimodalData)) img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
f32s, grid, err := m.ImageProcessor.ProcessImage(image) f32s, grid, err := m.ImageProcessor.ProcessImage(img)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Calculate tensor dimensions // Calculate tensor dimensions
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize * patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches) pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
@ -85,11 +83,13 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
} }
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid) visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
return []input.Multimodal{{Tensor: visionOutputs}}, nil return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
} }
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
// Reset position cache
m.positionCache = m.positionCache[:0]
var result []*input.Input var result []*input.Input
var ( var (
@ -98,40 +98,37 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
visionEndToken int32 = 151653 visionEndToken int32 = 151653
) )
nImg := 0 appendInput := func(i *input.Input, p int) int {
result = append(result, i)
m.positionCache = append(m.positionCache, int32(p))
return p + 1
}
var p int
for _, inp := range inputs { for _, inp := range inputs {
if inp.Multimodal == nil { if inp.Multimodal == nil {
// If not a multimodal input, add it to the result unchanged // If not a multimodal input, add it to the result unchanged
result = append(result, inp) p = appendInput(inp, p)
} else { } else {
// Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix
// the image tokens with a prompt, so we add a prefix here
nImg++
pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true)
if err != nil {
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
}
for i := range pre {
result = append(result, &input.Input{Token: pre[i]})
}
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
// First add the vision start token // First add the vision start token
result = append(result, &input.Input{Token: visionStartToken}) p = appendInput(&input.Input{Token: visionStartToken}, p)
// Add the image token with the multimodal tensor data at the first position // Add the image token with the multimodal tensor data at the first position
result = append(result, &input.Input{ tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
appendInput(&input.Input{
Token: imageToken, Token: imageToken,
Multimodal: inp.Multimodal, Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash, MultimodalHash: inp.MultimodalHash,
SameBatch: patchesPerChunk, SameBatch: tokensPerGrid,
}) }, p)
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...) for range tokensPerGrid - 1 {
appendInput(&input.Input{Token: imageToken}, p)
}
result = append(result, &input.Input{Token: visionEndToken}) grid := inp.Multimodal[0].Data.(*Grid)
p = appendInput(&input.Input{Token: visionEndToken}, p+max(grid.Width/m.spatialMergeSize, grid.Height/m.spatialMergeSize))
} }
} }
@ -139,9 +136,58 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) // Initial token embedding
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache) positionSlice := func() [][]int32 {
s := [][]int32{
make([]int32, len(batch.Positions)),
make([]int32, len(batch.Positions)),
make([]int32, len(batch.Positions)),
make([]int32, len(batch.Positions)),
}
for i, position := range batch.Positions {
if position < int32(len(m.positionCache)) {
position = m.positionCache[position]
} else if len(m.positionCache) > 0 {
position = position - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
}
s[0][i] = position
s[1][i] = position
s[2][i] = position
}
return s
}()
for _, mi := range batch.Multimodal {
img := mi.Multimodal[0].Tensor
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
for i := range img.Dim(1) {
w := grid.Width / m.spatialMergeSize
positionSlice[1][mi.Index+i] += int32(i / w)
positionSlice[2][mi.Index+i] += int32(i % w)
}
}
}
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
// Process through transformer layers
for i, layer := range m.TextModel.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.TextModel.Layers)-1 {
lastLayerOutputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
return m.Output.Forward(ctx, hiddenStates), nil
} }
func init() { func init() {

View File

@ -8,20 +8,17 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
) )
type TextOptions struct { type TextOptions struct {
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
ropeDim, originalContextLength int ropeDim, originalContextLength int
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
mropeSections []int
} }
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithMRoPE(o.mropeSections))
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithTypeNeoX(),
)
} }
type TextModel struct { type TextModel struct {
@ -31,6 +28,7 @@ type TextModel struct {
Output *nn.Linear `gguf:"output,alt:token_embd"` Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions *TextOptions
positionCache []int32
} }
func NewTextModel(c fs.Config) *TextModel { func NewTextModel(c fs.Config) *TextModel {
@ -45,6 +43,14 @@ func NewTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.scaling.factor", 1),
mropeSections: func() []int {
sections := c.Ints("rope.mrope_section")
s := make([]int, len(sections))
for i, section := range sections {
s[i] = int(section)
}
return s
}(),
}, },
} }
@ -84,6 +90,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching // Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
} }
@ -130,28 +137,3 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
hiddenState = l.MLP.Forward(ctx, hiddenState, opts) hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
// Initial token embedding
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal {
img := mi.Multimodal[0].Tensor
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
}
// Process through transformer layers
for i, layer := range m.Layers {
cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}

View File

@ -7,48 +7,28 @@ import (
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
) )
// We only support batch size of 1 func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor {
var batchSize int = 1 // Initialize a 2D mask with -Inf
s := make([][]float32, seqLength)
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { for i := range s {
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1) s[i] = slices.Repeat([]float32{float32(math.Inf(-1))}, seqLength)
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
}
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
// Create a flat slice for the mask (all -inf initially to block all attention)
flat := make([]float32, seqLength*seqLength)
for i := range flat {
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
} }
// Fill in the mask with zeros for tokens that CAN attend to each other // Fill in the mask with zeros for tokens that CAN attend to each other
for i := 1; i < len(bounds); i++ { for i := 1; i < len(bounds); i++ {
start := bounds[i-1] start, end := bounds[i-1], bounds[i]
end := bounds[i] // Enable attention within this sequence block
// Enable attention within this sequence block by setting values to 0
for row := start; row < end; row++ { for row := start; row < end; row++ {
for col := start; col < end; col++ { for col := start; col < end; col++ {
idx := row*seqLength + col s[row][col] = 0.0
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
} }
} }
} }
mask := ctx.Input().FromFloats(flat, seqLength, seqLength) return ctx.Input().FromFloats(slices.Concat(s...), seqLength, seqLength)
// Reshape to match [seqLength, seqLength, 1] for broadcasting
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
return mask
} }
type VisionSelfAttention struct { type VisionSelfAttention struct {
@ -58,17 +38,17 @@ type VisionSelfAttention struct {
Output *nn.Linear `gguf:"attn_out"` Output *nn.Linear `gguf:"attn_out"`
} }
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates) query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates) key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates) value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize) query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1))
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1))
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1))
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
// Scale factor for scaled dot-product attention // Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim)) scale := 1.0 / math.Sqrt(float64(opts.headDim))
@ -77,6 +57,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
query = query.Permute(ctx, 0, 2, 1, 3) query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query) kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale) kq = kq.Scale(ctx, scale)
if mask != nil { if mask != nil {
@ -85,7 +66,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
kq = kq.Softmax(ctx) kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq) kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
return sa.Output.Forward(ctx, attention) return sa.Output.Forward(ctx, attention)
} }
@ -98,10 +79,7 @@ type VisionMLP struct {
} }
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using activation as specified in config (likely GELU or SiLU/Swish) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }
@ -112,10 +90,10 @@ type VisionEncoderLayer struct {
MLP *VisionMLP MLP *VisionMLP
} }
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts) hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positions, mask, opts)
hiddenStates = hiddenStates.Add(ctx, residual) hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates residual = hiddenStates
@ -139,6 +117,17 @@ type VisionModelOptions struct {
temporalPatchSize int temporalPatchSize int
} }
func (o VisionModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.headDim/2, o.ropeTheta, 1,
rope.WithVision([]int{
o.headDim / 4,
o.headDim / 4,
o.headDim / 4,
o.headDim / 4,
}),
)
}
type PatchEmbedding struct { type PatchEmbedding struct {
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"` PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"` PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
@ -186,7 +175,7 @@ func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, op
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize) hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
// Reshape the normalized output to view the hidden size dimension // Reshape the normalized output to view the hidden size dimension
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize) reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize))
hidden := pm.MLP0.Forward(ctx, reshaped) hidden := pm.MLP0.Forward(ctx, reshaped)
activated := hidden.GELU(ctx) activated := hidden.GELU(ctx)
@ -209,36 +198,53 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
// Extract patch embeddings // Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions) hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
positionEmbedding := m.PositionalEmbedding(ctx, grid) index, bounds := m.windowIndex(grid)
windowIndex, bounds := m.WindowIndex(ctx, grid)
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
windowIndex := ctx.Input().FromInts(index, len(index))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
hiddenStates = hiddenStates.Rows(ctx, windowIndex) hiddenStates = hiddenStates.Rows(ctx, windowIndex.Argsort(ctx))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit) positions := ctx.Input().FromInts(func() []int32 {
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex) s := [][]int32{
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit) make([]int32, grid.Height*grid.Width),
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0) make([]int32, grid.Height*grid.Width),
make([]int32, grid.Height*grid.Width),
make([]int32, grid.Height*grid.Width),
}
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) var cur int
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1)) for y := 0; y < grid.Height; y += m.spatialMergeSize {
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1)) for x := 0; x < grid.Width; x += m.spatialMergeSize {
for dy := range 2 {
for dx := range 2 {
i := int(index[cur/spatialMergeUnit]) * spatialMergeUnit
i += cur % spatialMergeUnit
s[0][i] = int32(y + dy)
s[1][i] = int32(x + dx)
s[2][i] = int32(y + dy)
s[3][i] = int32(x + dx)
cur++
}
}
}
}
return slices.Concat(s...)
}(), grid.Height*grid.Width*4)
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds)
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
// Apply encoder layers // Apply encoder layers
for i, layer := range m.Layers { for i, layer := range m.Layers {
if slices.Contains(m.fullAttnBlocks, int32(i)) { if slices.Contains(m.fullAttnBlocks, int32(i)) {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions) hiddenStates = layer.Forward(ctx, hiddenStates, positions, nil, m.VisionModelOptions)
} else { } else {
hiddenStates = layer.Forward( hiddenStates = layer.Forward(
ctx, ctx,
hiddenStates, hiddenStates,
cos, positions,
sin,
mask, mask,
m.VisionModelOptions, m.VisionModelOptions,
) )
@ -246,102 +252,43 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
} }
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions) hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
reverseWindowIndex := windowIndex.Argsort(ctx) return hiddenStates.Rows(ctx, windowIndex)
return hiddenStates.Rows(ctx, reverseWindowIndex)
} }
// WindowIndex divides the grid into windows and returns: // windowIndex divides the grid into windows and returns:
// 1. A tensor containing flattened indices of all grid points organized by windows // 1. A slice of grid point indices organized by windows
// 2. A slice of boundaries that mark where each window's data begins and ends // 2. A slice of boundaries that mark where each window's data begins and ends
// in the flattened representation, scaled by spatialMergeSize squared // in the flattened representation, scaled by spatialMergeSize squared
// //
// The boundaries slice always starts with 0 and contains cumulative ending // The boundaries slice always starts with 0 and contains cumulative ending
// positions for each window, allowing downstream processing to identify // positions for each window, allowing downstream processing to identify
// window boundaries in the tensor data. // window boundaries in the tensor data.
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) { func (m *VisionModel) windowIndex(grid *Grid) (index []int32, bounds []int) {
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize height := grid.Height / m.spatialMergeSize
width := grid.Width / m.spatialMergeSize
window := m.windowSize / m.patchSize / m.spatialMergeSize
llmGridH := grid.Height / m.spatialMergeSize index = make([]int32, height*width)
llmGridW := grid.Width / m.spatialMergeSize
// Calculate window parameters bounds = make([]int, 0, ((height+window-1)/window)*((width+window-1)/window)+1)
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize))) bounds = append(bounds, 0)
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
// Initialize index_new slice var cur int32
var index []int32 for y := 0; y < height; y += window {
for x := 0; x < width; x += window {
// Initialize bounds with the first element as 0 h1 := min(window, height-y)
bounds := []int{0} w1 := min(window, width-x)
totalSeqLen := 0 for dy := range h1 {
for dx := range w1 {
// Process each window without padding win := (y+dy)*width + (x + dx)
for wh := range numWindowsH { index[win] = cur
for ww := range numWindowsW { cur++
// Calculate window boundaries
hStart := wh * vitMergerWindowSize
wStart := ww * vitMergerWindowSize
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
// Calculate sequence length for this window
seqLen := (hEnd - hStart) * (wEnd - wStart)
// Collect indices for this window
for h := hStart; h < hEnd; h++ {
for w := wStart; w < wEnd; w++ {
index = append(index, int32(h*llmGridW+w))
} }
} }
bounds = append(bounds, int(cur)*window)
totalSeqLen += seqLen
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
} }
} }
return index, bounds
t := ctx.Input().FromInts(index, len(index))
return t, bounds
}
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
dim := m.headDim / 2
freq := dim / 2
theta := float64(m.ropeTheta)
merge := m.spatialMergeSize
// Create frequency patterns for position encoding
maxGridSize := max(grid.Height, grid.Width)
freqVals := make([]float32, freq*maxGridSize)
for i := range maxGridSize {
for j := range freq {
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
}
}
freqs := ctx.Input().FromFloats(freqVals, freq, maxGridSize)
// Create position coordinates (y,x pairs) for the grid
// In PyTorch: Equivalent to generating position ids with torch.arange()
coords := make([]int32, 0, grid.Height*grid.Width*2)
for y := range grid.Height {
for x := range grid.Width {
coords = append(coords, int32(y), int32(x))
}
}
pos := ctx.Input().FromInts(coords, 2, grid.Width, grid.Height)
// Reshape and permute positions to match spatial merging pattern
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge)
// Use position indices to look up corresponding frequency values
positionalEmbedding := freqs.Rows(ctx, pos)
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
return positionalEmbedding
} }
// newVisionModel creates a new instance of the Qwen vision model // newVisionModel creates a new instance of the Qwen vision model

View File

@ -19,8 +19,8 @@ type ImageProcessor struct {
maxPixels int maxPixels int
factor int factor int
rescaleFactor float32 rescaleFactor float32
imageMean []float32 imageMean [3]float32
imageStd []float32 imageStd [3]float32
} }
// newImageProcessor creates a new image processor with default values // newImageProcessor creates a new image processor with default values
@ -34,11 +34,11 @@ func newImageProcessor(c fs.Config) ImageProcessor {
temporalPatchSize: 2, temporalPatchSize: 2,
mergeSize: mergeSize, mergeSize: mergeSize,
minPixels: 56 * 56, minPixels: 56 * 56,
maxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit maxPixels: int(c.Uint("vision.max_pixels", 2<<20)), // 2M limit
factor: patchSize * mergeSize, factor: patchSize * mergeSize,
rescaleFactor: 1.0 / 255.0, rescaleFactor: 1.0 / 255.0,
imageMean: imageproc.ClipDefaultMean[:], imageMean: imageproc.ClipDefaultMean,
imageStd: imageproc.ClipDefaultSTD[:], imageStd: imageproc.ClipDefaultSTD,
} }
} }
@ -90,13 +90,7 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error)
// Resize image using existing functions // Resize image using existing functions
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear) resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
normalizedPixels := imageproc.Normalize( normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
resizedImg,
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
true, // rescale
true, // channelFirst
)
// Calculate grid dimensions // Calculate grid dimensions
grid := &Grid{ grid := &Grid{