diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index bc0c6e99194..2546a42ff61 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -117,23 +117,61 @@ ruby_whisper_normalize_model_path(VALUE model_path) * call-seq: * new("base.en") -> Whisper::Context * new("path/to/model.bin") -> Whisper::Context + * new("path/to/model.bin", use_gpu: true, flash_attn: true) -> Whisper::Context * new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context + * + * Initialize a new Whisper context with optional parameters: + * use_gpu: Enable GPU acceleration (default: true) + * flash_attn: Enable flash attention (default: true) + * gpu_device: GPU device to use (default: 0) + * dtw_token_timestamps: Enable DTW token-level timestamps (default: false) + * dtw_aheads_preset: DTW attention heads preset (default: WHISPER_AHEADS_NONE) */ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; - VALUE whisper_model_file_path; + VALUE whisper_model_file_path, options; - // TODO: we can support init from buffer here too maybe another ruby object to expose - rb_scan_args(argc, argv, "01", &whisper_model_file_path); + rb_scan_args(argc, argv, "01:", &whisper_model_file_path, &options); TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path); if (!rb_respond_to(whisper_model_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } - rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); + + // Build context params from options + struct whisper_context_params cparams = whisper_context_default_params(); + + if (!NIL_P(options)) { + VALUE use_gpu = rb_hash_aref(options, ID2SYM(rb_intern("use_gpu"))); + if (!NIL_P(use_gpu)) { + cparams.use_gpu = RTEST(use_gpu); + } + + VALUE flash_attn = rb_hash_aref(options, ID2SYM(rb_intern("flash_attn"))); + if (!NIL_P(flash_attn)) { + cparams.flash_attn = RTEST(flash_attn); + } + + VALUE gpu_device = rb_hash_aref(options, ID2SYM(rb_intern("gpu_device"))); + if (!NIL_P(gpu_device)) { + cparams.gpu_device = NUM2INT(gpu_device); + } + + VALUE dtw_token_timestamps = rb_hash_aref(options, ID2SYM(rb_intern("dtw_token_timestamps"))); + if (!NIL_P(dtw_token_timestamps)) { + cparams.dtw_token_timestamps = RTEST(dtw_token_timestamps); + } + + VALUE dtw_aheads_preset = rb_hash_aref(options, ID2SYM(rb_intern("dtw_aheads_preset"))); + if (!NIL_P(dtw_aheads_preset)) { + cparams.dtw_aheads_preset = NUM2INT(dtw_aheads_preset); + } + } + + rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), cparams); if (rw->context == NULL) { rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); } diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 70417cb1664..d698abcfc85 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -26,7 +26,7 @@ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); -#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37 +#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 42 extern VALUE cParams; extern VALUE cVADParams; @@ -75,6 +75,11 @@ static ID id_abort_callback_user_data; static ID id_vad; static ID id_vad_model_path; static ID id_vad_params; +static ID id_suppress_regex; +static ID id_grammar_penalty; +static ID id_tdrz_enable; +static ID id_audio_ctx; +static ID id_debug_mode; static void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) @@ -1141,6 +1146,105 @@ ruby_whisper_params_get_vad_params(VALUE self) return rwp->vad_params; } +/* + * call-seq: + * suppress_regex = regex -> regex + */ +static VALUE +ruby_whisper_params_set_suppress_regex(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + if (NIL_P(value)) { + rwp->params.suppress_regex = NULL; + return value; + } + rwp->params.suppress_regex = StringValueCStr(value); + return value; +} + +static VALUE +ruby_whisper_params_get_suppress_regex(VALUE self) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + return rwp->params.suppress_regex == NULL ? Qnil : rb_str_new2(rwp->params.suppress_regex); +} + +/* + * call-seq: + * grammar_penalty = penalty -> penalty + */ +static VALUE +ruby_whisper_params_set_grammar_penalty(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + rwp->params.grammar_penalty = NUM2DBL(value); + return value; +} + +static VALUE +ruby_whisper_params_get_grammar_penalty(VALUE self) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + return DBL2NUM(rwp->params.grammar_penalty); +} + +/* + * call-seq: + * tdrz_enable = enable -> enable + */ +static VALUE +ruby_whisper_params_set_tdrz_enable(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, tdrz_enable, value) +} + +static VALUE +ruby_whisper_params_get_tdrz_enable(VALUE self) +{ + BOOL_PARAMS_GETTER(self, tdrz_enable) +} + +/* + * call-seq: + * audio_ctx = context_size -> context_size + */ +static VALUE +ruby_whisper_params_set_audio_ctx(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + rwp->params.audio_ctx = NUM2INT(value); + return value; +} + +static VALUE +ruby_whisper_params_get_audio_ctx(VALUE self) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + return INT2NUM(rwp->params.audio_ctx); +} + +/* + * call-seq: + * debug_mode = enable -> enable + */ +static VALUE +ruby_whisper_params_set_debug_mode(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, debug_mode, value) +} + +static VALUE +ruby_whisper_params_get_debug_mode(VALUE self) +{ + BOOL_PARAMS_GETTER(self, debug_mode) +} + #define SET_PARAM_IF_SAME(param_name) \ if (id == id_ ## param_name) { \ ruby_whisper_params_set_ ## param_name(self, value); \ @@ -1211,6 +1315,11 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) SET_PARAM_IF_SAME(vad) SET_PARAM_IF_SAME(vad_model_path) SET_PARAM_IF_SAME(vad_params) + SET_PARAM_IF_SAME(suppress_regex) + SET_PARAM_IF_SAME(grammar_penalty) + SET_PARAM_IF_SAME(tdrz_enable) + SET_PARAM_IF_SAME(audio_ctx) + SET_PARAM_IF_SAME(debug_mode) } } @@ -1348,6 +1457,11 @@ init_ruby_whisper_params(VALUE *mWhisper) DEFINE_PARAM(vad, 34) DEFINE_PARAM(vad_model_path, 35) DEFINE_PARAM(vad_params, 36) + DEFINE_PARAM(suppress_regex, 37) + DEFINE_PARAM(grammar_penalty, 38) + DEFINE_PARAM(tdrz_enable, 39) + DEFINE_PARAM(audio_ctx, 40) + DEFINE_PARAM(debug_mode, 41) rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); diff --git a/bindings/ruby/lib/whisper/context.rb b/bindings/ruby/lib/whisper/context.rb index c3a134b773d..2d5e1082790 100644 --- a/bindings/ruby/lib/whisper/context.rb +++ b/bindings/ruby/lib/whisper/context.rb @@ -1,13 +1,13 @@ module Whisper class Context def to_srt - each_segment.with_index.reduce("") {|srt, (segment, index)| + each_segment.with_index.reduce(String.new) {|srt, (segment, index)| srt << "#{index + 1}\n#{segment.to_srt_cue}\n" } end def to_webvtt - each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| + each_segment.with_index.reduce(String.new("WEBVTT\n\n")) {|webvtt, (segment, index)| webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" } end diff --git a/bindings/ruby/test/test_context_params.rb b/bindings/ruby/test/test_context_params.rb new file mode 100644 index 00000000000..d0461d45fd4 --- /dev/null +++ b/bindings/ruby/test/test_context_params.rb @@ -0,0 +1,59 @@ +require_relative "helper" + +class TestContextParams < TestBase + def test_context_new_with_default_params + whisper = Whisper::Context.new("base.en") + assert_instance_of Whisper::Context, whisper + end + + def test_context_new_with_use_gpu + whisper = Whisper::Context.new("base.en", use_gpu: true) + assert_instance_of Whisper::Context, whisper + + whisper = Whisper::Context.new("base.en", use_gpu: false) + assert_instance_of Whisper::Context, whisper + end + + def test_context_new_with_flash_attn + whisper = Whisper::Context.new("base.en", flash_attn: true) + assert_instance_of Whisper::Context, whisper + + whisper = Whisper::Context.new("base.en", flash_attn: false) + assert_instance_of Whisper::Context, whisper + end + + def test_context_new_with_gpu_device + whisper = Whisper::Context.new("base.en", gpu_device: 0) + assert_instance_of Whisper::Context, whisper + + whisper = Whisper::Context.new("base.en", gpu_device: 1) + assert_instance_of Whisper::Context, whisper + end + + def test_context_new_with_dtw_token_timestamps + whisper = Whisper::Context.new("base.en", dtw_token_timestamps: true) + assert_instance_of Whisper::Context, whisper + + whisper = Whisper::Context.new("base.en", dtw_token_timestamps: false) + assert_instance_of Whisper::Context, whisper + end + + def test_context_new_with_dtw_aheads_preset + whisper = Whisper::Context.new("base.en", dtw_aheads_preset: 0) + assert_instance_of Whisper::Context, whisper + + whisper = Whisper::Context.new("base.en", dtw_aheads_preset: 1) + assert_instance_of Whisper::Context, whisper + end + + def test_context_new_with_combined_params + whisper = Whisper::Context.new("base.en", + use_gpu: true, + flash_attn: true, + gpu_device: 0, + dtw_token_timestamps: false, + dtw_aheads_preset: 0 + ) + assert_instance_of Whisper::Context, whisper + end +end diff --git a/bindings/ruby/test/test_params.rb b/bindings/ruby/test/test_params.rb index 4dd9780de7d..50882187091 100644 --- a/bindings/ruby/test/test_params.rb +++ b/bindings/ruby/test/test_params.rb @@ -37,6 +37,11 @@ class TestParams < TestBase :vad, :vad_model_path, :vad_params, + :suppress_regex, + :grammar_penalty, + :tdrz_enable, + :audio_ctx, + :debug_mode, ] def setup @@ -245,13 +250,50 @@ def test_vad_model_path_with_URI end def test_vad_params - assert_kind_of Whisper::VAD::Params, @params.vad_params - default_params = @params.vad_params - assert_same default_params, @params.vad_params - assert_equal 0.5, default_params.threshold + default_params = Whisper::VAD::Params.new + # vad_params returns a new wrapper each time, so use assert_equal instead of assert_same + retrieved_params = @params.vad_params + assert_equal default_params.threshold, retrieved_params.threshold + assert_equal 0.5, retrieved_params.threshold new_params = Whisper::VAD::Params.new @params.vad_params = new_params - assert_same new_params, @params.vad_params + retrieved_params = @params.vad_params + assert_equal new_params.threshold, retrieved_params.threshold + end + + def test_suppress_regex + @params.suppress_regex = "[\\*\\[\\]]" + assert_equal @params.suppress_regex, "[\\*\\[\\]]" + @params.suppress_regex = nil + assert_nil @params.suppress_regex + end + + def test_grammar_penalty + @params.grammar_penalty = 50.0 + assert_in_delta @params.grammar_penalty, 50.0 + @params.grammar_penalty = 0.0 + assert_in_delta @params.grammar_penalty, 0.0 + end + + def test_tdrz_enable + @params.tdrz_enable = true + assert @params.tdrz_enable + @params.tdrz_enable = false + assert !@params.tdrz_enable + end + + def test_audio_ctx + @params.audio_ctx = 1024 + assert_equal @params.audio_ctx, 1024 + @params.audio_ctx = 0 + assert_equal @params.audio_ctx, 0 + end + + def test_debug_mode + @params.debug_mode = true + assert @params.debug_mode + @params.debug_mode = false + assert !@params.debug_mode end def test_new_with_kw_args @@ -284,6 +326,8 @@ def test_new_with_kw_args_default_values(param) "es" in [:initial_prompt, *] "Initial prompt" + in [:suppress_regex, *] + "[\\*\\[\\]]" in [/_callback\Z/, *] proc {} in [/_user_data\Z/, *] diff --git a/bindings/ruby/test/test_segment.rb b/bindings/ruby/test/test_segment.rb index cb4ba9eb705..08a037c01f2 100644 --- a/bindings/ruby/test/test_segment.rb +++ b/bindings/ruby/test/test_segment.rb @@ -73,7 +73,6 @@ def test_on_new_segment_twice end def test_transcription_after_segment_retrieved - params = Whisper::Params.new segment = whisper.each_segment.first assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)