Skip to content

Commit

Permalink
main : add stereo-channel-based diarization (#64)
Browse files Browse the repository at this point in the history
Not tested - I don't have stereo dialog audio
  • Loading branch information
ggerganov committed Nov 25, 2022
1 parent 1246dd0 commit 0f619b5
Showing 1 changed file with 71 additions and 10 deletions.
81 changes: 71 additions & 10 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ std::string to_timestamp(int64_t t, bool comma = false) {
return std::string(buf);
}

int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}

// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
Expand All @@ -60,6 +64,7 @@ struct whisper_params {

bool speed_up = false;
bool translate = false;
bool diarize = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
Expand Down Expand Up @@ -99,6 +104,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
Expand Down Expand Up @@ -135,6 +141,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
Expand All @@ -148,8 +155,15 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "\n");
}

struct whisper_print_user_data {
const whisper_params * params;

const std::vector<std::vector<float>> * pcmf32s;
};

void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data;
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;

const int n_segments = whisper_full_n_segments(ctx);

Expand Down Expand Up @@ -186,6 +200,33 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);

std::string speaker = "";

if (params.diarize && pcmf32s.size() == 2) {
const int64_t n_samples = pcmf32s[0].size();

const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);

double energy0 = 0.0f;
double energy1 = 0.0f;

for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}

if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}

//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}

if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
Expand All @@ -201,13 +242,13 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi

const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));

printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);

printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
}
}
}
Expand Down Expand Up @@ -235,7 +276,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return 9;
return false;
}

fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
Expand Down Expand Up @@ -425,6 +466,7 @@ int main(int argc, char ** argv) {
const auto fname_inp = params.fname_inp[f];

std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM

// WAV input
{
Expand Down Expand Up @@ -453,22 +495,27 @@ int main(int argc, char ** argv) {
}
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 4;
return 5;
}

if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
return 5;
return 6;
}

if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
return 6;
}

if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
return 6;
return 8;
}

if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
return 7;
return 9;
}

const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
Expand All @@ -489,6 +536,18 @@ int main(int argc, char ** argv) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}

if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);

pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (int i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
}

// print system information
Expand Down Expand Up @@ -540,15 +599,17 @@ int main(int argc, char ** argv) {

wparams.speed_up = params.speed_up;

whisper_print_user_data user_data = { &params, &pcmf32s };

// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &params;
wparams.new_segment_callback_user_data = &user_data;
}

if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 8;
return 10;
}
}

Expand Down

0 comments on commit 0f619b5

Please sign in to comment.