This commit is contained in:
Nhan Nguyen 2025-12-17 12:01:12 +09:00 committed by GitHub
commit ce9a1ec3b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 297 additions and 2 deletions

View File

@ -1484,6 +1484,11 @@ func (d DoneReason) String() string {
type TokenLogprob struct {
Token string `json:"token"`
Logprob float64 `json:"logprob"`
// Bytes contains the raw byte representation of the token.
// This field preserves the original bytes before JSON encoding,
// which is important for partial UTF-8 tokens that would otherwise
// be replaced with the replacement character during JSON marshaling.
Bytes []byte `json:"bytes,omitempty"`
}
// Logprob contains log probability information for a generated token.

View File

@ -44,6 +44,8 @@ func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder To
TokenLogprob: llm.TokenLogprob{
Token: selectedText,
Logprob: float64(selectedLogprob),
// Store raw bytes before JSON encoding to preserve partial UTF-8 sequences
Bytes: []byte(selectedText),
},
}
@ -70,6 +72,8 @@ func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder To
topLogprobs[i] = llm.TokenLogprob{
Token: tokenText,
Logprob: float64(pairs[i].logprob),
// Store raw bytes before JSON encoding to preserve partial UTF-8 sequences
Bytes: []byte(tokenText),
}
}
result.TopLogprobs = topLogprobs

View File

@ -496,3 +496,138 @@ func TestLogprobsWithStopSequences(t *testing.T) {
})
}
}
// TestCalculateLogprobsPartialUTF8Bytes verifies that partial UTF-8 sequences
// are correctly preserved in the Bytes field. This tests the fix for issue #13497
// where logprobs returned the replacement character bytes [239, 191, 189] instead
// of the actual partial UTF-8 bytes.
func TestCalculateLogprobsPartialUTF8Bytes(t *testing.T) {
// Simulate partial UTF-8 tokens for emoji 😊 (UTF-8: [0xF0, 0x9F, 0x98, 0x8A])
// When tokenized into partial bytes, each token is a single byte
tokens := map[int]string{
0: "\xF0", // First byte of 😊
1: "\x9F", // Second byte
2: "\x98", // Third byte
3: "\x8A", // Fourth byte
}
decoder := func(tokenID int) string {
return tokens[tokenID]
}
tests := []struct {
name string
tokenID int
expectedBytes []byte
}{
{
name: "First partial UTF-8 byte",
tokenID: 0,
expectedBytes: []byte{0xF0},
},
{
name: "Second partial UTF-8 byte",
tokenID: 1,
expectedBytes: []byte{0x9F},
},
{
name: "Third partial UTF-8 byte",
tokenID: 2,
expectedBytes: []byte{0x98},
},
{
name: "Fourth partial UTF-8 byte",
tokenID: 3,
expectedBytes: []byte{0x8A},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logits := []float32{1.0, 0.5, 0.3, 0.1}
result := CalculateLogprobs(logits, tt.tokenID, 0, decoder)
if len(result) != 1 {
t.Fatalf("Expected 1 result, got %d", len(result))
}
// Verify the Bytes field contains the correct raw bytes
if len(result[0].Bytes) != len(tt.expectedBytes) {
t.Errorf("Expected Bytes length %d, got %d", len(tt.expectedBytes), len(result[0].Bytes))
}
for i, b := range result[0].Bytes {
if i < len(tt.expectedBytes) && b != tt.expectedBytes[i] {
t.Errorf("Bytes[%d] = %d (0x%X), want %d (0x%X)",
i, b, b, tt.expectedBytes[i], tt.expectedBytes[i])
}
}
// Verify the bytes are NOT the replacement character bytes [239, 191, 189]
replacementBytes := []byte{239, 191, 189}
if len(result[0].Bytes) == len(replacementBytes) {
isReplacement := true
for i := range result[0].Bytes {
if result[0].Bytes[i] != replacementBytes[i] {
isReplacement = false
break
}
}
if isReplacement {
t.Errorf("Bytes field incorrectly contains replacement character bytes [239, 191, 189]")
}
}
})
}
}
// TestCalculateLogprobsBytesWithTopLogprobs verifies that partial UTF-8 bytes
// are preserved in TopLogprobs as well.
func TestCalculateLogprobsBytesWithTopLogprobs(t *testing.T) {
tokens := map[int]string{
0: "\xF0", // Partial UTF-8
1: "hello", // Normal ASCII
2: "\x9F", // Another partial UTF-8
3: "\xE2\x82", // Partial multi-byte sequence
}
decoder := func(tokenID int) string {
return tokens[tokenID]
}
logits := []float32{2.0, 3.0, 1.0, 0.5}
result := CalculateLogprobs(logits, 0, 4, decoder)
if len(result) != 1 {
t.Fatalf("Expected 1 result, got %d", len(result))
}
// Verify selected token bytes
expectedSelectedBytes := []byte{0xF0}
if len(result[0].Bytes) != len(expectedSelectedBytes) {
t.Errorf("Selected token Bytes length = %d, want %d", len(result[0].Bytes), len(expectedSelectedBytes))
}
// Verify TopLogprobs have correct bytes
expectedTopBytes := [][]byte{
{0x68, 0x65, 0x6c, 0x6c, 0x6f}, // "hello"
{0xF0}, // "\xF0"
{0x9F}, // "\x9F"
{0xE2, 0x82}, // "\xE2\x82"
}
if len(result[0].TopLogprobs) != 4 {
t.Fatalf("Expected 4 TopLogprobs, got %d", len(result[0].TopLogprobs))
}
for i, tlp := range result[0].TopLogprobs {
if len(tlp.Bytes) != len(expectedTopBytes[i]) {
t.Errorf("TopLogprobs[%d].Bytes length = %d, want %d", i, len(tlp.Bytes), len(expectedTopBytes[i]))
continue
}
for j, b := range tlp.Bytes {
if b != expectedTopBytes[i][j] {
t.Errorf("TopLogprobs[%d].Bytes[%d] = %d (0x%X), want %d (0x%X)",
i, j, b, b, expectedTopBytes[i][j], expectedTopBytes[i][j])
}
}
}
}

