whisper : add support for --carry-initial-prompt (#3395)
* Add support for --carry-initial-prompt * PR fixes for ruby and go * Refactoring for readability * WIP 1 * WIP 2 * PR fixes * More PR fixes * PR fix * Further simplification * d'oh * One more logic fix * Update src/whisper.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Truncate prompt_past0 upon initialization * Slight simplification --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
a0ca50f3b9
commit
85871a9469
|
|
@ -47,6 +47,7 @@ func (p *Params) SetPrintTimestamps(v bool) {
|
|||
p.print_timestamps = toBool(v)
|
||||
}
|
||||
|
||||
|
||||
// Set language id
|
||||
func (p *Params) SetLanguage(lang int) error {
|
||||
if lang == -1 {
|
||||
|
|
@ -146,6 +147,10 @@ func (p *Params) SetInitialPrompt(prompt string) {
|
|||
p.initial_prompt = C.CString(prompt)
|
||||
}
|
||||
|
||||
func (p *Params) SetCarryInitialPrompt(v bool) {
|
||||
p.carry_initial_prompt = toBool(v)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
|
|
@ -199,6 +204,9 @@ func (p *Params) String() string {
|
|||
if p.token_timestamps {
|
||||
str += " token_timestamps"
|
||||
}
|
||||
if p.carry_initial_prompt {
|
||||
str += " carry_initial_prompt"
|
||||
}
|
||||
|
||||
return str + ">"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -157,6 +157,8 @@ public class WhisperFullParams extends Structure {
|
|||
/** Tokens to provide to the whisper decoder as an initial prompt.
|
||||
* These are prepended to any existing text context from a previous call. */
|
||||
public String initial_prompt;
|
||||
/** Always prepend initial_prompt for every decode chunk. */
|
||||
public CBool carry_initial_prompt;
|
||||
|
||||
/** Prompt tokens. (int*) */
|
||||
public Pointer prompt_tokens;
|
||||
|
|
@ -337,7 +339,7 @@ public class WhisperFullParams extends Structure {
|
|||
"print_progress", "print_realtime", "print_timestamps",
|
||||
"token_timestamps", "thold_pt", "thold_ptsum", "max_len",
|
||||
"split_on_word", "max_tokens", "debug_mode", "audio_ctx",
|
||||
"tdrz_enable", "suppress_regex", "initial_prompt",
|
||||
"tdrz_enable", "suppress_regex", "initial_prompt", "carry_initial_prompt",
|
||||
"prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||
"suppress_blank", "suppress_nst", "temperature",
|
||||
"max_initial_ts", "length_penalty", "temperature_inc",
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
|
||||
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
|
||||
|
||||
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36
|
||||
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37
|
||||
|
||||
extern VALUE cParams;
|
||||
extern VALUE cVADParams;
|
||||
|
|
@ -46,6 +46,7 @@ static ID id_print_special;
|
|||
static ID id_print_progress;
|
||||
static ID id_print_realtime;
|
||||
static ID id_print_timestamps;
|
||||
static ID id_carry_initial_prompt;
|
||||
static ID id_suppress_blank;
|
||||
static ID id_suppress_nst;
|
||||
static ID id_token_timestamps;
|
||||
|
|
@ -455,6 +456,26 @@ ruby_whisper_params_get_print_timestamps(VALUE self)
|
|||
{
|
||||
BOOL_PARAMS_GETTER(self, print_timestamps)
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* carry_initial_prompt -> true or false
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_params_get_carry_initial_prompt(VALUE self)
|
||||
{
|
||||
BOOL_PARAMS_GETTER(self, carry_initial_prompt)
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* carry_initial_prompt = bool -> bool
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_params_set_carry_initial_prompt(VALUE self, VALUE value)
|
||||
{
|
||||
BOOL_PARAMS_SETTER(self, carry_initial_prompt, value)
|
||||
}
|
||||
/*
|
||||
* call-seq:
|
||||
* suppress_blank = force_suppress -> force_suppress
|
||||
|
|
@ -1168,6 +1189,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
|
|||
SET_PARAM_IF_SAME(max_len)
|
||||
SET_PARAM_IF_SAME(split_on_word)
|
||||
SET_PARAM_IF_SAME(initial_prompt)
|
||||
SET_PARAM_IF_SAME(carry_initial_prompt)
|
||||
SET_PARAM_IF_SAME(offset)
|
||||
SET_PARAM_IF_SAME(duration)
|
||||
SET_PARAM_IF_SAME(max_text_tokens)
|
||||
|
|
@ -1303,28 +1325,29 @@ init_ruby_whisper_params(VALUE *mWhisper)
|
|||
DEFINE_PARAM(max_len, 11)
|
||||
DEFINE_PARAM(split_on_word, 12)
|
||||
DEFINE_PARAM(initial_prompt, 13)
|
||||
DEFINE_PARAM(diarize, 14)
|
||||
DEFINE_PARAM(offset, 15)
|
||||
DEFINE_PARAM(duration, 16)
|
||||
DEFINE_PARAM(max_text_tokens, 17)
|
||||
DEFINE_PARAM(temperature, 18)
|
||||
DEFINE_PARAM(max_initial_ts, 19)
|
||||
DEFINE_PARAM(length_penalty, 20)
|
||||
DEFINE_PARAM(temperature_inc, 21)
|
||||
DEFINE_PARAM(entropy_thold, 22)
|
||||
DEFINE_PARAM(logprob_thold, 23)
|
||||
DEFINE_PARAM(no_speech_thold, 24)
|
||||
DEFINE_PARAM(new_segment_callback, 25)
|
||||
DEFINE_PARAM(new_segment_callback_user_data, 26)
|
||||
DEFINE_PARAM(progress_callback, 27)
|
||||
DEFINE_PARAM(progress_callback_user_data, 28)
|
||||
DEFINE_PARAM(encoder_begin_callback, 29)
|
||||
DEFINE_PARAM(encoder_begin_callback_user_data, 30)
|
||||
DEFINE_PARAM(abort_callback, 31)
|
||||
DEFINE_PARAM(abort_callback_user_data, 32)
|
||||
DEFINE_PARAM(vad, 33)
|
||||
DEFINE_PARAM(vad_model_path, 34)
|
||||
DEFINE_PARAM(vad_params, 35)
|
||||
DEFINE_PARAM(carry_initial_prompt, 14)
|
||||
DEFINE_PARAM(diarize, 15)
|
||||
DEFINE_PARAM(offset, 16)
|
||||
DEFINE_PARAM(duration, 17)
|
||||
DEFINE_PARAM(max_text_tokens, 18)
|
||||
DEFINE_PARAM(temperature, 19)
|
||||
DEFINE_PARAM(max_initial_ts, 20)
|
||||
DEFINE_PARAM(length_penalty, 21)
|
||||
DEFINE_PARAM(temperature_inc, 22)
|
||||
DEFINE_PARAM(entropy_thold, 23)
|
||||
DEFINE_PARAM(logprob_thold, 24)
|
||||
DEFINE_PARAM(no_speech_thold, 25)
|
||||
DEFINE_PARAM(new_segment_callback, 26)
|
||||
DEFINE_PARAM(new_segment_callback_user_data, 27)
|
||||
DEFINE_PARAM(progress_callback, 28)
|
||||
DEFINE_PARAM(progress_callback_user_data, 29)
|
||||
DEFINE_PARAM(encoder_begin_callback, 30)
|
||||
DEFINE_PARAM(encoder_begin_callback_user_data, 31)
|
||||
DEFINE_PARAM(abort_callback, 32)
|
||||
DEFINE_PARAM(abort_callback_user_data, 33)
|
||||
DEFINE_PARAM(vad, 34)
|
||||
DEFINE_PARAM(vad_model_path, 35)
|
||||
DEFINE_PARAM(vad_params, 36)
|
||||
|
||||
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
|
||||
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ module Whisper
|
|||
?max_len: Integer,
|
||||
?split_on_word: boolish,
|
||||
?initial_prompt: string | nil,
|
||||
?carry_initial_prompt: boolish,
|
||||
?diarize: boolish,
|
||||
?offset: Integer,
|
||||
?duration: Integer,
|
||||
|
|
@ -236,6 +237,7 @@ module Whisper
|
|||
def split_on_word: () -> (true | false)
|
||||
|
||||
def initial_prompt=: (_ToS) -> _ToS
|
||||
def carry_initial_prompt=: (boolish) -> boolish
|
||||
|
||||
# Tokens to provide to the whisper decoder as initial prompt
|
||||
# these are prepended to any existing text context from a previous call
|
||||
|
|
@ -243,6 +245,7 @@ module Whisper
|
|||
# Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
|
||||
#
|
||||
def initial_prompt: () -> (String | nil)
|
||||
def carry_initial_prompt: () -> (true | false)
|
||||
|
||||
def diarize=: (boolish) -> boolish
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class TestParams < TestBase
|
|||
:max_len,
|
||||
:split_on_word,
|
||||
:initial_prompt,
|
||||
:carry_initial_prompt,
|
||||
:diarize,
|
||||
:offset,
|
||||
:duration,
|
||||
|
|
@ -119,6 +120,13 @@ class TestParams < TestBase
|
|||
assert !@params.print_timestamps
|
||||
end
|
||||
|
||||
def test_carry_initial_prompt
|
||||
@params.carry_initial_prompt = true
|
||||
assert @params.carry_initial_prompt
|
||||
@params.carry_initial_prompt = false
|
||||
assert !@params.carry_initial_prompt
|
||||
end
|
||||
|
||||
def test_suppress_blank
|
||||
@params.suppress_blank = true
|
||||
assert @params.suppress_blank
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "grammar-parser.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
|
@ -77,6 +78,7 @@ struct whisper_params {
|
|||
bool use_gpu = true;
|
||||
bool flash_attn = true;
|
||||
bool suppress_nst = false;
|
||||
bool carry_initial_prompt = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
|
|
@ -175,17 +177,18 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
|
|||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
||||
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
|
||||
else if (arg == "-ojf" || arg == "--output-json-full") { params.output_jsn_full = params.output_jsn = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); }
|
||||
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if ( arg == "--print-confidence"){ params.print_confidence= true; }
|
||||
else if ( arg == "--print-confidence") { params.print_confidence= true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); }
|
||||
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
|
||||
else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; }
|
||||
else if ( arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
|
||||
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; }
|
||||
|
|
@ -266,6 +269,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
|
|||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
|
||||
fprintf(stderr, " --carry-initial-prompt [%-7s] always prepend initial prompt\n", params.carry_initial_prompt ? "true" : "false");
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input audio file path\n", "");
|
||||
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
|
||||
|
|
@ -387,7 +391,11 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
|
|||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
const int n_colors = (int) k_colors.size();
|
||||
int raw_col = (int) (std::pow(p, 3)*float(n_colors));
|
||||
if (raw_col < 0) raw_col = 0;
|
||||
if (raw_col > n_colors - 1) raw_col = n_colors - 1;
|
||||
const int col = raw_col;
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
|
|
@ -1179,6 +1187,7 @@ int main(int argc, char ** argv) {
|
|||
wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();
|
||||
|
||||
wparams.initial_prompt = params.prompt.c_str();
|
||||
wparams.carry_initial_prompt = params.carry_initial_prompt;
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
|
|
|||
|
|
@ -525,6 +525,7 @@ extern "C" {
|
|||
// use whisper_tokenize() to convert text to tokens
|
||||
// maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
|
||||
const char * initial_prompt;
|
||||
bool carry_initial_prompt; // if true, always prepend initial_prompt to every decode window (may reduce conditioning on previous text)
|
||||
const whisper_token * prompt_tokens;
|
||||
int prompt_n_tokens;
|
||||
|
||||
|
|
|
|||
|
|
@ -140,6 +140,10 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|||
} while (0)
|
||||
|
||||
#define WHISPER_MAX_DECODERS 8
|
||||
|
||||
// temperature below which we condition on past text history
|
||||
static constexpr float WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF = 0.5f;
|
||||
|
||||
#define WHISPER_MAX_NODES 4096
|
||||
|
||||
static std::string format(const char * fmt, ...) {
|
||||
|
|
@ -882,7 +886,10 @@ struct whisper_state {
|
|||
std::vector<float> logits;
|
||||
|
||||
std::vector<whisper_segment> result_all;
|
||||
std::vector<whisper_token> prompt_past;
|
||||
|
||||
// prompt history split into static prefix (prompt_past0) and dynamic rolling context (prompt_past1)
|
||||
std::vector<whisper_token> prompt_past0; // static carried initial prompt (if enabled)
|
||||
std::vector<whisper_token> prompt_past1; // dynamic context from decoded output
|
||||
|
||||
int lang_id = 0; // english by default
|
||||
|
||||
|
|
@ -5923,6 +5930,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||
/* suppress_regex =*/ nullptr,
|
||||
|
||||
/*.initial_prompt =*/ nullptr,
|
||||
/*.carry_initial_prompt =*/ false,
|
||||
/*.prompt_tokens =*/ nullptr,
|
||||
/*.prompt_n_tokens =*/ 0,
|
||||
|
||||
|
|
@ -6880,17 +6888,22 @@ int whisper_full_with_state(
|
|||
decoder.rng = std::mt19937(j);
|
||||
}
|
||||
|
||||
// the accumulated text context so far
|
||||
auto & prompt_past = state->prompt_past;
|
||||
// the accumulated text context split into static (prompt_past0) and dynamic (prompt_past1)
|
||||
auto & prompt_past0 = state->prompt_past0;
|
||||
auto & prompt_past1 = state->prompt_past1;
|
||||
if (params.no_context) {
|
||||
prompt_past.clear();
|
||||
prompt_past0.clear();
|
||||
prompt_past1.clear();
|
||||
}
|
||||
|
||||
// calculate the maximum context budget for prompt history
|
||||
const int max_prompt_ctx = std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2);
|
||||
|
||||
// prepare prompt
|
||||
{
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
// initial prompt
|
||||
// tokenize the initial prompt
|
||||
if (!params.prompt_tokens && params.initial_prompt) {
|
||||
prompt_tokens.resize(1024);
|
||||
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
|
||||
|
|
@ -6902,14 +6915,25 @@ int whisper_full_with_state(
|
|||
params.prompt_tokens = prompt_tokens.data();
|
||||
params.prompt_n_tokens = prompt_tokens.size();
|
||||
}
|
||||
|
||||
// prepend the prompt tokens to the prompt_past
|
||||
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
|
||||
// parse tokens from the pointer
|
||||
for (int i = 0; i < params.prompt_n_tokens; i++) {
|
||||
prompt_past.push_back(params.prompt_tokens[i]);
|
||||
if (params.carry_initial_prompt) {
|
||||
if (prompt_past0.empty()) {
|
||||
const int max_tokens = std::max(1, max_prompt_ctx - 1);
|
||||
|
||||
if (params.prompt_n_tokens > max_tokens) {
|
||||
WHISPER_LOG_WARN("%s: initial prompt is too long (%d tokens), will use only the last %d tokens\n",
|
||||
__func__, params.prompt_n_tokens, max_tokens);
|
||||
}
|
||||
|
||||
const int n_tokens = std::min(params.prompt_n_tokens, max_tokens);
|
||||
prompt_past0.assign(params.prompt_tokens + (params.prompt_n_tokens - n_tokens), params.prompt_tokens + params.prompt_n_tokens);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < params.prompt_n_tokens; ++i) {
|
||||
prompt_past1.push_back(params.prompt_tokens[i]);
|
||||
}
|
||||
std::rotate(prompt_past1.begin(), prompt_past1.end() - params.prompt_n_tokens, prompt_past1.end());
|
||||
}
|
||||
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -6995,7 +7019,8 @@ int whisper_full_with_state(
|
|||
// if there is a very short audio segment left to process, we remove any past prompt since it tends
|
||||
// to confuse the decoder and often make it repeat or hallucinate stuff
|
||||
if (seek > seek_start && seek + 500 >= seek_end) {
|
||||
prompt_past.clear();
|
||||
prompt_past0.clear();
|
||||
prompt_past1.clear();
|
||||
}
|
||||
|
||||
int best_decoder_id = 0;
|
||||
|
|
@ -7056,12 +7081,25 @@ int whisper_full_with_state(
|
|||
{
|
||||
prompt.clear();
|
||||
|
||||
// if we have already generated some text, use it as a prompt to condition the next generation
|
||||
if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
|
||||
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
|
||||
if (params.n_max_text_ctx > 0 && t_cur < WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF) {
|
||||
const bool can_take0 = params.carry_initial_prompt && !prompt_past0.empty();
|
||||
const bool can_take1 = !prompt_past1.empty();
|
||||
|
||||
prompt = { whisper_token_prev(ctx) };
|
||||
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
|
||||
if (max_prompt_ctx > 0 && (can_take0 || can_take1)) {
|
||||
// Always start with previous token marker to connect continuity
|
||||
prompt.push_back(whisper_token_prev(ctx));
|
||||
|
||||
// Take static tokens (initial prompt) first
|
||||
int n_take0 = 0;
|
||||
if (can_take0) {
|
||||
n_take0 = prompt_past0.size();
|
||||
prompt.insert(prompt.end(), prompt_past0.end() - n_take0, prompt_past0.end());
|
||||
}
|
||||
|
||||
// Fill remaining budget with dynamic tokens (rolling context)
|
||||
const int n_take1 = std::min<int>(max_prompt_ctx - n_take0 - 1, prompt_past1.size());
|
||||
prompt.insert(prompt.end(), prompt_past1.end() - n_take1, prompt_past1.end());
|
||||
}
|
||||
}
|
||||
|
||||
// init new transcription with sot, language (opt) and task tokens
|
||||
|
|
@ -7543,14 +7581,17 @@ int whisper_full_with_state(
|
|||
|
||||
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
||||
|
||||
// update prompt_past
|
||||
prompt_past.clear();
|
||||
if (prompt.front() == whisper_token_prev(ctx)) {
|
||||
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
||||
// update prompt_past1
|
||||
prompt_past1.clear();
|
||||
if (!params.carry_initial_prompt && !prompt.empty() && prompt.front() == whisper_token_prev(ctx)) {
|
||||
prompt_past1.insert(prompt_past1.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
||||
}
|
||||
|
||||
for (int i = 0; i < result_len && !is_no_speech; ++i) {
|
||||
prompt_past.push_back(tokens_cur[i].id);
|
||||
// Add newly decoded tokens to the rolling context
|
||||
if (!is_no_speech) {
|
||||
for (int i = 0; i < result_len; ++i) {
|
||||
prompt_past1.push_back(tokens_cur[i].id);
|
||||
}
|
||||
}
|
||||
|
||||
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue