ollama/runner/common/logprob.go

80 lines
1.9 KiB
Go

package common
import (
"math"
"sort"
"github.com/ollama/ollama/llm"
)
// TokenDecoderFunc is a function that converts token IDs to text.
type TokenDecoderFunc func(tokenID int) string
// CalculateLogprobs converts raw logits to log probabilities and finds top K tokens.
// It uses numerically stable softmax to compute log probabilities.
func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder TokenDecoderFunc) []llm.Logprob {
if len(logits) == 0 {
return nil
}
// Step 1: Convert logits to log probabilities using numerically stable softmax
maxLogit := logits[0]
for _, logit := range logits[1:] {
if logit > maxLogit {
maxLogit = logit
}
}
var sumExp float64
for _, logit := range logits {
sumExp += math.Exp(float64(logit - maxLogit))
}
logSumExp := float32(math.Log(sumExp))
logProbs := make([]float32, len(logits))
for i, logit := range logits {
logProbs[i] = (logit - maxLogit) - logSumExp
}
// Step 2: Get selected token's information
selectedLogprob := logProbs[selectedToken]
selectedText := decoder(selectedToken)
result := llm.Logprob{
TokenLogprob: llm.TokenLogprob{
Token: selectedText,
Logprob: float64(selectedLogprob),
},
}
// Step 3: If topK requested, find the top K tokens
if topK > 0 {
type tokenLogprobPair struct {
tokenID int
logprob float32
}
pairs := make([]tokenLogprobPair, len(logProbs))
for i, lp := range logProbs {
pairs[i] = tokenLogprobPair{tokenID: i, logprob: lp}
}
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].logprob > pairs[j].logprob
})
k := min(topK, len(pairs))
topLogprobs := make([]llm.TokenLogprob, k)
for i := range k {
tokenText := decoder(pairs[i].tokenID)
topLogprobs[i] = llm.TokenLogprob{
Token: tokenText,
Logprob: float64(pairs[i].logprob),
}
}
result.TopLogprobs = topLogprobs
}
return []llm.Logprob{result}
}