mirror of https://github.com/ollama/ollama
fix: qwen2.5 vl rope (#13486)
* qwen25vl: bump max pixels * qwen25vl: mrope fix qwen2.5vl window * qwen25vl: vision rope
This commit is contained in:
parent
ffbe8e076d
commit
971d62595a
|
|
@ -1534,7 +1534,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
|
||||||
unsafe.SliceData(mropeSections),
|
unsafe.SliceData(mropeSections),
|
||||||
C.int(opts.Type),
|
C.int(opts.Type),
|
||||||
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
|
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
|
||||||
C.float(ropeBase), C.float(ropeScale),
|
C.float(ropeBase),
|
||||||
|
C.float(ropeScale),
|
||||||
C.float(opts.YaRN.ExtrapolationFactor),
|
C.float(opts.YaRN.ExtrapolationFactor),
|
||||||
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
|
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
|
||||||
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
|
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
|
||||||
|
|
@ -1546,9 +1547,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
|
||||||
dequant,
|
dequant,
|
||||||
positions.(*Tensor).t,
|
positions.(*Tensor).t,
|
||||||
opts.Factors.(*Tensor).t,
|
opts.Factors.(*Tensor).t,
|
||||||
C.int(ropeDim), C.int(opts.Type),
|
C.int(ropeDim),
|
||||||
|
C.int(opts.Type),
|
||||||
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
|
cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
|
||||||
C.float(ropeBase), C.float(ropeScale),
|
C.float(ropeBase),
|
||||||
|
C.float(ropeScale),
|
||||||
C.float(opts.YaRN.ExtrapolationFactor),
|
C.float(opts.YaRN.ExtrapolationFactor),
|
||||||
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
|
cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
|
||||||
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
|
cmp.Or(C.float(opts.YaRN.BetaFast), 32),
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,13 @@ func WithMRoPE(sections []int) func(*Options) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithVision(sections []int) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.Type |= 1<<3 | 1<<4
|
||||||
|
opts.MRoPE.Sections = sections
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithInterleaveMRoPE(sections []int) func(*Options) {
|
func WithInterleaveMRoPE(sections []int) func(*Options) {
|
||||||
return func(opts *Options) {
|
return func(opts *Options) {
|
||||||
opts.Type |= 1<<3 | 1<<5
|
opts.Type |= 1<<3 | 1<<5
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ package qwen25vl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"image"
|
"image"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
|
|
@ -33,7 +32,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
EOS: append(
|
EOS: append(
|
||||||
|
|
@ -54,19 +53,18 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
|
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
|
||||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
f32s, grid, err := m.ImageProcessor.ProcessImage(img)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate tensor dimensions
|
// Calculate tensor dimensions
|
||||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
|
||||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
|
||||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||||
|
|
||||||
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
||||||
|
|
@ -85,11 +83,13 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
}
|
}
|
||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
||||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
|
// Reset position cache
|
||||||
|
m.positionCache = m.positionCache[:0]
|
||||||
var result []*input.Input
|
var result []*input.Input
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
@ -98,40 +98,37 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
visionEndToken int32 = 151653
|
visionEndToken int32 = 151653
|
||||||
)
|
)
|
||||||
|
|
||||||
nImg := 0
|
appendInput := func(i *input.Input, p int) int {
|
||||||
|
result = append(result, i)
|
||||||
|
m.positionCache = append(m.positionCache, int32(p))
|
||||||
|
return p + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
var p int
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if inp.Multimodal == nil {
|
if inp.Multimodal == nil {
|
||||||
// If not a multimodal input, add it to the result unchanged
|
// If not a multimodal input, add it to the result unchanged
|
||||||
result = append(result, inp)
|
p = appendInput(inp, p)
|
||||||
} else {
|
} else {
|
||||||
// Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix
|
|
||||||
// the image tokens with a prompt, so we add a prefix here
|
|
||||||
nImg++
|
|
||||||
pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
|
||||||
}
|
|
||||||
for i := range pre {
|
|
||||||
result = append(result, &input.Input{Token: pre[i]})
|
|
||||||
}
|
|
||||||
|
|
||||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
|
||||||
|
|
||||||
// First add the vision start token
|
// First add the vision start token
|
||||||
result = append(result, &input.Input{Token: visionStartToken})
|
p = appendInput(&input.Input{Token: visionStartToken}, p)
|
||||||
|
|
||||||
// Add the image token with the multimodal tensor data at the first position
|
// Add the image token with the multimodal tensor data at the first position
|
||||||
result = append(result, &input.Input{
|
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
|
||||||
|
appendInput(&input.Input{
|
||||||
Token: imageToken,
|
Token: imageToken,
|
||||||
Multimodal: inp.Multimodal,
|
Multimodal: inp.Multimodal,
|
||||||
MultimodalHash: inp.MultimodalHash,
|
MultimodalHash: inp.MultimodalHash,
|
||||||
SameBatch: patchesPerChunk,
|
SameBatch: tokensPerGrid,
|
||||||
})
|
}, p)
|
||||||
|
|
||||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||||
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
for range tokensPerGrid - 1 {
|
||||||
|
appendInput(&input.Input{Token: imageToken}, p)
|
||||||
|
}
|
||||||
|
|
||||||
result = append(result, &input.Input{Token: visionEndToken})
|
grid := inp.Multimodal[0].Data.(*Grid)
|
||||||
|
p = appendInput(&input.Input{Token: visionEndToken}, p+max(grid.Width/m.spatialMergeSize, grid.Height/m.spatialMergeSize))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -139,9 +136,58 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
// Initial token embedding
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
positionSlice := func() [][]int32 {
|
||||||
|
s := [][]int32{
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
}
|
||||||
|
for i, position := range batch.Positions {
|
||||||
|
if position < int32(len(m.positionCache)) {
|
||||||
|
position = m.positionCache[position]
|
||||||
|
} else if len(m.positionCache) > 0 {
|
||||||
|
position = position - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
s[0][i] = position
|
||||||
|
s[1][i] = position
|
||||||
|
s[2][i] = position
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, mi := range batch.Multimodal {
|
||||||
|
img := mi.Multimodal[0].Tensor
|
||||||
|
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
||||||
|
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
|
||||||
|
for i := range img.Dim(1) {
|
||||||
|
w := grid.Width / m.spatialMergeSize
|
||||||
|
positionSlice[1][mi.Index+i] += int32(i / w)
|
||||||
|
positionSlice[2][mi.Index+i] += int32(i % w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
|
||||||
|
|
||||||
|
// Process through transformer layers
|
||||||
|
for i, layer := range m.TextModel.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
|
var lastLayerOutputs ml.Tensor
|
||||||
|
if i == len(m.TextModel.Layers)-1 {
|
||||||
|
lastLayerOutputs = batch.Outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
||||||
|
|
@ -8,20 +8,17 @@ import (
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TextOptions struct {
|
type TextOptions struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
ropeDim, originalContextLength int
|
ropeDim, originalContextLength int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
|
mropeSections []int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale,
|
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithMRoPE(o.mropeSections))
|
||||||
rope.WithOriginalContextLength(o.originalContextLength),
|
|
||||||
rope.WithTypeNeoX(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
|
|
@ -31,6 +28,7 @@ type TextModel struct {
|
||||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
*TextOptions
|
*TextOptions
|
||||||
|
positionCache []int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTextModel(c fs.Config) *TextModel {
|
func NewTextModel(c fs.Config) *TextModel {
|
||||||
|
|
@ -45,6 +43,14 @@ func NewTextModel(c fs.Config) *TextModel {
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
|
mropeSections: func() []int {
|
||||||
|
sections := c.Ints("rope.mrope_section")
|
||||||
|
s := make([]int, len(sections))
|
||||||
|
for i, section := range sections {
|
||||||
|
s[i] = int(section)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -84,6 +90,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
|
|
||||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
m.positionCache = nil
|
||||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -130,28 +137,3 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
|
||||||
// Initial token embedding
|
|
||||||
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
|
||||||
|
|
||||||
for _, mi := range batch.Multimodal {
|
|
||||||
img := mi.Multimodal[0].Tensor
|
|
||||||
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process through transformer layers
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
cache.SetLayer(i)
|
|
||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
|
||||||
if i == len(m.Layers)-1 {
|
|
||||||
lastLayerOutputs = outputs
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions)
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
|
||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -7,48 +7,28 @@ import (
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
)
|
)
|
||||||
|
|
||||||
// We only support batch size of 1
|
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor {
|
||||||
var batchSize int = 1
|
// Initialize a 2D mask with -Inf
|
||||||
|
s := make([][]float32, seqLength)
|
||||||
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
for i := range s {
|
||||||
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
|
s[i] = slices.Repeat([]float32{float32(math.Inf(-1))}, seqLength)
|
||||||
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
|
|
||||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
|
|
||||||
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
|
|
||||||
}
|
|
||||||
|
|
||||||
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
|
|
||||||
// Create a flat slice for the mask (all -inf initially to block all attention)
|
|
||||||
flat := make([]float32, seqLength*seqLength)
|
|
||||||
for i := range flat {
|
|
||||||
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill in the mask with zeros for tokens that CAN attend to each other
|
// Fill in the mask with zeros for tokens that CAN attend to each other
|
||||||
for i := 1; i < len(bounds); i++ {
|
for i := 1; i < len(bounds); i++ {
|
||||||
start := bounds[i-1]
|
start, end := bounds[i-1], bounds[i]
|
||||||
end := bounds[i]
|
// Enable attention within this sequence block
|
||||||
|
|
||||||
// Enable attention within this sequence block by setting values to 0
|
|
||||||
for row := start; row < end; row++ {
|
for row := start; row < end; row++ {
|
||||||
for col := start; col < end; col++ {
|
for col := start; col < end; col++ {
|
||||||
idx := row*seqLength + col
|
s[row][col] = 0.0
|
||||||
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mask := ctx.Input().FromFloats(flat, seqLength, seqLength)
|
return ctx.Input().FromFloats(slices.Concat(s...), seqLength, seqLength)
|
||||||
|
|
||||||
// Reshape to match [seqLength, seqLength, 1] for broadcasting
|
|
||||||
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
|
|
||||||
|
|
||||||
return mask
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type VisionSelfAttention struct {
|
type VisionSelfAttention struct {
|
||||||
|
|
@ -58,17 +38,17 @@ type VisionSelfAttention struct {
|
||||||
Output *nn.Linear `gguf:"attn_out"`
|
Output *nn.Linear `gguf:"attn_out"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
query := sa.Query.Forward(ctx, hiddenStates)
|
query := sa.Query.Forward(ctx, hiddenStates)
|
||||||
key := sa.Key.Forward(ctx, hiddenStates)
|
key := sa.Key.Forward(ctx, hiddenStates)
|
||||||
value := sa.Value.Forward(ctx, hiddenStates)
|
value := sa.Value.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
|
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1))
|
||||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1))
|
||||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1))
|
||||||
|
|
||||||
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||||
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||||
|
|
||||||
// Scale factor for scaled dot-product attention
|
// Scale factor for scaled dot-product attention
|
||||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||||
|
|
@ -77,6 +57,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
kq := key.MulmatFullPrec(ctx, query)
|
kq := key.MulmatFullPrec(ctx, query)
|
||||||
kq = kq.Scale(ctx, scale)
|
kq = kq.Scale(ctx, scale)
|
||||||
if mask != nil {
|
if mask != nil {
|
||||||
|
|
@ -85,7 +66,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
|
||||||
kq = kq.Softmax(ctx)
|
kq = kq.Softmax(ctx)
|
||||||
kqv := value.Mulmat(ctx, kq)
|
kqv := value.Mulmat(ctx, kq)
|
||||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, attention)
|
return sa.Output.Forward(ctx, attention)
|
||||||
}
|
}
|
||||||
|
|
@ -98,10 +79,7 @@ type VisionMLP struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
|
||||||
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
|
||||||
|
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -112,10 +90,10 @@ type VisionEncoderLayer struct {
|
||||||
MLP *VisionMLP
|
MLP *VisionMLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
residual := hiddenStates
|
residual := hiddenStates
|
||||||
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts)
|
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positions, mask, opts)
|
||||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
|
||||||
residual = hiddenStates
|
residual = hiddenStates
|
||||||
|
|
@ -139,6 +117,17 @@ type VisionModelOptions struct {
|
||||||
temporalPatchSize int
|
temporalPatchSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o VisionModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
|
return nn.RoPE(ctx, states, positions, o.headDim/2, o.ropeTheta, 1,
|
||||||
|
rope.WithVision([]int{
|
||||||
|
o.headDim / 4,
|
||||||
|
o.headDim / 4,
|
||||||
|
o.headDim / 4,
|
||||||
|
o.headDim / 4,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
type PatchEmbedding struct {
|
type PatchEmbedding struct {
|
||||||
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
|
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
|
||||||
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
|
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
|
||||||
|
|
@ -186,7 +175,7 @@ func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, op
|
||||||
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
|
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
|
||||||
|
|
||||||
// Reshape the normalized output to view the hidden size dimension
|
// Reshape the normalized output to view the hidden size dimension
|
||||||
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize)
|
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize))
|
||||||
hidden := pm.MLP0.Forward(ctx, reshaped)
|
hidden := pm.MLP0.Forward(ctx, reshaped)
|
||||||
activated := hidden.GELU(ctx)
|
activated := hidden.GELU(ctx)
|
||||||
|
|
||||||
|
|
@ -209,36 +198,53 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||||
// Extract patch embeddings
|
// Extract patch embeddings
|
||||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
|
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
|
||||||
|
|
||||||
positionEmbedding := m.PositionalEmbedding(ctx, grid)
|
index, bounds := m.windowIndex(grid)
|
||||||
|
|
||||||
windowIndex, bounds := m.WindowIndex(ctx, grid)
|
|
||||||
|
|
||||||
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
||||||
|
|
||||||
|
windowIndex := ctx.Input().FromInts(index, len(index))
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
|
||||||
hiddenStates = hiddenStates.Rows(ctx, windowIndex)
|
hiddenStates = hiddenStates.Rows(ctx, windowIndex.Argsort(ctx))
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
|
||||||
|
|
||||||
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit)
|
positions := ctx.Input().FromInts(func() []int32 {
|
||||||
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex)
|
s := [][]int32{
|
||||||
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit)
|
make([]int32, grid.Height*grid.Width),
|
||||||
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0)
|
make([]int32, grid.Height*grid.Width),
|
||||||
|
make([]int32, grid.Height*grid.Width),
|
||||||
|
make([]int32, grid.Height*grid.Width),
|
||||||
|
}
|
||||||
|
|
||||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
var cur int
|
||||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
for y := 0; y < grid.Height; y += m.spatialMergeSize {
|
||||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
for x := 0; x < grid.Width; x += m.spatialMergeSize {
|
||||||
|
for dy := range 2 {
|
||||||
|
for dx := range 2 {
|
||||||
|
i := int(index[cur/spatialMergeUnit]) * spatialMergeUnit
|
||||||
|
i += cur % spatialMergeUnit
|
||||||
|
s[0][i] = int32(y + dy)
|
||||||
|
s[1][i] = int32(x + dx)
|
||||||
|
s[2][i] = int32(y + dy)
|
||||||
|
s[3][i] = int32(x + dx)
|
||||||
|
cur++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return slices.Concat(s...)
|
||||||
|
}(), grid.Height*grid.Width*4)
|
||||||
|
|
||||||
|
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds)
|
||||||
|
|
||||||
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
|
|
||||||
// Apply encoder layers
|
// Apply encoder layers
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
if slices.Contains(m.fullAttnBlocks, int32(i)) {
|
if slices.Contains(m.fullAttnBlocks, int32(i)) {
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, nil, m.VisionModelOptions)
|
||||||
} else {
|
} else {
|
||||||
hiddenStates = layer.Forward(
|
hiddenStates = layer.Forward(
|
||||||
ctx,
|
ctx,
|
||||||
hiddenStates,
|
hiddenStates,
|
||||||
cos,
|
positions,
|
||||||
sin,
|
|
||||||
mask,
|
mask,
|
||||||
m.VisionModelOptions,
|
m.VisionModelOptions,
|
||||||
)
|
)
|
||||||
|
|
@ -246,102 +252,43 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
|
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
|
||||||
reverseWindowIndex := windowIndex.Argsort(ctx)
|
return hiddenStates.Rows(ctx, windowIndex)
|
||||||
return hiddenStates.Rows(ctx, reverseWindowIndex)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WindowIndex divides the grid into windows and returns:
|
// windowIndex divides the grid into windows and returns:
|
||||||
// 1. A tensor containing flattened indices of all grid points organized by windows
|
// 1. A slice of grid point indices organized by windows
|
||||||
// 2. A slice of boundaries that mark where each window's data begins and ends
|
// 2. A slice of boundaries that mark where each window's data begins and ends
|
||||||
// in the flattened representation, scaled by spatialMergeSize squared
|
// in the flattened representation, scaled by spatialMergeSize squared
|
||||||
//
|
//
|
||||||
// The boundaries slice always starts with 0 and contains cumulative ending
|
// The boundaries slice always starts with 0 and contains cumulative ending
|
||||||
// positions for each window, allowing downstream processing to identify
|
// positions for each window, allowing downstream processing to identify
|
||||||
// window boundaries in the tensor data.
|
// window boundaries in the tensor data.
|
||||||
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
|
func (m *VisionModel) windowIndex(grid *Grid) (index []int32, bounds []int) {
|
||||||
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
height := grid.Height / m.spatialMergeSize
|
||||||
|
width := grid.Width / m.spatialMergeSize
|
||||||
|
window := m.windowSize / m.patchSize / m.spatialMergeSize
|
||||||
|
|
||||||
llmGridH := grid.Height / m.spatialMergeSize
|
index = make([]int32, height*width)
|
||||||
llmGridW := grid.Width / m.spatialMergeSize
|
|
||||||
|
|
||||||
// Calculate window parameters
|
bounds = make([]int, 0, ((height+window-1)/window)*((width+window-1)/window)+1)
|
||||||
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize)))
|
bounds = append(bounds, 0)
|
||||||
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
|
|
||||||
|
|
||||||
// Initialize index_new slice
|
var cur int32
|
||||||
var index []int32
|
for y := 0; y < height; y += window {
|
||||||
|
for x := 0; x < width; x += window {
|
||||||
// Initialize bounds with the first element as 0
|
h1 := min(window, height-y)
|
||||||
bounds := []int{0}
|
w1 := min(window, width-x)
|
||||||
totalSeqLen := 0
|
for dy := range h1 {
|
||||||
|
for dx := range w1 {
|
||||||
// Process each window without padding
|
win := (y+dy)*width + (x + dx)
|
||||||
for wh := range numWindowsH {
|
index[win] = cur
|
||||||
for ww := range numWindowsW {
|
cur++
|
||||||
// Calculate window boundaries
|
|
||||||
hStart := wh * vitMergerWindowSize
|
|
||||||
wStart := ww * vitMergerWindowSize
|
|
||||||
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
|
|
||||||
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
|
|
||||||
|
|
||||||
// Calculate sequence length for this window
|
|
||||||
seqLen := (hEnd - hStart) * (wEnd - wStart)
|
|
||||||
|
|
||||||
// Collect indices for this window
|
|
||||||
for h := hStart; h < hEnd; h++ {
|
|
||||||
for w := wStart; w < wEnd; w++ {
|
|
||||||
index = append(index, int32(h*llmGridW+w))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
bounds = append(bounds, int(cur)*window)
|
||||||
totalSeqLen += seqLen
|
|
||||||
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return index, bounds
|
||||||
t := ctx.Input().FromInts(index, len(index))
|
|
||||||
|
|
||||||
return t, bounds
|
|
||||||
}
|
|
||||||
|
|
||||||
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
|
|
||||||
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
|
|
||||||
dim := m.headDim / 2
|
|
||||||
freq := dim / 2
|
|
||||||
theta := float64(m.ropeTheta)
|
|
||||||
merge := m.spatialMergeSize
|
|
||||||
|
|
||||||
// Create frequency patterns for position encoding
|
|
||||||
maxGridSize := max(grid.Height, grid.Width)
|
|
||||||
freqVals := make([]float32, freq*maxGridSize)
|
|
||||||
for i := range maxGridSize {
|
|
||||||
for j := range freq {
|
|
||||||
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
freqs := ctx.Input().FromFloats(freqVals, freq, maxGridSize)
|
|
||||||
|
|
||||||
// Create position coordinates (y,x pairs) for the grid
|
|
||||||
// In PyTorch: Equivalent to generating position ids with torch.arange()
|
|
||||||
coords := make([]int32, 0, grid.Height*grid.Width*2)
|
|
||||||
for y := range grid.Height {
|
|
||||||
for x := range grid.Width {
|
|
||||||
coords = append(coords, int32(y), int32(x))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pos := ctx.Input().FromInts(coords, 2, grid.Width, grid.Height)
|
|
||||||
|
|
||||||
// Reshape and permute positions to match spatial merging pattern
|
|
||||||
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
|
|
||||||
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
|
|
||||||
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge)
|
|
||||||
|
|
||||||
// Use position indices to look up corresponding frequency values
|
|
||||||
positionalEmbedding := freqs.Rows(ctx, pos)
|
|
||||||
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
|
|
||||||
return positionalEmbedding
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newVisionModel creates a new instance of the Qwen vision model
|
// newVisionModel creates a new instance of the Qwen vision model
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@ type ImageProcessor struct {
|
||||||
maxPixels int
|
maxPixels int
|
||||||
factor int
|
factor int
|
||||||
rescaleFactor float32
|
rescaleFactor float32
|
||||||
imageMean []float32
|
imageMean [3]float32
|
||||||
imageStd []float32
|
imageStd [3]float32
|
||||||
}
|
}
|
||||||
|
|
||||||
// newImageProcessor creates a new image processor with default values
|
// newImageProcessor creates a new image processor with default values
|
||||||
|
|
@ -34,11 +34,11 @@ func newImageProcessor(c fs.Config) ImageProcessor {
|
||||||
temporalPatchSize: 2,
|
temporalPatchSize: 2,
|
||||||
mergeSize: mergeSize,
|
mergeSize: mergeSize,
|
||||||
minPixels: 56 * 56,
|
minPixels: 56 * 56,
|
||||||
maxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit
|
maxPixels: int(c.Uint("vision.max_pixels", 2<<20)), // 2M limit
|
||||||
factor: patchSize * mergeSize,
|
factor: patchSize * mergeSize,
|
||||||
rescaleFactor: 1.0 / 255.0,
|
rescaleFactor: 1.0 / 255.0,
|
||||||
imageMean: imageproc.ClipDefaultMean[:],
|
imageMean: imageproc.ClipDefaultMean,
|
||||||
imageStd: imageproc.ClipDefaultSTD[:],
|
imageStd: imageproc.ClipDefaultSTD,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -90,13 +90,7 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error)
|
||||||
// Resize image using existing functions
|
// Resize image using existing functions
|
||||||
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
||||||
|
|
||||||
normalizedPixels := imageproc.Normalize(
|
normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
|
||||||
resizedImg,
|
|
||||||
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
|
|
||||||
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
|
|
||||||
true, // rescale
|
|
||||||
true, // channelFirst
|
|
||||||
)
|
|
||||||
|
|
||||||
// Calculate grid dimensions
|
// Calculate grid dimensions
|
||||||
grid := &Grid{
|
grid := &Grid{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue