From c86855e0942f1e499d09130fd698f13762a995c7 Mon Sep 17 00:00:00 2001 From: Simon Berger <simon.berger95@gmail.com> Date: Wed, 5 Mar 2025 20:50:44 +0100 Subject: [PATCH 1/5] Add RecognizerNodeV2 --- src/Flf/Makefile | 1 + src/Flf/NodeRegistration.hh | 18 +++ src/Flf/RecognizerV2.cc | 231 +++++++++++++++++++++++++++++++++ src/Flf/RecognizerV2.hh | 74 +++++++++++ src/Speech/ModelCombination.cc | 54 +++++--- src/Speech/ModelCombination.hh | 14 +- 6 files changed, 369 insertions(+), 23 deletions(-) create mode 100644 src/Flf/RecognizerV2.cc create mode 100644 src/Flf/RecognizerV2.hh diff --git a/src/Flf/Makefile b/src/Flf/Makefile index 99afeb53..62db8d3a 100644 --- a/src/Flf/Makefile +++ b/src/Flf/Makefile @@ -59,6 +59,7 @@ LIBSPRINTFLF_O = \ $(OBJDIR)/Prune.o \ $(OBJDIR)/PushForwardRescoring.o \ $(OBJDIR)/Recognizer.o \ + $(OBJDIR)/RecognizerV2.o \ $(OBJDIR)/IncrementalRecognizer.o \ $(OBJDIR)/Rescore.o \ $(OBJDIR)/RescoreLm.o \ diff --git a/src/Flf/NodeRegistration.hh b/src/Flf/NodeRegistration.hh index 8beb4e47..2b9d03c8 100644 --- a/src/Flf/NodeRegistration.hh +++ b/src/Flf/NodeRegistration.hh @@ -51,6 +51,7 @@ #include "Prune.hh" #include "PushForwardRescoring.hh" #include "Recognizer.hh" +#include "RecognizerV2.hh" #include "Rescale.hh" #include "Rescore.hh" #include "RescoreLm.hh" @@ -2145,6 +2146,23 @@ void registerNodeCreators(NodeFactory* factory) { " 0:lattice", &createRecognizerNode)); + factory->add( + NodeCreator( + "recognizer-v2", + "Second version of RASR recognizer.\n" + "Output are lattices in Flf format.\n" + "Much more minimalistic than the first recognizer node\n" + "and works with a `SearchAlgorithmV2` instead of\n" + "`SearchAlgorithm`. Performs recognition of the input segments\n" + "and sends the result lattices as outputs.\n" + "[*.network.recognizer-v2]\n" + "type = recognizer-v2\n" + "input:\n" + " 0:bliss-speech-segment\n" + "output:\n" + " 0:lattice", + &createRecognizerNodeV2)); + factory->add( NodeCreator( "incremental-recognizer", diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc new file mode 100644 index 00000000..f8a26a02 --- /dev/null +++ b/src/Flf/RecognizerV2.cc @@ -0,0 +1,231 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RecognizerV2.hh" +#include <Speech/ModelCombination.hh> +#include <chrono> +#include "Core/XmlStream.hh" +#include "LatticeHandler.hh" +#include "Module.hh" + +namespace Flf { + +NodeRef createRecognizerNodeV2(const std::string& name, const Core::Configuration& config) { + return NodeRef(new RecognizerNodeV2(name, config)); +} + +RecognizerNodeV2::RecognizerNodeV2(const std::string& name, const Core::Configuration& config) + : Node(name, config), + searchAlgorithm_(Search::Module::instance().createSearchAlgorithm(select("search-algorithm"))), + modelCombination_(config) { + Core::Configuration featureExtractionConfig(config, "feature-extraction"); + DataSourceRef dataSource = DataSourceRef(Speech::Module::instance().createDataSource(featureExtractionConfig)); + featureExtractor_ = SegmentwiseFeatureExtractorRef(new SegmentwiseFeatureExtractor(featureExtractionConfig, dataSource)); +} + +void RecognizerNodeV2::recognizeSegment(const Bliss::SpeechSegment* segment) { + if (!segment->orth().empty()) { + clog() << Core::XmlOpen("orth") + Core::XmlAttribute("source", "reference") + << segment->orth() + << Core::XmlClose("orth"); + } + + // Initialize recognizer and feature extractor + searchAlgorithm_->reset(); + searchAlgorithm_->enterSegment(); + + featureExtractor_->enterSegment(segment); + DataSourceRef dataSource = featureExtractor_->extractor(); + dataSource->initialize(const_cast<Bliss::SpeechSegment*>(segment)); + FeatureRef feature; + dataSource->getData(feature); + Time startTime = feature->timestamp().startTime(); + Time endTime; + + auto timerStart = std::chrono::steady_clock::now(); + + // Loop over features and perform recognition + do { + searchAlgorithm_->putFeature(*feature->mainStream()); + endTime = feature->timestamp().endTime(); + } while (dataSource->getData(feature)); + + searchAlgorithm_->finishSegment(); + searchAlgorithm_->decodeManySteps(); + dataSource->finalize(); + featureExtractor_->leaveSegment(segment); + + // Result processing and logging + auto traceback = searchAlgorithm_->getCurrentBestTraceback(); + + auto lattice = buildLattice(searchAlgorithm_->getCurrentBestWordLattice(), segment->name()); + resultBuffer_ = std::make_pair(lattice, SegmentRef(new Flf::Segment(segment))); + + Core::XmlWriter& os(clog()); + os << Core::XmlOpen("traceback"); + traceback->write(os, modelCombination_.lexicon()->phonemeInventory()); + os << Core::XmlClose("traceback"); + + os << Core::XmlOpen("orth") + Core::XmlAttribute("source", "recognized"); + for (auto const& tracebackItem : *traceback) { + if (tracebackItem.pronunciation and tracebackItem.pronunciation->lemma()) { + os << tracebackItem.pronunciation->lemma()->preferredOrthographicForm() << Core::XmlBlank(); + } + } + os << Core::XmlClose("orth"); + + auto timerEnd = std::chrono::steady_clock::now(); + double duration = std::chrono::duration<double, std::milli>(timerEnd - timerStart).count(); + double signalDuration = (endTime - startTime) * 1000.; // convert duration to ms + + clog() << Core::XmlOpen("flf-recognizer-time") + Core::XmlAttribute("unit", "milliseconds") << duration << Core::XmlClose("flf-recognizer-time"); + clog() << Core::XmlOpen("flf-recognizer-rtf") << (duration / signalDuration) << Core::XmlClose("flf-recognizer-rtf"); +} + +void RecognizerNodeV2::work() { + clog() << Core::XmlOpen("layer") + Core::XmlAttribute("name", name); + recognizeSegment(static_cast<const Bliss::SpeechSegment*>(requestData(0))); + clog() << Core::XmlClose("layer"); +} + +ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref<const Search::LatticeAdaptor> latticeAdaptor, std::string segmentName) { + auto semiring = Semiring::create(Fsa::SemiringTypeTropical, 2); + semiring->setKey(0, "am"); + semiring->setScale(0, 1.0); + semiring->setKey(1, "lm"); + semiring->setScale(1, modelCombination_.languageModel()->scale()); + + auto sentenceEndLabel = Fsa::Epsilon; + const Bliss::Lemma* specialSentenceEndLemma = modelCombination_.lexicon()->specialLemma("sentence-end"); + if (specialSentenceEndLemma and specialSentenceEndLemma->nPronunciations() > 0) { + sentenceEndLabel = specialSentenceEndLemma->pronunciations().first->id(); + } + + Flf::LatticeHandler* handler = Flf::Module::instance().createLatticeHandler(config); + handler->setLexicon(Lexicon::us()); + if (latticeAdaptor->empty()) { + return ConstLatticeRef(); + } + ::Lattice::ConstWordLatticeRef lattice = latticeAdaptor->wordLattice(handler); + Core::Ref<const ::Lattice::WordBoundaries> boundaries = lattice->wordBoundaries(); + Fsa::ConstAutomatonRef amFsa = lattice->part(::Lattice::WordLattice::acousticFsa); + Fsa::ConstAutomatonRef lmFsa = lattice->part(::Lattice::WordLattice::lmFsa); + require_(Fsa::isAcyclic(amFsa) && Fsa::isAcyclic(lmFsa)); + + StaticBoundariesRef b = StaticBoundariesRef(new StaticBoundaries); + StaticLatticeRef s = StaticLatticeRef(new StaticLattice); + s->setType(Fsa::TypeAcceptor); + s->setProperties(Fsa::PropertyAcyclic | PropertyCrossWord, Fsa::PropertyAll); + s->setInputAlphabet(modelCombination_.lexicon()->lemmaPronunciationAlphabet()); + s->setSemiring(semiring); + s->setDescription(Core::form("recog(%s)", segmentName.c_str())); + s->setBoundaries(ConstBoundariesRef(b)); + s->setInitialStateId(0); + + Time timeOffset = (*boundaries)[amFsa->initialStateId()].time(); + + Fsa::Stack<Fsa::StateId> stateStack; + Core::Vector<Fsa::StateId> sidMap(amFsa->initialStateId() + 1, Fsa::InvalidStateId); + sidMap[amFsa->initialStateId()] = 0; + stateStack.push_back(amFsa->initialStateId()); + Fsa::StateId nextSid = 2; + Time finalTime = 0; + while (not stateStack.isEmpty()) { + Fsa::StateId sid = stateStack.pop(); + verify(sid < sidMap.size()); + const ::Lattice::WordBoundary& boundary((*boundaries)[sid]); + Fsa::ConstStateRef amSr = amFsa->getState(sid); + Fsa::ConstStateRef lmSr = lmFsa->getState(sid); + State* sp = new State(sidMap[sid]); + s->setState(sp); + b->set(sp->id(), Boundary(boundary.time() - timeOffset, + Boundary::Transit(boundary.transit().final, boundary.transit().initial))); + if (amSr->isFinal()) { + auto scores = semiring->create(); + scores->set(0, amSr->weight()); + scores->set(1, static_cast<Score>(lmSr->weight()) / semiring->scale(1)); + sp->newArc(1, scores, sentenceEndLabel); + finalTime = std::max(finalTime, boundary.time() - timeOffset); + } + for (Fsa::State::const_iterator am_a = amSr->begin(), lm_a = lmSr->begin(); (am_a != amSr->end()) && (lm_a != lmSr->end()); ++am_a, ++lm_a) { + sidMap.grow(am_a->target(), Fsa::InvalidStateId); + if (sidMap[am_a->target()] == Fsa::InvalidStateId) { + sidMap[am_a->target()] = nextSid++; + stateStack.push(am_a->target()); + } + Fsa::ConstStateRef targetAmSr = amFsa->getState(am_a->target()); + Fsa::ConstStateRef targetLmSr = amFsa->getState(lm_a->target()); + if (targetAmSr->isFinal() && targetLmSr->isFinal()) { + if (am_a->input() == Fsa::Epsilon) { + auto scores = semiring->create(); + scores->set(0, am_a->weight()); + scores->set(1, static_cast<Score>(lm_a->weight()) / semiring->scale(1)); + scores->add(0, Score(targetAmSr->weight())); + scores->add(1, Score(targetLmSr->weight()) / semiring->scale(1)); + sp->newArc(1, scores, sentenceEndLabel); + } + else { + auto scores = semiring->create(); + scores->set(0, am_a->weight()); + scores->set(1, static_cast<Score>(lm_a->weight()) / semiring->scale(1)); + sp->newArc(sidMap[am_a->target()], scores, am_a->input()); + } + } + else { + auto scores = semiring->create(); + scores->set(0, am_a->weight()); + scores->set(1, static_cast<Score>(lm_a->weight()) / semiring->scale(1)); + sp->newArc(sidMap[am_a->target()], scores, am_a->input()); + } + } + } + State* sp = new State(1); + sp->setFinal(semiring->clone(semiring->one())); + s->setState(sp); + b->set(sp->id(), Boundary(finalTime)); + return s; +} + +void RecognizerNodeV2::init(std::vector<std::string> const& arguments) { + modelCombination_.build(searchAlgorithm_->requiredModelCombination(), searchAlgorithm_->requiredAcousticModel(), Lexicon::us()); + searchAlgorithm_->setModelCombination(modelCombination_); + if (not connected(0)) { + criticalError("Speech segment at port 1 required"); + } +} + +void RecognizerNodeV2::sync() { + resultBuffer_.first.reset(); + resultBuffer_.second.reset(); +} + +void RecognizerNodeV2::finalize() { + searchAlgorithm_->reset(); +} + +ConstSegmentRef RecognizerNodeV2::sendSegment(RecognizerNodeV2::Port to) { + if (!resultBuffer_.second) { + work(); + } + return resultBuffer_.second; +} + +ConstLatticeRef RecognizerNodeV2::sendLattice(RecognizerNodeV2::Port to) { + if (!resultBuffer_.first) { + work(); + } + return resultBuffer_.first; +} + +} // namespace Flf diff --git a/src/Flf/RecognizerV2.hh b/src/Flf/RecognizerV2.hh new file mode 100644 index 00000000..ea6b04e2 --- /dev/null +++ b/src/Flf/RecognizerV2.hh @@ -0,0 +1,74 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef RECOGNIZER_V2_HH +#define RECOGNIZER_V2_HH + +#include <Flf/FlfCore/Lattice.hh> +#include <Search/Module.hh> +#include <Search/SearchV2.hh> +#include <Speech/Module.hh> +#include "Network.hh" +#include "SegmentwiseSpeechProcessor.hh" +#include "Speech/ModelCombination.hh" + +namespace Flf { + +NodeRef createRecognizerNodeV2(std::string const& name, Core::Configuration const& config); + +/* + * Node to run recognition on speech segments using a `SearchAlgorithmV2` internally. + */ +class RecognizerNodeV2 : public Node { +public: + RecognizerNodeV2(std::string const& name, Core::Configuration const& config); + + virtual ~RecognizerNodeV2() { + delete searchAlgorithm_; + } + + // Inherited methods + virtual void init(std::vector<std::string> const& arguments) override; + virtual void sync() override; + virtual void finalize() override; + + virtual ConstSegmentRef sendSegment(Port to) override; + virtual ConstLatticeRef sendLattice(Port to) override; + +private: + /* + * Perform recognition of `segment` using `searchAlgorithm_` and store the result in `resultBuffer_` + */ + void recognizeSegment(const Bliss::SpeechSegment* segment); + + /* + * Requests input segment and runs recognition on it + */ + void work(); + + /* + * Convert an output lattice from `searchAlgorithm_` to an Flf lattice + */ + ConstLatticeRef buildLattice(Core::Ref<const Search::LatticeAdaptor> latticeAdaptor, std::string segmentName); + + std::pair<ConstLatticeRef, ConstSegmentRef> resultBuffer_; + + Search::SearchAlgorithmV2* searchAlgorithm_; + Speech::ModelCombination modelCombination_; + SegmentwiseFeatureExtractorRef featureExtractor_; +}; + +} // namespace Flf + +#endif // RECOGNIZER_V2_HH diff --git a/src/Speech/ModelCombination.cc b/src/Speech/ModelCombination.cc index 075c9d8c..d417327e 100644 --- a/src/Speech/ModelCombination.cc +++ b/src/Speech/ModelCombination.cc @@ -16,6 +16,7 @@ #include <Am/Module.hh> #include <Lm/Module.hh> #include <Nn/Module.hh> +#include "Am/AcousticModel.hh" using namespace Speech; @@ -32,16 +33,43 @@ const Core::ParameterFloat ModelCombination::paramPronunciationScale( ModelCombination::ModelCombination(const Core::Configuration& c, Mode mode, - Am::AcousticModel::Mode acousticModelMode) + Am::AcousticModel::Mode acousticModelMode, + Bliss::LexiconRef lexicon) : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { - setLexicon(Bliss::Lexicon::create(select("lexicon"))); - if (!lexicon_) - criticalError("failed to initialize the lexicon"); + setPronunciationScale(paramPronunciationScale(c)); + build(mode, acousticModelMode, lexicon); +} - /*! \todo Scalable lexicon not implemented yet */ +ModelCombination::ModelCombination(const Core::Configuration& c, + Bliss::LexiconRef lexicon, + Core::Ref<Am::AcousticModel> acousticModel, + Core::Ref<Lm::ScaledLanguageModel> languageModel) + : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { setPronunciationScale(paramPronunciationScale(c)); + setLexicon(lexicon); + setAcousticModel(acousticModel); + setLanguageModel(languageModel); +} + +ModelCombination::~ModelCombination() {} + +void ModelCombination::build(Mode mode, + Am::AcousticModel::Mode acousticModelMode, + Bliss::LexiconRef lexicon) { + if (lexicon) { + setLexicon(lexicon); + log() << "Set lexicon in ModelCombination"; + } + else { + log() << "Create lexicon in ModelCombination"; + setLexicon(Bliss::Lexicon::create(select("lexicon"))); + } + + if (!lexicon_) { + criticalError("failed to initialize the lexicon"); + } if (mode & useAcousticModel) { setAcousticModel(Am::Module::instance().createAcousticModel( @@ -58,24 +86,12 @@ ModelCombination::ModelCombination(const Core::Configuration& c, if (mode & useLabelScorer) { setLabelScorer(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("label-scorer"))); - if (!labelScorer_) + if (!labelScorer_) { criticalError("failed to initialize label scorer"); + } } } -ModelCombination::ModelCombination(const Core::Configuration& c, - Bliss::LexiconRef lexicon, - Core::Ref<Am::AcousticModel> acousticModel, - Core::Ref<Lm::ScaledLanguageModel> languageModel) - : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { - setPronunciationScale(paramPronunciationScale(c)); - setLexicon(lexicon); - setAcousticModel(acousticModel); - setLanguageModel(languageModel); -} - -ModelCombination::~ModelCombination() {} - void ModelCombination::setLexicon(Bliss::LexiconRef lexicon) { lexicon_ = lexicon; } diff --git a/src/Speech/ModelCombination.hh b/src/Speech/ModelCombination.hh index 27466749..bfbcdb68 100644 --- a/src/Speech/ModelCombination.hh +++ b/src/Speech/ModelCombination.hh @@ -23,7 +23,6 @@ #include <Mc/Component.hh> #include <Nn/LabelScorer/LabelScorer.hh> - namespace Speech { /** Combination of a lexicon, an acoustic model or label scorer, and a language model. @@ -65,11 +64,14 @@ protected: public: ModelCombination(const Core::Configuration&, Mode = complete, - Am::AcousticModel::Mode = Am::AcousticModel::complete); + Am::AcousticModel::Mode = Am::AcousticModel::complete, + Bliss::LexiconRef = Bliss::LexiconRef()); ModelCombination(const Core::Configuration&, Bliss::LexiconRef, Core::Ref<Am::AcousticModel>, Core::Ref<Lm::ScaledLanguageModel>); virtual ~ModelCombination(); + void build(Mode = complete, Am::AcousticModel::Mode = Am::AcousticModel::complete, Bliss::LexiconRef = Bliss::LexiconRef()); + void getDependencies(Core::DependencySet&) const; Bliss::LexiconRef lexicon() const { @@ -88,8 +90,12 @@ public: } void setLanguageModel(Core::Ref<Lm::ScaledLanguageModel>); - void setLabelScorer(Core::Ref<Nn::LabelScorer> ls) { labelScorer_ = ls; } - Core::Ref<Nn::LabelScorer> labelScorer() const { return labelScorer_; } + void setLabelScorer(Core::Ref<Nn::LabelScorer> ls) { + labelScorer_ = ls; + } + Core::Ref<Nn::LabelScorer> labelScorer() const { + return labelScorer_; + } }; typedef Core::Ref<ModelCombination> ModelCombinationRef; From 1dc47e6e6d5dd30fe3db23019d8f67f48cb19ba8 Mon Sep 17 00:00:00 2001 From: Simon Berger <simon.berger95@gmail.com> Date: Thu, 6 Mar 2025 15:24:42 +0100 Subject: [PATCH 2/5] Make `modelCombination_` a ref + some formatting --- src/Flf/RecognizerV2.cc | 18 ++++++++----- src/Flf/RecognizerV2.hh | 6 ++--- src/Speech/ModelCombination.cc | 49 +++++++++++++++++----------------- src/Speech/ModelCombination.hh | 15 +++++++---- 4 files changed, 49 insertions(+), 39 deletions(-) diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc index f8a26a02..bcf7be5b 100644 --- a/src/Flf/RecognizerV2.cc +++ b/src/Flf/RecognizerV2.cc @@ -28,7 +28,7 @@ NodeRef createRecognizerNodeV2(const std::string& name, const Core::Configuratio RecognizerNodeV2::RecognizerNodeV2(const std::string& name, const Core::Configuration& config) : Node(name, config), searchAlgorithm_(Search::Module::instance().createSearchAlgorithm(select("search-algorithm"))), - modelCombination_(config) { + modelCombination_() { Core::Configuration featureExtractionConfig(config, "feature-extraction"); DataSourceRef dataSource = DataSourceRef(Speech::Module::instance().createDataSource(featureExtractionConfig)); featureExtractor_ = SegmentwiseFeatureExtractorRef(new SegmentwiseFeatureExtractor(featureExtractionConfig, dataSource)); @@ -74,7 +74,7 @@ void RecognizerNodeV2::recognizeSegment(const Bliss::SpeechSegment* segment) { Core::XmlWriter& os(clog()); os << Core::XmlOpen("traceback"); - traceback->write(os, modelCombination_.lexicon()->phonemeInventory()); + traceback->write(os, modelCombination_->lexicon()->phonemeInventory()); os << Core::XmlClose("traceback"); os << Core::XmlOpen("orth") + Core::XmlAttribute("source", "recognized"); @@ -104,10 +104,10 @@ ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref<const Search::LatticeAd semiring->setKey(0, "am"); semiring->setScale(0, 1.0); semiring->setKey(1, "lm"); - semiring->setScale(1, modelCombination_.languageModel()->scale()); + semiring->setScale(1, modelCombination_->languageModel()->scale()); auto sentenceEndLabel = Fsa::Epsilon; - const Bliss::Lemma* specialSentenceEndLemma = modelCombination_.lexicon()->specialLemma("sentence-end"); + const Bliss::Lemma* specialSentenceEndLemma = modelCombination_->lexicon()->specialLemma("sentence-end"); if (specialSentenceEndLemma and specialSentenceEndLemma->nPronunciations() > 0) { sentenceEndLabel = specialSentenceEndLemma->pronunciations().first->id(); } @@ -127,7 +127,7 @@ ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref<const Search::LatticeAd StaticLatticeRef s = StaticLatticeRef(new StaticLattice); s->setType(Fsa::TypeAcceptor); s->setProperties(Fsa::PropertyAcyclic | PropertyCrossWord, Fsa::PropertyAll); - s->setInputAlphabet(modelCombination_.lexicon()->lemmaPronunciationAlphabet()); + s->setInputAlphabet(modelCombination_->lexicon()->lemmaPronunciationAlphabet()); s->setSemiring(semiring); s->setDescription(Core::form("recog(%s)", segmentName.c_str())); s->setBoundaries(ConstBoundariesRef(b)); @@ -198,8 +198,12 @@ ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref<const Search::LatticeAd } void RecognizerNodeV2::init(std::vector<std::string> const& arguments) { - modelCombination_.build(searchAlgorithm_->requiredModelCombination(), searchAlgorithm_->requiredAcousticModel(), Lexicon::us()); - searchAlgorithm_->setModelCombination(modelCombination_); + modelCombination_ = Core::ref(new Speech::ModelCombination( + config, + searchAlgorithm_->requiredModelCombination(), + searchAlgorithm_->requiredAcousticModel(), + Lexicon::us())); + searchAlgorithm_->setModelCombination(*modelCombination_); if (not connected(0)) { criticalError("Speech segment at port 1 required"); } diff --git a/src/Flf/RecognizerV2.hh b/src/Flf/RecognizerV2.hh index ea6b04e2..3898216f 100644 --- a/src/Flf/RecognizerV2.hh +++ b/src/Flf/RecognizerV2.hh @@ -64,9 +64,9 @@ private: std::pair<ConstLatticeRef, ConstSegmentRef> resultBuffer_; - Search::SearchAlgorithmV2* searchAlgorithm_; - Speech::ModelCombination modelCombination_; - SegmentwiseFeatureExtractorRef featureExtractor_; + Search::SearchAlgorithmV2* searchAlgorithm_; + Core::Ref<Speech::ModelCombination> modelCombination_; + SegmentwiseFeatureExtractorRef featureExtractor_; }; } // namespace Flf diff --git a/src/Speech/ModelCombination.cc b/src/Speech/ModelCombination.cc index d417327e..7bd385be 100644 --- a/src/Speech/ModelCombination.cc +++ b/src/Speech/ModelCombination.cc @@ -39,25 +39,7 @@ ModelCombination::ModelCombination(const Core::Configuration& c, Mc::Component(c), pronunciationScale_(0) { setPronunciationScale(paramPronunciationScale(c)); - build(mode, acousticModelMode, lexicon); -} - -ModelCombination::ModelCombination(const Core::Configuration& c, - Bliss::LexiconRef lexicon, - Core::Ref<Am::AcousticModel> acousticModel, - Core::Ref<Lm::ScaledLanguageModel> languageModel) - : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { - setPronunciationScale(paramPronunciationScale(c)); - setLexicon(lexicon); - setAcousticModel(acousticModel); - setLanguageModel(languageModel); -} -ModelCombination::~ModelCombination() {} - -void ModelCombination::build(Mode mode, - Am::AcousticModel::Mode acousticModelMode, - Bliss::LexiconRef lexicon) { if (lexicon) { setLexicon(lexicon); log() << "Set lexicon in ModelCombination"; @@ -68,30 +50,45 @@ void ModelCombination::build(Mode mode, } if (!lexicon_) { - criticalError("failed to initialize the lexicon"); + criticalError("Failed to initialize the lexicon"); } if (mode & useAcousticModel) { setAcousticModel(Am::Module::instance().createAcousticModel( select("acoustic-model"), lexicon_, acousticModelMode)); - if (!acousticModel_) - criticalError("failed to initialize the acoustic model"); + if (!acousticModel_) { + criticalError("Failed to initialize the acoustic model"); + } } if (mode & useLanguageModel) { setLanguageModel(Lm::Module::instance().createScaledLanguageModel(select("lm"), lexicon_)); - if (!languageModel_) - criticalError("failed to initialize language model"); + if (!languageModel_) { + criticalError("Failed to initialize language model"); + } } if (mode & useLabelScorer) { setLabelScorer(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("label-scorer"))); if (!labelScorer_) { - criticalError("failed to initialize label scorer"); + criticalError("Failed to initialize label scorer"); } } } +ModelCombination::ModelCombination(const Core::Configuration& c, + Bliss::LexiconRef lexicon, + Core::Ref<Am::AcousticModel> acousticModel, + Core::Ref<Lm::ScaledLanguageModel> languageModel) + : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { + setPronunciationScale(paramPronunciationScale(c)); + setLexicon(lexicon); + setAcousticModel(acousticModel); + setLanguageModel(languageModel); +} + +ModelCombination::~ModelCombination() {} + void ModelCombination::setLexicon(Bliss::LexiconRef lexicon) { lexicon_ = lexicon; } @@ -108,6 +105,10 @@ void ModelCombination::setLanguageModel(Core::Ref<Lm::ScaledLanguageModel> langu languageModel_->setParentScale(scale()); } +void ModelCombination::setLabelScorer(Core::Ref<Nn::LabelScorer> ls) { + labelScorer_ = ls; +} + void ModelCombination::distributeScaleUpdate(const Mc::ScaleUpdate& scaleUpdate) { if (lexicon_) { Mm::Score scale; diff --git a/src/Speech/ModelCombination.hh b/src/Speech/ModelCombination.hh index bfbcdb68..d94c9d1f 100644 --- a/src/Speech/ModelCombination.hh +++ b/src/Speech/ModelCombination.hh @@ -77,25 +77,30 @@ public: Bliss::LexiconRef lexicon() const { return lexicon_; } - void setLexicon(Bliss::LexiconRef); + + void setLexicon(Bliss::LexiconRef); + Mm::Score pronunciationScale() const { return pronunciationScale_ * scale(); } + Core::Ref<Am::AcousticModel> acousticModel() const { return acousticModel_; } - void setAcousticModel(Core::Ref<Am::AcousticModel>); + + void setAcousticModel(Core::Ref<Am::AcousticModel>); + Core::Ref<Lm::ScaledLanguageModel> languageModel() const { return languageModel_; } + void setLanguageModel(Core::Ref<Lm::ScaledLanguageModel>); - void setLabelScorer(Core::Ref<Nn::LabelScorer> ls) { - labelScorer_ = ls; - } Core::Ref<Nn::LabelScorer> labelScorer() const { return labelScorer_; } + + void setLabelScorer(Core::Ref<Nn::LabelScorer> ls); }; typedef Core::Ref<ModelCombination> ModelCombinationRef; From 600999b06b8fb524a4cc0494e4d5bc89f6655397 Mon Sep 17 00:00:00 2001 From: Simon Berger <simon.berger95@gmail.com> Date: Thu, 6 Mar 2025 18:03:31 +0100 Subject: [PATCH 3/5] Better readable lattice building function --- src/Flf/RecognizerV2.cc | 123 +++++++++++++++++++++------------------- 1 file changed, 65 insertions(+), 58 deletions(-) diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc index bcf7be5b..35121a11 100644 --- a/src/Flf/RecognizerV2.cc +++ b/src/Flf/RecognizerV2.cc @@ -13,9 +13,10 @@ * limitations under the License. */ #include "RecognizerV2.hh" +#include <Core/XmlStream.hh> +#include <Fsa/Types.hh> #include <Speech/ModelCombination.hh> #include <chrono> -#include "Core/XmlStream.hh" #include "LatticeHandler.hh" #include "Module.hh" @@ -100,11 +101,13 @@ void RecognizerNodeV2::work() { } ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref<const Search::LatticeAdaptor> latticeAdaptor, std::string segmentName) { + auto lmScale = modelCombination_->languageModel()->scale(); + auto semiring = Semiring::create(Fsa::SemiringTypeTropical, 2); semiring->setKey(0, "am"); semiring->setScale(0, 1.0); semiring->setKey(1, "lm"); - semiring->setScale(1, modelCombination_->languageModel()->scale()); + semiring->setScale(1, lmScale); auto sentenceEndLabel = Fsa::Epsilon; const Bliss::Lemma* specialSentenceEndLemma = modelCombination_->lexicon()->specialLemma("sentence-end"); @@ -123,78 +126,82 @@ ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref<const Search::LatticeAd Fsa::ConstAutomatonRef lmFsa = lattice->part(::Lattice::WordLattice::lmFsa); require_(Fsa::isAcyclic(amFsa) && Fsa::isAcyclic(lmFsa)); - StaticBoundariesRef b = StaticBoundariesRef(new StaticBoundaries); - StaticLatticeRef s = StaticLatticeRef(new StaticLattice); - s->setType(Fsa::TypeAcceptor); - s->setProperties(Fsa::PropertyAcyclic | PropertyCrossWord, Fsa::PropertyAll); - s->setInputAlphabet(modelCombination_->lexicon()->lemmaPronunciationAlphabet()); - s->setSemiring(semiring); - s->setDescription(Core::form("recog(%s)", segmentName.c_str())); - s->setBoundaries(ConstBoundariesRef(b)); - s->setInitialStateId(0); + StaticBoundariesRef flfBoundaries = StaticBoundariesRef(new StaticBoundaries); + StaticLatticeRef flfLattice = StaticLatticeRef(new StaticLattice); + flfLattice->setType(Fsa::TypeAcceptor); + flfLattice->setProperties(Fsa::PropertyAcyclic | PropertyCrossWord, Fsa::PropertyAll); + flfLattice->setInputAlphabet(modelCombination_->lexicon()->lemmaPronunciationAlphabet()); + flfLattice->setSemiring(semiring); + flfLattice->setDescription(Core::form("recog(%s)", segmentName.c_str())); + flfLattice->setBoundaries(ConstBoundariesRef(flfBoundaries)); + flfLattice->setInitialStateId(0); Time timeOffset = (*boundaries)[amFsa->initialStateId()].time(); Fsa::Stack<Fsa::StateId> stateStack; - Core::Vector<Fsa::StateId> sidMap(amFsa->initialStateId() + 1, Fsa::InvalidStateId); - sidMap[amFsa->initialStateId()] = 0; + Core::Vector<Fsa::StateId> stateIdMap(amFsa->initialStateId() + 1, Fsa::InvalidStateId); + stateIdMap[amFsa->initialStateId()] = 0; stateStack.push_back(amFsa->initialStateId()); - Fsa::StateId nextSid = 2; - Time finalTime = 0; + Fsa::StateId nextStateId = 2; + Time finalTime = 0; while (not stateStack.isEmpty()) { - Fsa::StateId sid = stateStack.pop(); - verify(sid < sidMap.size()); - const ::Lattice::WordBoundary& boundary((*boundaries)[sid]); - Fsa::ConstStateRef amSr = amFsa->getState(sid); - Fsa::ConstStateRef lmSr = lmFsa->getState(sid); - State* sp = new State(sidMap[sid]); - s->setState(sp); - b->set(sp->id(), Boundary(boundary.time() - timeOffset, - Boundary::Transit(boundary.transit().final, boundary.transit().initial))); - if (amSr->isFinal()) { + Fsa::StateId stateId = stateStack.pop(); + verify(stateId < stateIdMap.size()); + const ::Lattice::WordBoundary& boundary((*boundaries)[stateId]); + Fsa::ConstStateRef amFsaState = amFsa->getState(stateId); + Fsa::ConstStateRef lmFsaState = lmFsa->getState(stateId); + State* flfState = new State(stateIdMap[stateId]); + flfLattice->setState(flfState); + flfBoundaries->set(flfState->id(), Boundary(boundary.time() - timeOffset, + Boundary::Transit(boundary.transit().final, boundary.transit().initial))); + if (amFsaState->isFinal()) { auto scores = semiring->create(); - scores->set(0, amSr->weight()); - scores->set(1, static_cast<Score>(lmSr->weight()) / semiring->scale(1)); - sp->newArc(1, scores, sentenceEndLabel); + scores->set(0, amFsaState->weight()); + if (lmScale) { + scores->set(1, static_cast<Score>(lmFsaState->weight()) / lmScale); + } + else { + scores->set(1, 0.0); + } + flfState->newArc(1, scores, sentenceEndLabel); finalTime = std::max(finalTime, boundary.time() - timeOffset); } - for (Fsa::State::const_iterator am_a = amSr->begin(), lm_a = lmSr->begin(); (am_a != amSr->end()) && (lm_a != lmSr->end()); ++am_a, ++lm_a) { - sidMap.grow(am_a->target(), Fsa::InvalidStateId); - if (sidMap[am_a->target()] == Fsa::InvalidStateId) { - sidMap[am_a->target()] = nextSid++; - stateStack.push(am_a->target()); + for (Fsa::State::const_iterator amArc = amFsaState->begin(), lmArc = lmFsaState->begin(); (amArc != amFsaState->end()) && (lmArc != lmFsaState->end()); ++amArc, ++lmArc) { + stateIdMap.grow(amArc->target(), Fsa::InvalidStateId); + if (stateIdMap[amArc->target()] == Fsa::InvalidStateId) { + stateIdMap[amArc->target()] = nextStateId++; + stateStack.push(amArc->target()); } - Fsa::ConstStateRef targetAmSr = amFsa->getState(am_a->target()); - Fsa::ConstStateRef targetLmSr = amFsa->getState(lm_a->target()); - if (targetAmSr->isFinal() && targetLmSr->isFinal()) { - if (am_a->input() == Fsa::Epsilon) { - auto scores = semiring->create(); - scores->set(0, am_a->weight()); - scores->set(1, static_cast<Score>(lm_a->weight()) / semiring->scale(1)); - scores->add(0, Score(targetAmSr->weight())); - scores->add(1, Score(targetLmSr->weight()) / semiring->scale(1)); - sp->newArc(1, scores, sentenceEndLabel); - } - else { - auto scores = semiring->create(); - scores->set(0, am_a->weight()); - scores->set(1, static_cast<Score>(lm_a->weight()) / semiring->scale(1)); - sp->newArc(sidMap[am_a->target()], scores, am_a->input()); + Fsa::ConstStateRef targetAmState = amFsa->getState(amArc->target()); + Fsa::ConstStateRef targetLmState = amFsa->getState(lmArc->target()); + + auto scores = semiring->create(); + scores->set(0, amArc->weight()); + + if (lmScale) { + scores->set(1, static_cast<Score>(lmArc->weight()) / lmScale); + } + else { + scores->set(1, 0); + } + + if (targetAmState->isFinal() and targetLmState->isFinal() and amArc->input() == Fsa::Epsilon) { + scores->add(0, Score(targetAmState->weight())); + if (lmScale) { + scores->add(1, Score(targetLmState->weight()) / lmScale); } + flfState->newArc(1, scores, sentenceEndLabel); } else { - auto scores = semiring->create(); - scores->set(0, am_a->weight()); - scores->set(1, static_cast<Score>(lm_a->weight()) / semiring->scale(1)); - sp->newArc(sidMap[am_a->target()], scores, am_a->input()); + flfState->newArc(stateIdMap[amArc->target()], scores, amArc->input()); } } } - State* sp = new State(1); - sp->setFinal(semiring->clone(semiring->one())); - s->setState(sp); - b->set(sp->id(), Boundary(finalTime)); - return s; + State* finalState = new State(1); + finalState->setFinal(semiring->clone(semiring->one())); + flfLattice->setState(finalState); + flfBoundaries->set(finalState->id(), Boundary(finalTime)); + return flfLattice; } void RecognizerNodeV2::init(std::vector<std::string> const& arguments) { From e309602a9bb037ff2b57c3aba0dca186384e2007 Mon Sep 17 00:00:00 2001 From: Simon Berger <simon.berger95@gmail.com> Date: Wed, 26 Mar 2025 19:06:38 +0100 Subject: [PATCH 4/5] Fix error string --- src/Flf/RecognizerV2.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc index 35121a11..3740b488 100644 --- a/src/Flf/RecognizerV2.cc +++ b/src/Flf/RecognizerV2.cc @@ -212,7 +212,7 @@ void RecognizerNodeV2::init(std::vector<std::string> const& arguments) { Lexicon::us())); searchAlgorithm_->setModelCombination(*modelCombination_); if (not connected(0)) { - criticalError("Speech segment at port 1 required"); + criticalError("Speech segment at port 0 required"); } } From 0121868ae9f20b791ec8138568cb366254fda609 Mon Sep 17 00:00:00 2001 From: Simon Berger <simon.berger95@gmail.com> Date: Mon, 31 Mar 2025 11:54:58 +0200 Subject: [PATCH 5/5] Remove additional `decodeManySteps()` call --- src/Flf/RecognizerV2.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc index 3740b488..bf6c5c61 100644 --- a/src/Flf/RecognizerV2.cc +++ b/src/Flf/RecognizerV2.cc @@ -63,7 +63,6 @@ void RecognizerNodeV2::recognizeSegment(const Bliss::SpeechSegment* segment) { } while (dataSource->getData(feature)); searchAlgorithm_->finishSegment(); - searchAlgorithm_->decodeManySteps(); dataSource->finalize(); featureExtractor_->leaveSegment(segment);