This commit is contained in:
Kelvin Schoofs 2025-12-15 16:59:17 -06:00 committed by GitHub
commit ae1b138a20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 74 additions and 19 deletions

View File

@ -688,17 +688,19 @@ func SchemaToGrammar(schema []byte) []byte {
cStr := C.CString(string(schema)) cStr := C.CString(string(schema))
defer C.free(unsafe.Pointer(cStr)) defer C.free(unsafe.Pointer(cStr))
// Allocate buffer for grammar based on schema length but with upper bound
maxLen := max(32768, min(1024*1024, len(schema)*4))
buf := make([]byte, maxLen)
// Call C function to convert schema to grammar // Call C function to convert schema to grammar
n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen)) result := C.schema_to_grammar(cStr)
if n == 0 { defer C.free(unsafe.Pointer(result))
// preserve nil
if int(result.length) == 0 {
// Preserve nil when we get an empty string back
return nil return nil
} }
return buf[:n]
// Copy the string to a Go version and free the char*
str := C.GoStringN(result.grammar, C.int(result.length))
C.free(unsafe.Pointer(result.grammar))
return []byte(str)
} }
type TokenData struct { type TokenData struct {

View File

@ -103,3 +103,54 @@ func TestSchemaToGrammar(t *testing.T) {
}) })
} }
} }
const growingJSONSchema = `{
"type": "object",
"properties": {
"execSchema": { "type": "object", "additionalProperties": true },
"execInstructions": { "type": "string" },
"combineSchema": { "type": "object", "additionalProperties": true },
"combineInstructions": { "type": "string" },
"question": { "type": "string" }
}
}`
const growingJSONSchemaGrammar = `root ::= "{" space (execSchema-kv execSchema-rest | execInstructions-kv execInstructions-rest | combineSchema-kv combineSchema-rest | combineInstructions-kv combineInstructions-rest | question-kv )? "}" space
execSchema-rest ::= ( "," space execInstructions-kv )? execInstructions-rest
execInstructions-rest ::= ( "," space combineSchema-kv )? combineSchema-rest
combineInstructions-rest ::= ( "," space question-kv )?
question-kv ::= "\"question\"" space ":" space string
combineInstructions-kv ::= "\"combineInstructions\"" space ":" space string
combineSchema-kv ::= "\"combineSchema\"" space ":" space combineSchema
execInstructions-kv ::= "\"execInstructions\"" space ":" space string
combineSchema-rest ::= ( "," space combineInstructions-kv )? combineInstructions-rest
space ::= | " " | "\n"{1,2} [ \t]{0,20}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
boolean ::= ("true" | "false") space
combineSchema ::= object
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
execSchema-kv ::= "\"execSchema\"" space ":" space execSchema
array ::= "[" space ( value ("," space value)* )? "]" space
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
execSchema ::= object`
func TestGrowingSchema(t *testing.T) {
g := SchemaToGrammar([]byte(growingJSONSchema))
if g == nil {
t.Fatal("failed to convert JSON schema to grammar")
}
gStr := strings.TrimSpace(string(g))
// Check using length as for some reason the macos-latest GitHub action had a different order
// and if we just check the length, it kind of ignores the order of the lines.
gLen := len(gStr)
wLen := len(growingJSONSchemaGrammar)
if gLen != wLen {
t.Errorf("length mismatch\ngot %d:\n%q\nwant %d:\n%q", gLen, g, wLen, growingJSONSchemaGrammar)
}
}

View File

@ -8,6 +8,7 @@
#include "llama-grammar.h" #include "llama-grammar.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) { struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
try { try {
common_params_sampling sparams; common_params_sampling sparams;
@ -46,25 +47,22 @@ llama_token common_sampler_csample(struct common_sampler *sampler, struct llama_
return common_sampler_sample(sampler, ctx, idx); return common_sampler_sample(sampler, ctx, idx);
} }
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len) struct parsed_grammar * schema_to_grammar(const char *json_schema)
{ {
struct parsed_grammar *result = new parsed_grammar();
try try
{ {
nlohmann::ordered_json schema = nlohmann::ordered_json::parse(json_schema); nlohmann::ordered_json schema = nlohmann::ordered_json::parse(json_schema);
std::string grammar_str = json_schema_to_grammar(schema); std::string grammar_str = json_schema_to_grammar(schema);
size_t len = grammar_str.length(); result->length = grammar_str.length();
if (len >= max_len) result->grammar = new char[result->length];
{ std::strcpy(result->grammar, grammar_str.c_str());
len = max_len - 1;
}
strncpy(grammar, grammar_str.c_str(), len);
return len;
} }
catch (const std::exception &e) catch (const std::exception &e)
{ {
strncpy(grammar, "", max_len - 1); result->length = 0;
return 0;
} }
return result;
} }
struct llama_vocab * llama_load_vocab_from_file(const char * fname) { struct llama_vocab * llama_load_vocab_from_file(const char * fname) {

View File

@ -30,7 +30,11 @@ extern "C"
void common_sampler_caccept(struct common_sampler *sampler, llama_token id, bool apply_grammar); void common_sampler_caccept(struct common_sampler *sampler, llama_token id, bool apply_grammar);
llama_token common_sampler_csample(struct common_sampler *sampler, struct llama_context *ctx, int idx); llama_token common_sampler_csample(struct common_sampler *sampler, struct llama_context *ctx, int idx);
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len); struct parsed_grammar {
char *grammar;
size_t length;
};
struct parsed_grammar *schema_to_grammar(const char *json_schema);
struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens); struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens);