mirror of https://github.com/ollama/ollama
model: add rnj-1 inference support (#13354)
This commit is contained in:
parent
603ceefaa6
commit
d2f334c1f7
|
|
@ -2,6 +2,7 @@ package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
@ -26,16 +27,26 @@ type gemma3Model struct {
|
||||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||||
} `json:"vision_config"`
|
} `json:"vision_config"`
|
||||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
HeadDim uint32 `json:"head_dim"`
|
HeadDim uint32 `json:"head_dim"`
|
||||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
SlidingWindow uint32 `json:"sliding_window"`
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||||
|
RopeScaling *struct {
|
||||||
|
Type string `json:"rope_type"`
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||||
|
BetaFast float32 `json:"beta_fast"`
|
||||||
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
|
} `json:"rope_scaling"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -81,9 +92,38 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
kv["gemma3.attention.key_length"] = p.HeadDim
|
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||||
kv["gemma3.attention.value_length"] = p.HeadDim
|
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||||
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||||
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
|
||||||
|
// The sliding window pattern is either provided as the sliding_window_pattern
|
||||||
|
// key (an int) or as the layer_types key (a list of strings).
|
||||||
|
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
|
||||||
|
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||||
|
for i := range numBlocks {
|
||||||
|
var isLocal bool
|
||||||
|
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
|
||||||
|
isLocal = p.LayerTypes[i] == "sliding_attention"
|
||||||
|
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
|
||||||
|
isLocal = (i+1)%*p.SlidingWindowPattern != 0
|
||||||
|
}
|
||||||
|
if !yield(isLocal) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if p.FinalLogitSoftcap > 0 {
|
||||||
|
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
|
||||||
|
}
|
||||||
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||||
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
|
||||||
|
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
|
||||||
|
kv["gemma3.rope.scaling.type"] = "yarn"
|
||||||
|
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||||
|
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||||
|
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
|
||||||
|
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
|
||||||
|
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
|
||||||
|
}
|
||||||
|
|
||||||
kv["gemma3.embedding_length"] = p.HiddenSize
|
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||||
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,18 @@ func WithAttentionFactor(attentionFactor float32) func(*Options) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithBetaFast(betaFast float32) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.YaRN.BetaFast = betaFast
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBetaSlow(betaSlow float32) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.YaRN.BetaSlow = betaSlow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithMRoPE(sections []int) func(*Options) {
|
func WithMRoPE(sections []int) func(*Options) {
|
||||||
return func(opts *Options) {
|
return func(opts *Options) {
|
||||||
opts.Type |= 1 << 3
|
opts.Type |= 1 << 3
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import (
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePiece
|
model.TextProcessor
|
||||||
|
|
||||||
*VisionModel `gguf:"v"`
|
*VisionModel `gguf:"v"`
|
||||||
*TextModel
|
*TextModel
|
||||||
|
|
@ -54,24 +54,35 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
vocabulary := model.Vocabulary{
|
||||||
SentencePiece: model.NewSentencePiece(
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
&model.Vocabulary{
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
EOS: append(
|
||||||
EOS: append(
|
[]int32{
|
||||||
[]int32{
|
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
|
||||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
|
||||||
},
|
|
||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
),
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
var processor model.TextProcessor
|
||||||
|
switch c.String("tokenizer.ggml.model") {
|
||||||
|
case "gpt2":
|
||||||
|
processor = model.NewBytePairEncoding(&vocabulary)
|
||||||
|
default:
|
||||||
|
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||||
|
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||||
|
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||||
|
processor = model.NewSentencePiece(&vocabulary)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
TextProcessor: processor,
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: newImageProcessor(c),
|
||||||
VisionModel: newVisionModel(c),
|
VisionModel: newVisionModel(c),
|
||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
|
|
@ -141,8 +152,16 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||||
|
|
||||||
|
if m.TextConfig.finalLogitSoftcap > 0.0 {
|
||||||
|
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextConfig.finalLogitSoftcap))
|
||||||
|
hiddenState = hiddenState.Tanh(ctx)
|
||||||
|
hiddenState = hiddenState.Scale(ctx, float64(m.TextConfig.finalLogitSoftcap))
|
||||||
|
}
|
||||||
|
|
||||||
|
return hiddenState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package gemma3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
|
@ -15,12 +16,32 @@ type TextConfig struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
attnKeyLen, attnValLen int
|
attnKeyLen, attnValLen int
|
||||||
eps, ropeScale float32
|
eps, ropeScale float32
|
||||||
ropeLocalBase, ropeGlobalBase float32
|
ropeLocalBase float32
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
|
slidingWindowPattern []bool
|
||||||
|
ropeBase float32
|
||||||
|
ropeType string
|
||||||
|
ropeOriginalContext int
|
||||||
|
ropeExtrapolation float32
|
||||||
|
ropeBetaFast float32
|
||||||
|
ropeBetaSlow float32
|
||||||
|
finalLogitSoftcap float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
|
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
|
||||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX())
|
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||||
|
if o.ropeType == "yarn" {
|
||||||
|
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||||
|
ropeOpts = append(ropeOpts,
|
||||||
|
rope.WithOriginalContextLength(o.ropeOriginalContext),
|
||||||
|
rope.WithExtrapolationFactor(o.ropeExtrapolation),
|
||||||
|
rope.WithAttentionFactor(attnFactor),
|
||||||
|
rope.WithBetaFast(o.ropeBetaFast),
|
||||||
|
rope.WithBetaSlow(o.ropeBetaSlow),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
|
|
@ -48,21 +69,35 @@ func newTextModel(c fs.Config) *TextModel {
|
||||||
m := TextModel{
|
m := TextModel{
|
||||||
Layers: make([]TextLayer, numBlocks),
|
Layers: make([]TextLayer, numBlocks),
|
||||||
TextConfig: &TextConfig{
|
TextConfig: &TextConfig{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
ropeBase: c.Float("rope.freq_base", 1000000.0),
|
||||||
ropeScale: 1,
|
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||||
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
ropeType: c.String("rope.scaling.type"),
|
||||||
// (8 instead of 1)
|
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
|
||||||
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
||||||
|
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
|
||||||
|
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
||||||
|
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||||
|
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Google's Gemma 3 release with sliding window attention does
|
||||||
|
// not use final logit softcapping, and so force it to 0.0
|
||||||
|
// TODO (jmorganca): this should ideally be set to 0.0 in the
|
||||||
|
// model configuration instead of here, as future versions of
|
||||||
|
// models may include both sliding window attention and final
|
||||||
|
// logit softcapping.
|
||||||
|
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
|
||||||
|
m.TextConfig.finalLogitSoftcap = 0.0
|
||||||
|
}
|
||||||
|
|
||||||
if numBlocks == gemma27BLayerCount {
|
if numBlocks == gemma27BLayerCount {
|
||||||
m.largeModelScaling = true
|
m.largeModelScaling = true
|
||||||
}
|
}
|
||||||
|
|
@ -79,13 +114,26 @@ type TextSelfAttention struct {
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
|
||||||
|
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
|
||||||
|
return opts.ropeLocalBase
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard Gemma3: only every n-th layer is global,
|
||||||
|
// where n = gemmaGlobalCacheCount, otherwise use
|
||||||
|
// the local rope base
|
||||||
|
if (layer+1)%gemmaGlobalCacheCount > 0 {
|
||||||
|
return opts.ropeLocalBase
|
||||||
|
}
|
||||||
|
|
||||||
|
// default to global rope base
|
||||||
|
return opts.ropeBase
|
||||||
|
}
|
||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
|
|
||||||
ropeBase := opts.ropeLocalBase
|
ropeBase := opts.ropeBaseForLayer(layer)
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
|
||||||
ropeBase = opts.ropeGlobalBase
|
|
||||||
}
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
|
|
@ -114,12 +162,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeBase := m.TextConfig.ropeLocalBase
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
|
||||||
ropeBase = m.TextConfig.ropeGlobalBase
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
|
@ -207,6 +250,5 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
return hiddenState
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -300,18 +300,13 @@ func filesForModel(path string) ([]string, error) {
|
||||||
}
|
}
|
||||||
files = append(files, js...)
|
files = append(files, js...)
|
||||||
|
|
||||||
// only include tokenizer.model is tokenizer.json is not present
|
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
|
||||||
if !slices.ContainsFunc(files, func(s string) bool {
|
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||||
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
|
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||||
}) {
|
files = append(files, tks...)
|
||||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
files = append(files, tks...)
|
||||||
files = append(files, tks...)
|
|
||||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
|
||||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
|
||||||
files = append(files, tks...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return files, nil
|
return files, nil
|
||||||
|
|
|
||||||
|
|
@ -888,6 +888,37 @@ func TestFilesForModel(t *testing.T) {
|
||||||
"tokenizer.json",
|
"tokenizer.json",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "safetensors with both tokenizer.json and tokenizer.model",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
// Create binary content for tokenizer.model (application/octet-stream)
|
||||||
|
binaryContent := make([]byte, 512)
|
||||||
|
for i := range binaryContent {
|
||||||
|
binaryContent[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
files := []string{
|
||||||
|
"model-00001-of-00001.safetensors",
|
||||||
|
"config.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Write tokenizer.model as binary
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantFiles: []string{
|
||||||
|
"model-00001-of-00001.safetensors",
|
||||||
|
"config.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
"tokenizer.model",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "safetensors with consolidated files - prefers model files",
|
name: "safetensors with consolidated files - prefers model files",
|
||||||
setup: func(dir string) error {
|
setup: func(dir string) error {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue