mirror of https://github.com/ollama/ollama
Merge 8dac29b9a9 into a013693f80
This commit is contained in:
commit
ce9a1ec3b3
|
|
@ -1484,6 +1484,11 @@ func (d DoneReason) String() string {
|
|||
type TokenLogprob struct {
|
||||
Token string `json:"token"`
|
||||
Logprob float64 `json:"logprob"`
|
||||
// Bytes contains the raw byte representation of the token.
|
||||
// This field preserves the original bytes before JSON encoding,
|
||||
// which is important for partial UTF-8 tokens that would otherwise
|
||||
// be replaced with the replacement character during JSON marshaling.
|
||||
Bytes []byte `json:"bytes,omitempty"`
|
||||
}
|
||||
|
||||
// Logprob contains log probability information for a generated token.
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@ func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder To
|
|||
TokenLogprob: llm.TokenLogprob{
|
||||
Token: selectedText,
|
||||
Logprob: float64(selectedLogprob),
|
||||
// Store raw bytes before JSON encoding to preserve partial UTF-8 sequences
|
||||
Bytes: []byte(selectedText),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -70,6 +72,8 @@ func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder To
|
|||
topLogprobs[i] = llm.TokenLogprob{
|
||||
Token: tokenText,
|
||||
Logprob: float64(pairs[i].logprob),
|
||||
// Store raw bytes before JSON encoding to preserve partial UTF-8 sequences
|
||||
Bytes: []byte(tokenText),
|
||||
}
|
||||
}
|
||||
result.TopLogprobs = topLogprobs
|
||||
|
|
|
|||
|
|
@ -496,3 +496,138 @@ func TestLogprobsWithStopSequences(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateLogprobsPartialUTF8Bytes verifies that partial UTF-8 sequences
|
||||
// are correctly preserved in the Bytes field. This tests the fix for issue #13497
|
||||
// where logprobs returned the replacement character bytes [239, 191, 189] instead
|
||||
// of the actual partial UTF-8 bytes.
|
||||
func TestCalculateLogprobsPartialUTF8Bytes(t *testing.T) {
|
||||
// Simulate partial UTF-8 tokens for emoji 😊 (UTF-8: [0xF0, 0x9F, 0x98, 0x8A])
|
||||
// When tokenized into partial bytes, each token is a single byte
|
||||
tokens := map[int]string{
|
||||
0: "\xF0", // First byte of 😊
|
||||
1: "\x9F", // Second byte
|
||||
2: "\x98", // Third byte
|
||||
3: "\x8A", // Fourth byte
|
||||
}
|
||||
decoder := func(tokenID int) string {
|
||||
return tokens[tokenID]
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenID int
|
||||
expectedBytes []byte
|
||||
}{
|
||||
{
|
||||
name: "First partial UTF-8 byte",
|
||||
tokenID: 0,
|
||||
expectedBytes: []byte{0xF0},
|
||||
},
|
||||
{
|
||||
name: "Second partial UTF-8 byte",
|
||||
tokenID: 1,
|
||||
expectedBytes: []byte{0x9F},
|
||||
},
|
||||
{
|
||||
name: "Third partial UTF-8 byte",
|
||||
tokenID: 2,
|
||||
expectedBytes: []byte{0x98},
|
||||
},
|
||||
{
|
||||
name: "Fourth partial UTF-8 byte",
|
||||
tokenID: 3,
|
||||
expectedBytes: []byte{0x8A},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logits := []float32{1.0, 0.5, 0.3, 0.1}
|
||||
result := CalculateLogprobs(logits, tt.tokenID, 0, decoder)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("Expected 1 result, got %d", len(result))
|
||||
}
|
||||
|
||||
// Verify the Bytes field contains the correct raw bytes
|
||||
if len(result[0].Bytes) != len(tt.expectedBytes) {
|
||||
t.Errorf("Expected Bytes length %d, got %d", len(tt.expectedBytes), len(result[0].Bytes))
|
||||
}
|
||||
|
||||
for i, b := range result[0].Bytes {
|
||||
if i < len(tt.expectedBytes) && b != tt.expectedBytes[i] {
|
||||
t.Errorf("Bytes[%d] = %d (0x%X), want %d (0x%X)",
|
||||
i, b, b, tt.expectedBytes[i], tt.expectedBytes[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the bytes are NOT the replacement character bytes [239, 191, 189]
|
||||
replacementBytes := []byte{239, 191, 189}
|
||||
if len(result[0].Bytes) == len(replacementBytes) {
|
||||
isReplacement := true
|
||||
for i := range result[0].Bytes {
|
||||
if result[0].Bytes[i] != replacementBytes[i] {
|
||||
isReplacement = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isReplacement {
|
||||
t.Errorf("Bytes field incorrectly contains replacement character bytes [239, 191, 189]")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateLogprobsBytesWithTopLogprobs verifies that partial UTF-8 bytes
|
||||
// are preserved in TopLogprobs as well.
|
||||
func TestCalculateLogprobsBytesWithTopLogprobs(t *testing.T) {
|
||||
tokens := map[int]string{
|
||||
0: "\xF0", // Partial UTF-8
|
||||
1: "hello", // Normal ASCII
|
||||
2: "\x9F", // Another partial UTF-8
|
||||
3: "\xE2\x82", // Partial multi-byte sequence
|
||||
}
|
||||
decoder := func(tokenID int) string {
|
||||
return tokens[tokenID]
|
||||
}
|
||||
|
||||
logits := []float32{2.0, 3.0, 1.0, 0.5}
|
||||
result := CalculateLogprobs(logits, 0, 4, decoder)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("Expected 1 result, got %d", len(result))
|
||||
}
|
||||
|
||||
// Verify selected token bytes
|
||||
expectedSelectedBytes := []byte{0xF0}
|
||||
if len(result[0].Bytes) != len(expectedSelectedBytes) {
|
||||
t.Errorf("Selected token Bytes length = %d, want %d", len(result[0].Bytes), len(expectedSelectedBytes))
|
||||
}
|
||||
|
||||
// Verify TopLogprobs have correct bytes
|
||||
expectedTopBytes := [][]byte{
|
||||
{0x68, 0x65, 0x6c, 0x6c, 0x6f}, // "hello"
|
||||
{0xF0}, // "\xF0"
|
||||
{0x9F}, // "\x9F"
|
||||
{0xE2, 0x82}, // "\xE2\x82"
|
||||
}
|
||||
|
||||
if len(result[0].TopLogprobs) != 4 {
|
||||
t.Fatalf("Expected 4 TopLogprobs, got %d", len(result[0].TopLogprobs))
|
||||
}
|
||||
|
||||
for i, tlp := range result[0].TopLogprobs {
|
||||
if len(tlp.Bytes) != len(expectedTopBytes[i]) {
|
||||
t.Errorf("TopLogprobs[%d].Bytes length = %d, want %d", i, len(tlp.Bytes), len(expectedTopBytes[i]))
|
||||
continue
|
||||
}
|
||||
for j, b := range tlp.Bytes {
|
||||
if b != expectedTopBytes[i][j] {
|
||||
t.Errorf("TopLogprobs[%d].Bytes[%d] = %d (0x%X), want %d (0x%X)",
|
||||
i, j, b, b, expectedTopBytes[i][j], expectedTopBytes[i][j])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob {
|
|||
result[i] = api.Logprob{
|
||||
TokenLogprob: api.TokenLogprob{
|
||||
Token: lp.Token,
|
||||
Bytes: stringToByteInts(lp.Token),
|
||||
Bytes: bytesToInts(lp.Bytes),
|
||||
Logprob: lp.Logprob,
|
||||
},
|
||||
}
|
||||
|
|
@ -21,7 +21,7 @@ func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob {
|
|||
for j, tlp := range lp.TopLogprobs {
|
||||
result[i].TopLogprobs[j] = api.TokenLogprob{
|
||||
Token: tlp.Token,
|
||||
Bytes: stringToByteInts(tlp.Token),
|
||||
Bytes: bytesToInts(tlp.Bytes),
|
||||
Logprob: tlp.Logprob,
|
||||
}
|
||||
}
|
||||
|
|
@ -30,6 +30,24 @@ func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob {
|
|||
return result
|
||||
}
|
||||
|
||||
// bytesToInts converts a byte slice to an int slice.
|
||||
// This function uses the raw bytes stored in llm.TokenLogprob.Bytes,
|
||||
// which preserves partial UTF-8 sequences that would otherwise be
|
||||
// corrupted during JSON marshaling/unmarshaling.
|
||||
func bytesToInts(b []byte) []int {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ints := make([]int, len(b))
|
||||
for i, v := range b {
|
||||
ints[i] = int(v)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
|
||||
// stringToByteInts converts a string to an int slice of bytes.
|
||||
// This is kept for backward compatibility with tests.
|
||||
func stringToByteInts(s string) []int {
|
||||
if s == "" {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -0,0 +1,127 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
// TestLogprobBytesJSONRoundTrip verifies that partial UTF-8 bytes are preserved
|
||||
// through JSON marshaling and unmarshaling. This tests the fix for issue #13497.
|
||||
func TestLogprobBytesJSONRoundTrip(t *testing.T) {
|
||||
// Create a logprob with partial UTF-8 bytes (first byte of emoji 😊)
|
||||
original := llm.Logprob{
|
||||
TokenLogprob: llm.TokenLogprob{
|
||||
Token: "\xF0", // Invalid UTF-8 (partial sequence)
|
||||
Logprob: -0.5,
|
||||
Bytes: []byte{0xF0}, // Raw bytes stored before JSON encoding
|
||||
},
|
||||
TopLogprobs: []llm.TokenLogprob{
|
||||
{
|
||||
Token: "\x9F",
|
||||
Logprob: -1.0,
|
||||
Bytes: []byte{0x9F},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonData, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal from JSON
|
||||
var decoded llm.Logprob
|
||||
if err := json.Unmarshal(jsonData, &decoded); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// The Token field may have been corrupted by JSON (replaced with U+FFFD),
|
||||
// but the Bytes field should preserve the original bytes
|
||||
if len(decoded.Bytes) != 1 || decoded.Bytes[0] != 0xF0 {
|
||||
t.Errorf("Bytes field corrupted: got %v, want [240]", decoded.Bytes)
|
||||
}
|
||||
|
||||
// Verify TopLogprobs bytes are also preserved
|
||||
if len(decoded.TopLogprobs) != 1 {
|
||||
t.Fatalf("TopLogprobs length = %d, want 1", len(decoded.TopLogprobs))
|
||||
}
|
||||
if len(decoded.TopLogprobs[0].Bytes) != 1 || decoded.TopLogprobs[0].Bytes[0] != 0x9F {
|
||||
t.Errorf("TopLogprobs[0].Bytes corrupted: got %v, want [159]", decoded.TopLogprobs[0].Bytes)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToAPILogprobsPreservesBytes verifies that toAPILogprobs uses the stored
|
||||
// bytes instead of converting from the (potentially corrupted) token string.
|
||||
func TestToAPILogprobsPreservesBytes(t *testing.T) {
|
||||
// Simulate logprobs that have been through JSON round-trip
|
||||
// The Token field contains the replacement character (corrupted)
|
||||
// but the Bytes field contains the correct original bytes
|
||||
logprobs := []llm.Logprob{
|
||||
{
|
||||
TokenLogprob: llm.TokenLogprob{
|
||||
Token: "\uFFFD", // Replacement character (corrupted)
|
||||
Logprob: -0.5,
|
||||
Bytes: []byte{0xF0}, // Original bytes preserved
|
||||
},
|
||||
TopLogprobs: []llm.TokenLogprob{
|
||||
{
|
||||
Token: "\uFFFD",
|
||||
Logprob: -1.0,
|
||||
Bytes: []byte{0x9F},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Convert to API logprobs
|
||||
apiLogprobs := toAPILogprobs(logprobs)
|
||||
|
||||
if len(apiLogprobs) != 1 {
|
||||
t.Fatalf("Expected 1 API logprob, got %d", len(apiLogprobs))
|
||||
}
|
||||
|
||||
// Verify that the Bytes field contains the correct bytes, not the
|
||||
// replacement character bytes [239, 191, 189]
|
||||
expectedBytes := []int{240} // 0xF0
|
||||
if len(apiLogprobs[0].Bytes) != len(expectedBytes) {
|
||||
t.Errorf("Bytes length = %d, want %d", len(apiLogprobs[0].Bytes), len(expectedBytes))
|
||||
}
|
||||
for i, b := range apiLogprobs[0].Bytes {
|
||||
if b != expectedBytes[i] {
|
||||
t.Errorf("Bytes[%d] = %d, want %d", i, b, expectedBytes[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify TopLogprobs bytes
|
||||
if len(apiLogprobs[0].TopLogprobs) != 1 {
|
||||
t.Fatalf("Expected 1 TopLogprob, got %d", len(apiLogprobs[0].TopLogprobs))
|
||||
}
|
||||
expectedTopBytes := []int{159} // 0x9F
|
||||
if len(apiLogprobs[0].TopLogprobs[0].Bytes) != len(expectedTopBytes) {
|
||||
t.Errorf("TopLogprobs[0].Bytes length = %d, want %d",
|
||||
len(apiLogprobs[0].TopLogprobs[0].Bytes), len(expectedTopBytes))
|
||||
}
|
||||
for i, b := range apiLogprobs[0].TopLogprobs[0].Bytes {
|
||||
if b != expectedTopBytes[i] {
|
||||
t.Errorf("TopLogprobs[0].Bytes[%d] = %d, want %d", i, b, expectedTopBytes[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we're NOT getting replacement character bytes
|
||||
replacementBytes := []int{239, 191, 189}
|
||||
if len(apiLogprobs[0].Bytes) == len(replacementBytes) {
|
||||
allMatch := true
|
||||
for i := range apiLogprobs[0].Bytes {
|
||||
if apiLogprobs[0].Bytes[i] != replacementBytes[i] {
|
||||
allMatch = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allMatch {
|
||||
t.Errorf("Bytes field incorrectly contains replacement character bytes")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1317,15 +1317,18 @@ func TestGenerateLogprobs(t *testing.T) {
|
|||
expectedPrimary := llm.TokenLogprob{
|
||||
Token: "Hi",
|
||||
Logprob: -0.01,
|
||||
Bytes: []byte("Hi"),
|
||||
}
|
||||
expectedAlternatives := []llm.TokenLogprob{
|
||||
{
|
||||
Token: "Hello",
|
||||
Logprob: -0.25,
|
||||
Bytes: []byte("Hello"),
|
||||
},
|
||||
{
|
||||
Token: "Hey",
|
||||
Logprob: -0.5,
|
||||
Bytes: []byte("Hey"),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -1492,15 +1495,18 @@ func TestChatLogprobs(t *testing.T) {
|
|||
expectedPrimary := llm.TokenLogprob{
|
||||
Token: "Hi",
|
||||
Logprob: -0.02,
|
||||
Bytes: []byte("Hi"),
|
||||
}
|
||||
expectedAlternatives := []llm.TokenLogprob{
|
||||
{
|
||||
Token: "Hello",
|
||||
Logprob: -0.3,
|
||||
Bytes: []byte("Hello"),
|
||||
},
|
||||
{
|
||||
Token: "Hey",
|
||||
Logprob: -0.45,
|
||||
Bytes: []byte("Hey"),
|
||||
},
|
||||
}
|
||||
expectedPrimaryBytes := stringToByteInts(expectedPrimary.Token)
|
||||
|
|
|
|||
Loading…
Reference in New Issue