whisper : add support for --carry-initial-prompt (#3395)

* Add support for --carry-initial-prompt

* PR fixes for ruby and go

* Refactoring for readability

* WIP 1

* WIP 2

* PR fixes

* More PR fixes

* PR fix

* Further simplification

* d'oh

* One more logic fix

* Update src/whisper.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Truncate prompt_past0 upon initialization

* Slight simplification

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Andreas Lubbe 2025-10-10 18:51:15 +02:00 committed by GitHub
parent a0ca50f3b9
commit 85871a9469
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 257 additions and 162 deletions

View File

@ -47,6 +47,7 @@ func (p *Params) SetPrintTimestamps(v bool) {
p.print_timestamps = toBool(v)
}
// Set language id
func (p *Params) SetLanguage(lang int) error {
if lang == -1 {
@ -146,6 +147,10 @@ func (p *Params) SetInitialPrompt(prompt string) {
p.initial_prompt = C.CString(prompt)
}
func (p *Params) SetCarryInitialPrompt(v bool) {
p.carry_initial_prompt = toBool(v)
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
@ -199,6 +204,9 @@ func (p *Params) String() string {
if p.token_timestamps {
str += " token_timestamps"
}
if p.carry_initial_prompt {
str += " carry_initial_prompt"
}
return str + ">"
}

View File

@ -157,6 +157,8 @@ public class WhisperFullParams extends Structure {
/** Tokens to provide to the whisper decoder as an initial prompt.
* These are prepended to any existing text context from a previous call. */
public String initial_prompt;
/** Always prepend initial_prompt for every decode chunk. */
public CBool carry_initial_prompt;
/** Prompt tokens. (int*) */
public Pointer prompt_tokens;
@ -337,7 +339,7 @@ public class WhisperFullParams extends Structure {
"print_progress", "print_realtime", "print_timestamps",
"token_timestamps", "thold_pt", "thold_ptsum", "max_len",
"split_on_word", "max_tokens", "debug_mode", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt",
"tdrz_enable", "suppress_regex", "initial_prompt", "carry_initial_prompt",
"prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_nst", "temperature",
"max_initial_ts", "length_penalty", "temperature_inc",

View File

@ -26,7 +26,7 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37
extern VALUE cParams;
extern VALUE cVADParams;
@ -46,6 +46,7 @@ static ID id_print_special;
static ID id_print_progress;
static ID id_print_realtime;
static ID id_print_timestamps;
static ID id_carry_initial_prompt;
static ID id_suppress_blank;
static ID id_suppress_nst;
static ID id_token_timestamps;
@ -455,6 +456,26 @@ ruby_whisper_params_get_print_timestamps(VALUE self)
{
BOOL_PARAMS_GETTER(self, print_timestamps)
}
/*
* call-seq:
* carry_initial_prompt -> true or false
*/
static VALUE
ruby_whisper_params_get_carry_initial_prompt(VALUE self)
{
BOOL_PARAMS_GETTER(self, carry_initial_prompt)
}
/*
* call-seq:
* carry_initial_prompt = bool -> bool
*/
static VALUE
ruby_whisper_params_set_carry_initial_prompt(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, carry_initial_prompt, value)
}
/*
* call-seq:
* suppress_blank = force_suppress -> force_suppress
@ -1168,6 +1189,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(max_len)
SET_PARAM_IF_SAME(split_on_word)
SET_PARAM_IF_SAME(initial_prompt)
SET_PARAM_IF_SAME(carry_initial_prompt)
SET_PARAM_IF_SAME(offset)
SET_PARAM_IF_SAME(duration)
SET_PARAM_IF_SAME(max_text_tokens)
@ -1303,28 +1325,29 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(max_len, 11)
DEFINE_PARAM(split_on_word, 12)
DEFINE_PARAM(initial_prompt, 13)
DEFINE_PARAM(diarize, 14)
DEFINE_PARAM(offset, 15)
DEFINE_PARAM(duration, 16)
DEFINE_PARAM(max_text_tokens, 17)
DEFINE_PARAM(temperature, 18)
DEFINE_PARAM(max_initial_ts, 19)
DEFINE_PARAM(length_penalty, 20)
DEFINE_PARAM(temperature_inc, 21)
DEFINE_PARAM(entropy_thold, 22)
DEFINE_PARAM(logprob_thold, 23)
DEFINE_PARAM(no_speech_thold, 24)
DEFINE_PARAM(new_segment_callback, 25)
DEFINE_PARAM(new_segment_callback_user_data, 26)
DEFINE_PARAM(progress_callback, 27)
DEFINE_PARAM(progress_callback_user_data, 28)
DEFINE_PARAM(encoder_begin_callback, 29)
DEFINE_PARAM(encoder_begin_callback_user_data, 30)
DEFINE_PARAM(abort_callback, 31)
DEFINE_PARAM(abort_callback_user_data, 32)
DEFINE_PARAM(vad, 33)
DEFINE_PARAM(vad_model_path, 34)
DEFINE_PARAM(vad_params, 35)
DEFINE_PARAM(carry_initial_prompt, 14)
DEFINE_PARAM(diarize, 15)
DEFINE_PARAM(offset, 16)
DEFINE_PARAM(duration, 17)
DEFINE_PARAM(max_text_tokens, 18)
DEFINE_PARAM(temperature, 19)
DEFINE_PARAM(max_initial_ts, 20)
DEFINE_PARAM(length_penalty, 21)
DEFINE_PARAM(temperature_inc, 22)
DEFINE_PARAM(entropy_thold, 23)
DEFINE_PARAM(logprob_thold, 24)
DEFINE_PARAM(no_speech_thold, 25)
DEFINE_PARAM(new_segment_callback, 26)
DEFINE_PARAM(new_segment_callback_user_data, 27)
DEFINE_PARAM(progress_callback, 28)
DEFINE_PARAM(progress_callback_user_data, 29)
DEFINE_PARAM(encoder_begin_callback, 30)
DEFINE_PARAM(encoder_begin_callback_user_data, 31)
DEFINE_PARAM(abort_callback, 32)
DEFINE_PARAM(abort_callback_user_data, 33)
DEFINE_PARAM(vad, 34)
DEFINE_PARAM(vad_model_path, 35)
DEFINE_PARAM(vad_params, 36)
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);

View File

@ -138,6 +138,7 @@ module Whisper
?max_len: Integer,
?split_on_word: boolish,
?initial_prompt: string | nil,
?carry_initial_prompt: boolish,
?diarize: boolish,
?offset: Integer,
?duration: Integer,
@ -236,6 +237,7 @@ module Whisper
def split_on_word: () -> (true | false)
def initial_prompt=: (_ToS) -> _ToS
def carry_initial_prompt=: (boolish) -> boolish
# Tokens to provide to the whisper decoder as initial prompt
# these are prepended to any existing text context from a previous call
@ -243,6 +245,7 @@ module Whisper
# Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
#
def initial_prompt: () -> (String | nil)
def carry_initial_prompt: () -> (true | false)
def diarize=: (boolish) -> boolish

View File

@ -16,6 +16,7 @@ class TestParams < TestBase
:max_len,
:split_on_word,
:initial_prompt,
:carry_initial_prompt,
:diarize,
:offset,
:duration,
@ -119,6 +120,13 @@ class TestParams < TestBase
assert !@params.print_timestamps
end
def test_carry_initial_prompt
@params.carry_initial_prompt = true
assert @params.carry_initial_prompt
@params.carry_initial_prompt = false
assert !@params.carry_initial_prompt
end
def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank

View File

@ -5,6 +5,7 @@
#include "grammar-parser.h"
#include <cmath>
#include <algorithm>
#include <fstream>
#include <cstdio>
#include <string>
@ -77,6 +78,7 @@ struct whisper_params {
bool use_gpu = true;
bool flash_attn = true;
bool suppress_nst = false;
bool carry_initial_prompt = false;
std::string language = "en";
std::string prompt;
@ -175,17 +177,18 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
else if (arg == "-ojf" || arg == "--output-json-full") { params.output_jsn_full = params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); }
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if ( arg == "--print-confidence"){ params.print_confidence= true; }
else if ( arg == "--print-confidence") { params.print_confidence= true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; }
else if ( arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; }
else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; }
@ -266,6 +269,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
fprintf(stderr, " --carry-initial-prompt [%-7s] always prepend initial prompt\n", params.carry_initial_prompt ? "true" : "false");
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input audio file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
@ -387,7 +391,11 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
const int n_colors = (int) k_colors.size();
int raw_col = (int) (std::pow(p, 3)*float(n_colors));
if (raw_col < 0) raw_col = 0;
if (raw_col > n_colors - 1) raw_col = n_colors - 1;
const int col = raw_col;
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
@ -1179,6 +1187,7 @@ int main(int argc, char ** argv) {
wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();
wparams.initial_prompt = params.prompt.c_str();
wparams.carry_initial_prompt = params.carry_initial_prompt;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;

View File

@ -525,6 +525,7 @@ extern "C" {
// use whisper_tokenize() to convert text to tokens
// maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
const char * initial_prompt;
bool carry_initial_prompt; // if true, always prepend initial_prompt to every decode window (may reduce conditioning on previous text)
const whisper_token * prompt_tokens;
int prompt_n_tokens;

View File

@ -140,6 +140,10 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
} while (0)
#define WHISPER_MAX_DECODERS 8
// temperature below which we condition on past text history
static constexpr float WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF = 0.5f;
#define WHISPER_MAX_NODES 4096
static std::string format(const char * fmt, ...) {
@ -882,7 +886,10 @@ struct whisper_state {
std::vector<float> logits;
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
// prompt history split into static prefix (prompt_past0) and dynamic rolling context (prompt_past1)
std::vector<whisper_token> prompt_past0; // static carried initial prompt (if enabled)
std::vector<whisper_token> prompt_past1; // dynamic context from decoded output
int lang_id = 0; // english by default
@ -5923,6 +5930,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/* suppress_regex =*/ nullptr,
/*.initial_prompt =*/ nullptr,
/*.carry_initial_prompt =*/ false,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
@ -6880,17 +6888,22 @@ int whisper_full_with_state(
decoder.rng = std::mt19937(j);
}
// the accumulated text context so far
auto & prompt_past = state->prompt_past;
// the accumulated text context split into static (prompt_past0) and dynamic (prompt_past1)
auto & prompt_past0 = state->prompt_past0;
auto & prompt_past1 = state->prompt_past1;
if (params.no_context) {
prompt_past.clear();
prompt_past0.clear();
prompt_past1.clear();
}
// calculate the maximum context budget for prompt history
const int max_prompt_ctx = std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2);
// prepare prompt
{
std::vector<whisper_token> prompt_tokens;
// initial prompt
// tokenize the initial prompt
if (!params.prompt_tokens && params.initial_prompt) {
prompt_tokens.resize(1024);
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
@ -6902,14 +6915,25 @@ int whisper_full_with_state(
params.prompt_tokens = prompt_tokens.data();
params.prompt_n_tokens = prompt_tokens.size();
}
// prepend the prompt tokens to the prompt_past
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
// parse tokens from the pointer
for (int i = 0; i < params.prompt_n_tokens; i++) {
prompt_past.push_back(params.prompt_tokens[i]);
if (params.carry_initial_prompt) {
if (prompt_past0.empty()) {
const int max_tokens = std::max(1, max_prompt_ctx - 1);
if (params.prompt_n_tokens > max_tokens) {
WHISPER_LOG_WARN("%s: initial prompt is too long (%d tokens), will use only the last %d tokens\n",
__func__, params.prompt_n_tokens, max_tokens);
}
const int n_tokens = std::min(params.prompt_n_tokens, max_tokens);
prompt_past0.assign(params.prompt_tokens + (params.prompt_n_tokens - n_tokens), params.prompt_tokens + params.prompt_n_tokens);
}
} else {
for (int i = 0; i < params.prompt_n_tokens; ++i) {
prompt_past1.push_back(params.prompt_tokens[i]);
}
std::rotate(prompt_past1.begin(), prompt_past1.end() - params.prompt_n_tokens, prompt_past1.end());
}
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
}
}
@ -6995,7 +7019,8 @@ int whisper_full_with_state(
// if there is a very short audio segment left to process, we remove any past prompt since it tends
// to confuse the decoder and often make it repeat or hallucinate stuff
if (seek > seek_start && seek + 500 >= seek_end) {
prompt_past.clear();
prompt_past0.clear();
prompt_past1.clear();
}
int best_decoder_id = 0;
@ -7056,12 +7081,25 @@ int whisper_full_with_state(
{
prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation
if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
if (params.n_max_text_ctx > 0 && t_cur < WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF) {
const bool can_take0 = params.carry_initial_prompt && !prompt_past0.empty();
const bool can_take1 = !prompt_past1.empty();
prompt = { whisper_token_prev(ctx) };
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
if (max_prompt_ctx > 0 && (can_take0 || can_take1)) {
// Always start with previous token marker to connect continuity
prompt.push_back(whisper_token_prev(ctx));
// Take static tokens (initial prompt) first
int n_take0 = 0;
if (can_take0) {
n_take0 = prompt_past0.size();
prompt.insert(prompt.end(), prompt_past0.end() - n_take0, prompt_past0.end());
}
// Fill remaining budget with dynamic tokens (rolling context)
const int n_take1 = std::min<int>(max_prompt_ctx - n_take0 - 1, prompt_past1.size());
prompt.insert(prompt.end(), prompt_past1.end() - n_take1, prompt_past1.end());
}
}
// init new transcription with sot, language (opt) and task tokens
@ -7543,14 +7581,17 @@ int whisper_full_with_state(
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
// update prompt_past
prompt_past.clear();
if (prompt.front() == whisper_token_prev(ctx)) {
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
// update prompt_past1
prompt_past1.clear();
if (!params.carry_initial_prompt && !prompt.empty() && prompt.front() == whisper_token_prev(ctx)) {
prompt_past1.insert(prompt_past1.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
}
for (int i = 0; i < result_len && !is_no_speech; ++i) {
prompt_past.push_back(tokens_cur[i].id);
// Add newly decoded tokens to the rolling context
if (!is_no_speech) {
for (int i = 0; i < result_len; ++i) {
prompt_past1.push_back(tokens_cur[i].id);
}
}
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {