diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index 3d51f70e8..acb58743b 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -237,7 +237,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { } } - if addSpecial && len(ids) > 0 { + if addSpecial { ids = bpe.vocab.addSpecials(ids) } diff --git a/model/sentencepiece.go b/model/sentencepiece.go index db07beee9..2c178ec0c 100644 --- a/model/sentencepiece.go +++ b/model/sentencepiece.go @@ -181,7 +181,7 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { } } - if addSpecial && len(ids) > 0 { + if addSpecial { ids = spm.vocab.addSpecials(ids) } diff --git a/model/vocabulary.go b/model/vocabulary.go index 9b7fc789e..d977c4957 100644 --- a/model/vocabulary.go +++ b/model/vocabulary.go @@ -45,7 +45,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool { func (v *Vocabulary) addSpecials(ids []int32) []int32 { if v.AddBOS && len(v.BOS) > 0 { - if slices.Contains(v.BOS, ids[0]) { + if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) { slog.Warn("adding bos token to prompt which already has it", "id", v.BOS) } @@ -54,7 +54,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 { } if v.AddEOS && len(v.EOS) > 0 { - if slices.Contains(v.BOS, ids[len(ids)-1]) { + if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) { slog.Warn("adding eos token to prompt which already has it", "id", v.EOS) } diff --git a/model/vocabulary_test.go b/model/vocabulary_test.go index 46f0ead23..ccfc39e69 100644 --- a/model/vocabulary_test.go +++ b/model/vocabulary_test.go @@ -1,8 +1,12 @@ package model -import "testing" +import ( + "testing" -func TestVocabulary_SpecialVocabulary(t *testing.T) { + "github.com/google/go-cmp/cmp" +) + +func TestSpecialVocabulary(t *testing.T) { vocab := &Vocabulary{ Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"}, Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL}, @@ -14,3 +18,90 @@ func TestVocabulary_SpecialVocabulary(t *testing.T) { t.Errorf("expected 4 special tokens, got %d", len(specialVocab)) } } + +func TestAddSpecialVocabulary(t *testing.T) { + cases := []struct { + name string + vocab *Vocabulary + input []int32 + want []int32 + }{ + { + name: "add bos", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: true, + AddEOS: false, + }, + input: []int32{2, 3, 4}, + want: []int32{0, 2, 3, 4}, + }, + { + // TODO(mxyng): this is to match previous behaviour + name: "add bos when already present", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: true, + AddEOS: false, + }, + input: []int32{0, 2, 3, 4}, + want: []int32{0, 0, 2, 3, 4}, + }, + { + name: "add eos", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: false, + AddEOS: true, + }, + input: []int32{2, 3, 4}, + want: []int32{2, 3, 4, 1}, + }, + { + // TODO(mxyng): this is to match previous behaviour + name: "add eos when already present", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: false, + AddEOS: true, + }, + input: []int32{2, 3, 4, 1}, + want: []int32{2, 3, 4, 1, 1}, + }, + { + name: "add both", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: true, + AddEOS: true, + }, + input: []int32{2, 3, 4}, + want: []int32{0, 2, 3, 4, 1}, + }, + { + name: "add bos to empty inputs", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: true, + AddEOS: false, + }, + input: []int32{}, + want: []int32{0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := tt.vocab.addSpecials(tt.input) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("no match (-want +got):\n%s", diff) + } + }) + } +} diff --git a/model/wordpiece.go b/model/wordpiece.go index e8d5e848a..ef451c73a 100644 --- a/model/wordpiece.go +++ b/model/wordpiece.go @@ -140,7 +140,7 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) { } } - if addSpecial && len(ids) > 0 { + if addSpecial { ids = wpm.vocab.addSpecials(ids) }