qwen25vl: vision rope

This commit is contained in:
Michael Yang 2025-12-15 15:32:01 -08:00
parent 5fb395f9f9
commit d47f63e425
1 changed files with 88 additions and 141 deletions

View File

@ -7,48 +7,28 @@ import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
// We only support batch size of 1
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
}
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
// Create a flat slice for the mask (all -inf initially to block all attention)
flat := make([]float32, seqLength*seqLength)
for i := range flat {
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor {
// Initialize a 2D mask with -Inf
s := make([][]float32, seqLength)
for i := range s {
s[i] = slices.Repeat([]float32{float32(math.Inf(-1))}, seqLength)
}
// Fill in the mask with zeros for tokens that CAN attend to each other
for i := 1; i < len(bounds); i++ {
start := bounds[i-1]
end := bounds[i]
// Enable attention within this sequence block by setting values to 0
start, end := bounds[i-1], bounds[i]
// Enable attention within this sequence block
for row := start; row < end; row++ {
for col := start; col < end; col++ {
idx := row*seqLength + col
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
s[row][col] = 0.0
}
}
}
mask := ctx.Input().FromFloats(flat, seqLength, seqLength)
// Reshape to match [seqLength, seqLength, 1] for broadcasting
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
return mask
return ctx.Input().FromFloats(slices.Concat(s...), seqLength, seqLength)
}
type VisionSelfAttention struct {
@ -58,17 +38,17 @@ type VisionSelfAttention struct {
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1))
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1))
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1))
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
@ -77,6 +57,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
@ -85,7 +66,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
return sa.Output.Forward(ctx, attention)
}
@ -98,10 +79,7 @@ type VisionMLP struct {
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
@ -112,10 +90,10 @@ type VisionEncoderLayer struct {
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positions, mask, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
@ -139,6 +117,17 @@ type VisionModelOptions struct {
temporalPatchSize int
}
func (o VisionModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.headDim/2, o.ropeTheta, 1,
rope.WithVision([]int{
o.headDim / 4,
o.headDim / 4,
o.headDim / 4,
o.headDim / 4,
}),
)
}
type PatchEmbedding struct {
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
@ -186,7 +175,7 @@ func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, op
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
// Reshape the normalized output to view the hidden size dimension
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize)
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize))
hidden := pm.MLP0.Forward(ctx, reshaped)
activated := hidden.GELU(ctx)
@ -209,36 +198,53 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
// Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
positionEmbedding := m.PositionalEmbedding(ctx, grid)
windowIndex, bounds := m.WindowIndex(ctx, grid)
index, bounds := m.windowIndex(grid)
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
windowIndex := ctx.Input().FromInts(index, len(index))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
hiddenStates = hiddenStates.Rows(ctx, windowIndex)
hiddenStates = hiddenStates.Rows(ctx, windowIndex.Argsort(ctx))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit)
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex)
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit)
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0)
positions := ctx.Input().FromInts(func() []int32 {
s := [][]int32{
make([]int32, grid.Height*grid.Width),
make([]int32, grid.Height*grid.Width),
make([]int32, grid.Height*grid.Width),
make([]int32, grid.Height*grid.Width),
}
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
var cur int
for y := 0; y < grid.Height; y += m.spatialMergeSize {
for x := 0; x < grid.Width; x += m.spatialMergeSize {
for dy := range 2 {
for dx := range 2 {
i := int(index[cur/spatialMergeUnit]) * spatialMergeUnit
i += cur % spatialMergeUnit
s[0][i] = int32(y + dy)
s[1][i] = int32(x + dx)
s[2][i] = int32(y + dy)
s[3][i] = int32(x + dx)
cur++
}
}
}
}
return slices.Concat(s...)
}(), grid.Height*grid.Width*4)
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds)
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
// Apply encoder layers
for i, layer := range m.Layers {
if slices.Contains(m.fullAttnBlocks, int32(i)) {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
hiddenStates = layer.Forward(ctx, hiddenStates, positions, nil, m.VisionModelOptions)
} else {
hiddenStates = layer.Forward(
ctx,
hiddenStates,
cos,
sin,
positions,
mask,
m.VisionModelOptions,
)
@ -246,102 +252,43 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
}
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
reverseWindowIndex := windowIndex.Argsort(ctx)
return hiddenStates.Rows(ctx, reverseWindowIndex)
return hiddenStates.Rows(ctx, windowIndex)
}
// WindowIndex divides the grid into windows and returns:
// 1. A tensor containing flattened indices of all grid points organized by windows
// windowIndex divides the grid into windows and returns:
// 1. A slice of grid point indices organized by windows
// 2. A slice of boundaries that mark where each window's data begins and ends
// in the flattened representation, scaled by spatialMergeSize squared
//
// The boundaries slice always starts with 0 and contains cumulative ending
// positions for each window, allowing downstream processing to identify
// window boundaries in the tensor data.
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
func (m *VisionModel) windowIndex(grid *Grid) (index []int32, bounds []int) {
height := grid.Height / m.spatialMergeSize
width := grid.Width / m.spatialMergeSize
window := m.windowSize / m.patchSize / m.spatialMergeSize
llmGridH := grid.Height / m.spatialMergeSize
llmGridW := grid.Width / m.spatialMergeSize
index = make([]int32, height*width)
// Calculate window parameters
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize)))
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
bounds = make([]int, 0, ((height+window-1)/window)*((width+window-1)/window)+1)
bounds = append(bounds, 0)
// Initialize index_new slice
var index []int32
// Initialize bounds with the first element as 0
bounds := []int{0}
totalSeqLen := 0
// Process each window without padding
for wh := range numWindowsH {
for ww := range numWindowsW {
// Calculate window boundaries
hStart := wh * vitMergerWindowSize
wStart := ww * vitMergerWindowSize
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
// Calculate sequence length for this window
seqLen := (hEnd - hStart) * (wEnd - wStart)
// Collect indices for this window
for h := hStart; h < hEnd; h++ {
for w := wStart; w < wEnd; w++ {
index = append(index, int32(h*llmGridW+w))
var cur int32
for y := 0; y < height; y += window {
for x := 0; x < width; x += window {
h1 := min(window, height-y)
w1 := min(window, width-x)
for dy := range h1 {
for dx := range w1 {
win := (y+dy)*width + (x + dx)
index[win] = cur
cur++
}
}
totalSeqLen += seqLen
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
bounds = append(bounds, int(cur)*window)
}
}
t := ctx.Input().FromInts(index, len(index))
return t, bounds
}
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
dim := m.headDim / 2
freq := dim / 2
theta := float64(m.ropeTheta)
merge := m.spatialMergeSize
// Create frequency patterns for position encoding
maxGridSize := max(grid.Height, grid.Width)
freqVals := make([]float32, freq*maxGridSize)
for i := range maxGridSize {
for j := range freq {
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
}
}
freqs := ctx.Input().FromFloats(freqVals, freq, maxGridSize)
// Create position coordinates (y,x pairs) for the grid
// In PyTorch: Equivalent to generating position ids with torch.arange()
coords := make([]int32, 0, grid.Height*grid.Width*2)
for y := range grid.Height {
for x := range grid.Width {
coords = append(coords, int32(y), int32(x))
}
}
pos := ctx.Input().FromInts(coords, 2, grid.Width, grid.Height)
// Reshape and permute positions to match spatial merging pattern
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge)
// Use position indices to look up corresponding frequency values
positionalEmbedding := freqs.Rows(ctx, pos)
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
return positionalEmbedding
return index, bounds
}
// newVisionModel creates a new instance of the Qwen vision model