diff --git a/convert/convert_gptoss.go b/convert/convert_gptoss.go index 5338df21a..d7bfb361d 100644 --- a/convert/convert_gptoss.go +++ b/convert/convert_gptoss.go @@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { for name, mxfp4 := range mxfp4s { dims := mxfp4.blocks.Shape() + if !strings.HasSuffix(name, ".weight") { + name = name + ".weight" + } if strings.Contains(name, "ffn_down_exps") { out = append(out, &ggml.Tensor{ - Name: name + ".weight", + Name: name, Kind: uint32(ggml.TensorTypeMXFP4), Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2}, WriterTo: mxfp4, @@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { // gate_up_exps is interleaved, need to split into gate_exps and up_exps // e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...] out = append(out, &ggml.Tensor{ - Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight", + Name: strings.Replace(name, "gate_up", "gate", 1), Kind: uint32(ggml.TensorTypeMXFP4), Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2), }, &ggml.Tensor{ - Name: strings.Replace(name, "gate_up", "up", 1) + ".weight", + Name: strings.Replace(name, "gate_up", "up", 1), Kind: uint32(ggml.TensorTypeMXFP4), Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2), diff --git a/ml/backend.go b/ml/backend.go index c6fadb7f9..36557e62e 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -146,7 +146,6 @@ type Tensor interface { FromFloats([]float32) FromInts([]int32) - Neg(ctx Context) Tensor Add(ctx Context, t2 Tensor) Tensor Sub(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor @@ -185,7 +184,6 @@ type Tensor interface { View(ctx Context, offset int, shape ...int) Tensor Permute(ctx Context, shape ...int) Tensor Contiguous(ctx Context, shape ...int) Tensor - Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor Pad(ctx Context, shape ...int) Tensor @@ -209,7 +207,6 @@ type Tensor interface { Stddev(ctx Context) Tensor Sqr(ctx Context) Tensor Sqrt(ctx Context) Tensor - Clamp(ctx Context, min, max float32) Tensor } // ScaledDotProductAttention implements a fused attention diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index e18d2f387..2aa721902 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1137,13 +1137,6 @@ func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor { } } -func (t *Tensor) Neg(ctx ml.Context) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_neg(ctx.(*Context).ctx, t.t), - } -} - func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ b: t.b, @@ -1632,20 +1625,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { } } -func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor { - var tt *C.struct_ggml_tensor - switch len(strides) { - case 0: - tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset)) - case 1: - tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0])) - default: - panic("unsupported number of dimensions") - } - - return &Tensor{b: t.b, t: tt} -} - func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor { var kqMask *C.struct_ggml_tensor if mask != nil { @@ -1732,13 +1711,6 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor { } } -func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)), - } -} - // Slice returns a view of the tensor sliced along dim from low to high in step steps. // Slice panics if the dimension is invalid or the slice parameters are out of range. // If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape. diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go index 63b63b3af..47af87463 100644 --- a/ml/nn/pooling/pooling.go +++ b/ml/nn/pooling/pooling.go @@ -32,10 +32,9 @@ func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) case TypeCLS: - return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) + return hiddenStates.Slice(ctx, 1, 0, 1, 1) case TypeLast: - hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) - return hiddenStates + return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1) default: panic("unknown pooling type") } diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go index 2d78710f7..f2dd1deb4 100644 --- a/model/models/bert/embed.go +++ b/model/models/bert/embed.go @@ -29,7 +29,7 @@ type Model struct { // Forward implements model.Model. func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) - hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) + hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1)) hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions)))) hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps) diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go index cfd579ca5..68b12cd99 100644 --- a/model/models/deepseek2/model.go +++ b/model/models/deepseek2/model.go @@ -78,44 +78,31 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor } query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength) - - qPass := query.View(ctx, 0, - opts.qkNopeHeadDim, query.Stride(1), - query.Dim(1), query.Stride(2), - query.Dim(2)) - - qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0), - opts.qkRopeHeadDim, query.Stride(1), - query.Dim(1), query.Stride(2), - query.Dim(2)) + queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim) compressedKV := attn.KVA.Forward(ctx, hiddenStates) - - kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1)) - kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0), - opts.qkRopeHeadDim, compressedKV.Stride(1), - 1, compressedKV.Stride(1), - compressedKV.Dim(1)) + kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1) + kRot := compressedKV.View(ctx, + opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim, + compressedKV.Stride(1), 1, + compressedKV.Stride(1), compressedKV.Dim(1), + ) kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) kPass = attn.KVB.Forward(ctx, kPass) kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) - kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2)) - value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0), - opts.vHeadDim, kv.Stride(1), - kv.Dim(1), kv.Stride(2), - kv.Dim(2)).Contiguous(ctx) + kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim) - qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) - kRot = kRot.Repeat(ctx, 1, qPass.Dim(1)) + kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) - query = qRot.Concat(ctx, qPass, 0) - key := kRot.Concat(ctx, kPass, 0) + query = qRot.Concat(ctx, queryChunks[0], 0) + key := kRot.Concat(ctx, kvChunks[0], 0) - attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache) + attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) return attn.Output.Forward(ctx, attention) } @@ -142,6 +129,7 @@ func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices) experts = experts.Mul(ctx, topKWeights) + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) for i := 1; i < opts.numExpertsUsed; i++ { nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index ec038a287..3a89afe72 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -64,18 +64,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac cache.(*kvcache.WrapperCache).SetLayerType(layerType) - // inputPerLayer = inputsPerLayer[:, i, :] - inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)).Contiguous(ctx) + // inputPerLayer = inputsPerLayer[:, i, :].squeeze(1) + inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)) hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions) } // hiddenStates = hiddenStates[:, :, 0] - hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1)) + hiddenStates0 := hiddenStates.Slice(ctx, 2, 0, 1, 1) targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx) targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1) // hiddenState = hiddenStates[:, :, 1:] - hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1) + hiddenState = hiddenStates.Slice(ctx, 2, 1, hiddenStates.Dim(2), 1) altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState) altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx))) @@ -176,10 +176,10 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) // inactive := predictions[:, :, 1:] - inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1) + inactive := predictions.Slice(ctx, 2, 1, predictions.Dim(2), 1) active = inactive.Add(ctx, active) - predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1)) + predictions0 := predictions.Slice(ctx, 2, 0, 1, 1) return predictions0.Concat(ctx, active, 2) } @@ -319,7 +319,7 @@ type TextOptions struct { func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor { // t[:, :, o.altupActiveIndex] - return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1)) + return t.Slice(ctx, 2, o.altupActiveIndex, o.altupActiveIndex+1, 1) } func (o *TextOptions) headDim() int { diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index c10920f13..da08ed96f 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -121,30 +121,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T var query, key, value ml.Tensor if attn.QKV != nil { qkv := attn.QKV.Forward(ctx, hiddenStates) - - // query = qkv[..., : num_attention_heads * head_dim].reshape(batch_size, num_attention_heads, head_dim) - query = qkv.View(ctx, - 0, - opts.headDim(), qkv.Stride(0)*opts.headDim(), - opts.numHeads, qkv.Stride(1), - batchSize, - ) - - // key = qkv[..., num_attention_heads * head_dim:(num_attention_heads + num_key_value_heads) * head_dim].reshape(batch_size, num_key_value_heads, head_dim) - key = qkv.View(ctx, - qkv.Stride(0)*opts.headDim()*opts.numHeads, - opts.headDim(), qkv.Stride(0)*opts.headDim(), - opts.numKVHeads, qkv.Stride(1), - batchSize, - ) - - // value = qkv[..., (num_attention_heads + num_key_value_heads) * head_dim:].reshape(batch_size, num_key_value_heads, head_dim) - value = qkv.View(ctx, - qkv.Stride(0)*opts.headDim()*(opts.numHeads+opts.numKVHeads), - opts.headDim(), qkv.Stride(0)*opts.headDim(), - opts.numKVHeads, qkv.Stride(1), - batchSize, - ) + qkv = qkv.Reshape(ctx, opts.headDim(), -1, batchSize) + chunks := qkv.ChunkSections(ctx, 1, opts.numHeads, opts.numKVHeads, opts.numKVHeads) + query, key, value = chunks[0], chunks[1], chunks[2] } else { query = attn.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) @@ -195,15 +174,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio var gate, up ml.Tensor if mlp.GateUp != nil { hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts) - hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2)) - - dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)} - - gate = hiddenStates.View(ctx, 0, dimStride...) - gate = gate.Contiguous(ctx, gate.Dim(0)*gate.Dim(1), gate.Dim(2), gate.Dim(3)) - - up = hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...) - up = up.Contiguous(ctx, up.Dim(0)*up.Dim(1), up.Dim(2), up.Dim(3)) + gate = hiddenStates.Slice(ctx, 0, 0, hiddenStates.Dim(0), 2) + up = hiddenStates.Slice(ctx, 0, 1, hiddenStates.Dim(0), 2) } else { gate = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts) up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 5eeac07c2..4a22bc4bb 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -105,9 +105,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input for range aspectRatio.Y { for x := range aspectRatio.X { - view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, - projectedOutputs.Dim(0), projectedOutputs.Stride(1), - patchesPerChunk) + view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1) var separator separator if x < aspectRatio.X-1 { separator.x = true // <|tile_x_separator|> @@ -120,9 +118,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input } } - view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, - projectedOutputs.Dim(0), projectedOutputs.Stride(1), - patchesPerChunk) + view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1) multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}}) return multimodal, nil diff --git a/model/models/llama4/model_vision.go b/model/models/llama4/model_vision.go index 1aa50aec4..ff6b7fcf2 100644 --- a/model/models/llama4/model_vision.go +++ b/model/models/llama4/model_vision.go @@ -37,27 +37,23 @@ type VisionAttention struct { func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3) - t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3)) - // t1 = t[..., 0::2] - t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) - t1 = t1.Reshape(ctx, width/2, height, channels, tiles) + t1 := t.Slice(ctx, 0, 0, t.Dim(0), 2) // t2 = t[..., 1::2] - t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) - t2 = t2.Reshape(ctx, width/2, height, channels, tiles) + t2 := t.Slice(ctx, 0, 1, t.Dim(0), 2) // cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1) cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0) - cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3)) - cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - cosOut = cosOut.Reshape(ctx, width, height, channels, tiles) + cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, -1) + cosOut = cosOut.Permute(ctx, 1, 0, 2, 3) + cosOut = cosOut.Contiguous(ctx, width, height, channels, tiles) // sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1) - sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0) - sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3)) - sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - sinOut = sinOut.Reshape(ctx, width, height, channels, tiles) + sinOut := t2.Scale(ctx, -1).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0) + sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, -1) + sinOut = sinOut.Permute(ctx, 1, 0, 2, 3) + sinOut = sinOut.Contiguous(ctx, width, height, channels, tiles) return cosOut.Add(ctx, sinOut) } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index ce3110c7c..d763df7a0 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -11,9 +11,9 @@ import ( var batchSize int = 1 func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { - x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) - x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) - return x2.Neg(ctx).Concat(ctx, x1, 0) + x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1) + x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx) + return x2.Scale(ctx, -1).Concat(ctx, x1, 0) } func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 88b2c005c..5cbb01f7e 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -13,9 +13,9 @@ import ( var batchSize int = 1 func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { - x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) - x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) - return x2.Neg(ctx).Concat(ctx, x1, 0) + x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1) + x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx) + return x2.Scale(ctx, -1).Concat(ctx, x1, 0) } func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { diff --git a/model/models/qwen3vl/model_vision.go b/model/models/qwen3vl/model_vision.go index 69118666b..b22ac305c 100644 --- a/model/models/qwen3vl/model_vision.go +++ b/model/models/qwen3vl/model_vision.go @@ -18,8 +18,8 @@ type VisionAttention struct { } func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { - x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) - x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) + x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1) + x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx) return x2.Scale(ctx, -1).Concat(ctx, x1, 0) } @@ -160,10 +160,11 @@ func (m *VisionPositionEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor positionEmbeds = positionEmbeds.Mul(ctx, weights) positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4) - positionEmbeds = positionEmbeds.View(ctx, 0, n, positionEmbeds.Stride(1), grid.Height*grid.Width). - Add(ctx, positionEmbeds.View(ctx, 1*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)). - Add(ctx, positionEmbeds.View(ctx, 2*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)). - Add(ctx, positionEmbeds.View(ctx, 3*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)) + positionEmbedsChunks := positionEmbeds.Chunk(ctx, 2, 1) + positionEmbeds = positionEmbedsChunks[0]. + Add(ctx, positionEmbedsChunks[1]). + Add(ctx, positionEmbedsChunks[2]). + Add(ctx, positionEmbedsChunks[3]) positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize) positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1)