mirror of https://github.com/ollama/ollama
renderers/parsers: olmo3 instruct (#13383)
This commit is contained in:
parent
0c5e5f6630
commit
2bccf8c624
|
|
@ -0,0 +1,465 @@
|
|||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type olmo3ParserState int
|
||||
|
||||
const (
|
||||
olmo3StateContent olmo3ParserState = iota
|
||||
olmo3StateToolCalls
|
||||
olmo3StateToolCallsDone
|
||||
)
|
||||
|
||||
const (
|
||||
olmo3FuncCallsOpenTag = "<function_calls>"
|
||||
olmo3FuncCallsCloseTag = "</function_calls>"
|
||||
)
|
||||
|
||||
type Olmo3Parser struct {
|
||||
state olmo3ParserState
|
||||
buffer strings.Builder
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.state = olmo3StateContent
|
||||
return tools
|
||||
}
|
||||
|
||||
type olmo3ParserEvent interface {
|
||||
isOlmo3ParserEvent()
|
||||
}
|
||||
|
||||
type olmo3ParserEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type olmo3ParserEventToolCalls struct {
|
||||
calls []api.ToolCall
|
||||
}
|
||||
|
||||
func (olmo3ParserEventContent) isOlmo3ParserEvent() {}
|
||||
func (olmo3ParserEventToolCalls) isOlmo3ParserEvent() {}
|
||||
|
||||
func (p *Olmo3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
if done {
|
||||
// Drain any remaining content
|
||||
bufStr := p.buffer.String()
|
||||
p.buffer.Reset()
|
||||
if p.state == olmo3StateContent && len(bufStr) > 0 {
|
||||
return bufStr, "", nil, nil
|
||||
}
|
||||
return "", "", nil, nil
|
||||
}
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var allCalls []api.ToolCall
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case olmo3ParserEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
case olmo3ParserEventToolCalls:
|
||||
allCalls = append(allCalls, event.calls...)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), "", allCalls, nil
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) parseEvents() []olmo3ParserEvent {
|
||||
var all []olmo3ParserEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []olmo3ParserEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) eat() ([]olmo3ParserEvent, bool) {
|
||||
var events []olmo3ParserEvent
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case olmo3StateContent:
|
||||
if strings.Contains(bufStr, olmo3FuncCallsOpenTag) {
|
||||
// Found <function_calls> tag
|
||||
split := strings.SplitN(bufStr, olmo3FuncCallsOpenTag, 2)
|
||||
content := split[0]
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = olmo3StateToolCalls
|
||||
|
||||
if len(content) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: content})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, olmo3FuncCallsOpenTag); overlapLen > 0 {
|
||||
// Partial <function_calls> tag - withhold ambiguous content
|
||||
unambiguous := bufStr[:len(bufStr)-overlapLen]
|
||||
ambiguous := bufStr[len(bufStr)-overlapLen:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Regular content - emit all
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case olmo3StateToolCalls:
|
||||
if strings.Contains(bufStr, olmo3FuncCallsCloseTag) {
|
||||
// Found </function_calls> tag
|
||||
split := strings.SplitN(bufStr, olmo3FuncCallsCloseTag, 2)
|
||||
toolCallsStr := split[0]
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = olmo3StateToolCallsDone
|
||||
|
||||
// Parse the function calls
|
||||
calls, err := parseOlmo3FunctionCalls(toolCallsStr)
|
||||
if err != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "failed to parse olmo3 function calls", "error", err, "content", toolCallsStr)
|
||||
} else if len(calls) > 0 {
|
||||
events = append(events, olmo3ParserEventToolCalls{calls: calls})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, olmo3FuncCallsCloseTag); overlapLen > 0 {
|
||||
// Partial </function_calls> tag - wait for more
|
||||
return events, false
|
||||
}
|
||||
// Still collecting tool calls, wait for close tag
|
||||
return events, false
|
||||
|
||||
case olmo3StateToolCallsDone:
|
||||
// After tool calls, emit remaining content
|
||||
p.buffer.Reset()
|
||||
p.state = olmo3StateContent
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseOlmo3FunctionCalls parses function calls in Python-esque format:
|
||||
// func_name(arg1="value1", arg2=123)
|
||||
// Multiple calls are separated by newlines
|
||||
func parseOlmo3FunctionCalls(s string) ([]api.ToolCall, error) {
|
||||
var calls []api.ToolCall
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return calls, nil
|
||||
}
|
||||
|
||||
// Split by newlines for multiple function calls
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
call, err := parseOlmo3SingleFunctionCall(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse function call %q: %w", line, err)
|
||||
}
|
||||
calls = append(calls, call)
|
||||
}
|
||||
|
||||
return calls, nil
|
||||
}
|
||||
|
||||
// Regex to match function call: func_name(args)
|
||||
var funcCallRegex = regexp.MustCompile(`^(\w+)\((.*)\)$`)
|
||||
|
||||
func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
|
||||
matches := funcCallRegex.FindStringSubmatch(s)
|
||||
if matches == nil {
|
||||
return api.ToolCall{}, fmt.Errorf("invalid function call format")
|
||||
}
|
||||
|
||||
funcName := matches[1]
|
||||
argsStr := matches[2]
|
||||
|
||||
args, err := parseOlmo3Arguments(argsStr)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse arguments: %w", err)
|
||||
}
|
||||
|
||||
return api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: funcName,
|
||||
Arguments: args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseOlmo3Arguments parses comma-separated key=value pairs
|
||||
// Handles nested parentheses, brackets, braces, and quoted strings
|
||||
func parseOlmo3Arguments(s string) (map[string]any, error) {
|
||||
args := make(map[string]any)
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// Split by commas, but respect nested structures and quotes
|
||||
parts := splitArguments(s)
|
||||
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the first = sign
|
||||
eqIdx := strings.Index(part, "=")
|
||||
if eqIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid argument format: %s", part)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(part[:eqIdx])
|
||||
valueStr := strings.TrimSpace(part[eqIdx+1:])
|
||||
|
||||
value, err := parseOlmo3Value(valueStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
||||
}
|
||||
|
||||
args[key] = value
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// splitArguments splits arguments by commas, respecting quotes and nested structures
|
||||
func splitArguments(s string) []string {
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
depth := 0
|
||||
inString := false
|
||||
stringChar := byte(0)
|
||||
escaped := false
|
||||
|
||||
for i := range s {
|
||||
c := s[i]
|
||||
|
||||
if escaped {
|
||||
current.WriteByte(c)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '\\' && inString {
|
||||
current.WriteByte(c)
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if (c == '"' || c == '\'') && !inString {
|
||||
inString = true
|
||||
stringChar = c
|
||||
current.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if c == stringChar && inString {
|
||||
inString = false
|
||||
stringChar = 0
|
||||
current.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if !inString {
|
||||
switch c {
|
||||
case '(', '[', '{':
|
||||
depth++
|
||||
current.WriteByte(c)
|
||||
case ')', ']', '}':
|
||||
depth--
|
||||
current.WriteByte(c)
|
||||
case ',':
|
||||
if depth == 0 {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
continue
|
||||
}
|
||||
current.WriteByte(c)
|
||||
default:
|
||||
current.WriteByte(c)
|
||||
}
|
||||
} else {
|
||||
current.WriteByte(c)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// parseOlmo3Value parses a value which can be a string, number, boolean, null, array, or object
|
||||
func parseOlmo3Value(s string) (any, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Check for quoted string
|
||||
if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) ||
|
||||
(strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) {
|
||||
// Remove quotes and unescape
|
||||
inner := s[1 : len(s)-1]
|
||||
return unescapeString(inner), nil
|
||||
}
|
||||
|
||||
// Check for boolean
|
||||
if s == "true" || s == "True" {
|
||||
return true, nil
|
||||
}
|
||||
if s == "false" || s == "False" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check for null/None
|
||||
if s == "null" || s == "None" || s == "nil" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check for number
|
||||
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
return i, nil
|
||||
}
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Check for array [...]
|
||||
if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") {
|
||||
return parseOlmo3Array(s[1 : len(s)-1])
|
||||
}
|
||||
|
||||
// Check for object {...}
|
||||
if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") {
|
||||
return parseOlmo3Object(s[1 : len(s)-1])
|
||||
}
|
||||
|
||||
// Default to string without quotes
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func parseOlmo3Array(s string) ([]any, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return []any{}, nil
|
||||
}
|
||||
|
||||
parts := splitArguments(s)
|
||||
var arr []any
|
||||
for _, part := range parts {
|
||||
val, err := parseOlmo3Value(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, val)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
func parseOlmo3Object(s string) (map[string]any, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
// Objects use key: value or "key": value format
|
||||
obj := make(map[string]any)
|
||||
parts := splitArguments(s)
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find colon separator
|
||||
colonIdx := strings.Index(part, ":")
|
||||
if colonIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid object entry: %s", part)
|
||||
}
|
||||
|
||||
keyStr := strings.TrimSpace(part[:colonIdx])
|
||||
valueStr := strings.TrimSpace(part[colonIdx+1:])
|
||||
|
||||
// Remove quotes from key if present
|
||||
if (strings.HasPrefix(keyStr, `"`) && strings.HasSuffix(keyStr, `"`)) ||
|
||||
(strings.HasPrefix(keyStr, `'`) && strings.HasSuffix(keyStr, `'`)) {
|
||||
keyStr = keyStr[1 : len(keyStr)-1]
|
||||
}
|
||||
|
||||
val, err := parseOlmo3Value(valueStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse value for key %s: %w", keyStr, err)
|
||||
}
|
||||
|
||||
obj[keyStr] = val
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func unescapeString(s string) string {
|
||||
// Handle common escape sequences
|
||||
s = strings.ReplaceAll(s, `\\`, "\x00") // Placeholder for backslash
|
||||
s = strings.ReplaceAll(s, `\"`, `"`)
|
||||
s = strings.ReplaceAll(s, `\'`, `'`)
|
||||
s = strings.ReplaceAll(s, `\n`, "\n")
|
||||
s = strings.ReplaceAll(s, `\t`, "\t")
|
||||
s = strings.ReplaceAll(s, `\r`, "\r")
|
||||
s = strings.ReplaceAll(s, "\x00", `\`) // Restore backslash
|
||||
return s
|
||||
}
|
||||
|
|
@ -0,0 +1,483 @@
|
|||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOlmo3Parser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "simple content",
|
||||
input: "Hello, how can I help you?",
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "simple tool call",
|
||||
input: `<function_calls>get_weather(location="San Francisco")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "content then tool call",
|
||||
input: `Let me check the weather.<function_calls>get_weather(location="NYC")</function_calls>`,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple arguments",
|
||||
input: `<function_calls>book_flight(from="SFO", to="NYC", date="2024-01-15")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
"date": "2024-01-15",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
input: `<function_calls>get_weather(location="San Francisco")
|
||||
get_weather(location="New York")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with numeric argument",
|
||||
input: `<function_calls>set_temperature(value=72)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temperature",
|
||||
Arguments: map[string]any{"value": int64(72)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with float argument",
|
||||
input: `<function_calls>set_price(amount=19.99)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_price",
|
||||
Arguments: map[string]any{"amount": 19.99},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with boolean argument",
|
||||
input: `<function_calls>toggle_setting(enabled=true)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "toggle_setting",
|
||||
Arguments: map[string]any{"enabled": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with null argument",
|
||||
input: `<function_calls>clear_value(field=null)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "clear_value",
|
||||
Arguments: map[string]any{"field": nil},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with array argument",
|
||||
input: `<function_calls>process_items(items=["apple", "banana", "cherry"])</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_items",
|
||||
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with dict argument",
|
||||
input: `<function_calls>update_config(settings={"theme": "dark", "fontSize": 14})</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "update_config",
|
||||
Arguments: map[string]any{
|
||||
"settings": map[string]any{
|
||||
"theme": "dark",
|
||||
"fontSize": int64(14),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with nested dict",
|
||||
input: `<function_calls>create_request(data={"user": {"name": "John", "age": 30}, "active": true})</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_request",
|
||||
Arguments: map[string]any{
|
||||
"data": map[string]any{
|
||||
"user": map[string]any{
|
||||
"name": "John",
|
||||
"age": int64(30),
|
||||
},
|
||||
"active": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with no arguments",
|
||||
input: `<function_calls>get_current_time()</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_time",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with single quotes",
|
||||
input: `<function_calls>search(query='hello world')</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": "hello world"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with escaped quotes",
|
||||
input: `<function_calls>search(query="say \"hello\"")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": `say "hello"`},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with mixed argument types",
|
||||
input: `<function_calls>create_user(name="John", age=30, active=true)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_user",
|
||||
Arguments: map[string]any{
|
||||
"name": "John",
|
||||
"age": int64(30),
|
||||
"active": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
p.Init(nil, nil, nil)
|
||||
|
||||
content, thinking, calls, err := p.Add(tt.input, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Drain remaining content
|
||||
finalContent, finalThinking, finalCalls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
content += finalContent
|
||||
thinking += finalThinking
|
||||
calls = append(calls, finalCalls...)
|
||||
|
||||
if diff := cmp.Diff(content, tt.expectedContent); diff != "" {
|
||||
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
expectedContent string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "streaming content",
|
||||
chunks: []string{"Hello, ", "how ", "can I help?"},
|
||||
expectedContent: "Hello, how can I help?",
|
||||
},
|
||||
{
|
||||
name: "streaming tool call",
|
||||
chunks: []string{"<function_", "calls>get_weather", "(location=\"SF\")", "</function_calls>"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming content then tool call",
|
||||
chunks: []string{"Let me check.", "<function_calls>", "get_weather(location=\"NYC\")", "</function_calls>"},
|
||||
expectedContent: "Let me check.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call tag split across chunks",
|
||||
chunks: []string{"<func", "tion_calls>test()</function_calls>"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
p.Init(nil, nil, nil)
|
||||
|
||||
var allContent string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for _, chunk := range tt.chunks {
|
||||
content, _, calls, err := p.Add(chunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
// Drain
|
||||
content, _, calls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allCalls = append(allCalls, calls...)
|
||||
|
||||
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
|
||||
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3Parser_HasToolSupport(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
if !p.HasToolSupport() {
|
||||
t.Error("expected HasToolSupport to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3Parser_HasThinkingSupport(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
if p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOlmo3FunctionCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []api.ToolCall
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple call",
|
||||
input: `get_weather(location="SF")`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple args",
|
||||
input: `send_email(to="user@example.com", subject="Hello", body="Test message")`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "send_email",
|
||||
Arguments: map[string]any{
|
||||
"to": "user@example.com",
|
||||
"subject": "Hello",
|
||||
"body": "Test message",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple calls with newlines",
|
||||
input: `get_weather(location="SF")
|
||||
get_time(timezone="PST")`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: map[string]any{"timezone": "PST"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " \n ",
|
||||
expected: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
calls, err := parseOlmo3FunctionCalls(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expected); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOlmo3Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected any
|
||||
}{
|
||||
{"string double quotes", `"hello"`, "hello"},
|
||||
{"string single quotes", `'hello'`, "hello"},
|
||||
{"integer", "42", int64(42)},
|
||||
{"negative integer", "-10", int64(-10)},
|
||||
{"float", "3.14", 3.14},
|
||||
{"boolean true", "true", true},
|
||||
{"boolean True", "True", true},
|
||||
{"boolean false", "false", false},
|
||||
{"null", "null", nil},
|
||||
{"None", "None", nil},
|
||||
{"empty array", "[]", []any{}},
|
||||
{"array with strings", `["a", "b"]`, []any{"a", "b"}},
|
||||
{"array with numbers", "[1, 2, 3]", []any{int64(1), int64(2), int64(3)}},
|
||||
{"empty object", "{}", map[string]any{}},
|
||||
{"simple object", `{"name": "John"}`, map[string]any{"name": "John"}},
|
||||
{"object with number", `{"age": 30}`, map[string]any{"age": int64(30)}},
|
||||
{"object with multiple keys", `{"a": 1, "b": 2}`, map[string]any{"a": int64(1), "b": int64(2)}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parseOlmo3Value(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(result, tt.expected); diff != "" {
|
||||
t.Errorf("value mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -58,6 +58,8 @@ func ParserForName(name string) Parser {
|
|||
return harmony.NewHarmonyMessageHandler()
|
||||
case "cogito":
|
||||
return &CogitoParser{}
|
||||
case "olmo3":
|
||||
return &Olmo3Parser{}
|
||||
case "olmo3-think":
|
||||
return &Olmo3ThinkParser{}
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,147 @@
|
|||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
|
||||
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
|
||||
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
|
||||
)
|
||||
|
||||
type Olmo3Renderer struct{}
|
||||
|
||||
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
var systemMessage *api.Message
|
||||
filteredMessages := make([]api.Message, 0, len(messages))
|
||||
for i, message := range messages {
|
||||
if message.Role == "system" {
|
||||
if systemMessage == nil {
|
||||
systemMessage = &messages[i]
|
||||
}
|
||||
continue
|
||||
}
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
}
|
||||
|
||||
// Render system message
|
||||
if systemMessage != nil {
|
||||
// Custom system message - single newline after "system"
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(systemMessage.Content)
|
||||
|
||||
if len(tools) > 0 {
|
||||
functionsJSON, err := marshalWithSpaces(tools)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString("<functions>")
|
||||
sb.WriteString(string(functionsJSON))
|
||||
sb.WriteString("</functions>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
} else {
|
||||
// Default system message - single newline after "system"
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(olmo3DefaultSystemMessage)
|
||||
|
||||
if len(tools) > 0 {
|
||||
functionsJSON, err := marshalWithSpaces(tools)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(olmo3WithFunctionsMessage)
|
||||
sb.WriteString("<functions>")
|
||||
sb.WriteString(string(functionsJSON))
|
||||
sb.WriteString("</functions>")
|
||||
} else {
|
||||
sb.WriteString(olmo3NoFunctionsMessage)
|
||||
sb.WriteString("<functions></functions>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
for i, message := range filteredMessages {
|
||||
lastMessage := i == len(filteredMessages)-1
|
||||
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
sb.WriteString("<function_calls>")
|
||||
for j, tc := range message.ToolCalls {
|
||||
// Format as function_name(arg1="value1", arg2="value2")
|
||||
sb.WriteString(tc.Function.Name)
|
||||
sb.WriteString("(")
|
||||
|
||||
// Get sorted keys for deterministic output
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for k, key := range keys {
|
||||
if k > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
value, err := json.Marshal(tc.Function.Arguments[key])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%s=%s", key, string(value)))
|
||||
}
|
||||
sb.WriteString(")")
|
||||
|
||||
if j < len(message.ToolCalls)-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
sb.WriteString("</function_calls>")
|
||||
}
|
||||
|
||||
// Add end tag unless it's the last message with content only (prefill)
|
||||
if !lastMessage || len(message.ToolCalls) > 0 {
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
case "tool":
|
||||
sb.WriteString("<|im_start|>environment\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt if needed
|
||||
needsGenerationPrompt := true
|
||||
if len(filteredMessages) > 0 {
|
||||
lastMsg := filteredMessages[len(filteredMessages)-1]
|
||||
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
|
||||
needsGenerationPrompt = false
|
||||
}
|
||||
}
|
||||
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOlmo3Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic without system - adds default system",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "with system message no tools",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "with system message and tools",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "default system with tools - includes function instruction",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "What is the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. " +
|
||||
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls - function call syntax",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check the weather.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather in SF?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"Hi there!<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"How are you?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls - newline separated",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Get weather in SF and NYC"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. " +
|
||||
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||
`<functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Get weather in SF and NYC<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`<function_calls>get_weather(location="San Francisco")` + "\n" +
|
||||
`get_weather(location="New York")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 55}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple arguments",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Book a flight"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "book_flight",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"from": {Type: api.PropertyType{"string"}},
|
||||
"to": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. " +
|
||||
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||
`<functions>[{"type": "function", "function": {"name": "book_flight", "parameters": {"type": "object", "properties": {"from": {"type": "string"}, "to": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Book a flight<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "assistant prefill - no generation prompt",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"Hi there!",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := (&Olmo3Renderer{}).Render(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -59,6 +59,9 @@ func rendererForName(name string) Renderer {
|
|||
case "cogito":
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
return renderer
|
||||
case "olmo3":
|
||||
renderer := &Olmo3Renderer{}
|
||||
return renderer
|
||||
case "olmo3-think":
|
||||
renderer := &Olmo3ThinkRenderer{}
|
||||
return renderer
|
||||
|
|
|
|||
Loading…
Reference in New Issue