diff --git a/Modules.make b/Modules.make index a9ee0ae7c..4ae7efcee 100644 --- a/Modules.make +++ b/Modules.make @@ -148,6 +148,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make index f171381f7..f427caceb 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make @@ -143,6 +143,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make index f171381f7..f427caceb 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make @@ -143,6 +143,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make index 2ea9bf106..34a902935 100644 --- a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make +++ b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make @@ -143,6 +143,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make index af199b57a..0a597c8cb 100644 --- a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make +++ b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make @@ -147,6 +147,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make index bc36c260b..1daa8d5dc 100644 --- a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make +++ b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make @@ -148,6 +148,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/LexiconfreeTimesyncBeamSearch/libSprintLexiconfreeTimesyncBeamSearch.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index b732df8b5..788edbb76 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -83,6 +83,8 @@ public: LABEL_TO_BLANK, BLANK_TO_LABEL, BLANK_LOOP, + INITIAL_LABEL, + INITIAL_BLANK, }; // Request for scoring or context extension diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc new file mode 100644 index 000000000..ffea5c8ce --- /dev/null +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -0,0 +1,499 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LexiconfreeTimesyncBeamSearch.hh" + +#include +#include + +#include +#include +#include +#include +#include + +namespace Search { + +/* + * ======================= + * === LabelHypothesis === + * ======================= + */ + +LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis() + : scoringContext(), + currentToken(Core::Type::max), + score(0.0), + trace(Core::ref(new LatticeTrace(0, {0, 0}, {}))) {} + +LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( + LexiconfreeTimesyncBeamSearch::LabelHypothesis const& base, + LexiconfreeTimesyncBeamSearch::ExtensionCandidate const& extension, + Nn::ScoringContextRef const& newScoringContext) + : scoringContext(newScoringContext), + currentToken(extension.nextToken), + score(extension.score), + trace() { + switch (extension.transitionType) { + case Nn::LabelScorer::INITIAL_BLANK: + case Nn::LabelScorer::INITIAL_LABEL: + case Nn::LabelScorer::LABEL_TO_LABEL: + case Nn::LabelScorer::LABEL_TO_BLANK: + case Nn::LabelScorer::BLANK_TO_LABEL: + trace = Core::ref(new LatticeTrace( + base.trace, + extension.pron, + extension.timeframe + 1, + {extension.score, 0}, + {})); + break; + case Nn::LabelScorer::LABEL_LOOP: + case Nn::LabelScorer::BLANK_LOOP: + // Copy base trace and update it + trace = Core::ref(new LatticeTrace(*base.trace)); + trace->sibling = {}; + trace->score.acoustic = extension.score; + trace->time = extension.timeframe + 1; + break; + } +} + +std::string LexiconfreeTimesyncBeamSearch::LabelHypothesis::toString() const { + std::stringstream ss; + ss << "Score: " << score << ", traceback: "; + + auto traceback = trace->performTraceback(); + + for (auto& item : *traceback) { + if (item.pronunciation and item.pronunciation->lemma()) { + ss << item.pronunciation->lemma()->symbol() << " "; + } + } + return ss.str(); +} + +/* + * ===================================== + * === LexiconfreeTimesyncBeamSearch === + * ===================================== + */ + +const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramMaxBeamSize( + "max-beam-size", + "Maximum number of elements in the search beam.", + 1, 1); + +const Core::ParameterFloat LexiconfreeTimesyncBeamSearch::paramScoreThreshold( + "score-threshold", + "Prune any hypotheses with a score that is at least this much worse than the best hypothesis. If not set, no score pruning will be done.", + Core::Type::max, 0); + +const Core::ParameterInt LexiconfreeTimesyncBeamSearch::paramBlankLabelIndex( + "blank-label-index", + "Index of the blank label in the lexicon. Can also be inferred from lexicon if it has a lemma with `special='blank'`. If not set, the search will not use blank.", + Core::Type::max); + +const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramCollapseRepeatedLabels( + "collapse-repeated-labels", + "Collapse repeated emission of the same label into one output. If false, every emission is treated like a new output.", + false); + +const Core::ParameterBool LexiconfreeTimesyncBeamSearch::paramLogStepwiseStatistics( + "log-stepwise-statistics", + "Log statistics about the beam at every search step.", + false); + +LexiconfreeTimesyncBeamSearch::LexiconfreeTimesyncBeamSearch(Core::Configuration const& config) + : Core::Component(config), + SearchAlgorithmV2(config), + maxBeamSize_(paramMaxBeamSize(config)), + scoreThreshold_(paramScoreThreshold(config)), + blankLabelIndex_(paramBlankLabelIndex(config)), + collapseRepeatedLabels_(paramCollapseRepeatedLabels(config)), + logStepwiseStatistics_(paramLogStepwiseStatistics(config)), + debugChannel_(config, "debug"), + labelScorer_(), + beam_(), + extensions_(), + newBeam_(), + requests_(), + recombinedHypotheses_(), + initializationTime_(), + featureProcessingTime_(), + scoringTime_(), + contextExtensionTime_(), + numHypsAfterScorePruning_("num-hyps-after-score-pruning"), + numHypsAfterBeamPruning_("num-hyps-after-beam-pruning"), + numActiveHyps_("num-active-hyps"), + finishedSegment_(false) { + beam_.reserve(maxBeamSize_); + newBeam_.reserve(maxBeamSize_); + recombinedHypotheses_.reserve(maxBeamSize_); + useBlank_ = blankLabelIndex_ != Core::Type::max; + if (useBlank_) { + log() << "Use blank label with index " << blankLabelIndex_; + } + useScorePruning_ = scoreThreshold_ != Core::Type::max; +} + +Speech::ModelCombination::Mode LexiconfreeTimesyncBeamSearch::requiredModelCombination() const { + return Speech::ModelCombination::useLabelScorer | Speech::ModelCombination::useLexicon; +} + +bool LexiconfreeTimesyncBeamSearch::setModelCombination(Speech::ModelCombination const& modelCombination) { + lexicon_ = modelCombination.lexicon(); + labelScorer_ = modelCombination.labelScorer(); + + extensions_.reserve(maxBeamSize_ * lexicon_->nLemmas()); + requests_.reserve(extensions_.size()); + + auto blankLemma = lexicon_->specialLemma("blank"); + if (blankLemma) { + if (blankLabelIndex_ == Core::Type::max) { + blankLabelIndex_ = blankLemma->id(); + useBlank_ = true; + log() << "Use blank index " << blankLabelIndex_ << " inferred from lexicon"; + } + else if (blankLabelIndex_ != static_cast(blankLemma->id())) { + warning() << "Blank lemma exists in lexicon with id " << blankLemma->id() << " but is overwritten by config parameter with value " << blankLabelIndex_; + } + } + + reset(); + return true; +} + +void LexiconfreeTimesyncBeamSearch::reset() { + initializationTime_.start(); + + labelScorer_->reset(); + + // Reset beam to a single empty hypothesis + beam_.clear(); + beam_.push_back(LabelHypothesis()); + beam_.front().scoringContext = labelScorer_->getInitialScoringContext(); + + finishedSegment_ = false; + + initializationTime_.stop(); +} + +void LexiconfreeTimesyncBeamSearch::enterSegment(Bliss::SpeechSegment const* segment) { + initializationTime_.start(); + labelScorer_->reset(); + resetStatistics(); + initializationTime_.stop(); + finishedSegment_ = false; +} + +void LexiconfreeTimesyncBeamSearch::finishSegment() { + featureProcessingTime_.start(); + labelScorer_->signalNoMoreFeatures(); + featureProcessingTime_.stop(); + decodeManySteps(); + logStatistics(); + finishedSegment_ = true; +} + +void LexiconfreeTimesyncBeamSearch::putFeature(std::shared_ptr const& data, size_t featureSize) { + featureProcessingTime_.start(); + labelScorer_->addInput(data, featureSize); + featureProcessingTime_.stop(); +} + +void LexiconfreeTimesyncBeamSearch::putFeature(std::vector const& data) { + featureProcessingTime_.start(); + labelScorer_->addInput(data); + featureProcessingTime_.stop(); +} + +void LexiconfreeTimesyncBeamSearch::putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) { + featureProcessingTime_.start(); + labelScorer_->addInputs(data, timeSize, featureSize); + featureProcessingTime_.stop(); +} + +Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestTraceback() const { + return getBestHypothesis().trace->performTraceback(); +} + +Core::Ref LexiconfreeTimesyncBeamSearch::getCurrentBestWordLattice() const { + auto& bestHypothesis = getBestHypothesis(); + LatticeTrace endTrace(bestHypothesis.trace, 0, bestHypothesis.trace->time + 1, bestHypothesis.trace->score, {}); + + for (size_t hypIdx = 1ul; hypIdx < beam_.size(); ++hypIdx) { + auto& hyp = beam_[hypIdx]; + auto siblingTrace = Core::ref(new LatticeTrace(hyp.trace, 0, hyp.trace->time, hyp.trace->score, {})); + endTrace.appendSiblingToChain(siblingTrace); + } + + return endTrace.buildWordLattice(lexicon_); +} + +bool LexiconfreeTimesyncBeamSearch::decodeStep() { + if (finishedSegment_) { + return false; + } + + // Assume the output labels are stored as lexicon lemma orth and ordered consistently with NN output index + auto lemmas = lexicon_->lemmas(); + + /* + * Collect all possible extensions for all hypotheses in the beam. + * Also Create scoring requests for the label scorer. + * Each extension candidate makes up a request. + */ + extensions_.clear(); + requests_.clear(); + + for (size_t hypIndex = 0ul; hypIndex < beam_.size(); ++hypIndex) { + auto& hyp = beam_[hypIndex]; + + // Iterate over possible successors (all lemmas) + for (auto lemmaIt = lemmas.first; lemmaIt != lemmas.second; ++lemmaIt) { + const Bliss::Lemma* lemma(*lemmaIt); + Nn::LabelIndex tokenIdx = lemma->id(); + + auto transitionType = inferTransitionType(hyp.currentToken, tokenIdx); + + extensions_.push_back( + {tokenIdx, + lemma->pronunciations().first, + hyp.score, + 0, + transitionType, + hypIndex}); + requests_.push_back({beam_[hypIndex].scoringContext, tokenIdx, transitionType}); + } + } + + /* + * Perform scoring of all the requests with the label scorer. + */ + scoringTime_.start(); + auto result = labelScorer_->computeScoresWithTimes(requests_); + scoringTime_.stop(); + + if (not result) { + // LabelScorer could not compute scores -> no search step can be made. + return false; + } + + for (size_t extensionIdx = 0ul; extensionIdx < extensions_.size(); ++extensionIdx) { + extensions_[extensionIdx].score += result->scores[extensionIdx]; + extensions_[extensionIdx].timeframe = result->timeframes[extensionIdx]; + } + + if (logStepwiseStatistics_) { + clog() << Core::XmlOpen("search-step-stats"); + } + + /* + * Prune set of possible extensions by max beam size and possibly also by score. + */ + + if (useScorePruning_) { + scorePruning(extensions_); + + numHypsAfterScorePruning_ += extensions_.size(); + + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-hyps-after-score-pruning", extensions_.size()); + } + } + + beamSizePruning(extensions_); + numHypsAfterBeamPruning_ += extensions_.size(); + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("num-hyps-after-beam-pruning", extensions_.size()); + } + + /* + * Create new beam from surviving extensions. + */ + newBeam_.clear(); + + for (auto const& extension : extensions_) { + auto const& baseHyp = beam_[extension.baseHypIndex]; + + auto newScoringContext = labelScorer_->extendedScoringContext( + {baseHyp.scoringContext, + extension.nextToken, + extension.transitionType}); + + newBeam_.push_back({baseHyp, extension, newScoringContext}); + } + + /* + * For all hypotheses with the same scoring context keep only the best since they will + * all develop in the same way. + */ + recombination(newBeam_); + numActiveHyps_ += newBeam_.size(); + + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("active-hyps", newBeam_.size()); + } + + if (debugChannel_.isOpen()) { + std::stringstream ss; + for (size_t hypIdx = 0ul; hypIdx < newBeam_.size(); ++hypIdx) { + ss << "Hypothesis " << hypIdx + 1ul << ": " << newBeam_[hypIdx].toString() << "\n"; + } + ss << "\n"; + debugChannel_ << ss.str(); + } + + beam_.swap(newBeam_); + + if (logStepwiseStatistics_) { + clog() << Core::XmlFull("best-hyp-score", getBestHypothesis().score); + clog() << Core::XmlFull("worst-hyp-score", getWorstHypothesis().score); + clog() << Core::XmlClose("search-step-stats"); + } + + return true; +} + +LexiconfreeTimesyncBeamSearch::LabelHypothesis const& LexiconfreeTimesyncBeamSearch::getBestHypothesis() const { + verify(not beam_.empty()); + + return *std::min_element(beam_.begin(), beam_.end()); +} + +LexiconfreeTimesyncBeamSearch::LabelHypothesis const& LexiconfreeTimesyncBeamSearch::getWorstHypothesis() const { + verify(not beam_.empty()); + + return *std::max_element(beam_.begin(), beam_.end()); +} + +void LexiconfreeTimesyncBeamSearch::resetStatistics() { + initializationTime_.reset(); + featureProcessingTime_.reset(); + scoringTime_.reset(); + contextExtensionTime_.reset(); + numHypsAfterScorePruning_.clear(); + numHypsAfterBeamPruning_.clear(); + numActiveHyps_.clear(); +} + +void LexiconfreeTimesyncBeamSearch::logStatistics() const { + clog() << Core::XmlOpen("timing-statistics") + Core::XmlAttribute("unit", "milliseconds"); + clog() << Core::XmlOpen("initialization-time") << initializationTime_.elapsedMilliseconds() << Core::XmlClose("initialization-time"); + clog() << Core::XmlOpen("feature-processing-time") << featureProcessingTime_.elapsedMilliseconds() << Core::XmlClose("feature-processing-time"); + clog() << Core::XmlOpen("scoring-time") << scoringTime_.elapsedMilliseconds() << Core::XmlClose("scoring-time"); + clog() << Core::XmlOpen("context-extension-time") << contextExtensionTime_.elapsedMilliseconds() << Core::XmlClose("context-extension-time"); + clog() << Core::XmlClose("timing-statistics"); + numHypsAfterScorePruning_.write(clog()); + numHypsAfterBeamPruning_.write(clog()); + numActiveHyps_.write(clog()); +} + +Nn::LabelScorer::TransitionType LexiconfreeTimesyncBeamSearch::inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const { + bool prevIsBlank = (useBlank_ and prevLabel == blankLabelIndex_); + bool nextIsBlank = (useBlank_ and nextLabel == blankLabelIndex_); + + if (prevLabel == Core::Type::max) { + if (nextIsBlank) { + return Nn::LabelScorer::TransitionType::INITIAL_BLANK; + } + else { + return Nn::LabelScorer::TransitionType::INITIAL_LABEL; + } + } + + if (prevIsBlank) { + if (nextIsBlank) { + return Nn::LabelScorer::TransitionType::BLANK_LOOP; + } + else { + return Nn::LabelScorer::TransitionType::BLANK_TO_LABEL; + } + } + else { + if (nextIsBlank) { + return Nn::LabelScorer::TransitionType::LABEL_TO_BLANK; + } + else if (collapseRepeatedLabels_ and prevLabel == nextLabel) { + return Nn::LabelScorer::TransitionType::LABEL_LOOP; + } + else { + return Nn::LabelScorer::TransitionType::LABEL_TO_LABEL; + } + } +} + +void LexiconfreeTimesyncBeamSearch::beamSizePruning(std::vector& extensions) const { + if (extensions.size() <= maxBeamSize_) { + return; + } + + // Reorder the hypotheses by associated score value such that the first `beamSize_` elements are the best + std::nth_element(extensions.begin(), extensions.begin() + maxBeamSize_, extensions.end()); + extensions.resize(maxBeamSize_); // Get rid of excessive elements +} + +void LexiconfreeTimesyncBeamSearch::scorePruning(std::vector& extensions) const { + if (extensions.empty()) { + return; + } + + // Compute the pruning threshold + auto bestScore = std::min_element(extensions.begin(), extensions.end())->score; + auto pruningThreshold = bestScore + scoreThreshold_; + + // Remove elements with score > pruningThreshold + extensions.erase( + std::remove_if( + extensions.begin(), + extensions.end(), + [=](auto const& ext) { return ext.score > pruningThreshold; }), + extensions.end()); +} + +void LexiconfreeTimesyncBeamSearch::recombination(std::vector& hypotheses) { + recombinedHypotheses_.clear(); + // Map each unique ScoringContext in newHypotheses to its hypothesis + std::unordered_map seenScoringContexts; + for (auto const& hyp : hypotheses) { + // Use try_emplace to check if the scoring context already exists and create a new entry if not at the same time + auto [it, inserted] = seenScoringContexts.try_emplace(hyp.scoringContext, nullptr); + + if (inserted) { + // First time seeing this scoring context so move it over to `newHypotheses` + recombinedHypotheses_.push_back(std::move(hyp)); + it->second = &recombinedHypotheses_.back(); + } + else { + verify(not hyp.trace->sibling); + + auto* existingHyp = it->second; + if (hyp.score < existingHyp->score) { + // New hyp is better -> replace in `newHypotheses` and add existing one as sibling + hyp.trace->sibling = existingHyp->trace; + *existingHyp = std::move(hyp); // Overwrite in-place + } + else { + // New hyp is worse -> add to existing one as sibling + hyp.trace->sibling = existingHyp->trace->sibling; + existingHyp->trace->sibling = hyp.trace; + } + } + } + + hypotheses.swap(recombinedHypotheses_); +} + +} // namespace Search diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh new file mode 100644 index 000000000..5258a27b0 --- /dev/null +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh @@ -0,0 +1,169 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LEXICONFREE_TIMESYNC_BEAM_SEARCH_HH +#define LEXICONFREE_TIMESYNC_BEAM_SEARCH_HH + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Search { + +/* + * Simple time synchronous beam search algorithm without pronunciation lexicon, word-level LM or transition model. + * Can handle a blank symbol if a blank index is set. + * Main purpose is open vocabulary search with CTC/Neural Transducer (or similar) models. + * Supports global pruning by max beam-size and by score difference to the best hypothesis. + * Uses a LabelScorer to context initialization/extension and scoring. + * + * The search requires a lexicon that represents the vocabulary. Each lemma is viewed as a token with its index + * in the lexicon corresponding to the associated output index of the label scorer. + */ +class LexiconfreeTimesyncBeamSearch : public SearchAlgorithmV2 { +protected: + /* + * Possible extension for some label hypothesis in the beam + */ + struct ExtensionCandidate { + Nn::LabelIndex nextToken; // Proposed token to extend the hypothesis with + const Bliss::LemmaPronunciation* pron; // Pronunciation of lemma corresponding to `nextToken` for traceback + Score score; // Would-be score of full hypothesis after extension + Search::TimeframeIndex timeframe; // Timestamp of `nextToken` for traceback + Nn::LabelScorer::TransitionType transitionType; // Type of transition toward `nextToken` + size_t baseHypIndex; // Index of base hypothesis in global beam + + bool operator<(ExtensionCandidate const& other) const { + return score < other.score; + } + }; + + /* + * Struct containing all information about a single hypothesis in the beam + */ + struct LabelHypothesis { + Nn::ScoringContextRef scoringContext; // Context to compute scores based on this hypothesis + Nn::LabelIndex currentToken; // Most recent token in associated label sequence (useful to infer transition type) + Score score; // Full score of hypothesis + Core::Ref trace; // Associated trace for traceback or lattice building off of hypothesis + + LabelHypothesis(); + LabelHypothesis(LabelHypothesis const& base, ExtensionCandidate const& extension, Nn::ScoringContextRef const& newScoringContext); + + bool operator<(LabelHypothesis const& other) const { + return score < other.score; + } + + /* + * Get string representation for debugging. + */ + std::string toString() const; + }; + +public: + static const Core::ParameterInt paramMaxBeamSize; + static const Core::ParameterFloat paramScoreThreshold; + static const Core::ParameterInt paramBlankLabelIndex; + static const Core::ParameterBool paramCollapseRepeatedLabels; + static const Core::ParameterBool paramLogStepwiseStatistics; + + LexiconfreeTimesyncBeamSearch(Core::Configuration const&); + + // Inherited methods from `SearchAlgorithmV2` + + Speech::ModelCombination::Mode requiredModelCombination() const override; + bool setModelCombination(Speech::ModelCombination const& modelCombination) override; + void reset() override; + void enterSegment(Bliss::SpeechSegment const* = nullptr) override; + void finishSegment() override; + void putFeature(std::shared_ptr const& data, size_t featureSize) override; + void putFeature(std::vector const& data) override; + void putFeatures(std::shared_ptr const& data, size_t timeSize, size_t featureSize) override; + Core::Ref getCurrentBestTraceback() const override; + Core::Ref getCurrentBestWordLattice() const override; + bool decodeStep() override; + +private: + LabelHypothesis const& getBestHypothesis() const; + LabelHypothesis const& getWorstHypothesis() const; + + void resetStatistics(); + void logStatistics() const; + + /* + * Infer type of transition between two tokens based on whether each of them is blank + * and/or whether they are the same + */ + Nn::LabelScorer::TransitionType inferTransitionType(Nn::LabelIndex prevLabel, Nn::LabelIndex nextLabel) const; + + /* + * Helper function for pruning to maxBeamSize_ + */ + void beamSizePruning(std::vector& extensions) const; + + /* + * Helper function for pruning to scoreThreshold_ + */ + void scorePruning(std::vector& extensions) const; + + /* + * Helper function for recombination of hypotheses with the same scoring context + */ + void recombination(std::vector& hypotheses); + + size_t maxBeamSize_; + + bool useScorePruning_; + Score scoreThreshold_; + + bool useBlank_; + Nn::LabelIndex blankLabelIndex_; + + bool collapseRepeatedLabels_; + + bool logStepwiseStatistics_; + + Core::Channel debugChannel_; + + Core::Ref labelScorer_; + Bliss::LexiconRef lexicon_; + std::vector beam_; + + // Pre-allocated intermediate vectors + std::vector extensions_; + std::vector newBeam_; + std::vector requests_; + std::vector recombinedHypotheses_; + + Core::StopWatch initializationTime_; + Core::StopWatch featureProcessingTime_; + Core::StopWatch scoringTime_; + Core::StopWatch contextExtensionTime_; + + Core::Statistics numHypsAfterScorePruning_; + Core::Statistics numHypsAfterBeamPruning_; + Core::Statistics numActiveHyps_; + + bool finishedSegment_; +}; + +} // namespace Search + +#endif // LEXICONFREE_TIMESYNC_BEAM_SEARCH_HH diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/Makefile b/src/Search/LexiconfreeTimesyncBeamSearch/Makefile new file mode 100644 index 000000000..c9834e9ad --- /dev/null +++ b/src/Search/LexiconfreeTimesyncBeamSearch/Makefile @@ -0,0 +1,24 @@ +#!gmake + +TOPDIR = ../../.. + +include $(TOPDIR)/Makefile.cfg + +# ----------------------------------------------------------------------------- + +SUBDIRS = +TARGETS = libSprintLexiconfreeTimesyncBeamSearch.$(a) + +LIBSPRINTLEXICONFREETIMESYNCBEAMSEARCH_O = $(OBJDIR)/LexiconfreeTimesyncBeamSearch.o + + +# ----------------------------------------------------------------------------- + +all: $(TARGETS) + +libSprintLexiconfreeTimesyncBeamSearch.$(a): $(LIBSPRINTLEXICONFREETIMESYNCBEAMSEARCH_O) + $(MAKELIB) $@ $^ + +include $(TOPDIR)/Rules.make + +sinclude $(LIBSPRINTLEXICONFREETIMESYNCBEAMSEARCH_O:.o=.d) diff --git a/src/Search/Makefile b/src/Search/Makefile index bba7b5f3b..69fcc1a74 100644 --- a/src/Search/Makefile +++ b/src/Search/Makefile @@ -36,6 +36,7 @@ LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskAStarSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskNBestListSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskSearchUtil.o endif +SUBDIRS += LexiconfreeTimesyncBeamSearch ifdef MODULE_SEARCH_WFST SUBDIRS += Wfst endif @@ -66,6 +67,9 @@ Wfst: AdvancedTreeSearch: $(MAKE) -C $@ libSprintAdvancedTreeSearch.$(a) +LexiconfreeTimesyncBeamSearch: + $(MAKE) -C $@ libSprintLexiconfreeTimesyncBeamSearch.$(a) + include $(TOPDIR)/Rules.make sinclude $(LIBSPRINTSEARCH_O:.o=.d) diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 7dbaefdd5..ddfb30280 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -16,6 +16,7 @@ #include #include #include +#include "LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.hh" #include "TreeBuilder.hh" #ifdef MODULE_SEARCH_WFST #include @@ -33,6 +34,13 @@ using namespace Search; Module_::Module_() { } +const Core::Choice Module_::searchTypeV2Choice( + "lexiconfree-timesync-beam-search", SearchTypeV2::LexiconfreeTimesyncBeamSearchType, + Core::Choice::endMark()); + +const Core::ParameterChoice Module_::searchTypeV2Param( + "type", &Module_::searchTypeV2Choice, "type of search", SearchTypeV2::LexiconfreeTimesyncBeamSearchType); + const Core::Choice choiceTreeBuilderType( "classic-hmm", static_cast(TreeBuilderType::classicHmm), "minimized-hmm", static_cast(TreeBuilderType::minimizedHmm), @@ -101,6 +109,19 @@ SearchAlgorithm* Module_::createRecognizer(SearchType type, const Core::Configur return recognizer; } +SearchAlgorithmV2* Module_::createSearchAlgorithmV2(const Core::Configuration& config) const { + SearchAlgorithmV2* searchAlgorithm = 0; + switch (searchTypeV2Param(config)) { + case LexiconfreeTimesyncBeamSearchType: + searchAlgorithm = new Search::LexiconfreeTimesyncBeamSearch(config); + break; + default: + Core::Application::us()->criticalError("Unknown search algorithm type: %d", searchTypeV2Param(config)); + break; + } + return searchAlgorithm; +} + LatticeHandler* Module_::createLatticeHandler(const Core::Configuration& c) const { LatticeHandler* handler = new LatticeHandler(c); #ifdef MODULE_SEARCH_WFST diff --git a/src/Search/Module.hh b/src/Search/Module.hh index 88079a68c..72256811f 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -17,6 +17,7 @@ #include #include +#include "SearchV2.hh" #include "TreeBuilder.hh" @@ -40,12 +41,21 @@ enum SearchType { ExpandingFsaSearchType }; +enum SearchTypeV2 { + LexiconfreeTimesyncBeamSearchType +}; + class Module_ { +private: + static const Core::Choice searchTypeV2Choice; + static const Core::ParameterChoice searchTypeV2Param; + public: Module_(); std::unique_ptr createTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true) const; SearchAlgorithm* createRecognizer(SearchType type, const Core::Configuration& config) const; + SearchAlgorithmV2* createSearchAlgorithmV2(const Core::Configuration& config) const; LatticeHandler* createLatticeHandler(const Core::Configuration& c) const; }; diff --git a/src/Search/SearchV2.hh b/src/Search/SearchV2.hh index 4548b29a0..74f9b5250 100644 --- a/src/Search/SearchV2.hh +++ b/src/Search/SearchV2.hh @@ -43,13 +43,11 @@ namespace Search { * 4. Pass audio features via `putFeature` or `putFeatures`. * (5. Call `decodeStep` or `decodeManySteps` to run the next search step(s) given the currently available features.) * (6. Optionally retrieve intermediate results via `getCurrentBestTraceback` or `getCurrentBestWordLattice`.) - * 7. Call `finishSegment` to signal that all features have been passed and the search doesn't need to wait for more. - * 8. Call `decodeMore` to finalize the search with all the segment features. - * 9. Retrieve the final result via `getCurrentBestTraceback` or `getCurrentBestWordLattice`. - * (10. Optionally log search statistics via `logStatistics`). - * 11. Call `reset` to clean up any buffered features, hypotheses, flags etc. from the previous segment and prepare the algorithm for the next one. - * (12. Optionally also reset search statistics via `resetStatistics`). - * 13. Continue again at step 3. + * 7. Call `finishSegment` to signal that all features have been passed and finalize the search with all the segment features. + * 8. Retrieve the final result via `getCurrentBestTraceback` or `getCurrentBestWordLattice`. + * 9. Call `reset` to clean up any buffered features, hypotheses, flags etc. from the previous segment and prepare the algorithm for the next one. + * (10. Optionally also reset search statistics via `resetStatistics`). + * 11. Continue again at step 3. */ class SearchAlgorithmV2 : public virtual Core::Component { public: diff --git a/src/Search/Traceback.cc b/src/Search/Traceback.cc index cc16ecc22..91ab49867 100644 --- a/src/Search/Traceback.cc +++ b/src/Search/Traceback.cc @@ -17,6 +17,9 @@ #include +#include +#include + namespace Search { void Traceback::write(std::ostream& os, Core::Ref phi) const { diff --git a/src/Speech/Makefile b/src/Speech/Makefile index 4b5ab3528..cdd707568 100644 --- a/src/Speech/Makefile +++ b/src/Speech/Makefile @@ -45,7 +45,7 @@ CHECK_O = $(OBJDIR)/check.o \ ../Am/libSprintAm.$(a) \ ../Mm/libSprintMm.$(a) \ ../Mc/libSprintMc.$(a) \ - ../Search/libSprintSearch.$(a) \ + $(subst src,..,$(LIBS_SEARCH)) \ ../Bliss/libSprintBliss.$(a) \ ../Flow/libSprintFlow.$(a) \ ../Fsa/libSprintFsa.$(a) \ @@ -137,12 +137,8 @@ ifdef MODULE_NN_SEQUENCE_TRAINING CHECK_O += ../Nn/libSprintNn.$(a) endif ifdef MODULE_SEARCH_WFST -CHECK_O += ../Search/Wfst/libSprintSearchWfst.$(a) CHECK_O += ../OpenFst/libSprintOpenFst.$(a) endif -ifdef MODULE_ADVANCED_TREE_SEARCH -CHECK_O += ../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) -endif ifdef MODULE_PYTHON CHECK_O += ../Python/libSprintPython.$(a) endif diff --git a/src/Test/Makefile b/src/Test/Makefile index b9f3f2fcf..b045a385b 100644 --- a/src/Test/Makefile +++ b/src/Test/Makefile @@ -61,7 +61,7 @@ UNIT_TEST_O = $(OBJDIR)/UnitTester.o $(TEST_O) \ ../Fsa/libSprintFsa.$(a) \ ../Core/libSprintCore.$(a)\ ../Speech/libSprintSpeech.$(a) \ - ../Search/libSprintSearch.$(a) \ + $(subst src,..,$(LIBS_SEARCH)) \ ../Lattice/libSprintLattice.$(a) \ ../Am/libSprintAm.$(a) \ ../Mm/libSprintMm.$(a) \ @@ -82,9 +82,6 @@ endif ifdef MODULE_FLF_EXT UNIT_TEST_O += ../Flf/FlfExt/libSprintFlfExt.$(a) endif -ifdef MODULE_ADVANCED_TREE_SEARCH -UNIT_TEST_O += ../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) -endif ifdef MODULE_PYTHON UNIT_TEST_O += ../Python/libSprintPython.$(a) endif @@ -97,9 +94,6 @@ endif ifdef MODULE_MATH_NR UNIT_TEST_O += ../Math/Nr/libSprintMathNr.$(a) endif -ifdef MODULE_SEARCH_WFST -UNIT_TEST_O += ../Search/Wfst/libSprintSearchWfst.$(a) -endif ifdef MODULE_OPENFST UNIT_TEST_O += ../OpenFst/libSprintOpenFst.$(a) endif diff --git a/src/Tools/Archiver/Makefile b/src/Tools/Archiver/Makefile index e9000e78a..1814b4d7d 100644 --- a/src/Tools/Archiver/Makefile +++ b/src/Tools/Archiver/Makefile @@ -11,8 +11,7 @@ TARGETS = archiver$(exe) ARCHIVER_O = $(OBJDIR)/Archiver.o \ ../../Speech/libSprintSpeech.$(a) \ - ../../Search/libSprintSearch.$(a) \ - ../../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) \ + $(subst src,../..,$(LIBS_SEARCH)) \ ../../Lattice/libSprintLattice.$(a) \ ../../Lm/libSprintLm.$(a) \ ../../Flf/libSprintFlf.$(a) \ @@ -36,9 +35,6 @@ endif ifdef MODULE_NN ARCHIVER_O += ../../Nn/libSprintNn.$(a) endif -ifdef MODULE_SEARCH_WFST -ARCHIVER_O += ../../Search/Wfst/libSprintSearchWfst.$(a) -endif ifdef MODULE_OPENFST ARCHIVER_O += ../../OpenFst/libSprintOpenFst.$(a) endif diff --git a/src/Tools/NnTrainer/Makefile b/src/Tools/NnTrainer/Makefile index fc71a4bd2..9bb85b9df 100644 --- a/src/Tools/NnTrainer/Makefile +++ b/src/Tools/NnTrainer/Makefile @@ -24,7 +24,7 @@ NN_TRAINER_O = $(OBJDIR)/NnTrainer.o \ ../../Mc/libSprintMc.$(a) \ ../../Mm/libSprintMm.$(a) \ ../../Nn/libSprintNn.$(a) \ - ../../Search/libSprintSearch.$(a) \ + $(subst src,../..,$(LIBS_SEARCH)) \ ../../Signal/libSprintSignal.$(a) \ ../../Speech/libSprintSpeech.$(a) @@ -37,12 +37,6 @@ endif ifdef MODULE_MATH_NR NN_TRAINER_O += ../../Math/Nr/libSprintMathNr.$(a) endif -ifdef MODULE_ADVANCED_TREE_SEARCH -NN_TRAINER_O += ../../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) -endif -ifdef MODULE_SEARCH_WFST -NN_TRAINER_O += ../../Search/Wfst/libSprintSearchWfst.$(a) ../../OpenFst/libSprintOpenFst.$(a) -endif ifdef MODULE_PYTHON NN_TRAINER_O += ../../Python/libSprintPython.$(a) endif