diff --git a/src/Flf/Makefile b/src/Flf/Makefile index 99afeb53..62db8d3a 100644 --- a/src/Flf/Makefile +++ b/src/Flf/Makefile @@ -59,6 +59,7 @@ LIBSPRINTFLF_O = \ $(OBJDIR)/Prune.o \ $(OBJDIR)/PushForwardRescoring.o \ $(OBJDIR)/Recognizer.o \ + $(OBJDIR)/RecognizerV2.o \ $(OBJDIR)/IncrementalRecognizer.o \ $(OBJDIR)/Rescore.o \ $(OBJDIR)/RescoreLm.o \ diff --git a/src/Flf/NodeRegistration.hh b/src/Flf/NodeRegistration.hh index 8beb4e47..2b9d03c8 100644 --- a/src/Flf/NodeRegistration.hh +++ b/src/Flf/NodeRegistration.hh @@ -51,6 +51,7 @@ #include "Prune.hh" #include "PushForwardRescoring.hh" #include "Recognizer.hh" +#include "RecognizerV2.hh" #include "Rescale.hh" #include "Rescore.hh" #include "RescoreLm.hh" @@ -2145,6 +2146,23 @@ void registerNodeCreators(NodeFactory* factory) { " 0:lattice", &createRecognizerNode)); + factory->add( + NodeCreator( + "recognizer-v2", + "Second version of RASR recognizer.\n" + "Output are lattices in Flf format.\n" + "Much more minimalistic than the first recognizer node\n" + "and works with a `SearchAlgorithmV2` instead of\n" + "`SearchAlgorithm`. Performs recognition of the input segments\n" + "and sends the result lattices as outputs.\n" + "[*.network.recognizer-v2]\n" + "type = recognizer-v2\n" + "input:\n" + " 0:bliss-speech-segment\n" + "output:\n" + " 0:lattice", + &createRecognizerNodeV2)); + factory->add( NodeCreator( "incremental-recognizer", diff --git a/src/Flf/RecognizerV2.cc b/src/Flf/RecognizerV2.cc new file mode 100644 index 00000000..f3233690 --- /dev/null +++ b/src/Flf/RecognizerV2.cc @@ -0,0 +1,245 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RecognizerV2.hh" +#include +#include +#include +#include +#include "LatticeHandler.hh" +#include "Module.hh" + +namespace Flf { + +NodeRef createRecognizerNodeV2(const std::string& name, const Core::Configuration& config) { + return NodeRef(new RecognizerNodeV2(name, config)); +} + +RecognizerNodeV2::RecognizerNodeV2(const std::string& name, const Core::Configuration& config) + : Node(name, config), + latticeResultBuffer_(), + segmentResultBuffer_(), + searchAlgorithm_(Search::Module::instance().createSearchAlgorithmV2(select("search-algorithm"))), + modelCombination_() { + Core::Configuration featureExtractionConfig(config, "feature-extraction"); + DataSourceRef dataSource = DataSourceRef(Speech::Module::instance().createDataSource(featureExtractionConfig)); + featureExtractor_ = SegmentwiseFeatureExtractorRef(new SegmentwiseFeatureExtractor(featureExtractionConfig, dataSource)); +} + +void RecognizerNodeV2::recognizeSegment(const Bliss::SpeechSegment* segment) { + if (!segment->orth().empty()) { + clog() << Core::XmlOpen("orth") + Core::XmlAttribute("source", "reference") + << segment->orth() + << Core::XmlClose("orth"); + } + + // Initialize recognizer and feature extractor + searchAlgorithm_->reset(); + searchAlgorithm_->enterSegment(); + + featureExtractor_->enterSegment(segment); + DataSourceRef dataSource = featureExtractor_->extractor(); + dataSource->initialize(const_cast(segment)); + + auto timerStart = std::chrono::steady_clock::now(); + + FeatureRef feature; + dataSource->getData(feature); + Time startTimestamp = feature->timestamp().startTime(); + Time endTimestamp; + + // Loop over features and perform recognition + do { + searchAlgorithm_->putFeature(*feature->mainStream()); + endTimestamp = feature->timestamp().endTime(); + } while (dataSource->getData(feature)); + + searchAlgorithm_->finishSegment(); + dataSource->finalize(); + featureExtractor_->leaveSegment(segment); + + // Result processing and logging + auto traceback = searchAlgorithm_->getCurrentBestTraceback(); + + auto lattice = buildLattice(searchAlgorithm_->getCurrentBestWordLattice(), segment->name()); + latticeResultBuffer_ = lattice; + segmentResultBuffer_ = SegmentRef(new Flf::Segment(segment)); + + Core::XmlWriter& os(clog()); + os << Core::XmlOpen("traceback"); + traceback->write(os, modelCombination_->lexicon()->phonemeInventory()); + os << Core::XmlClose("traceback"); + + os << Core::XmlOpen("orth") + Core::XmlAttribute("source", "recognized"); + for (auto const& tracebackItem : *traceback) { + if (tracebackItem.pronunciation and tracebackItem.pronunciation->lemma()) { + os << tracebackItem.pronunciation->lemma()->preferredOrthographicForm() << Core::XmlBlank(); + } + } + os << Core::XmlClose("orth"); + + auto timerEnd = std::chrono::steady_clock::now(); + double duration = std::chrono::duration(timerEnd - timerStart).count(); + double signalDuration = (endTimestamp - startTimestamp) * 1000.; // convert duration to ms + + clog() << Core::XmlOpen("flf-recognizer-time") + Core::XmlAttribute("unit", "milliseconds") << duration << Core::XmlClose("flf-recognizer-time"); + clog() << Core::XmlOpen("flf-recognizer-rtf") << (duration / signalDuration) << Core::XmlClose("flf-recognizer-rtf"); +} + +void RecognizerNodeV2::work() { + clog() << Core::XmlOpen("layer") + Core::XmlAttribute("name", name); + recognizeSegment(static_cast(requestData(0))); + clog() << Core::XmlClose("layer"); +} + +ConstLatticeRef RecognizerNodeV2::buildLattice(Core::Ref latticeAdaptor, std::string segmentName) { + auto lmScale = modelCombination_->languageModel()->scale(); + + auto semiring = Semiring::create(Fsa::SemiringTypeTropical, 2); + semiring->setKey(0, "am"); + semiring->setScale(0, 1.0); + semiring->setKey(1, "lm"); + semiring->setScale(1, lmScale); + + auto sentenceEndLabel = Fsa::Epsilon; + const Bliss::Lemma* specialSentenceEndLemma = modelCombination_->lexicon()->specialLemma("sentence-end"); + if (specialSentenceEndLemma and specialSentenceEndLemma->nPronunciations() > 0) { + sentenceEndLabel = specialSentenceEndLemma->pronunciations().first->id(); + } + + Flf::LatticeHandler* handler = Flf::Module::instance().createLatticeHandler(config); + handler->setLexicon(Lexicon::us()); + if (latticeAdaptor->empty()) { + return ConstLatticeRef(); + } + ::Lattice::ConstWordLatticeRef lattice = latticeAdaptor->wordLattice(handler); + Core::Ref boundaries = lattice->wordBoundaries(); + Fsa::ConstAutomatonRef amFsa = lattice->part(::Lattice::WordLattice::acousticFsa); + Fsa::ConstAutomatonRef lmFsa = lattice->part(::Lattice::WordLattice::lmFsa); + require_(Fsa::isAcyclic(amFsa) && Fsa::isAcyclic(lmFsa)); + + StaticBoundariesRef flfBoundaries = StaticBoundariesRef(new StaticBoundaries); + StaticLatticeRef flfLattice = StaticLatticeRef(new StaticLattice); + flfLattice->setType(Fsa::TypeAcceptor); + flfLattice->setProperties(Fsa::PropertyAcyclic | PropertyCrossWord, Fsa::PropertyAll); + flfLattice->setInputAlphabet(modelCombination_->lexicon()->lemmaPronunciationAlphabet()); + flfLattice->setSemiring(semiring); + flfLattice->setDescription(Core::form("recog(%s)", segmentName.c_str())); + flfLattice->setBoundaries(ConstBoundariesRef(flfBoundaries)); + flfLattice->setInitialStateId(0); + + Time timeOffset = (*boundaries)[amFsa->initialStateId()].time(); + + Fsa::Stack stateStack; + Core::Vector stateIdMap(amFsa->initialStateId() + 1, Fsa::InvalidStateId); + stateIdMap[amFsa->initialStateId()] = 0; + stateStack.push_back(amFsa->initialStateId()); + Fsa::StateId nextStateId = 2; + Time finalTime = 0; + while (not stateStack.isEmpty()) { + Fsa::StateId stateId = stateStack.pop(); + verify(stateId < stateIdMap.size()); + const ::Lattice::WordBoundary& boundary((*boundaries)[stateId]); + Fsa::ConstStateRef amFsaState = amFsa->getState(stateId); + Fsa::ConstStateRef lmFsaState = lmFsa->getState(stateId); + State* flfState = new State(stateIdMap[stateId]); + flfLattice->setState(flfState); + flfBoundaries->set(flfState->id(), Boundary(boundary.time() - timeOffset, + Boundary::Transit(boundary.transit().final, boundary.transit().initial))); + if (amFsaState->isFinal()) { + auto scores = semiring->create(); + scores->set(0, amFsaState->weight()); + if (lmScale) { + scores->set(1, static_cast(lmFsaState->weight()) / lmScale); + } + else { + scores->set(1, 0.0); + } + flfState->newArc(1, scores, sentenceEndLabel); + finalTime = std::max(finalTime, boundary.time() - timeOffset); + } + for (Fsa::State::const_iterator amArc = amFsaState->begin(), lmArc = lmFsaState->begin(); (amArc != amFsaState->end()) && (lmArc != lmFsaState->end()); ++amArc, ++lmArc) { + stateIdMap.grow(amArc->target(), Fsa::InvalidStateId); + if (stateIdMap[amArc->target()] == Fsa::InvalidStateId) { + stateIdMap[amArc->target()] = nextStateId++; + stateStack.push(amArc->target()); + } + Fsa::ConstStateRef targetAmState = amFsa->getState(amArc->target()); + Fsa::ConstStateRef targetLmState = amFsa->getState(lmArc->target()); + + auto scores = semiring->create(); + scores->set(0, amArc->weight()); + + if (lmScale) { + scores->set(1, static_cast(lmArc->weight()) / lmScale); + } + else { + scores->set(1, 0); + } + + if (targetAmState->isFinal() and targetLmState->isFinal() and amArc->input() == Fsa::Epsilon) { + scores->add(0, Score(targetAmState->weight())); + if (lmScale) { + scores->add(1, Score(targetLmState->weight()) / lmScale); + } + flfState->newArc(1, scores, sentenceEndLabel); + } + else { + flfState->newArc(stateIdMap[amArc->target()], scores, amArc->input()); + } + } + } + State* finalState = new State(1); + finalState->setFinal(semiring->clone(semiring->one())); + flfLattice->setState(finalState); + flfBoundaries->set(finalState->id(), Boundary(finalTime)); + return flfLattice; +} + +void RecognizerNodeV2::init(std::vector const& arguments) { + modelCombination_ = Core::ref(new Speech::ModelCombination( + config, + searchAlgorithm_->requiredModelCombination(), + searchAlgorithm_->requiredAcousticModel(), + Lexicon::us())); + searchAlgorithm_->setModelCombination(*modelCombination_); + if (not connected(0)) { + criticalError("Speech segment at port 0 required"); + } +} + +void RecognizerNodeV2::sync() { + latticeResultBuffer_.reset(); + segmentResultBuffer_.reset(); +} + +void RecognizerNodeV2::finalize() { + searchAlgorithm_->reset(); +} + +ConstSegmentRef RecognizerNodeV2::sendSegment(RecognizerNodeV2::Port to) { + if (not segmentResultBuffer_) { + work(); + } + return segmentResultBuffer_; +} + +ConstLatticeRef RecognizerNodeV2::sendLattice(RecognizerNodeV2::Port to) { + if (not latticeResultBuffer_) { + work(); + } + return latticeResultBuffer_; +} + +} // namespace Flf diff --git a/src/Flf/RecognizerV2.hh b/src/Flf/RecognizerV2.hh new file mode 100644 index 00000000..ce6e38f9 --- /dev/null +++ b/src/Flf/RecognizerV2.hh @@ -0,0 +1,71 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef RECOGNIZER_V2_HH +#define RECOGNIZER_V2_HH + +#include +#include +#include +#include +#include "Network.hh" +#include "SegmentwiseSpeechProcessor.hh" +#include "Speech/ModelCombination.hh" + +namespace Flf { + +NodeRef createRecognizerNodeV2(std::string const& name, Core::Configuration const& config); + +/* + * Node to run recognition on speech segments using a `SearchAlgorithmV2` internally. + */ +class RecognizerNodeV2 : public Node { +public: + RecognizerNodeV2(std::string const& name, Core::Configuration const& config); + + // Inherited methods + virtual void init(std::vector const& arguments) override; + virtual void sync() override; + virtual void finalize() override; + + virtual ConstSegmentRef sendSegment(Port to) override; + virtual ConstLatticeRef sendLattice(Port to) override; + +private: + /* + * Perform recognition of `segment` using `searchAlgorithm_` and store the result in `resultBuffer_` + */ + void recognizeSegment(const Bliss::SpeechSegment* segment); + + /* + * Requests input segment and runs recognition on it + */ + void work(); + + /* + * Convert an output lattice from `searchAlgorithm_` to an Flf lattice + */ + ConstLatticeRef buildLattice(Core::Ref latticeAdaptor, std::string segmentName); + + ConstLatticeRef latticeResultBuffer_; + ConstSegmentRef segmentResultBuffer_; + + std::unique_ptr searchAlgorithm_; + Core::Ref modelCombination_; + SegmentwiseFeatureExtractorRef featureExtractor_; +}; + +} // namespace Flf + +#endif // RECOGNIZER_V2_HH diff --git a/src/Speech/ModelCombination.cc b/src/Speech/ModelCombination.cc index 075c9d8c..7bd385be 100644 --- a/src/Speech/ModelCombination.cc +++ b/src/Speech/ModelCombination.cc @@ -16,6 +16,7 @@ #include #include #include +#include "Am/AcousticModel.hh" using namespace Speech; @@ -32,34 +33,46 @@ const Core::ParameterFloat ModelCombination::paramPronunciationScale( ModelCombination::ModelCombination(const Core::Configuration& c, Mode mode, - Am::AcousticModel::Mode acousticModelMode) + Am::AcousticModel::Mode acousticModelMode, + Bliss::LexiconRef lexicon) : Core::Component(c), Mc::Component(c), pronunciationScale_(0) { - setLexicon(Bliss::Lexicon::create(select("lexicon"))); - if (!lexicon_) - criticalError("failed to initialize the lexicon"); - - /*! \todo Scalable lexicon not implemented yet */ setPronunciationScale(paramPronunciationScale(c)); + if (lexicon) { + setLexicon(lexicon); + log() << "Set lexicon in ModelCombination"; + } + else { + log() << "Create lexicon in ModelCombination"; + setLexicon(Bliss::Lexicon::create(select("lexicon"))); + } + + if (!lexicon_) { + criticalError("Failed to initialize the lexicon"); + } + if (mode & useAcousticModel) { setAcousticModel(Am::Module::instance().createAcousticModel( select("acoustic-model"), lexicon_, acousticModelMode)); - if (!acousticModel_) - criticalError("failed to initialize the acoustic model"); + if (!acousticModel_) { + criticalError("Failed to initialize the acoustic model"); + } } if (mode & useLanguageModel) { setLanguageModel(Lm::Module::instance().createScaledLanguageModel(select("lm"), lexicon_)); - if (!languageModel_) - criticalError("failed to initialize language model"); + if (!languageModel_) { + criticalError("Failed to initialize language model"); + } } if (mode & useLabelScorer) { setLabelScorer(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("label-scorer"))); - if (!labelScorer_) - criticalError("failed to initialize label scorer"); + if (!labelScorer_) { + criticalError("Failed to initialize label scorer"); + } } } @@ -92,6 +105,10 @@ void ModelCombination::setLanguageModel(Core::Ref langu languageModel_->setParentScale(scale()); } +void ModelCombination::setLabelScorer(Core::Ref ls) { + labelScorer_ = ls; +} + void ModelCombination::distributeScaleUpdate(const Mc::ScaleUpdate& scaleUpdate) { if (lexicon_) { Mm::Score scale; diff --git a/src/Speech/ModelCombination.hh b/src/Speech/ModelCombination.hh index 6b067afa..704b4297 100644 --- a/src/Speech/ModelCombination.hh +++ b/src/Speech/ModelCombination.hh @@ -64,32 +64,40 @@ protected: public: ModelCombination(const Core::Configuration&, Mode = complete, - Am::AcousticModel::Mode = Am::AcousticModel::complete); + Am::AcousticModel::Mode = Am::AcousticModel::complete, + Bliss::LexiconRef = Bliss::LexiconRef()); ModelCombination(const Core::Configuration&, Bliss::LexiconRef, Core::Ref, Core::Ref); virtual ~ModelCombination(); + void build(Mode = complete, Am::AcousticModel::Mode = Am::AcousticModel::complete, Bliss::LexiconRef = Bliss::LexiconRef()); + void getDependencies(Core::DependencySet&) const; Bliss::LexiconRef lexicon() const { return lexicon_; } - void setLexicon(Bliss::LexiconRef); + + void setLexicon(Bliss::LexiconRef); + Mm::Score pronunciationScale() const { return pronunciationScale_ * scale(); } + Core::Ref acousticModel() const { return acousticModel_; } - void setAcousticModel(Core::Ref); + + void setAcousticModel(Core::Ref); + Core::Ref languageModel() const { return languageModel_; } + void setLanguageModel(Core::Ref); - void setLabelScorer(Core::Ref ls) { - labelScorer_ = ls; - } + void setLabelScorer(Core::Ref ls); + Core::Ref labelScorer() const { return labelScorer_; }