mirror of https://github.com/ollama/ollama
499 lines
14 KiB
Go
499 lines
14 KiB
Go
package common
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/llm"
|
|
)
|
|
|
|
func TestCalculateLogprobs(t *testing.T) {
|
|
tokens := map[int]string{
|
|
0: "hello",
|
|
1: "hi",
|
|
2: "hey",
|
|
3: "world",
|
|
}
|
|
decoder := func(tokenID int) string {
|
|
if text, ok := tokens[tokenID]; ok {
|
|
return text
|
|
}
|
|
return ""
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
logits []float32
|
|
selectedToken int
|
|
topK int
|
|
wantLen int
|
|
wantToken string
|
|
}{
|
|
{
|
|
name: "Empty logits",
|
|
logits: []float32{},
|
|
selectedToken: 0,
|
|
topK: 0,
|
|
wantLen: 0,
|
|
},
|
|
{
|
|
name: "Single token without top logprobs",
|
|
logits: []float32{1.0, 0.5, 0.3, 0.1},
|
|
selectedToken: 0,
|
|
topK: 0,
|
|
wantLen: 1,
|
|
wantToken: "hello",
|
|
},
|
|
{
|
|
name: "Single token with top logprobs",
|
|
logits: []float32{1.0, 0.5, 0.3, 0.1},
|
|
selectedToken: 0,
|
|
topK: 3,
|
|
wantLen: 1,
|
|
wantToken: "hello",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := CalculateLogprobs(tt.logits, tt.selectedToken, tt.topK, decoder)
|
|
if len(result) != tt.wantLen {
|
|
t.Errorf("CalculateLogprobs() returned %d results, want %d", len(result), tt.wantLen)
|
|
}
|
|
if tt.wantLen > 0 && result[0].Token != tt.wantToken {
|
|
t.Errorf("CalculateLogprobs() token = %s, want %s", result[0].Token, tt.wantToken)
|
|
}
|
|
if tt.topK > 0 && len(result) > 0 {
|
|
if len(result[0].TopLogprobs) != tt.topK {
|
|
t.Errorf("CalculateLogprobs() top logprobs count = %d, want %d", len(result[0].TopLogprobs), tt.topK)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCalculateLogprobsNumericalStability(t *testing.T) {
|
|
tokens := map[int]string{
|
|
0: "a",
|
|
1: "b",
|
|
2: "c",
|
|
}
|
|
decoder := func(tokenID int) string {
|
|
if text, ok := tokens[tokenID]; ok {
|
|
return text
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// Test with very large logits to ensure numerical stability
|
|
logits := []float32{1000.0, 999.0, 998.0}
|
|
result := CalculateLogprobs(logits, 0, 3, decoder)
|
|
|
|
if len(result) != 1 {
|
|
t.Fatalf("Expected 1 result, got %d", len(result))
|
|
}
|
|
|
|
// Check that log probabilities are finite and reasonable
|
|
if math.IsInf(result[0].Logprob, 0) || math.IsNaN(result[0].Logprob) {
|
|
t.Errorf("Selected token logprob is not finite: %f", result[0].Logprob)
|
|
}
|
|
|
|
for i, tlp := range result[0].TopLogprobs {
|
|
if math.IsInf(tlp.Logprob, 0) || math.IsNaN(tlp.Logprob) {
|
|
t.Errorf("Top logprob[%d] is not finite: %f", i, tlp.Logprob)
|
|
}
|
|
}
|
|
|
|
// Top logprobs should be in descending order
|
|
for i := 1; i < len(result[0].TopLogprobs); i++ {
|
|
if result[0].TopLogprobs[i].Logprob > result[0].TopLogprobs[i-1].Logprob {
|
|
t.Errorf("Top logprobs not in descending order: %f > %f",
|
|
result[0].TopLogprobs[i].Logprob, result[0].TopLogprobs[i-1].Logprob)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCalculateLogprobsProbabilityCorrectness(t *testing.T) {
|
|
tokens := map[int]string{
|
|
0: "hello",
|
|
1: "world",
|
|
2: "foo",
|
|
3: "bar",
|
|
}
|
|
decoder := func(tokenID int) string {
|
|
if text, ok := tokens[tokenID]; ok {
|
|
return text
|
|
}
|
|
return ""
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
logits []float32
|
|
selectedToken int
|
|
topK int
|
|
}{
|
|
{
|
|
name: "Uniform logits",
|
|
logits: []float32{1.0, 1.0, 1.0, 1.0},
|
|
selectedToken: 0,
|
|
topK: 4,
|
|
},
|
|
{
|
|
name: "Different logits",
|
|
logits: []float32{2.0, 1.0, 0.5, 0.1},
|
|
selectedToken: 0,
|
|
topK: 4,
|
|
},
|
|
{
|
|
name: "Negative logits",
|
|
logits: []float32{-1.0, -2.0, -3.0, -4.0},
|
|
selectedToken: 0,
|
|
topK: 4,
|
|
},
|
|
{
|
|
name: "Mixed logits",
|
|
logits: []float32{5.0, -5.0, 0.0, 2.5},
|
|
selectedToken: 0,
|
|
topK: 4,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := CalculateLogprobs(tt.logits, tt.selectedToken, tt.topK, decoder)
|
|
|
|
if len(result) != 1 {
|
|
t.Fatalf("Expected 1 result, got %d", len(result))
|
|
}
|
|
|
|
// Verify all probabilities are non-positive (log probabilities should be <= 0)
|
|
if result[0].Logprob > 0 {
|
|
t.Errorf("Selected token logprob should be <= 0, got %f", result[0].Logprob)
|
|
}
|
|
|
|
for i, tlp := range result[0].TopLogprobs {
|
|
if tlp.Logprob > 0 {
|
|
t.Errorf("Top logprob[%d] should be <= 0, got %f", i, tlp.Logprob)
|
|
}
|
|
}
|
|
|
|
// Verify that probabilities sum to approximately 1
|
|
// Sum of exp(logprob) for all tokens should equal 1
|
|
var probSum float64
|
|
for _, lp := range result[0].TopLogprobs {
|
|
probSum += math.Exp(lp.Logprob)
|
|
}
|
|
|
|
// For uniform logits, each probability should be 1/n
|
|
if tt.name == "Uniform logits" {
|
|
expectedProb := 1.0 / float64(len(tt.logits))
|
|
actualProb := math.Exp(result[0].Logprob)
|
|
if math.Abs(actualProb-expectedProb) > 1e-6 {
|
|
t.Errorf("For uniform logits, expected probability %f, got %f",
|
|
expectedProb, actualProb)
|
|
}
|
|
}
|
|
|
|
// Verify top logprobs are sorted in descending order
|
|
for i := 1; i < len(result[0].TopLogprobs); i++ {
|
|
if result[0].TopLogprobs[i].Logprob > result[0].TopLogprobs[i-1].Logprob {
|
|
t.Errorf("Top logprobs not sorted: position %d (%f) > position %d (%f)",
|
|
i, result[0].TopLogprobs[i].Logprob,
|
|
i-1, result[0].TopLogprobs[i-1].Logprob)
|
|
}
|
|
}
|
|
|
|
// Verify the selected token appears in top logprobs
|
|
selectedText := decoder(tt.selectedToken)
|
|
found := false
|
|
for _, tlp := range result[0].TopLogprobs {
|
|
if tlp.Token == selectedText {
|
|
found = true
|
|
// The logprob in top logprobs should match the selected token's logprob
|
|
if math.Abs(tlp.Logprob-result[0].Logprob) > 1e-6 {
|
|
t.Errorf("Selected token logprob mismatch: main=%f, in top=%f",
|
|
result[0].Logprob, tlp.Logprob)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Errorf("Selected token %q not found in top logprobs", selectedText)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCalculateLogprobsSoftmaxCorrectness(t *testing.T) {
|
|
// Test that softmax calculation is correct by verifying probabilities sum to 1
|
|
decoder := func(tokenID int) string {
|
|
return string(rune('A' + tokenID))
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
logits []float32
|
|
}{
|
|
{
|
|
name: "Small vocabulary",
|
|
logits: []float32{1.0, 2.0, 3.0},
|
|
},
|
|
{
|
|
name: "Large differences",
|
|
logits: []float32{10.0, 0.0, -10.0},
|
|
},
|
|
{
|
|
name: "All equal",
|
|
logits: []float32{5.0, 5.0, 5.0, 5.0, 5.0},
|
|
},
|
|
{
|
|
name: "Very large values",
|
|
logits: []float32{500.0, 499.0, 498.0},
|
|
},
|
|
{
|
|
name: "Very small values",
|
|
logits: []float32{-500.0, -499.0, -498.0},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Calculate logprobs for all tokens
|
|
var totalProb float64
|
|
for i := range tt.logits {
|
|
result := CalculateLogprobs(tt.logits, i, 0, decoder)
|
|
if len(result) != 1 {
|
|
t.Fatalf("Expected 1 result, got %d", len(result))
|
|
}
|
|
prob := math.Exp(result[0].Logprob)
|
|
totalProb += prob
|
|
|
|
// Verify each probability is between 0 and 1
|
|
if prob < 0 || prob > 1 {
|
|
t.Errorf("Token %d probability %f is out of range [0, 1]", i, prob)
|
|
}
|
|
}
|
|
|
|
// Total probability should be very close to 1.0 (allowing for floating point errors)
|
|
if math.Abs(totalProb-1.0) > 1e-5 {
|
|
t.Errorf("Total probability sum is %f, expected 1.0", totalProb)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCalculateLogprobsSelectedTokenCorrectness(t *testing.T) {
|
|
decoder := func(tokenID int) string {
|
|
return string(rune('A' + tokenID))
|
|
}
|
|
|
|
logits := []float32{3.0, 1.0, 2.0, 0.5}
|
|
|
|
// Test that selecting different tokens gives the correct probabilities
|
|
// and that the highest logit has the highest probability
|
|
maxLogitIndex := 0
|
|
maxLogitValue := logits[0]
|
|
for i, logit := range logits[1:] {
|
|
if logit > maxLogitValue {
|
|
maxLogitValue = logit
|
|
maxLogitIndex = i + 1
|
|
}
|
|
}
|
|
|
|
var maxProb float64
|
|
var maxProbIndex int
|
|
|
|
for i := range logits {
|
|
result := CalculateLogprobs(logits, i, 0, decoder)
|
|
prob := math.Exp(result[0].Logprob)
|
|
|
|
if prob > maxProb {
|
|
maxProb = prob
|
|
maxProbIndex = i
|
|
}
|
|
|
|
// Verify the token matches
|
|
expectedToken := decoder(i)
|
|
if result[0].Token != expectedToken {
|
|
t.Errorf("Token %d: expected token %q, got %q", i, expectedToken, result[0].Token)
|
|
}
|
|
}
|
|
|
|
// The token with the highest logit should have the highest probability
|
|
if maxProbIndex != maxLogitIndex {
|
|
t.Errorf("Token with highest probability (%d) doesn't match token with highest logit (%d)",
|
|
maxProbIndex, maxLogitIndex)
|
|
}
|
|
}
|
|
|
|
func TestCalculateLogprobsTopKOrdering(t *testing.T) {
|
|
tokens := map[int]string{
|
|
0: "first",
|
|
1: "second",
|
|
2: "third",
|
|
3: "fourth",
|
|
4: "fifth",
|
|
}
|
|
decoder := func(tokenID int) string {
|
|
return tokens[tokenID]
|
|
}
|
|
|
|
// Logits in non-sorted order
|
|
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
|
|
// Expected order by probability: 1 (5.0), 3 (4.0), 4 (3.0), 0 (2.0), 2 (1.0)
|
|
expectedOrder := []string{"second", "fourth", "fifth", "first", "third"}
|
|
|
|
result := CalculateLogprobs(logits, 0, 5, decoder)
|
|
|
|
if len(result) != 1 {
|
|
t.Fatalf("Expected 1 result, got %d", len(result))
|
|
}
|
|
|
|
if len(result[0].TopLogprobs) != 5 {
|
|
t.Fatalf("Expected 5 top logprobs, got %d", len(result[0].TopLogprobs))
|
|
}
|
|
|
|
// Verify ordering matches expected
|
|
for i, tlp := range result[0].TopLogprobs {
|
|
if tlp.Token != expectedOrder[i] {
|
|
t.Errorf("Position %d: expected token %q, got %q", i, expectedOrder[i], tlp.Token)
|
|
}
|
|
}
|
|
|
|
// Verify probabilities are in descending order
|
|
for i := 1; i < len(result[0].TopLogprobs); i++ {
|
|
if result[0].TopLogprobs[i].Logprob > result[0].TopLogprobs[i-1].Logprob {
|
|
t.Errorf("Probabilities not in descending order at position %d: %f > %f",
|
|
i, result[0].TopLogprobs[i].Logprob, result[0].TopLogprobs[i-1].Logprob)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestLogprobsWithStopSequences(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
pendingResponses []string
|
|
pendingLogprobs []llm.Logprob
|
|
stop string
|
|
expectedResponses []string
|
|
expectedLogprobs int
|
|
}{
|
|
{
|
|
name: "Single token stop",
|
|
pendingResponses: []string{"Hello", " world", "!"},
|
|
pendingLogprobs: []llm.Logprob{
|
|
{TokenLogprob: llm.TokenLogprob{Token: "Hello", Logprob: -0.1}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: " world", Logprob: -0.2}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: "!", Logprob: -0.3}},
|
|
},
|
|
stop: "!",
|
|
expectedResponses: []string{"Hello", " world"},
|
|
expectedLogprobs: 2,
|
|
},
|
|
{
|
|
name: "Multi-token stop sequence",
|
|
pendingResponses: []string{"Hello", " ", "there", "STOP"},
|
|
pendingLogprobs: []llm.Logprob{
|
|
{TokenLogprob: llm.TokenLogprob{Token: "Hello", Logprob: -0.1}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: " ", Logprob: -0.2}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: "there", Logprob: -0.3}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: "STOP", Logprob: -0.4}},
|
|
},
|
|
stop: "STOP",
|
|
expectedResponses: []string{"Hello", " ", "there"},
|
|
expectedLogprobs: 3,
|
|
},
|
|
{
|
|
name: "Partial token stop",
|
|
pendingResponses: []string{"Hello", " the", "re!"},
|
|
pendingLogprobs: []llm.Logprob{
|
|
{TokenLogprob: llm.TokenLogprob{Token: "Hello", Logprob: -0.1}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: " the", Logprob: -0.2}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: "re!", Logprob: -0.3}},
|
|
},
|
|
stop: "there!",
|
|
expectedResponses: []string{"Hello", " "},
|
|
expectedLogprobs: 2,
|
|
},
|
|
{
|
|
name: "Stop at beginning of last token",
|
|
pendingResponses: []string{"Hello", " world", "END"},
|
|
pendingLogprobs: []llm.Logprob{
|
|
{TokenLogprob: llm.TokenLogprob{Token: "Hello", Logprob: -0.1}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: " world", Logprob: -0.2}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: "END", Logprob: -0.3}},
|
|
},
|
|
stop: "END",
|
|
expectedResponses: []string{"Hello", " world"},
|
|
expectedLogprobs: 2,
|
|
},
|
|
{
|
|
name: "Multi-token stop across tokens",
|
|
pendingResponses: []string{"Text", " ", "with", " ", "stop", " ", "word"},
|
|
pendingLogprobs: []llm.Logprob{
|
|
{TokenLogprob: llm.TokenLogprob{Token: "Text", Logprob: -0.1}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: " ", Logprob: -0.2}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: "with", Logprob: -0.3}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: " ", Logprob: -0.4}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: "stop", Logprob: -0.5}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: " ", Logprob: -0.6}},
|
|
{TokenLogprob: llm.TokenLogprob{Token: "word", Logprob: -0.7}},
|
|
},
|
|
stop: "stop word",
|
|
expectedResponses: []string{"Text", " ", "with", " "},
|
|
expectedLogprobs: 4,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Simulate the stop sequence detection and truncation
|
|
origLen := len(tt.pendingResponses)
|
|
responses, tokenTruncated := TruncateStop(tt.pendingResponses, tt.stop)
|
|
newLen := len(responses)
|
|
|
|
// Simulate logprobs truncation
|
|
logprobs := make([]llm.Logprob, len(tt.pendingLogprobs))
|
|
copy(logprobs, tt.pendingLogprobs)
|
|
|
|
origLogprobsLen := len(logprobs)
|
|
numTokensRemoved := origLen - newLen
|
|
newLogprobsLen := origLogprobsLen - numTokensRemoved
|
|
if newLogprobsLen < 0 {
|
|
newLogprobsLen = 0
|
|
}
|
|
logprobs = logprobs[:newLogprobsLen]
|
|
|
|
// Verify responses were truncated correctly
|
|
if len(responses) != len(tt.expectedResponses) {
|
|
t.Errorf("Expected %d responses, got %d", len(tt.expectedResponses), len(responses))
|
|
}
|
|
|
|
// Verify logprobs count matches truncated responses
|
|
if len(logprobs) != tt.expectedLogprobs {
|
|
t.Errorf("Expected %d logprobs after truncation, got %d", tt.expectedLogprobs, len(logprobs))
|
|
}
|
|
|
|
// Verify logprobs count matches response count
|
|
if len(logprobs) != len(responses) {
|
|
t.Errorf("Logprobs count (%d) doesn't match responses count (%d)", len(logprobs), len(responses))
|
|
}
|
|
|
|
// Verify the correct logprobs were kept (skip last token if it was truncated)
|
|
// When tokenTruncated is true, the last response token may not match the logprob token
|
|
checkLen := len(logprobs)
|
|
if tokenTruncated && checkLen > 0 {
|
|
checkLen-- // Skip checking the last token when it was partially truncated
|
|
}
|
|
|
|
for i := range checkLen {
|
|
if i < len(responses) && logprobs[i].Token != responses[i] {
|
|
t.Errorf("Logprob[%d] token %q doesn't match response[%d] %q",
|
|
i, logprobs[i].Token, i, responses[i])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|