mirror of https://github.com/ollama/ollama
80 lines
1.9 KiB
Go
80 lines
1.9 KiB
Go
package common
|
|
|
|
import (
|
|
"math"
|
|
"sort"
|
|
|
|
"github.com/ollama/ollama/llm"
|
|
)
|
|
|
|
// TokenDecoderFunc is a function that converts token IDs to text.
|
|
type TokenDecoderFunc func(tokenID int) string
|
|
|
|
// CalculateLogprobs converts raw logits to log probabilities and finds top K tokens.
|
|
// It uses numerically stable softmax to compute log probabilities.
|
|
func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder TokenDecoderFunc) []llm.Logprob {
|
|
if len(logits) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Step 1: Convert logits to log probabilities using numerically stable softmax
|
|
maxLogit := logits[0]
|
|
for _, logit := range logits[1:] {
|
|
if logit > maxLogit {
|
|
maxLogit = logit
|
|
}
|
|
}
|
|
|
|
var sumExp float64
|
|
for _, logit := range logits {
|
|
sumExp += math.Exp(float64(logit - maxLogit))
|
|
}
|
|
logSumExp := float32(math.Log(sumExp))
|
|
|
|
logProbs := make([]float32, len(logits))
|
|
for i, logit := range logits {
|
|
logProbs[i] = (logit - maxLogit) - logSumExp
|
|
}
|
|
|
|
// Step 2: Get selected token's information
|
|
selectedLogprob := logProbs[selectedToken]
|
|
selectedText := decoder(selectedToken)
|
|
|
|
result := llm.Logprob{
|
|
TokenLogprob: llm.TokenLogprob{
|
|
Token: selectedText,
|
|
Logprob: float64(selectedLogprob),
|
|
},
|
|
}
|
|
|
|
// Step 3: If topK requested, find the top K tokens
|
|
if topK > 0 {
|
|
type tokenLogprobPair struct {
|
|
tokenID int
|
|
logprob float32
|
|
}
|
|
|
|
pairs := make([]tokenLogprobPair, len(logProbs))
|
|
for i, lp := range logProbs {
|
|
pairs[i] = tokenLogprobPair{tokenID: i, logprob: lp}
|
|
}
|
|
|
|
sort.Slice(pairs, func(i, j int) bool {
|
|
return pairs[i].logprob > pairs[j].logprob
|
|
})
|
|
|
|
k := min(topK, len(pairs))
|
|
topLogprobs := make([]llm.TokenLogprob, k)
|
|
for i := range k {
|
|
tokenText := decoder(pairs[i].tokenID)
|
|
topLogprobs[i] = llm.TokenLogprob{
|
|
Token: tokenText,
|
|
Logprob: float64(pairs[i].logprob),
|
|
}
|
|
}
|
|
result.TopLogprobs = topLogprobs
|
|
}
|
|
|
|
return []llm.Logprob{result}
|
|
}
|