View File

@ -12,7 +12,7 @@ func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob {
result[i] = api.Logprob{
TokenLogprob: api.TokenLogprob{
Token: lp.Token,
Bytes: stringToByteInts(lp.Token),
Bytes: bytesToInts(lp.Bytes),
Logprob: lp.Logprob,
},
}
@ -21,7 +21,7 @@ func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob {
for j, tlp := range lp.TopLogprobs {
result[i].TopLogprobs[j] = api.TokenLogprob{
Token: tlp.Token,
Bytes: stringToByteInts(tlp.Token),
Bytes: bytesToInts(tlp.Bytes),
Logprob: tlp.Logprob,
}
}
@ -30,6 +30,24 @@ func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob {
return result
}
// bytesToInts converts a byte slice to an int slice.
// This function uses the raw bytes stored in llm.TokenLogprob.Bytes,
// which preserves partial UTF-8 sequences that would otherwise be
// corrupted during JSON marshaling/unmarshaling.
func bytesToInts(b []byte) []int {
if len(b) == 0 {
return nil
}
ints := make([]int, len(b))
for i, v := range b {
ints[i] = int(v)
}
return ints
}
// stringToByteInts converts a string to an int slice of bytes.
// This is kept for backward compatibility with tests.
func stringToByteInts(s string) []int {
if s == "" {
return nil

127
server/logprob_test.go Normal file
View File

@ -0,0 +1,127 @@
package server
import (
"encoding/json"
"testing"
"github.com/ollama/ollama/llm"
)
// TestLogprobBytesJSONRoundTrip verifies that partial UTF-8 bytes are preserved
// through JSON marshaling and unmarshaling. This tests the fix for issue #13497.
func TestLogprobBytesJSONRoundTrip(t *testing.T) {
// Create a logprob with partial UTF-8 bytes (first byte of emoji 😊)
original := llm.Logprob{
TokenLogprob: llm.TokenLogprob{
Token: "\xF0", // Invalid UTF-8 (partial sequence)
Logprob: -0.5,
Bytes: []byte{0xF0}, // Raw bytes stored before JSON encoding
},
TopLogprobs: []llm.TokenLogprob{
{
Token: "\x9F",
Logprob: -1.0,
Bytes: []byte{0x9F},
},
},
}
// Marshal to JSON
jsonData, err := json.Marshal(original)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Unmarshal from JSON
var decoded llm.Logprob
if err := json.Unmarshal(jsonData, &decoded); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
// The Token field may have been corrupted by JSON (replaced with U+FFFD),
// but the Bytes field should preserve the original bytes
if len(decoded.Bytes) != 1 || decoded.Bytes[0] != 0xF0 {
t.Errorf("Bytes field corrupted: got %v, want [240]", decoded.Bytes)
}
// Verify TopLogprobs bytes are also preserved
if len(decoded.TopLogprobs) != 1 {
t.Fatalf("TopLogprobs length = %d, want 1", len(decoded.TopLogprobs))
}
if len(decoded.TopLogprobs[0].Bytes) != 1 || decoded.TopLogprobs[0].Bytes[0] != 0x9F {
t.Errorf("TopLogprobs[0].Bytes corrupted: got %v, want [159]", decoded.TopLogprobs[0].Bytes)
}
}
// TestToAPILogprobsPreservesBytes verifies that toAPILogprobs uses the stored
// bytes instead of converting from the (potentially corrupted) token string.
func TestToAPILogprobsPreservesBytes(t *testing.T) {
// Simulate logprobs that have been through JSON round-trip
// The Token field contains the replacement character (corrupted)
// but the Bytes field contains the correct original bytes
logprobs := []llm.Logprob{
{
TokenLogprob: llm.TokenLogprob{
Token: "\uFFFD", // Replacement character (corrupted)
Logprob: -0.5,
Bytes: []byte{0xF0}, // Original bytes preserved
},
TopLogprobs: []llm.TokenLogprob{
{
Token: "\uFFFD",
Logprob: -1.0,
Bytes: []byte{0x9F},
},
},
},
}
// Convert to API logprobs
apiLogprobs := toAPILogprobs(logprobs)
if len(apiLogprobs) != 1 {
t.Fatalf("Expected 1 API logprob, got %d", len(apiLogprobs))
}
// Verify that the Bytes field contains the correct bytes, not the
// replacement character bytes [239, 191, 189]
expectedBytes := []int{240} // 0xF0
if len(apiLogprobs[0].Bytes) != len(expectedBytes) {
t.Errorf("Bytes length = %d, want %d", len(apiLogprobs[0].Bytes), len(expectedBytes))
}
for i, b := range apiLogprobs[0].Bytes {
if b != expectedBytes[i] {
t.Errorf("Bytes[%d] = %d, want %d", i, b, expectedBytes[i])
}
}
// Verify TopLogprobs bytes
if len(apiLogprobs[0].TopLogprobs) != 1 {
t.Fatalf("Expected 1 TopLogprob, got %d", len(apiLogprobs[0].TopLogprobs))
}
expectedTopBytes := []int{159} // 0x9F
if len(apiLogprobs[0].TopLogprobs[0].Bytes) != len(expectedTopBytes) {
t.Errorf("TopLogprobs[0].Bytes length = %d, want %d",
len(apiLogprobs[0].TopLogprobs[0].Bytes), len(expectedTopBytes))
}
for i, b := range apiLogprobs[0].TopLogprobs[0].Bytes {
if b != expectedTopBytes[i] {
t.Errorf("TopLogprobs[0].Bytes[%d] = %d, want %d", i, b, expectedTopBytes[i])
}
}
// Ensure we're NOT getting replacement character bytes
replacementBytes := []int{239, 191, 189}
if len(apiLogprobs[0].Bytes) == len(replacementBytes) {
allMatch := true
for i := range apiLogprobs[0].Bytes {
if apiLogprobs[0].Bytes[i] != replacementBytes[i] {
allMatch = false
break
}
}
if allMatch {
t.Errorf("Bytes field incorrectly contains replacement character bytes")
}
}
}

View File

@ -1317,15 +1317,18 @@ func TestGenerateLogprobs(t *testing.T) {
expectedPrimary := llm.TokenLogprob{
Token: "Hi",
Logprob: -0.01,
Bytes: []byte("Hi"),
}
expectedAlternatives := []llm.TokenLogprob{
{
Token: "Hello",
Logprob: -0.25,
Bytes: []byte("Hello"),
},
{
Token: "Hey",
Logprob: -0.5,
Bytes: []byte("Hey"),
},
}
@ -1492,15 +1495,18 @@ func TestChatLogprobs(t *testing.T) {
expectedPrimary := llm.TokenLogprob{
Token: "Hi",
Logprob: -0.02,
Bytes: []byte("Hi"),
}
expectedAlternatives := []llm.TokenLogprob{
{
Token: "Hello",
Logprob: -0.3,
Bytes: []byte("Hello"),
},
{
Token: "Hey",
Logprob: -0.45,
Bytes: []byte("Hey"),
},
}
expectedPrimaryBytes := stringToByteInts(expectedPrimary.Token)