From 0cef183882eb2b815969689ebe3d29ef52b5ddf9 Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 10 Dec 2024 18:57:29 +0100 Subject: [PATCH 01/24] Refactored TreeBuilder --- .../AdvancedTreeSearch/PersistentStateTree.cc | 9 +- .../AdvancedTreeSearch/PersistentStateTree.hh | 7 +- src/Search/AdvancedTreeSearch/SearchSpace.cc | 52 +- src/Search/AdvancedTreeSearch/SearchSpace.hh | 20 +- src/Search/AdvancedTreeSearch/TreeBuilder.cc | 1168 +++++++++-------- src/Search/AdvancedTreeSearch/TreeBuilder.hh | 183 +-- 6 files changed, 764 insertions(+), 675 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc index ed8a5d08c..83cd5030b 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc @@ -166,14 +166,15 @@ struct ConvertTree { } }; -PersistentStateTree::PersistentStateTree(Core::Configuration config, Core::Ref acousticModel, Bliss::LexiconRef lexicon) +PersistentStateTree::PersistentStateTree(Core::Configuration config, Core::Ref acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory) : masterTree(0), rootState(0), ciRootState(0), archive_(paramCacheArchive(Core::Configuration(config, "search-network"))), acousticModel_(acousticModel), lexicon_(lexicon), - config_(config) { + config_(config), + treeBuilderFactory_(treeBuilderFactory) { if (acousticModel_.get() && lexicon_.get()) { const Am::ClassicAcousticModel* am = required_cast(const Am::ClassicAcousticModel*, acousticModel.get()); Core::DependencySet d; @@ -320,7 +321,7 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) { in >> masterTree >> dependenciesChecksum; if (dependenciesChecksum != dependencies_.getChecksum()) { - Core::Application::us()->log() << "dependencies of the network image don't equal the requiered dependencies with checksum " << dependenciesChecksum; + Core::Application::us()->log() << "dependencies of the network image don't equal the required dependencies with checksum " << dependenciesChecksum; return false; } @@ -512,7 +513,7 @@ void PersistentStateTree::dumpDotGraph(std::string file, const std::vector& int depth = 0; if (!nodeDepths.empty()) depth = nodeDepths[node]; - os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d", node, node, depth, structure.state(node).stateDesc.acousticModel); + os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d\\nt=%d", node, node, depth, structure.state(node).stateDesc.acousticModel, structure.state(node).stateDesc.transitionModelIndex); for (HMMStateNetwork::SuccessorIterator target = structure.successors(node); target; ++target) if (target.isLabel() && exits[target.label()].pronunciation != Bliss::LemmaPronunciation::invalidId) diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh index 2e8db8bdc..4607b39c1 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh @@ -31,14 +31,18 @@ struct MyStandardValueHash { } }; +class AbstractTreeBuilder; + namespace Search { class HMMStateNetwork; class StateTree; class PersistentStateTree { public: + using TreeBuilderFactory = std::function(Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool)>; + ///@param lexicon This must be given if the resulting exits are supposed to be functional - PersistentStateTree(Core::Configuration config, Core::Ref acousticModel, Bliss::LexiconRef lexicon); + PersistentStateTree(Core::Configuration config, Core::Ref acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory); ///Builds this state tree. void build(); @@ -128,6 +132,7 @@ private: Core::Ref acousticModel_; Bliss::LexiconRef lexicon_; Core::Configuration config_; + TreeBuilderFactory treeBuilderFactory_; //Writes the whole state network into the given stream void write(Core::MappedArchiveWriter writer); diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index 72c3b822e..d6ea06ad7 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -131,6 +131,18 @@ const Core::ParameterBool paramBuildMinimizedTreeFromScratch( "", true); +const Core::Choice choiceTreeBuilderType( + "classic-hmm", static_cast(StaticSearchAutomaton::TreeBuilderType::classicHmm), + "minimized-hmm", static_cast(StaticSearchAutomaton::TreeBuilderType::minimizedHmm), + "ctc", static_cast(StaticSearchAutomaton::TreeBuilderType::ctc), + Core::Choice::endMark()); + +const Core::ParameterChoice paramTreeBuilderType( + "tree-builder-type", + &choiceTreeBuilderType, + "which tree builder to use", + static_cast(StaticSearchAutomaton::TreeBuilderType::previousBehavior)); + const Core::ParameterBool paramConditionPredecessorWord( "condition-on-predecessor-word", "", @@ -374,10 +386,15 @@ StaticSearchAutomaton::StaticSearchAutomaton(Core::Configuration config, Core::R : Precursor(config), hmmLength(acousticModel->hmmTopologySet()->getDefault().nPhoneStates() * acousticModel->hmmTopologySet()->getDefault().nSubStates()), minimized(paramBuildMinimizedTreeFromScratch(config)), - network(config, acousticModel, lexicon), + network(config, acousticModel, lexicon, std::bind(&StaticSearchAutomaton::createTreeBuilder, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)), prefixFilter(nullptr), + treeBuilderType_(static_cast(paramTreeBuilderType(config))), acousticModel_(acousticModel), lexicon_(lexicon) { + + if (treeBuilderType_ == TreeBuilderType::previousBehavior) { + treeBuilderType_ = minimized ? TreeBuilderType::minimizedHmm : TreeBuilderType::classicHmm; + } } StaticSearchAutomaton::~StaticSearchAutomaton() { @@ -388,19 +405,19 @@ StaticSearchAutomaton::~StaticSearchAutomaton() { void StaticSearchAutomaton::buildNetwork() { /// @todo Track the TreeBuilder configuration in transformation if minimizedTree - int transformation = minimized ? 32 : 0; + int transformation = minimized ? 32 : 0; if (!network.read(transformation)) { log() << "persistent network image could not be loaded, building it"; - if (minimized) { // Use TreeStructure.hh - TreeBuilder builder(config, *lexicon_, *acousticModel_, network); - builder.build(); - } - else { // Use StateTree.hh + std::unique_ptr builder = createTreeBuilder(config, *lexicon_, *acousticModel_, network); + if (not builder) { network.build(); network.cleanup(); network.cleanup(); // Additional cleanup, to make sure that the exits are ordered correctly } + else { + builder->build(); + } if (network.write(transformation)) { log() << "writing network image ready"; @@ -752,6 +769,21 @@ void StaticSearchAutomaton::buildBatches() { network.removeOutputs(); } +std::unique_ptr StaticSearchAutomaton::createTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) { + switch (treeBuilderType_) { + case TreeBuilderType::classicHmm: { // Use StateTree.hh + return std::unique_ptr(nullptr); + } break; + case TreeBuilderType::minimizedHmm: { // Use TreeStructure.hh + return std::unique_ptr(new MinimizedTreeBuilder(config, *lexicon_, *acousticModel_, network, initialize)); + } break; + case TreeBuilderType::ctc: { + defect(); // TODO: add CTC implementation + } break; + default: defect(); + } +} + // ------------------------------- Search Space -------------------------------- SearchSpace::SearchSpace(const Core::Configuration& config, @@ -3078,7 +3110,7 @@ void SearchSpace::doStateStatistics() { int len = 0; if (h.isValid()) - len = backOffLm->historyLength(h); + len = backOffLm->historyLenght(h); if (mt.lookahead.get()) statesInTreesWithLookAhead += mt.states.size(); @@ -3722,7 +3754,7 @@ Instance* SearchSpace::getBackOffInstance(Instance* instance) { Lm::History useHistory = instance->lookaheadHistory; - int length = lm->historyLength(useHistory); + int length = lm->historyLenght(useHistory); if (length == 0) return 0; @@ -3730,7 +3762,7 @@ Instance* SearchSpace::getBackOffInstance(Instance* instance) { // Create a back-off network for history-length length-1 Lm::History reduced = lm->reducedHistory(useHistory, length - 1); - verify(lm->historyLength(reduced) == length - 1); + verify(lm->historyLenght(reduced) == length - 1); verify(reduced.isValid()); diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.hh b/src/Search/AdvancedTreeSearch/SearchSpace.hh index 2dcdb1c34..b22279485 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.hh +++ b/src/Search/AdvancedTreeSearch/SearchSpace.hh @@ -50,6 +50,13 @@ class StaticSearchAutomaton : public Core::Component { public: using Precursor = Core::Component; + enum class TreeBuilderType { + previousBehavior = 0, + classicHmm = 1, + minimizedHmm = 2, + ctc = 3, + }; + /// HMM length of a common phoneme const u32 hmmLength; bool minimized; @@ -68,7 +75,7 @@ public: std::vector singleLabels; // LM- and acoustic look-ahead ids together, for quicker access std::vector> lookAheadIds; - // Sparse LM lookahead hash and acoustic lookahead id paired togeter + // Sparse LM lookahead hash and acoustic lookahead id paired together std::vector> lookAheadIdAndHash; std::vector stateDepths; @@ -79,7 +86,7 @@ public: std::vector truncatedInvertedStateDepths; std::vector truncatedStateDepths; - // number of transitions needed untill the next word end (that is not silence) + // number of transitions needed until the next word end (that is not silence) std::vector labelDistance; /// Optional filter which allows limiting the search space to a certain word sequence prefix @@ -115,6 +122,11 @@ public: // Creates fast look-up structures like singleOutputs_, quickOutputBatches_ and secondOrderEdgeTargetBatches_. void buildBatches(); +protected: + TreeBuilderType treeBuilderType_; + + std::unique_ptr createTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + private: Core::Ref acousticModel_; Bliss::LexiconRef lexicon_; @@ -299,7 +311,7 @@ public: // Creates early word end hypotheses from the active state hypotheses void findWordEnds(); - // Prunes early word end hypotheses, and expands them to normal word end hypothses + // Prunes early word end hypotheses, and expands them to normal word end hypotheses void pruneEarlyWordEnds(); // Applies time-, score- and transit-modification to the given trace-id, and returns the corrected trace item (as successor of the original trace item) @@ -332,7 +344,7 @@ public: // deletes all traces that did not survive in stateHypotheses and rootStateHypotheses of activeTrees void cleanup(); - // Optimize the lattice, removing redundant silence occurances + // Optimize the lattice, removing redundant silence occurrences void optimizeSilenceInWordLattice(const Bliss::Lemma* silence); Core::Ref getSentenceEnd(TimeframeIndex time, bool shallCreateLattice); diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.cc b/src/Search/AdvancedTreeSearch/TreeBuilder.cc index b2b5a75c5..c42e53a09 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.cc +++ b/src/Search/AdvancedTreeSearch/TreeBuilder.cc @@ -20,69 +20,75 @@ #include #include "PersistentStateTree.hh" +using namespace Search; + +// -------------------- AbstractTreeBuilder -------------------- + +AbstractTreeBuilder::AbstractTreeBuilder(Core::Configuration config, + const Bliss::Lexicon& lexicon, + const Am::AcousticModel& acousticModel, + Search::PersistentStateTree& network) + : Core::Component(config), + lexicon_(lexicon), + acousticModel_(acousticModel), + network_(network) { +} + +// -------------------- MinimizedTreeBuilder -------------------- + // TODO: Verify that pushed word-ends have the same transition penalty as the corresponding unpushed word-ends -const Core::ParameterBool paramAddCiTransitions( +const Core::ParameterInt MinimizedTreeBuilder::paramMinPhones( + "min-phones", + "minimum number of phones which are expanded without pushing the word ends", + 1); + +const Core::ParameterBool MinimizedTreeBuilder::paramAddCiTransitions( "add-ci-transitions", "whether context-independent acoustic transitions should be inserted between words. Useful for non-fluid speech, specifically when the training data consistent of fluid speech", false); // if this is false, then an additional special-root is used, which is followed only by non-words. it is labeled #|[sil] (where [sil] is the first special-phone) -const Core::ParameterBool paramUseRootForCiExits( +const Core::ParameterBool MinimizedTreeBuilder::paramUseRootForCiExits( "use-root-for-ci-exits", "whether the root-node should be used as target for exits behind context-independent phones", true); -const Core::ParameterInt paramMinPhones( - "min-phones", - "minimum number of phones which are expanded without pushing the word ends", - 1); - -const Core::ParameterInt paramMinimizeIterations( - "minimization-iterations", - "usually only the first 2 iterations show an effect", - 2); - -const Core::ParameterBool paramForceExactWordEnds( +const Core::ParameterBool MinimizedTreeBuilder::paramForceExactWordEnds( "force-exact-word-ends", "", false); -const Core::ParameterBool paramKeepRoots( +const Core::ParameterBool MinimizedTreeBuilder::paramKeepRoots( "keep-roots", "keep roots as they were after initial building (i.e. don't minimize them). might become useful to insert new words on-the-fly in the future, or to have correct boundary-information right after decoding.", false); -const Core::ParameterBool paramAllowCrossWordSkips( +const Core::ParameterBool MinimizedTreeBuilder::paramAllowCrossWordSkips( "allow-cross-word-skips", "add additional word labels to allow skips over word boundaries; equal skip penalties for all states are recommended", false); -const Core::ParameterBool paramRepeatSilence( +const Core::ParameterBool MinimizedTreeBuilder::paramRepeatSilence( "repeat-silence", "repeat silence. this makes cross-word skipping consistent in forward/backward case, given that all forward/skip penalties are the same", false); -using namespace Search; - -typedef std::set PhonemeList; +const Core::ParameterInt MinimizedTreeBuilder::paramMinimizeIterations( + "minimization-iterations", + "usually only the first 2 iterations show an effect", + 2); -TreeBuilder::TreeBuilder(Core::Configuration config, - const Bliss::Lexicon& lexicon, - const Am::AcousticModel& acousticModel, - Search::PersistentStateTree& network, - bool initialize, - bool arcBased) - : lexicon_(lexicon), - acousticModel_(acousticModel), - network_(network), - config_(config), +MinimizedTreeBuilder::MinimizedTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) + : AbstractTreeBuilder(config, lexicon, acousticModel, network), minPhones_(paramMinPhones(config)), + addCiTransitions_(paramAddCiTransitions(config)), + useRootForCiExits_(paramUseRootForCiExits(config)), forceExactWordEnds_(paramForceExactWordEnds(config)), keepRoots_(paramKeepRoots(config)), allowCrossWordSkips_(paramAllowCrossWordSkips(config)), repeatSilence_(paramRepeatSilence(config)), - reverse_(isBackwardRecognition(config)), - arcBased_(arcBased) { + minimizeIterations_(paramMinimizeIterations(config)), + reverse_(isBackwardRecognition(config)) { if (allowCrossWordSkips_) { Score skipPenalty = acousticModel_.stateTransition(0)->operator[](Am::StateTransitionModel::skip); Score forwardPenalty = acousticModel_.stateTransition(0)->operator[](Am::StateTransitionModel::forward); @@ -100,10 +106,12 @@ TreeBuilder::TreeBuilder(Core::Configuration config, } } - if (reverse_) + if (reverse_) { log() << "building backward network"; - else + } + else { log() << "building forward network"; + } if (initialize) { verify(!network_.rootState); @@ -114,125 +122,68 @@ TreeBuilder::TreeBuilder(Core::Configuration config, } } -TreeBuilder::HMMSequence TreeBuilder::arcSequence(u32 acousticModelIndex) const { - verify(acousticModelIndex < arcSequences_.size()); - return arcSequences_[acousticModelIndex]; +std::unique_ptr MinimizedTreeBuilder::newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) { + return std::unique_ptr(new MinimizedTreeBuilder(config, lexicon, acousticModel, network, initialize)); } -std::string TreeBuilder::arcDesc(u32 acousticModelIndex) const { - verify(acousticModelIndex < arcSequences_.size()); - ArcDesc desc = arcDescs_[acousticModelIndex]; - std::ostringstream os; - if (desc.central == Core::Type::max) { - os << "*"; - } - else { - if (isContextDependent(desc.central)) { - if (desc.left == Core::Type::max) - os << "*"; - else if (desc.left == Bliss::Phoneme::term || !isContextDependent(desc.left)) - os << "#"; - else - os << acousticModel_.phonemeInventory()->phoneme(desc.left)->symbol(); - os << "/"; - } - os << acousticModel_.phonemeInventory()->phoneme(desc.central)->symbol(); - if (isContextDependent(desc.central)) { - os << "/"; - if (desc.right == Core::Type::max) - os << "*"; - else if (desc.right == Bliss::Phoneme::term || !isContextDependent(desc.right)) - os << "#"; - else - os << acousticModel_.phonemeInventory()->phoneme(desc.right)->symbol(); - } - } - return os.str(); -} +void MinimizedTreeBuilder::build() { + buildBody(); -void TreeBuilder::hmmFromAllophone(TreeBuilder::HMMSequence& ret, - Bliss::Phoneme::Id left, - Bliss::Phoneme::Id central, - Bliss::Phoneme::Id right, - u32 boundary, - bool allowNonStandard) { - verify(ret.length == 0); - verify(central != Bliss::Phoneme::term); - verify(acousticModel_.phonemeInventory()->isValidPhonemeId(central)); - Bliss::ContextPhonology::SemiContext history, future; + buildFanInOutStructure(); - if (reverse_) { - std::swap(left, right); - if (boundary == Am::Allophone::isFinalPhone) - boundary = Am::Allophone::isInitialPhone; - else if (boundary == Am::Allophone::isInitialPhone) - boundary = Am::Allophone::isFinalPhone; - } + skipRootTransitions(); - if (isContextDependent(central)) { - if (acousticModel_.phonemeInventory()->isValidPhonemeId(left) && isContextDependent(left)) - history.append(1, left); - if (acousticModel_.phonemeInventory()->isValidPhonemeId(right) && isContextDependent(right)) - future.append(1, right); + for (u32 i = 0; i < minimizeIterations_; ++i) { + minimize(); } - const Am::Allophone* allophone = acousticModel_.allophoneAlphabet()->allophone(Am::Allophone(Bliss::ContextPhonology::PhonemeInContext(central, history, future), boundary)); - - const Am::ClassicHmmTopology* hmmTopology = acousticModel_.hmmTopology(central); - - for (u32 phoneState = 0; phoneState < hmmTopology->nPhoneStates(); ++phoneState) { - Am::AllophoneState alloState = acousticModel_.allophoneStateAlphabet()->allophoneState(allophone, phoneState); - StateTree::StateDesc desc; - desc.acousticModel = acousticModel_.emissionIndex(alloState); // Decision tree look-up for CART id. + if (allowCrossWordSkips_) { + addCrossWordSkips(); + } - for (u32 subState = 0; subState < hmmTopology->nSubStates(); ++subState) { - desc.transitionModelIndex = acousticModel_.stateTransitionIndex(alloState, subState); - verify(desc.transitionModelIndex < Core::Type::max); + log() << "building ready"; +} - verify(ret.length < HMMSequence::MaxLength); // So far hard-wired to MaxLength = 12. - ret.hmm[ret.length] = desc; - ++ret.length; +void MinimizedTreeBuilder::printStats(std::string occasion) { + log() << "stats " << occasion << ":"; + log() << "states: " << network_.structure.stateCount() << " exits: " << network_.exits.size(); + log() << "coarticulated roots: " << network_.coarticulatedRootStates.size() << " unpushed: " << network_.unpushedCoarticulatedRootStates.size(); + u32 roots = 0; + for (std::set::iterator it = network_.uncoarticulatedWordEndStates.begin(); it != network_.uncoarticulatedWordEndStates.end(); ++it) { + if (network_.coarticulatedRootStates.count(*it)) { + ++roots; } } + log() << "number of uncoarticulated pushed word-end nodes: " << network_.uncoarticulatedWordEndStates.size() << " out of those are roots: " << roots; +} - if (arcBased_) { - HMMSequence newRet; - newRet.length = 1; - newRet.hmm[0].transitionModelIndex = boundary; - - ArcSequenceHash::const_iterator it = arcSequencesHash_.find(ret); - if (it != arcSequencesHash_.end()) { - newRet.hmm[0].acousticModel = it->second; - if (arcDescs_[it->second].central != central) - arcDescs_[it->second].central = Core::Type::max; - if (arcDescs_[it->second].left != left) - arcDescs_[it->second].left = Core::Type::max; - if (arcDescs_[it->second].right != right) - arcDescs_[it->second].right = Core::Type::max; - } - else { - newRet.hmm[0].acousticModel = arcSequences_.size(); - verify(newRet.hmm[0].acousticModel == arcSequences_.size()); - arcSequencesHash_[ret] = arcSequences_.size(); - arcSequences_.push_back(ret); - ArcDesc desc; - desc.left = left; - desc.central = central; - desc.right = right; - arcDescs_.push_back(desc); - } - ret = newRet; +std::string MinimizedTreeBuilder::describe(std::pair desc) { + std::ostringstream os; + + if (desc.first == Bliss::Phoneme::term) { + os << "#"; + } + else { + os << lexicon_.phonemeInventory()->phoneme(desc.first)->symbol(); } - if (reverse_) - ret.reverse(); - if (repeatSilence_ && ret.length == 1 && central == acousticModel_.silence()) { - ret.hmm[1] = ret.hmm[0]; - ret.length = 2; + os << "<->"; + + if (desc.second == Bliss::Phoneme::term) { + os << "#"; + } + else { + os << lexicon_.phonemeInventory()->phoneme(desc.second)->symbol(); } + + return os.str(); +} + +bool MinimizedTreeBuilder::isContextDependent(Bliss::Phoneme::Id phone) const { + return acousticModel_.phonemeInventory()->phoneme(phone)->isContextDependent(); } -void TreeBuilder::build() { +void MinimizedTreeBuilder::buildBody() { std::pair prons = lexicon_.pronunciations(); u32 coarticulatedInitial = 0, uncoarticulatedInitial = 0, coarticulatedFinal = 0, uncoarticulatedFinal = 0; @@ -242,23 +193,25 @@ void TreeBuilder::build() { const Bliss::Pronunciation& pron(**pronIt); if (pron.length()) { Bliss::Phoneme::Id initial = pron[0], fin = pron[pron.length() - 1]; - if (reverse_) + if (reverse_) { std::swap(initial, fin); + } if (!initialPhonemes_.count(initial)) { initialPhonemes_.insert(initial); - if (isContextDependent(initial)) + if (isContextDependent(initial)) { coarticulatedInitial += 1; - else + } else { uncoarticulatedInitial += 1; + } } if (!finalPhonemes_.count(fin)) { - if (isContextDependent(fin)) + if (isContextDependent(fin)) { coarticulatedFinal += 1; - else + } else { uncoarticulatedFinal += 1; - + } finalPhonemes_.insert(fin); } } @@ -267,35 +220,40 @@ void TreeBuilder::build() { } } - if ((uncoarticulatedFinal == 0 || uncoarticulatedInitial == 0) && !paramAddCiTransitions(config_)) + if ((uncoarticulatedFinal == 0 || uncoarticulatedInitial == 0) && !addCiTransitions_) { Core::Application::us()->error() << "There are no context-independent initial or final phonemes in the lexicon, word-end detection will not work properly. Consider adding context-independent phonemes, or setting add-ci-transitions=true"; + } log() << "coarticulated initial phones: " << coarticulatedInitial << " uncoarticulated: " << uncoarticulatedInitial << ", coarticulated final phones: " << uncoarticulatedFinal << " uncoarticulated: " << uncoarticulatedFinal; - bool useRootForCiExits = paramUseRootForCiExits(config_) && !paramAddCiTransitions(config_); + bool useRootForCiExits = useRootForCiExits_ && !addCiTransitions_; // Build the network-like non-coarticulated portion starting at the context-independent root log() << "building"; for (Bliss::Lexicon::PronunciationIterator pronIt = prons.first; pronIt != prons.second; ++pronIt) { const Bliss::Pronunciation& pron(**pronIt); u32 pronLength = pron.length(); - if (pronLength == 0) + if (pronLength == 0) { continue; + } std::pair currentState(0, network_.rootState); std::vector phones; - for (u32 phoneIndex = 0; phoneIndex < pronLength; ++phoneIndex) + for (u32 phoneIndex = 0; phoneIndex < pronLength; ++phoneIndex) { phones.push_back(pron[phoneIndex]); + } - if (reverse_) + if (reverse_) { std::reverse(phones.begin(), phones.end()); + } - for (u32 phoneIndex = 0; phoneIndex < pronLength - 1; ++phoneIndex) + for (u32 phoneIndex = 0; phoneIndex < pronLength - 1; ++phoneIndex) { currentState = extendPhone(currentState.second, phoneIndex, phones); + } std::pair lemmaProns = pron.lemmas(); @@ -305,20 +263,21 @@ void TreeBuilder::build() { std::pair tail = extendPhone(currentState.second, pronLength - 1, phones, Bliss::Phoneme::term, *initialIt); for (Bliss::Pronunciation::LemmaIterator lemmaPron = lemmaProns.first; lemmaPron != lemmaProns.second; ++lemmaPron) { u32 exit; - if (!isContextDependent(phones[pronLength - 1]) && useRootForCiExits) - exit = addExit(tail.first, tail.second, Bliss::Phoneme::term, Bliss::Phoneme::term, 0, lemmaPron->id()); // Use the non-coarticulated root node - else - exit = addExit(tail.first, tail.second, phones[pronLength - 1], *initialIt, 0, lemmaPron->id()); - if (pronLength == 1) + if (!isContextDependent(phones[pronLength - 1]) && useRootForCiExits) { + exit = addExit(tail.second, Bliss::Phoneme::term, Bliss::Phoneme::term, 0, lemmaPron->id()); // Use the non-coarticulated root node + } else { + exit = addExit(tail.second, phones[pronLength - 1], *initialIt, 0, lemmaPron->id()); + } + if (pronLength == 1) { initialFinalPhoneSuffix_[RootKey(phones[0], *initialIt, 1)].insert(ID_FROM_LABEL(exit)); + } } } - } - else { + } else { // Minimize the remaining phoneme, insert corresponding word-ends. for (Bliss::Pronunciation::LemmaIterator lemmaPron = lemmaProns.first; lemmaPron != lemmaProns.second; ++lemmaPron) { if (pronLength == 1) { - addExit(currentState.first, currentState.second, Bliss::Phoneme::term, phones[0], -1, lemmaPron->id()); + addExit(currentState.second, Bliss::Phoneme::term, phones[0], -1, lemmaPron->id()); for (std::set::const_iterator finalIt = finalPhonemes_.begin(); finalIt != finalPhonemes_.end(); ++finalIt) { Search::PersistentStateTree::Exit exit; @@ -326,33 +285,115 @@ void TreeBuilder::build() { exit.pronunciation = lemmaPron->id(); addSuccessor(createRoot(*finalIt, phones[0], 0), ID_FROM_LABEL(createExit(exit))); } - } - else { - u32 exit = addExit(currentState.first, currentState.second, phones[pronLength - 2], phones[pronLength - 1], -1, lemmaPron->id()); - if (pronLength == 2) + } else { + u32 exit = addExit(currentState.second, phones[pronLength - 2], phones[pronLength - 1], -1, lemmaPron->id()); + if (pronLength == 2) { initialPhoneSuffix_[RootKey(phones[0], phones[1], 1)].insert(ID_FROM_LABEL(exit)); + } } } } } log() << "states: " << network_.structure.stateCount() << " exits: " << network_.exits.size() << " roots: " << roots_.size(); +} - buildFanInOutStructure(); +void MinimizedTreeBuilder::buildFanInOutStructure() { + // Create temporary coarticulated roots + for (std::set::const_iterator finalIt = finalPhonemes_.begin(); finalIt != finalPhonemes_.end(); ++finalIt) { + for (std::set::const_iterator initialIt = initialPhonemes_.begin(); initialIt != initialPhonemes_.end(); ++initialIt) { + createRoot(*finalIt, *initialIt, 0); + } + } - skipRootTransitions(); + log() << "building fan-in"; + // Build the fan-in structure (e.g. the HMM structure representing the initial word phonemes, behind roots, up to the joints) + for (RootHash::const_iterator rootIt = roots_.begin(); rootIt != roots_.end(); ++rootIt) { + if (rootIt->first.depth != 0 || rootIt->second == network_.rootState) { + continue; + } + Bliss::Phoneme::Id initial = rootIt->first.right; + verify(initialPhonemes_.count(initial)); + verify(initial != Bliss::Phoneme::term); + u32 paths = 0; - u32 it = paramMinimizeIterations(config_); - for (u32 i = 0; i < it; ++i) - minimize(); + for (CoarticulationJointHash::const_iterator initialSuffixIt = initialPhoneSuffix_.begin(); initialSuffixIt != initialPhoneSuffix_.end(); ++initialSuffixIt) { + if (initialSuffixIt->first.left != initial) { + continue; + } + ++paths; + HMMSequence hmm; + hmmFromAllophone(hmm, rootIt->first.left, initial, initialSuffixIt->first.right, Am::Allophone::isInitialPhone); + verify(hmm.length > 0); + StateId currentNode = extendFanIn(initialSuffixIt->second, hmm[hmm.length - 1]); + for (s32 s = hmm.length - 2; s >= 0; --s) { + currentNode = extendFanIn(currentNode, hmm[s]); + } - if (allowCrossWordSkips_) - addCrossWordSkips(); + addSuccessor(rootIt->second, currentNode); + } - log() << "building ready"; + for (CoarticulationJointHash::const_iterator initialSuffixIt = initialFinalPhoneSuffix_.begin(); initialSuffixIt != initialFinalPhoneSuffix_.end(); ++initialSuffixIt) { + if (initialSuffixIt->first.left != initial) { + continue; + } + ++paths; + HMMSequence hmm; + hmmFromAllophone(hmm, rootIt->first.left, initial, initialSuffixIt->first.right, Am::Allophone::isInitialPhone | Am::Allophone::isFinalPhone); + verify(hmm.length > 0); + StateId currentNode = extendFanIn(initialSuffixIt->second, hmm[hmm.length - 1]); + for (s32 s = hmm.length - 2; s >= 0; --s) { + currentNode = extendFanIn(currentNode, hmm[s]); + } + + addSuccessor(rootIt->second, currentNode); + } + } + + log() << "states: " << network_.structure.stateCount() << " exits: " << network_.exits.size() << " roots: " << roots_.size(); + log() << "building fan-out"; + + // Build the fan-out structure (e.g. the HMM structure representing the final word phonemes, behind special roots) + // On the left side delimited by the roots of depth -1, and on the right side by the roots of depth 0 + for (RootHash::const_iterator leftRootIt = roots_.begin(); leftRootIt != roots_.end(); ++leftRootIt) { + if (leftRootIt->first.depth != -1) { + continue; + } + Bliss::Phoneme::Id fin = leftRootIt->first.right; + verify(finalPhonemes_.count(fin)); + + u32 paths = 0; + for (RootHash::const_iterator rightRootIt = roots_.begin(); rightRootIt != roots_.end(); ++rightRootIt) { + if (rightRootIt->first.depth != 0 || (rightRootIt->first.left != fin && (!addCiTransitions_ || rightRootIt->first.left != Bliss::Phoneme::term))) { + continue; + } + ++paths; + HMMSequence hmm; + hmmFromAllophone(hmm, leftRootIt->first.left, fin, rightRootIt->first.right, Am::Allophone::isFinalPhone); + verify(hmm.length > 0); + + // The last state in the pushed fan-in is equivalent with the corresponding root state + + StateId lastNode = extendFanIn(network_.structure.targetSet(rightRootIt->second), hmm[hmm.length - 1]); + + StateId currentNode = lastNode; + for (s32 s = hmm.length - 2; s >= 0; --s) { + currentNode = extendFanIn(currentNode, hmm[s]); + } + + if (rightRootIt->first.right == Bliss::Phoneme::term || !isContextDependent(rightRootIt->first.right)) { + network_.uncoarticulatedWordEndStates.insert(lastNode); + } + + addSuccessor(leftRootIt->second, currentNode); + } + verify(paths); + } + + printStats("after fan-in/out structure"); } -void TreeBuilder::addCrossWordSkips() { +void MinimizedTreeBuilder::addCrossWordSkips() { log() << "adding cross-word skips"; u32 oldNodes = network_.structure.stateCount(); @@ -361,10 +402,12 @@ void TreeBuilder::addCrossWordSkips() { bool hasWordEnd = false; bool hasSuccessor = false; for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(node); target; ++target) { - if (!target.isLabel()) + if (!target.isLabel()) { hasSuccessor = true; - if (target.isLabel() && network_.structure.state(network_.exits[target.label()].transitState).stateDesc.transitionModelIndex != Am::TransitionModel::entryM2) + } + if (target.isLabel() && network_.structure.state(network_.exits[target.label()].transitState).stateDesc.transitionModelIndex != Am::TransitionModel::entryM2) { hasWordEnd = true; + } } verify(hasSuccessor || hasWordEnd); } @@ -372,19 +415,23 @@ void TreeBuilder::addCrossWordSkips() { std::set skipRoots; for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(node); target; ++target) { - if (target.isLabel()) + if (target.isLabel()) { continue; - for (HMMStateNetwork::SuccessorIterator target2 = network_.structure.successors(*target); target2; ++target2) - if (target2.isLabel()) + } + for (HMMStateNetwork::SuccessorIterator target2 = network_.structure.successors(*target); target2; ++target2) { + if (target2.isLabel()) { skipRoots.insert(network_.exits[target2.label()]); + } + } } if (skipRoots.size()) { for (std::set::iterator it = skipRoots.begin(); it != skipRoots.end(); ++it) { PersistentStateTree::Exit e(*it); verify(e.pronunciation != Bliss::LemmaPronunciation::invalidId); - if (network_.structure.state(e.transitState).stateDesc.transitionModelIndex == Am::TransitionModel::entryM2) + if (network_.structure.state(e.transitState).stateDesc.transitionModelIndex == Am::TransitionModel::entryM2) { continue; + } e.transitState = createSkipRoot(e.transitState); network_.structure.addOutputToNode(node, createExit(e)); } @@ -394,10 +441,12 @@ void TreeBuilder::addCrossWordSkips() { bool hasWordEnd = false; bool hasSuccessor = false; for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(node); target; ++target) { - if (!target.isLabel()) + if (!target.isLabel()) { hasSuccessor = true; - if (target.isLabel() && network_.structure.state(network_.exits[target.label()].transitState).stateDesc.transitionModelIndex != Am::TransitionModel::entryM2) + } + if (target.isLabel() && network_.structure.state(network_.exits[target.label()].transitState).stateDesc.transitionModelIndex != Am::TransitionModel::entryM2) { hasWordEnd = true; + } } verify(hasSuccessor || hasWordEnd); } @@ -407,10 +456,12 @@ void TreeBuilder::addCrossWordSkips() { bool hasWordEnd = false; bool hasSuccessor = false; for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(node); target; ++target) { - if (!target.isLabel()) + if (!target.isLabel()) { hasSuccessor = true; - if (target.isLabel() && network_.structure.state(network_.exits[target.label()].transitState).stateDesc.transitionModelIndex != Am::TransitionModel::entryM2) + } + if (target.isLabel() && network_.structure.state(network_.exits[target.label()].transitState).stateDesc.transitionModelIndex != Am::TransitionModel::entryM2) { hasWordEnd = true; + } } verify(hasSuccessor || hasWordEnd); } @@ -420,55 +471,314 @@ void TreeBuilder::addCrossWordSkips() { network_.cleanup(); } -void TreeBuilder::skipRootTransitions() { - for (StateId node = 1; node < network_.structure.stateCount(); ++node) { - if (network_.structure.state(node).stateDesc.acousticModel == Search::StateTree::invalidAcousticModel) +void MinimizedTreeBuilder::skipRootTransitions(StateId start) { + for (StateId node = start; node < network_.structure.stateCount(); ++node) { + if (network_.structure.state(node).stateDesc.acousticModel == Search::StateTree::invalidAcousticModel) { continue; + } HMMStateNetwork::ChangePlan change = network_.structure.change(node); for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(node); target; ++target) { - if (target.isLabel()) + if (target.isLabel()) { continue; + } if (network_.structure.state(*target).stateDesc.acousticModel == Search::StateTree::invalidAcousticModel) { change.removeSuccessor(*target); - for (HMMStateNetwork::SuccessorIterator target2 = network_.structure.successors(*target); target2; ++target2) + for (HMMStateNetwork::SuccessorIterator target2 = network_.structure.successors(*target); target2; ++target2) { change.addSuccessor(*target2); + } } } change.apply(); } } -std::vector TreeBuilder::minimize(bool forceDeterminization, bool onlyMinimizeBackwards, bool allowLost) { - log() << "minimizing"; +StateTree::StateDesc MinimizedTreeBuilder::rootDesc() const { + StateTree::StateDesc desc; + desc.acousticModel = Search::StateTree::invalidAcousticModel; + desc.transitionModelIndex = Am::TransitionModel::entryM1; + return desc; +} - if (forceExactWordEnds_) - log() << "forcing exact word-ends"; +AbstractTreeBuilder::StateId MinimizedTreeBuilder::createSkipRoot(StateId baseRoot) { + SkipRootsHash::const_iterator it = skipRoots_.find(baseRoot); + if (it != skipRoots_.end()) { + return it->second; + } + StateTree::StateDesc desc = rootDesc(); + desc.transitionModelIndex = Am::TransitionModel::entryM2; + StateId ret = createState(desc); - for (std::set::iterator it = network_.unpushedCoarticulatedRootStates.begin(); it != network_.unpushedCoarticulatedRootStates.end(); ++it) - verify(network_.coarticulatedRootStates.count(*it)); + skipRoots_.insert(std::make_pair(baseRoot, ret)); + network_.structure.addTargetToNode(ret, baseRoot); + skipRootSet_.insert(ret); + network_.coarticulatedRootStates.insert(ret); + verify(network_.rootTransitDescriptions.count(baseRoot)); + network_.rootTransitDescriptions.insert(std::make_pair(ret, network_.rootTransitDescriptions[baseRoot])); + return ret; +} - std::set usedRoots; - std::deque active; +AbstractTreeBuilder::StateId MinimizedTreeBuilder::createRoot(Bliss::Phoneme::Id left, Bliss::Phoneme::Id right, int depth) { + RootKey key(left, right, depth); + RootHash::const_iterator it = roots_.find(key); + if (it != roots_.end()) { + return it->second; + } - std::vector fanIn(network_.structure.stateCount(), 0); + // record newly inserted RootStates to reset network_ later + StateId ret = createState(rootDesc()); - // Collect all zero-depth roots to skip them during clean-up - std::set usefulRoots; - for (RootHash::const_iterator rootIt = roots_.begin(); rootIt != roots_.end(); ++rootIt) { - if (rootIt->first.depth == 0) { - usefulRoots.insert(rootIt->second); - } + if (depth == 0 && (left != Bliss::Phoneme::term || right != Bliss::Phoneme::term)) { + network_.unpushedCoarticulatedRootStates.insert(ret); } - for (StateId node = 1; node < network_.structure.stateCount(); ++node) { - active.push_back(node); - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(node); target; ++target) { - if (target.isLabel()) { - usedRoots.insert(network_.exits[target.label()].transitState); - fanIn[network_.exits[target.label()].transitState] += 1; - } + if (right == Bliss::Phoneme::term || !acousticModel_.phonemeInventory()->phoneme(right)->isContextDependent()) { + network_.uncoarticulatedWordEndStates.insert(ret); + } + + if (left != Bliss::Phoneme::term || right != Bliss::Phoneme::term) { + network_.coarticulatedRootStates.insert(ret); + } + + roots_.insert(std::make_pair(key, ret)); + + network_.rootTransitDescriptions.insert(std::make_pair(ret, std::make_pair(left, right))); + + return ret; +} + +AbstractTreeBuilder::StateId MinimizedTreeBuilder::createState(StateTree::StateDesc desc) { + StateId ret = network_.structure.allocateTreeNode(network_.masterTree); + network_.structure.state(ret).stateDesc = desc; + return ret; +} + +u32 MinimizedTreeBuilder::createExit(PersistentStateTree::Exit exit) { + ExitHash::iterator exitHashIt = exitHash_.find(exit); + if (exitHashIt != exitHash_.end()) { + return exitHashIt->second; + } else { + // Exit does not exist yet, add it + network_.exits.push_back(exit); + u32 exitIndex = network_.exits.size() - 1; + exitHash_.insert(std::make_pair(exit, exitIndex)); + return exitIndex; + } +} + +u32 MinimizedTreeBuilder::addExit(StateId predecessor, + Bliss::Phoneme::Id leftPhoneme, + Bliss::Phoneme::Id rightPhoneme, + int depth, + Bliss::LemmaPronunciation::Id pron) { + PersistentStateTree::Exit exit; + exit.transitState = createRoot(leftPhoneme, rightPhoneme, depth); + exit.pronunciation = pron; + + u32 exitIndex = createExit(exit); + + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { + if (target.isLabel() && target.label() == exitIndex) { + return exitIndex; + } + } + + network_.structure.addOutputToNode(predecessor, ID_FROM_LABEL(exitIndex)); + return exitIndex; +} + +void MinimizedTreeBuilder::hmmFromAllophone(HMMSequence& ret, + Bliss::Phoneme::Id left, + Bliss::Phoneme::Id central, + Bliss::Phoneme::Id right, + u32 boundary) { + verify(ret.length == 0); + verify(central != Bliss::Phoneme::term); + verify(acousticModel_.phonemeInventory()->isValidPhonemeId(central)); + Bliss::ContextPhonology::SemiContext history, future; + + if (reverse_) { + std::swap(left, right); + if (boundary == Am::Allophone::isFinalPhone) { + boundary = Am::Allophone::isInitialPhone; + } else if (boundary == Am::Allophone::isInitialPhone) { + boundary = Am::Allophone::isFinalPhone; + } + } + + if (isContextDependent(central)) { + if (acousticModel_.phonemeInventory()->isValidPhonemeId(left) && isContextDependent(left)) { + history.append(1, left); + } + if (acousticModel_.phonemeInventory()->isValidPhonemeId(right) && isContextDependent(right)) { + future.append(1, right); + } + } + + const Am::Allophone* allophone = acousticModel_.allophoneAlphabet()->allophone(Am::Allophone(Bliss::ContextPhonology::PhonemeInContext(central, history, future), boundary)); + + const Am::ClassicHmmTopology* hmmTopology = acousticModel_.hmmTopology(central); + + for (u32 phoneState = 0; phoneState < hmmTopology->nPhoneStates(); ++phoneState) { + Am::AllophoneState alloState = acousticModel_.allophoneStateAlphabet()->allophoneState(allophone, phoneState); + StateTree::StateDesc desc; + desc.acousticModel = acousticModel_.emissionIndex(alloState); // Decision tree look-up for CART id. + + for (u32 subState = 0; subState < hmmTopology->nSubStates(); ++subState) { + desc.transitionModelIndex = acousticModel_.stateTransitionIndex(alloState, subState); + verify(desc.transitionModelIndex < Core::Type::max); + + verify(ret.length < HMMSequence::MaxLength); // So far hard-wired to MaxLength = 12. + ret.hmm[ret.length] = desc; + ++ret.length; + } + } + + if (reverse_) { + ret.reverse(); + } + + if (repeatSilence_ && ret.length == 1 && central == acousticModel_.silence()) { + ret.hmm[1] = ret.hmm[0]; + ret.length = 2; + } +} + +bool MinimizedTreeBuilder::addSuccessor(StateId predecessor, StateId successor) { + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { + if (*target == successor) { + return false; + } + } + + network_.structure.addTargetToNode(predecessor, successor); + return true; +} + +std::pair MinimizedTreeBuilder::extendPhone(StateId currentState, + u32 phoneIndex, + const std::vector& phones, + Bliss::Phoneme::Id left, + Bliss::Phoneme::Id right) { + u8 boundary = 0; + if (phoneIndex != 0) { + left = phones[phoneIndex - 1]; + } else { + boundary |= Am::Allophone::isInitialPhone; + } + + if (phoneIndex != phones.size() - 1) { + right = phones[phoneIndex + 1]; + } else { + boundary |= Am::Allophone::isFinalPhone; + } + + HMMSequence hmm; + hmmFromAllophone(hmm, left, phones[phoneIndex], right, boundary); + + verify(hmm.length >= 1); + + u32 hmmState = 0; + StateId previousState = 0; + + if (phoneIndex == 1 && hmmState == 0) { + currentState = extendBodyState(currentState, left, phones[phoneIndex], hmm[hmmState++]); + } + + for (; hmmState < hmm.length; ++hmmState) { + previousState = currentState; + currentState = extendState(currentState, hmm[hmmState]); + } + + return std::make_pair(previousState, currentState); +} + +AbstractTreeBuilder::StateId MinimizedTreeBuilder::extendState(StateId predecessor, StateTree::StateDesc desc, MinimizedTreeBuilder::RootKey uniqueKey) { + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { + if (!target.isLabel() && network_.structure.state(*target).stateDesc == desc) { + if (uniqueKey.isValid()) { + Core::HashMap::const_iterator it = stateUniqueKeys_.find(*target); + verify(it != stateUniqueKeys_.end()); + if (!(it->second == uniqueKey)) { + continue; + } + } + return *target; + } + } + + // No matching successor found, extend + StateId ret = createState(desc); + if (uniqueKey.isValid()) { + stateUniqueKeys_.insert(std::make_pair(ret, uniqueKey)); + } + network_.structure.addTargetToNode(predecessor, ret); + return ret; +} + +AbstractTreeBuilder::StateId MinimizedTreeBuilder::extendBodyState(StateId state, + Bliss::Phoneme::Id first, + Bliss::Phoneme::Id second, + Search::StateTree::StateDesc desc) { + RootKey key(first, second, 1); + StateId ret = extendState(state, desc, key); + initialPhoneSuffix_[key].insert(ret); + + return ret; +} + +AbstractTreeBuilder::StateId MinimizedTreeBuilder::extendFanIn(StateId successorOrExit, StateTree::StateDesc desc) { + std::set successors; + successors.insert(successorOrExit); + return extendFanIn(successors, desc); +} + +AbstractTreeBuilder::StateId MinimizedTreeBuilder::extendFanIn(const std::set& successorsOrExits, Search::StateTree::StateDesc desc) { + StatePredecessor pred(successorsOrExits, desc); + PredecessorsHash::iterator it = predecessors_.find(pred); + if (it != predecessors_.end()) { + return it->second; + } + StateId ret = createState(desc); + for (std::set::const_iterator it = successorsOrExits.begin(); it != successorsOrExits.end(); ++it) { + network_.structure.addTargetToNode(ret, *it); + } + predecessors_.insert(std::make_pair(pred, ret)); + return ret; +} + +std::vector MinimizedTreeBuilder::minimize(bool forceDeterminization, bool onlyMinimizeBackwards, bool allowLost) { + log() << "minimizing"; + + if (forceExactWordEnds_) { + log() << "forcing exact word-ends"; + } + + for (std::set::iterator it = network_.unpushedCoarticulatedRootStates.begin(); it != network_.unpushedCoarticulatedRootStates.end(); ++it) { + verify(network_.coarticulatedRootStates.count(*it)); + } + + std::set usedRoots; + std::deque active; + + std::vector fanIn(network_.structure.stateCount(), 0); + + // Collect all zero-depth roots to skip them during clean-up + std::set usefulRoots; + for (RootHash::const_iterator rootIt = roots_.begin(); rootIt != roots_.end(); ++rootIt) { + if (rootIt->first.depth == 0) { + usefulRoots.insert(rootIt->second); + } + } + + for (StateId node = 1; node < network_.structure.stateCount(); ++node) { + active.push_back(node); + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(node); target; ++target) { + if (target.isLabel()) { + usedRoots.insert(network_.exits[target.label()].transitState); + fanIn[network_.exits[target.label()].transitState] += 1; + } else { fanIn[*target] += 1; } @@ -493,10 +803,10 @@ std::vector TreeBuilder::minimize(bool forceDeterminizatio if (onlyMinimizeBackwards) { log() << "skipping determinization"; - for (StateId node = 1; node < network_.structure.stateCount(); ++node) + for (StateId node = 1; node < network_.structure.stateCount(); ++node) { determinizeMap[node] = node; - } - else { + } + } else { // Determinize states: Join successor states with the same state-desc while (!active.empty()) { StateId state = active.front(); @@ -504,9 +814,11 @@ std::vector TreeBuilder::minimize(bool forceDeterminizatio HMMStateNetwork::ChangePlan change = network_.structure.change(state); typedef std::unordered_multimap SuccessorHash; SuccessorHash successors; - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(state); target; ++target) - if (!target.isLabel() && (forceDeterminization || fanIn[*target] == 1)) + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(state); target; ++target) { + if (!target.isLabel() && (forceDeterminization || fanIn[*target] == 1)) { successors.insert(std::make_pair(network_.structure.state(*target).stateDesc, *target)); + } + } while (!successors.empty()) { std::pair items = successors.equal_range(successors.begin()->first); @@ -514,22 +826,27 @@ std::vector TreeBuilder::minimize(bool forceDeterminizatio SuccessorHash::iterator it = items.first; if (++it != items.second) { StateId newNode = network_.structure.allocateTreeNode(network_.masterTree); - if (newNode >= determinizeMap.size()) + if (newNode >= determinizeMap.size()) { determinizeMap.resize(newNode + 1, 0); + } network_.structure.state(newNode).stateDesc = items.first->first; - if (network_.uncoarticulatedWordEndStates.count(items.first->second)) + if (network_.uncoarticulatedWordEndStates.count(items.first->second)) { network_.uncoarticulatedWordEndStates.insert(newNode); + } HMMStateNetwork::ChangePlan newChange = network_.structure.change(newNode); // There are multiple successors with the same state-desc, join them for (it = items.first; it != items.second; ++it) { verify(it->second < determinizeMap.size()); - if (forceExactWordEnds_ && network_.uncoarticulatedWordEndStates.count(it->second)) + if (forceExactWordEnds_ && network_.uncoarticulatedWordEndStates.count(it->second)) { network_.uncoarticulatedWordEndStates.insert(newNode); - if (determinizeMap[it->second]) + } + if (determinizeMap[it->second]) { ++determinizeClashes; + } determinizeMap[it->second] = newNode; - for (HMMStateNetwork::SuccessorIterator target2 = network_.structure.successors(it->second); target2; ++target2) + for (HMMStateNetwork::SuccessorIterator target2 = network_.structure.successors(it->second); target2; ++target2) { newChange.addSuccessor(*target2); + } change.removeSuccessor(it->second); } newChange.apply(); @@ -552,16 +869,19 @@ std::vector TreeBuilder::minimize(bool forceDeterminizatio std::vector minimizeMap(network_.structure.stateCount(), 0); minimizeState(network_.rootState, minimizeMap); - for (std::set::iterator it = network_.coarticulatedRootStates.begin(); it != network_.coarticulatedRootStates.end(); ++it) + for (std::set::iterator it = network_.coarticulatedRootStates.begin(); it != network_.coarticulatedRootStates.end(); ++it) { minimizeState(*it, minimizeMap); - for (std::set::iterator it = skipRootSet_.begin(); it != skipRootSet_.end(); ++it) + } + for (std::set::iterator it = skipRootSet_.begin(); it != skipRootSet_.end(); ++it) { minimizeState(*it, minimizeMap); + } // loop over 0-depth roots to make sure they are mapped and connected with updated successors for (std::set::iterator it = usefulRoots.begin(); it != usefulRoots.end(); ++it) { if (determinizeMap[*it]) { minimizeState(determinizeMap[*it], minimizeMap); - } else { + } + else { minimizeState(*it, minimizeMap); } } @@ -587,11 +907,13 @@ std::vector TreeBuilder::minimize(bool forceDeterminizatio log() << "joining exits, coarticulated roots before: " << network_.coarticulatedRootStates.size(); u32 oldNodeCount = network_.structure.stateCount(); // New nodes may be added during this procedure + // joint transitRoot is individual state specific, thus not update roots_ for general key for (StateId state = 1; state < oldNodeCount; ++state) { - if (minimizeMap[state] == state) + if (minimizeMap[state] == state) { minimizeExits(state, minimizeExitsMap); - else + } else { network_.structure.clearOutputEdges(state); + } } } @@ -611,10 +933,10 @@ std::vector TreeBuilder::minimize(bool forceDeterminizatio for (std::map>::iterator it = oldTransitDescs.begin(); it != oldTransitDescs.end(); ++it) { StateId orig = it->first; if (orig == network_.rootState || orig >= minimizeMap.size()) { - if (orig == network_.rootState || network_.coarticulatedRootStates.count(orig)) + if (orig == network_.rootState || network_.coarticulatedRootStates.count(orig)) { network_.rootTransitDescriptions.insert(*it); - } - else { + } + } else { StateId mapped = minimizeMap[it->first]; verify(mapped); verify(network_.coarticulatedRootStates.count(mapped)); @@ -632,13 +954,15 @@ std::vector TreeBuilder::minimize(bool forceDeterminizatio log() << "cleaning"; u32 lost = 0, kept = 0; for (StateId state = 1; state < determinizeMap.size(); ++state) { - if (determinizeMap[state]) + if (determinizeMap[state]) { determinizeMap[state] = minimizeMap[determinizeMap[state]]; - else + } else { determinizeMap[state] = minimizeMap[state]; + } } minimizeMap = determinizeMap; + // cleanup also changes structure, need to update map accordingly HMMStateNetwork::CleanupResult cleanupResult = network_.cleanup(); for (std::vector::iterator it = minimizeMap.begin(); it != minimizeMap.end(); ++it) { if (*it) { @@ -646,8 +970,7 @@ std::vector TreeBuilder::minimize(bool forceDeterminizatio *it = cleanupResult.nodeMap[*it]; kept += 1; verify(*it); - } - else { + } else { lost += 1; *it = 0; } @@ -664,25 +987,11 @@ std::vector TreeBuilder::minimize(bool forceDeterminizatio return minimizeMap; } -void TreeBuilder::mapSet(std::set& set, const std::vector& minimizeMap, bool force) { - std::set oldSet; - oldSet.swap(set); - for (std::set::iterator it = oldSet.begin(); it != oldSet.end(); ++it) { - if (*it >= minimizeMap.size()) - set.insert(*it); - else if (!minimizeMap[*it]) { - verify(!force); - } - else { - set.insert(minimizeMap[*it]); - } - } -} - -void TreeBuilder::minimizeState(StateId state, std::vector& minimizeMap) { +void MinimizedTreeBuilder::minimizeState(StateId state, std::vector& minimizeMap) { verify(state < minimizeMap.size()); - if (minimizeMap[state]) + if (minimizeMap[state]) { return; + } minimizeMap[state] = Core::Type::max; @@ -699,8 +1008,7 @@ void TreeBuilder::minimizeState(StateId state, std::vector& minimizeMap if (minimizeMap[*target] == Core::Type::max) { // std::cout << "detected recursion while minimization on " << *target << std::endl; successors.insert(*target); - } - else { + }else { successors.insert(minimizeMap[*target]); } } @@ -711,8 +1019,7 @@ void TreeBuilder::minimizeState(StateId state, std::vector& minimizeMap std::unordered_map::iterator it = predecessors_.find(pred); if (it != predecessors_.end()) { minimizeMap[state] = it->second; - } - else { + } else { minimizeMap[state] = state; predecessors_.insert(std::make_pair(pred, state)); for (std::set::iterator succIt = successors.begin(); succIt != successors.end(); ++succIt) @@ -720,7 +1027,7 @@ void TreeBuilder::minimizeState(StateId state, std::vector& minimizeMap } } -void TreeBuilder::minimizeExits(StateId state, const std::vector& minimizeExitsMap) { +void MinimizedTreeBuilder::minimizeExits(StateId state, const std::vector& minimizeExitsMap) { typedef std::multimap ExitMap; ExitMap successorExits; @@ -734,12 +1041,14 @@ void TreeBuilder::minimizeExits(StateId state, const std::vector& minimizeE successorStates.insert(*target); } - if (successorExits.empty()) + if (successorExits.empty()) { return; + } network_.structure.clearOutputEdges(state); - for (std::set::iterator it = successorStates.begin(); it != successorStates.end(); ++it) + for (std::set::iterator it = successorStates.begin(); it != successorStates.end(); ++it) { network_.structure.addTargetToNode(state, *it); + } } // Join multiple exits for the same pronunciation to one @@ -748,14 +1057,14 @@ void TreeBuilder::minimizeExits(StateId state, const std::vector& minimizeE ExitMap::iterator i = range.first; if (++i == range.second) { network_.structure.addOutputToNode(state, range.first->second); - } - else { + } else { // Join std::set newRootSuccessors; std::set left, right; for (i = range.first; i != range.second; ++i) { - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(network_.exits[i->second].transitState); target; ++target) + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(network_.exits[i->second].transitState); target; ++target) { newRootSuccessors.insert(*target); + } left.insert(network_.rootTransitDescriptions[network_.exits[i->second].transitState].first); right.insert(network_.rootTransitDescriptions[network_.exits[i->second].transitState].second); } @@ -770,10 +1079,12 @@ void TreeBuilder::minimizeExits(StateId state, const std::vector& minimizeE network_.rootTransitDescriptions.insert(std::make_pair(exit.transitState, std::make_pair(left.size() == 1 ? *left.begin() : Bliss::Phoneme::term, right.size() == 1 ? *right.begin() : Bliss::Phoneme::term))); for (i = range.first; i != range.second; ++i) { verify(i->second < network_.exits.size()); - if (network_.unpushedCoarticulatedRootStates.count(network_.exits[i->second].transitState)) + if (network_.unpushedCoarticulatedRootStates.count(network_.exits[i->second].transitState)) { network_.unpushedCoarticulatedRootStates.insert(exit.transitState); - if (network_.uncoarticulatedWordEndStates.count(network_.exits[i->second].transitState)) + } + if (network_.uncoarticulatedWordEndStates.count(network_.exits[i->second].transitState)) { network_.uncoarticulatedWordEndStates.insert(exit.transitState); + } } } } @@ -781,335 +1092,55 @@ void TreeBuilder::minimizeExits(StateId state, const std::vector& minimizeE } } -void TreeBuilder::buildFanInOutStructure() { - bool ciTransitions = paramAddCiTransitions(config_); - - // Create temporary coarticulated roots - for (std::set::const_iterator finalIt = finalPhonemes_.begin(); finalIt != finalPhonemes_.end(); ++finalIt) - for (std::set::const_iterator initialIt = initialPhonemes_.begin(); initialIt != initialPhonemes_.end(); ++initialIt) - createRoot(*finalIt, *initialIt, 0); - - log() << "building fan-in"; - // Build the fan-in structure (eg. the HMM structure representing the initial word phonemes, behind roots, up to the joints) - for (RootHash::const_iterator rootIt = roots_.begin(); rootIt != roots_.end(); ++rootIt) { - if (rootIt->first.depth != 0 || rootIt->second == network_.rootState) - continue; - Bliss::Phoneme::Id initial = rootIt->first.right; - verify(initialPhonemes_.count(initial)); - verify(initial != Bliss::Phoneme::term); - u32 paths = 0; - - for (CoarticulationJointHash::const_iterator initialSuffixIt = initialPhoneSuffix_.begin(); initialSuffixIt != initialPhoneSuffix_.end(); ++initialSuffixIt) { - if (initialSuffixIt->first.left != initial) - continue; - ++paths; - HMMSequence hmm; - hmmFromAllophone(hmm, rootIt->first.left, initial, initialSuffixIt->first.right, Am::Allophone::isInitialPhone); - verify(hmm.length > 0); - StateId currentNode = extendFanIn(initialSuffixIt->second, hmm[hmm.length - 1]); - for (s32 s = hmm.length - 2; s >= 0; --s) - currentNode = extendFanIn(currentNode, hmm[s]); - - addSuccessor(rootIt->second, currentNode); - } - - for (CoarticulationJointHash::const_iterator initialSuffixIt = initialFinalPhoneSuffix_.begin(); initialSuffixIt != initialFinalPhoneSuffix_.end(); ++initialSuffixIt) { - if (initialSuffixIt->first.left != initial) - continue; - ++paths; - HMMSequence hmm; - hmmFromAllophone(hmm, rootIt->first.left, initial, initialSuffixIt->first.right, Am::Allophone::isInitialPhone | Am::Allophone::isFinalPhone); - verify(hmm.length > 0); - StateId currentNode = extendFanIn(initialSuffixIt->second, hmm[hmm.length - 1]); - for (s32 s = hmm.length - 2; s >= 0; --s) - currentNode = extendFanIn(currentNode, hmm[s]); - - addSuccessor(rootIt->second, currentNode); - } - } - - log() << "states: " << network_.structure.stateCount() << " exits: " << network_.exits.size() << " roots: " << roots_.size(); - - log() << "building fan-out"; - - // Build the fan-out structure (eg. the HMM structure representing the final word phonemes, behind special roots) - // At the left side delimited by the roots of depth -1, and at the right side by the roots of depth 0 - for (RootHash::const_iterator leftRootIt = roots_.begin(); leftRootIt != roots_.end(); ++leftRootIt) { - if (leftRootIt->first.depth != -1) - continue; - Bliss::Phoneme::Id fin = leftRootIt->first.right; - verify(finalPhonemes_.count(fin)); - - u32 paths = 0; - for (RootHash::const_iterator rightRootIt = roots_.begin(); rightRootIt != roots_.end(); ++rightRootIt) { - if (rightRootIt->first.depth != 0 || (rightRootIt->first.left != fin && (!ciTransitions || rightRootIt->first.left != Bliss::Phoneme::term))) - continue; - ++paths; - HMMSequence hmm; - hmmFromAllophone(hmm, leftRootIt->first.left, fin, rightRootIt->first.right, Am::Allophone::isFinalPhone, false); - verify(hmm.length > 0); - - // The last state in the pushed fan-in is equivalent with the corresponding root state - - StateId lastNode = extendFanIn(network_.structure.targetSet(rightRootIt->second), hmm[hmm.length - 1]); - StateId currentNode = lastNode; - for (s32 s = hmm.length - 2; s >= 0; --s) - currentNode = extendFanIn(currentNode, hmm[s]); - - if (rightRootIt->first.right == Bliss::Phoneme::term || !isContextDependent(rightRootIt->first.right)) - network_.uncoarticulatedWordEndStates.insert(lastNode); - - addSuccessor(leftRootIt->second, currentNode); - } - verify(paths); - } - - printStats("after fan-in/out structure"); -} - -void TreeBuilder::printStats(std::string occasion) { - log() << "stats " << occasion << ":"; - log() << "states: " << network_.structure.stateCount() << " exits: " << network_.exits.size(); - log() << "coarticulated roots: " << network_.coarticulatedRootStates.size() << " unpushed: " << network_.unpushedCoarticulatedRootStates.size(); - u32 roots = 0; - for (std::set::iterator it = network_.uncoarticulatedWordEndStates.begin(); it != network_.uncoarticulatedWordEndStates.end(); ++it) - if (network_.coarticulatedRootStates.count(*it)) - ++roots; - log() << "number of uncoarticulated pushed word-end nodes: " << network_.uncoarticulatedWordEndStates.size() << " out of those are roots: " << roots; -} - -TreeBuilder::StateId TreeBuilder::createSkipRoot(StateId baseRoot) { - SkipRootsHash::const_iterator it = skipRoots_.find(baseRoot); - if (it != skipRoots_.end()) - return it->second; - StateTree::StateDesc desc = rootDesc(); - desc.transitionModelIndex = Am::TransitionModel::entryM2; - StateId ret = createState(desc); - - skipRoots_.insert(std::make_pair(baseRoot, ret)); - network_.structure.addTargetToNode(ret, baseRoot); - skipRootSet_.insert(ret); - network_.coarticulatedRootStates.insert(ret); - verify(network_.rootTransitDescriptions.count(baseRoot)); - network_.rootTransitDescriptions.insert(std::make_pair(ret, network_.rootTransitDescriptions[baseRoot])); - return ret; -} - -TreeBuilder::StateId TreeBuilder::createRoot(Bliss::Phoneme::Id left, Bliss::Phoneme::Id right, int depth) { - RootKey key(left, right, depth); - RootHash::const_iterator it = roots_.find(key); - if (it != roots_.end()) - return it->second; - StateId ret = createState(rootDesc()); - if (depth == 0 && (left != Bliss::Phoneme::term || right != Bliss::Phoneme::term)) - network_.unpushedCoarticulatedRootStates.insert(ret); - - if (right == Bliss::Phoneme::term || !acousticModel_.phonemeInventory()->phoneme(right)->isContextDependent()) - network_.uncoarticulatedWordEndStates.insert(ret); - - if (left != Bliss::Phoneme::term || right != Bliss::Phoneme::term) - network_.coarticulatedRootStates.insert(ret); - roots_.insert(std::make_pair(key, ret)); - - network_.rootTransitDescriptions.insert(std::make_pair(ret, std::make_pair(left, right))); - - return ret; -} - -TreeBuilder::StateId TreeBuilder::createState(StateTree::StateDesc desc) { - StateId ret = network_.structure.allocateTreeNode(network_.masterTree); - network_.structure.state(ret).stateDesc = desc; - return ret; -} - -TreeBuilder::StateId TreeBuilder::extendFanIn(StateId successorOrExit, StateTree::StateDesc desc) { - std::set successors; - successors.insert(successorOrExit); - return extendFanIn(successors, desc); -} - -TreeBuilder::StateId TreeBuilder::extendFanIn(const std::set& successorsOrExits, Search::StateTree::StateDesc desc) { - StatePredecessor pred(successorsOrExits, desc); - PredecessorsHash::iterator it = predecessors_.find(pred); - if (it != predecessors_.end()) - return it->second; - StateId ret = createState(desc); - for (std::set::const_iterator it = successorsOrExits.begin(); it != successorsOrExits.end(); ++it) - network_.structure.addTargetToNode(ret, *it); - predecessors_.insert(std::make_pair(pred, ret)); - return ret; -} - -bool TreeBuilder::addSuccessor(StateId predecessor, StateId successor) { - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) - if (*target == successor) - return false; - - network_.structure.addTargetToNode(predecessor, successor); - return true; -} - -std::pair TreeBuilder::extendPhone(StateId currentState, - u32 phoneIndex, - const std::vector& phones, - Bliss::Phoneme::Id left, - Bliss::Phoneme::Id right) { - u8 boundary = 0; - if (phoneIndex != 0) - left = phones[phoneIndex - 1]; - else - boundary |= Am::Allophone::isInitialPhone; - - if (phoneIndex != phones.size() - 1) - right = phones[phoneIndex + 1]; - else - boundary |= Am::Allophone::isFinalPhone; - - HMMSequence hmm; - hmmFromAllophone(hmm, left, phones[phoneIndex], right, boundary, phoneIndex == 0); - - verify(hmm.length >= 1); - - u32 hmmState = 0; - StateId previousState = 0; - - if (phoneIndex == 1 && hmmState == 0) - currentState = extendBodyState(currentState, left, phones[phoneIndex], hmm[hmmState++]); - - for (; hmmState < hmm.length; ++hmmState) { - previousState = currentState; - currentState = extendState(currentState, hmm[hmmState]); - } - - return std::make_pair(previousState, currentState); -} - -TreeBuilder::StateId TreeBuilder::extendState(StateId predecessor, StateTree::StateDesc desc, TreeBuilder::RootKey uniqueKey) { - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) - if (!target.isLabel() && network_.structure.state(*target).stateDesc == desc) { - if (uniqueKey.isValid()) { - Core::HashMap::const_iterator it = stateUniqueKeys_.find(*target); - verify(it != stateUniqueKeys_.end()); - if (!(it->second == uniqueKey)) - continue; - } - return *target; +void MinimizedTreeBuilder::mapSet(std::set& set, const std::vector& minimizeMap, bool force) { + std::set oldSet; + oldSet.swap(set); + for (std::set::iterator it = oldSet.begin(); it != oldSet.end(); ++it) { + if (*it >= minimizeMap.size()) { + set.insert(*it); + } else if (!minimizeMap[*it]) { + verify(!force); + } else { + set.insert(minimizeMap[*it]); } - - // No matching successor found, extend - StateId ret = createState(desc); - if (uniqueKey.isValid()) - stateUniqueKeys_.insert(std::make_pair(ret, uniqueKey)); - network_.structure.addTargetToNode(predecessor, ret); - return ret; -} - -u32 TreeBuilder::createExit(PersistentStateTree::Exit exit) { - ExitHash::iterator exitHashIt = exitHash_.find(exit); - if (exitHashIt != exitHash_.end()) { - return exitHashIt->second; } - else { - // Exit does not exist yet, add it - network_.exits.push_back(exit); - u32 exitIndex = network_.exits.size() - 1; - exitHash_.insert(std::make_pair(exit, exitIndex)); - return exitIndex; - } -} - -u32 TreeBuilder::addExit(StateId prePredecessor, - StateId predecessor, - Bliss::Phoneme::Id leftPhoneme, - Bliss::Phoneme::Id rightPhoneme, - int depth, - Bliss::LemmaPronunciation::Id pron) { - PersistentStateTree::Exit exit; - exit.transitState = createRoot(leftPhoneme, rightPhoneme, depth); - exit.pronunciation = pron; - - u32 exitIndex = createExit(exit); - - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) - if (target.isLabel() && target.label() == exitIndex) - return exitIndex; - - network_.structure.addOutputToNode(predecessor, ID_FROM_LABEL(exitIndex)); - return exitIndex; -} - -TreeBuilder::StateId TreeBuilder::extendBodyState(StateId state, - Bliss::Phoneme::Id first, - Bliss::Phoneme::Id second, - Search::StateTree::StateDesc desc) { - RootKey key(first, second, 1); - StateId ret = extendState(state, desc, key); - initialPhoneSuffix_[key].insert(ret); - return ret; -} - -bool TreeBuilder::isContextDependent(Bliss::Phoneme::Id phone) const { - return acousticModel_.phonemeInventory()->phoneme(phone)->isContextDependent(); -} - -StateTree::StateDesc TreeBuilder::rootDesc() const { - StateTree::StateDesc desc; - desc.acousticModel = Search::StateTree::invalidAcousticModel; - desc.transitionModelIndex = Am::TransitionModel::entryM1; - return desc; -} - -std::string TreeBuilder::describe(std::pair desc) { - std::ostringstream os; - - if (desc.first == Bliss::Phoneme::term) - os << "#"; - else - os << lexicon_.phonemeInventory()->phoneme(desc.first)->symbol(); - - os << "<->"; - - if (desc.second == Bliss::Phoneme::term) - os << "#"; - else - os << lexicon_.phonemeInventory()->phoneme(desc.second)->symbol(); - - return os.str(); -} - -Core::Component::Message TreeBuilder::log() const { - return Core::Application::us()->log("TreeBuilder: "); } // update hash structures according to minimizeMap (invalid ones are removed) // should be ok for any number of minimize iterations -void TreeBuilder::updateHashFromMap(const std::vector& map, const std::vector& exitMap) { +void MinimizedTreeBuilder::updateHashFromMap(const std::vector& map, const std::vector& exitMap) { Core::HashMap tmpKeyHash; - for (Core::HashMap::iterator iter = stateUniqueKeys_.begin(); iter != stateUniqueKeys_.end(); ++iter) - if (iter->first < map.size() && map[iter->first]) + for (Core::HashMap::iterator iter = stateUniqueKeys_.begin(); iter != stateUniqueKeys_.end(); ++iter) { + if (iter->first < map.size() && map[iter->first]) { tmpKeyHash.insert(std::make_pair(map[iter->first], iter->second)); + } + } stateUniqueKeys_.swap(tmpKeyHash); mapCoarticulationJointHash(initialPhoneSuffix_, map, exitMap); mapCoarticulationJointHash(initialFinalPhoneSuffix_, map, exitMap); RootHash tmpRootHash; - for (RootHash::const_iterator rootIt = roots_.begin(); rootIt != roots_.end(); ++rootIt) - if (rootIt->second < map.size() && map[rootIt->second]) + for (RootHash::const_iterator rootIt = roots_.begin(); rootIt != roots_.end(); ++rootIt) { + if (rootIt->second < map.size() && map[rootIt->second]) { tmpRootHash.insert(std::make_pair(rootIt->first, map[rootIt->second])); + } + } roots_.swap(tmpRootHash); // exits are changed in cleanup exitHash_.clear(); - for (u32 idx = 0; idx != network_.exits.size(); ++idx) + for (u32 idx = 0; idx != network_.exits.size(); ++idx) { exitHash_.insert(std::make_pair(network_.exits[idx], idx)); + } // PredecessorsHash still the FanIn/Out ones at this point (recorded in minimize()) PredecessorsHash tmpPredHash; for (PredecessorsHash::iterator pIt = predecessors_.begin(); pIt != predecessors_.end(); ++pIt) { - if (pIt->second >= map.size() || !map[pIt->second]) + if (pIt->second >= map.size() || !map[pIt->second]) { continue; + } const StatePredecessor& sp = pIt->first; std::set tmpSet; mapSuccessors(sp.successors, tmpSet, map, exitMap); @@ -1121,26 +1152,29 @@ void TreeBuilder::updateHashFromMap(const std::vector& map, const std:: predecessors_.swap(tmpPredHash); } -inline void TreeBuilder::mapCoarticulationJointHash(CoarticulationJointHash& hash, const std::vector& map, const std::vector& exitMap) { +inline void MinimizedTreeBuilder::mapCoarticulationJointHash(CoarticulationJointHash& hash, const std::vector& map, const std::vector& exitMap) { CoarticulationJointHash tmpHash; for (CoarticulationJointHash::iterator iter = hash.begin(); iter != hash.end(); ++iter) { std::set tmpSet; mapSuccessors(iter->second, tmpSet, map, exitMap); - if (!tmpSet.empty()) + if (!tmpSet.empty()) { tmpHash.insert(std::make_pair(iter->first, tmpSet)); + } } hash.swap(tmpHash); } -inline void TreeBuilder::mapSuccessors(const std::set& successors, std::set& tmpSet, const std::vector& map, const std::vector& exitMap) { - for (std::set::const_iterator sIt = successors.cbegin(); sIt != successors.cend(); ++sIt) +inline void MinimizedTreeBuilder::mapSuccessors(const std::set& successors, std::set& tmpSet, const std::vector& map, const std::vector& exitMap) { + for (std::set::const_iterator sIt = successors.cbegin(); sIt != successors.cend(); ++sIt) { if (IS_LABEL(*sIt)) { u32 eIdx = LABEL_FROM_ID(*sIt); - if (exitMap.empty() || eIdx >= exitMap.size()) + if (exitMap.empty() || eIdx >= exitMap.size()) { tmpSet.insert(*sIt); - else + } else { tmpSet.insert(ID_FROM_LABEL(exitMap[eIdx])); - } - else if (*sIt < map.size() && map[*sIt]) + } + } else if (*sIt < map.size() && map[*sIt]) { tmpSet.insert(map[*sIt]); + } + } } diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.hh b/src/Search/AdvancedTreeSearch/TreeBuilder.hh index 52f209f9f..c99f9db66 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.hh +++ b/src/Search/AdvancedTreeSearch/TreeBuilder.hh @@ -33,29 +33,57 @@ namespace Core { class Configuration; } -class TreeBuilder { +class AbstractTreeBuilder : public Core::Component { public: typedef u32 StateId; - TreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true, bool arcBased = false); + AbstractTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network); + virtual ~AbstractTreeBuilder() = default; + + virtual std::unique_ptr newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true) = 0; + // Build a new persistent state network. - void build(); - // Returns a mapping of state-indices. Zero means 'invalid'. - // If onlyMinimizeBackwards is true, then no forward determinization is performed, but rather only backwards minimization. - // If allowLost is true, losing states is allowed. Happens if there are unreachable garbage states. - std::vector minimize(bool forceDeterminization = true, bool onlyMinimizeBackwards = false, bool allowLost = false); + virtual void build() = 0; + +protected: + const Bliss::Lexicon& lexicon_; + const Am::AcousticModel& acousticModel_; + Search::PersistentStateTree& network_; +}; + +class MinimizedTreeBuilder : public AbstractTreeBuilder { +public: + static const Core::ParameterInt paramMinPhones; + static const Core::ParameterBool paramAddCiTransitions; + static const Core::ParameterBool paramUseRootForCiExits; + static const Core::ParameterBool paramForceExactWordEnds; + static const Core::ParameterBool paramKeepRoots; + static const Core::ParameterBool paramAllowCrossWordSkips; + static const Core::ParameterBool paramRepeatSilence; + static const Core::ParameterInt paramMinimizeIterations; + typedef u32 StateId; + + MinimizedTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + virtual ~MinimizedTreeBuilder() = default; + + virtual std::unique_ptr newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + + virtual void build(); +protected: struct HMMSequence { HMMSequence() : length(0) {} enum { MaxLength = 12 }; - s32 length; + + s32 length; + Search::StateTree::StateDesc hmm[MaxLength]; + inline const Search::StateTree::StateDesc& operator[](u32 index) const { return hmm[index]; } - Search::StateTree::StateDesc hmm[MaxLength]; bool operator==(const HMMSequence& rhs) const { verify(length < MaxLength); @@ -80,18 +108,6 @@ public: }; }; - HMMSequence arcSequence(u32 acousticModelIndex) const; - std::string arcDesc(u32 acousticModelIndex) const; - - // If this function returns true, then the hmm states are placeholders for hmm sequences which - // can be acquired through arcSequence(...). The transition model index then contains word boundary information. - bool arcBased() const { - return arcBased_; - } - -protected: - Core::Component::Message log() const; - struct RootKey { public: RootKey(Bliss::Phoneme::Id _left = Core::Type::max, Bliss::Phoneme::Id _right = Core::Type::max, int _depth = 0) @@ -127,28 +143,63 @@ protected: isWordEnd(_isWordEnd), hash(StandardValueHash()(SetHash()(successors) + Search::StateTree::StateDesc::Hash()(desc) + (isWordEnd ? 1312 : 0))) {} + bool operator==(const StatePredecessor& rhs) const { + return successors == rhs.successors && desc == rhs.desc && isWordEnd == rhs.isWordEnd; + } + struct Hash { u32 operator()(const StatePredecessor& pred) const { return pred.hash; } }; - bool operator==(const StatePredecessor& rhs) const { - return successors == rhs.successors && desc == rhs.desc && isWordEnd == rhs.isWordEnd; - } - const std::set successors; const Search::StateTree::StateDesc desc; bool isWordEnd; const u32 hash; }; - void printStats(std::string occasion); + typedef std::set PhonemeIdSet; + typedef Core::HashMap RootHash; + typedef Core::HashMap SkipRootsHash; + typedef Core::HashMap ExitHash; + typedef Core::HashMap, RootKey::Hash> CoarticulationJointHash; + typedef Core::HashMap PredecessorsHash; + + s32 minPhones_; + bool addCiTransitions_; + bool useRootForCiExits_; + bool forceExactWordEnds_; + bool keepRoots_; + bool allowCrossWordSkips_; + bool repeatSilence_; + u32 minimizeIterations_; + bool reverse_; + + PhonemeIdSet initialPhonemes_; + PhonemeIdSet finalPhonemes_; + + // Keys according to which specific states are supposed to be unique + // Required to omit merging of paths in some critical locations + Core::HashMap stateUniqueKeys_; + + RootHash roots_; // Contains roots and joint-states + SkipRootsHash skipRoots_; + std::set skipRootSet_; + ExitHash exitHash_; + CoarticulationJointHash initialPhoneSuffix_; + CoarticulationJointHash initialFinalPhoneSuffix_; + PredecessorsHash predecessors_; + + void printStats(std::string occasion); + std::string describe(std::pair); + + bool isContextDependent(Bliss::Phoneme::Id phone) const; + + void buildBody(); void buildFanInOutStructure(); void addCrossWordSkips(); - void skipRootTransitions(); - void propagateExits(StateId state, Search::HMMStateNetwork::ChangePlan change); - bool isContextDependent(Bliss::Phoneme::Id phone) const; + void skipRootTransitions(StateId start = 1); Search::StateTree::StateDesc rootDesc() const; @@ -156,18 +207,17 @@ protected: StateId createRoot(Bliss::Phoneme::Id left, Bliss::Phoneme::Id right, int depth); StateId createState(Search::StateTree::StateDesc desc); u32 createExit(Search::PersistentStateTree::Exit exit); - u32 addExit(StateId prePredecessor, - StateId predecessor, + u32 addExit(StateId predecessor, Bliss::Phoneme::Id leftPhoneme, Bliss::Phoneme::Id rightPhoneme, int depth, Bliss::LemmaPronunciation::Id pron); - void hmmFromAllophone(HMMSequence& ret, - Bliss::Phoneme::Id left, - Bliss::Phoneme::Id central, - Bliss::Phoneme::Id right, - u32 boundary = 0, - bool allowNonStandard = true); + + void hmmFromAllophone(HMMSequence& ret, + Bliss::Phoneme::Id left, + Bliss::Phoneme::Id central, + Bliss::Phoneme::Id right, + u32 boundary = 0); // Adds the successor as successor of the predecessor, if it isn't in the list yet bool addSuccessor(StateId predecessor, StateId successor); @@ -180,60 +230,15 @@ protected: StateId extendBodyState(StateId state, Bliss::Phoneme::Id first, Bliss::Phoneme::Id second, Search::StateTree::StateDesc desc); StateId extendFanIn(StateId successor, Search::StateTree::StateDesc desc); StateId extendFanIn(const std::set& successors, Search::StateTree::StateDesc desc); - void minimizeState(StateId state, std::vector& minimizeMap); - void minimizeExits(StateId state, const std::vector& minimizeExitsMap); - static void mapSet(std::set& set, const std::vector& minimizeMap, bool force); - std::string describe(std::pair); - - const Bliss::Lexicon& lexicon_; - const Am::AcousticModel& acousticModel_; - Search::PersistentStateTree& network_; - Core::Configuration config_; - s32 minPhones_; - bool forceExactWordEnds_; - bool keepRoots_; - bool allowCrossWordSkips_; - bool repeatSilence_; - bool reverse_; - bool arcBased_; - std::set initialPhonemes_, finalPhonemes_; - - // Keys according to which specific states are supposed to be unique - // Required to omit merging of paths in some critical locations - Core::HashMap stateUniqueKeys_; - - typedef Core::HashMap ArcSequenceHash; - ArcSequenceHash arcSequencesHash_; - std::vector arcSequences_; - struct ArcDesc { - ArcDesc() - : left(Bliss::Phoneme::term), - central(Bliss::Phoneme::term), - right(Bliss::Phoneme::term) { - } - Bliss::Phoneme::Id left; - Bliss::Phoneme::Id central; - Bliss::Phoneme::Id right; - }; - std::vector arcDescs_; - - typedef Core::HashMap RootHash; - RootHash roots_; // Contains roots and joint-states - - typedef Core::HashMap SkipRootsHash; - SkipRootsHash skipRoots_; - std::set skipRootSet_; - - typedef Core::HashMap ExitHash; - ExitHash exitHash_; - - typedef Core::HashMap, RootKey::Hash> CoarticulationJointHash; - CoarticulationJointHash initialPhoneSuffix_, initialFinalPhoneSuffix_; + // Returns a mapping of state-indices. Zero means 'invalid'. + // If onlyMinimizeBackwards is true, then no forward determinization is performed, but rather only backwards minimization. + // If allowLost is true, losing states is allowed. Happens if there are unreachable garbage states. + std::vector minimize(bool forceDeterminization = true, bool onlyMinimizeBackwards = false, bool allowLost = false); + void minimizeState(StateId state, std::vector& minimizeMap); + void minimizeExits(StateId state, const std::vector& minimizeExitsMap); + static void mapSet(std::set& set, const std::vector& minimizeMap, bool force); - typedef Core::HashMap PredecessorsHash; - PredecessorsHash predecessors_; -protected: void updateHashFromMap(const std::vector& map, const std::vector& exitMap); void mapCoarticulationJointHash(CoarticulationJointHash& hash, const std::vector& map, const std::vector& exitMap); void mapSuccessors(const std::set&, std::set&, const std::vector&, const std::vector&); From ed7430d5f99f18d378d0118b4ed4ba5fc00e7e65 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 12 Dec 2024 14:22:25 +0100 Subject: [PATCH 02/24] Apply formatting --- .../AdvancedTreeSearch/PersistentStateTree.cc | 12 ++-- .../AdvancedTreeSearch/PersistentStateTree.hh | 22 +++---- src/Search/AdvancedTreeSearch/SearchSpace.cc | 48 +++++++------- src/Search/AdvancedTreeSearch/SearchSpace.hh | 13 ++-- src/Search/AdvancedTreeSearch/TreeBuilder.cc | 66 ++++++++++++------- src/Search/AdvancedTreeSearch/TreeBuilder.hh | 2 +- 6 files changed, 91 insertions(+), 72 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc index 83cd5030b..7a0dae10a 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc @@ -41,7 +41,7 @@ struct ConvertTree { TreeIndex masterTreeIndex; StateId rootSubTree; StateId ciRootNode; - std::map exits; //Maps exits to label-indices @todo Make this a hash_map + std::map exits; // Maps exits to label-indices @todo Make this a hash_map std::vector exitVector; Core::HashMap statesForNodes; Core::HashMap nodesForStates; @@ -73,7 +73,7 @@ struct ConvertTree { } } - ///Make sure a node is created for every single state, so that also the coarticulated roots are respected + /// Make sure a node is created for every single state, so that also the coarticulated roots are respected for (std::set::iterator stateIt = coarticulatedRootStates.begin(); stateIt != coarticulatedRootStates.end(); ++stateIt) { StateTree::StateId state = *stateIt; @@ -121,7 +121,7 @@ struct ConvertTree { exitIndices.insert(exitEntry->second); } - //Add connections to the attached outputs/exits + // Add connections to the attached outputs/exits for (std::set::iterator it = exitIndices.begin(); it != exitIndices.end(); ++it) subtrees.addOutputToEdge(subtrees.state(node).successors, *it); } @@ -150,10 +150,10 @@ struct ConvertTree { subtrees.state(node).stateDesc = state; - //Build successor structure + // Build successor structure std::pair successors = tree->successors(stateId); - StateId current = node; //Just to verify the order + StateId current = node; // Just to verify the order for (; successors.first != successors.second; ++successors.first) { std::unordered_map::iterator nodeIt = nodesForStates.find(*successors.first); @@ -437,7 +437,7 @@ HMMStateNetwork::CleanupResult PersistentStateTree::cleanup(bool cleanupExits) { Core::HashMap::const_iterator targetNodeIt; if (rootState) { - verify(cleanupResult.nodeMap.find(rootState) != cleanupResult.nodeMap.end()); //Root-node must stay unchanged + verify(cleanupResult.nodeMap.find(rootState) != cleanupResult.nodeMap.end()); // Root-node must stay unchanged verify(cleanupResult.nodeMap.find(rootState)->second == rootState); targetNodeIt = cleanupResult.nodeMap.find(rootState); verify(targetNodeIt != cleanupResult.nodeMap.end()); diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh index 4607b39c1..e83dd1073 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh @@ -44,25 +44,25 @@ public: ///@param lexicon This must be given if the resulting exits are supposed to be functional PersistentStateTree(Core::Configuration config, Core::Ref acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory); - ///Builds this state tree. + /// Builds this state tree. void build(); - ///Writes the current state of the state tree into the file, - ///Returns whether writing was successful + /// Writes the current state of the state tree into the file, + /// Returns whether writing was successful bool write(int transformation = 0); - ///Reads the state tree from the file. + /// Reads the state tree from the file. ///@return Whether the reading was successful. bool read(int transformation = 0); - ///Cleans up the structure, saving memory and allowing a more efficient iteration. - ///Node and tree IDs may be changed. + /// Cleans up the structure, saving memory and allowing a more efficient iteration. + /// Node and tree IDs may be changed. ///@return An object that contains a mapping representing the index changes. HMMStateNetwork::CleanupResult cleanup(bool cleanupExits = true); - ///Removes all outputs from the network - ///Also performs a cleanup, so the search network must already be clean - ///for indices to stay equal + /// Removes all outputs from the network + /// Also performs a cleanup, so the search network must already be clean + /// for indices to stay equal void removeOutputs(); u32 getChecksum() const; @@ -134,10 +134,10 @@ private: Core::Configuration config_; TreeBuilderFactory treeBuilderFactory_; - //Writes the whole state network into the given stream + // Writes the whole state network into the given stream void write(Core::MappedArchiveWriter writer); - //Reads the state network from the given stream. + // Reads the state network from the given stream. //@return Whether the reading was successful. bool read(Core::MappedArchiveReader reader); }; diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index d6ea06ad7..4cbfac68e 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -230,7 +230,7 @@ const Core::ParameterBool paramReducedContextTreeKey( const Core::ParameterBool paramOnTheFlyRescoring( "on-the-fly-rescoring", - "keep track of recombined histories and use those aswell when searching for word ends", + "keep track of recombined histories and use those as well when searching for word ends", false); const Core::ParameterInt paramOnTheFlyRescoringMaxHistories( @@ -391,7 +391,6 @@ StaticSearchAutomaton::StaticSearchAutomaton(Core::Configuration config, Core::R treeBuilderType_(static_cast(paramTreeBuilderType(config))), acousticModel_(acousticModel), lexicon_(lexicon) { - if (treeBuilderType_ == TreeBuilderType::previousBehavior) { treeBuilderType_ = minimized ? TreeBuilderType::minimizedHmm : TreeBuilderType::classicHmm; } @@ -405,7 +404,7 @@ StaticSearchAutomaton::~StaticSearchAutomaton() { void StaticSearchAutomaton::buildNetwork() { /// @todo Track the TreeBuilder configuration in transformation if minimizedTree - int transformation = minimized ? 32 : 0; + int transformation = minimized ? 32 : 0; if (!network.read(transformation)) { log() << "persistent network image could not be loaded, building it"; @@ -1165,7 +1164,7 @@ inline bool SearchSpace::eventuallyDeactivateTree(Instance* at, bool increaseIna void SearchSpace::activateOrUpdateStateHypothesisLoop(const Search::StateHypothesis& hyp, Score score) { StateHypothesisIndex& recombination = stateHypothesisRecombinationArray[hyp.state]; // Look-up at node index, contains positions in newStateHypotheses. - StateHypothesis& sh(newStateHypotheses.data()[recombination]); //We may be referencing a not allocated position, so we use data() + StateHypothesis& sh(newStateHypotheses.data()[recombination]); // We may be referencing a not allocated position, so we use data() // Check if present in current tree (starting at currentTreeFirstNewStateHypothesis). if (recombination < currentTreeFirstNewStateHypothesis || recombination >= newStateHypotheses.size() || sh.state != hyp.state) { recombination = newStateHypotheses.size(); @@ -1173,7 +1172,7 @@ void SearchSpace::activateOrUpdateStateHypothesisLoop(const Search::StateHypothe newStateHypotheses.back().score = score; } else { - //Update existing hypothesis + // Update existing hypothesis if (sh.score >= score) { sh.score = score; sh.trace = hyp.trace; @@ -1183,7 +1182,7 @@ void SearchSpace::activateOrUpdateStateHypothesisLoop(const Search::StateHypothe void SearchSpace::activateOrUpdateStateHypothesisTransition(const Search::StateHypothesis& hyp, Score score, StateId successorState) { StateHypothesisIndex& recombination = stateHypothesisRecombinationArray[successorState]; - StateHypothesis& sh(newStateHypotheses.data()[recombination]); //We may be referencing a not allocated position, so we use data() + StateHypothesis& sh(newStateHypotheses.data()[recombination]); // We may be referencing a not allocated position, so we use data() // Check if present in current tree (starting at currentTreeFirstNewStateHypothesis). if (recombination < currentTreeFirstNewStateHypothesis || recombination >= newStateHypotheses.size() || sh.state != successorState) { recombination = newStateHypotheses.size(); @@ -1192,7 +1191,7 @@ void SearchSpace::activateOrUpdateStateHypothesisTransition(const Search::StateH newStateHypotheses.back().state = successorState; } else { - //Update existing hypothesis + // Update existing hypothesis if (sh.score >= score) { sh.score = score; sh.trace = hyp.trace; @@ -1202,14 +1201,14 @@ void SearchSpace::activateOrUpdateStateHypothesisTransition(const Search::StateH void SearchSpace::activateOrUpdateStateHypothesisDirectly(const Search::StateHypothesis& hyp) { StateHypothesisIndex& recombination = stateHypothesisRecombinationArray[hyp.state]; - StateHypothesis& sh(newStateHypotheses.data()[recombination]); //We may be referencing a not allocated position, so we use data() + StateHypothesis& sh(newStateHypotheses.data()[recombination]); // We may be referencing a not allocated position, so we use data() if (recombination < currentTreeFirstNewStateHypothesis || recombination >= newStateHypotheses.size() || sh.state != hyp.state) { recombination = newStateHypotheses.size(); addNewStateHypothesis(hyp); } else { - //Update existing hypothesis + // Update existing hypothesis if (sh.score >= hyp.score) { sh.score = hyp.score; sh.trace = hyp.trace; @@ -1245,7 +1244,7 @@ void SearchSpace::expandStateSlow(const Search::StateHypothesis& hyp) { if (forwardScore < Core::Type::max) { std::pair successors = net.structure.batchSuccessorsSimple(state.successors); if (successors.first != -1) { - //Fast iteration + // Fast iteration for (StateId successor = successors.first; successor != successors.second; ++successor) { if (expandForward) activateOrUpdateStateHypothesisTransition(hyp, forwardScore, successor); // Already covered by expandState? @@ -1254,7 +1253,7 @@ void SearchSpace::expandStateSlow(const Search::StateHypothesis& hyp) { { // Second order expansion (successors of successor). std::pair skipSuccessors = net.structure.batchSuccessorsSimple(net.structure.state(successor).successors); if (skipSuccessors.first != -1) { - //Fast iteration + // Fast iteration for (StateId skipSuccessor = skipSuccessors.first; skipSuccessor != skipSuccessors.second; ++skipSuccessor) activateOrUpdateStateHypothesisTransition(hyp, skipScore, skipSuccessor); } @@ -1346,7 +1345,7 @@ inline void SearchSpace::expandState(const Search::StateHypothesis& hyp) { activateOrUpdateStateHypothesisTransition(hyp, skipScore, successor2); } else if (secondStart == 0) { - //The secondOrderEdgeSuccessorBatches_ structure can not hold the successors, so use slow expansion to expand the second-order followers + // The secondOrderEdgeSuccessorBatches_ structure can not hold the successors, so use slow expansion to expand the second-order followers expandStateSlow(hyp); } } @@ -1561,7 +1560,7 @@ void SearchSpace::applyLookaheadInInstanceInternal(Instance* _instance, Acoustic fail = !la->getScoreForLookAheadHashSparse(ids.first, lmScore); if (fail) { - //This state needs to transfer into the back-off network + // This state needs to transfer into the back-off network // Set the prospect to max, so this state will be pruned away from this network sh->prospect = F32_MAX; @@ -1654,7 +1653,7 @@ void SearchSpace::addAcousticScoresInternal(Instance const& instance, Pruning& p // Omit overhead of a virtual function-call for cached scores by calling the score function directly with qualification for (; sh != sh_end; ++sh) { if (sh->prospect == F32_MAX) - continue; //This state will be pruned + continue; // This state will be pruned const HMMState& state = network().structure.state(sh->state); Mm::MixtureIndex mix = state.stateDesc.acousticModel; @@ -1672,7 +1671,7 @@ void SearchSpace::addAcousticScoresInternal(Instance const& instance, Pruning& p else { for (; sh != sh_end; ++sh) { if (sh->prospect == F32_MAX) - continue; //This state will be pruned + continue; // This state will be pruned const HMMState& state = network().structure.state(sh->state); Mm::MixtureIndex mix = state.stateDesc.acousticModel; @@ -1905,7 +1904,6 @@ void SearchSpace::pruneStates(Pruning& pruning) { activeInstances.resize(instOut); } - void SearchSpace::updateSsaLm() { if (!ssaLm_) { return; @@ -3011,7 +3009,7 @@ void SearchSpace::doStateStatisticsBeforePruning() { for (InstanceList::reverse_iterator it = activeInstances.rbegin(); it != activeInstances.rend(); ++it) { if (backOffLm) { - //Do statistics over the count of states in back-off instances + // Do statistics over the count of states in back-off instances Instance& mt = dynamic_cast(**it); if (mt.lookahead.get()) @@ -3102,7 +3100,7 @@ void SearchSpace::doStateStatistics() { for (InstanceList::reverse_iterator it = activeInstances.rbegin(); it != activeInstances.rend(); ++it) { if (backOffLm) { - //Do statistics over the count of states in back-off instances + // Do statistics over the count of states in back-off instances Instance& mt = dynamic_cast(**it); Lm::History h = mt.lookaheadHistory; @@ -3243,8 +3241,8 @@ void SearchSpace::recombineWordEndsInternal(bool shallCreateLattice) { for (in = out = wordEndHypotheses.begin(); in != wordEndHypotheses.end(); ++in) { auto key = std::make_pair(recombinationLm_->reducedHistory(in->recombinationHistory, - reducedContextWordRecombinationLimit_), - in->transitState); + reducedContextWordRecombinationLimit_), + in->transitState); ReducedContextRecombinationMap::iterator i = wordEndHypothesisMap.find(key); if (i != wordEndHypothesisMap.end()) { WordEndHypothesis& a(*in); @@ -3600,7 +3598,7 @@ Instance* SearchSpace::activateOrUpdateTree(const Core::Ref& trace, Score score) { /// TODO (Nolden): getLastSyntacticToken is inefficient for long sequences. A simple rule would be better: Stay in same instance, or follow most recent pron. InstanceKey key(recombinationHistory, conditionPredecessorWord_ ? getLastSyntacticToken(trace) : Bliss::LemmaPronunciation::invalidId); - Instance* instance = instanceForKey(true, key, lookaheadHistory, scoreHistory); + Instance* instance = instanceForKey(true, key, lookaheadHistory, scoreHistory); if (!instance) return 0; @@ -3619,7 +3617,7 @@ void SearchSpace::processOneWordEnd(Instance const& at, StateHypothesis const& h PersistentStateTree::Exit const* we = &network().exits[exit]; TraceItem const& item = trace_manager_.traceItem(hyp.trace); - //We can do a more efficient word end handling if there is only one item in the trace, which is the standard case + // We can do a more efficient word end handling if there is only one item in the trace, which is the standard case verify_(item.scoreHistory.isValid()); EarlyWordEndHypothesis weh(hyp.trace, @@ -3697,10 +3695,10 @@ void SearchSpace::findWordEndsInternal() { Score exitPenalty = (*transitionModel(state.stateDesc))[Am::StateTransitionModel::exit]; if (earlyWordEndPruning && hyp.score + exitPenalty + earlyWordEndPruningAnticipatedLmScore_ > bestWordEndPruning) { - continue; //Apply early word-end pruning (If the best score can not be reached, do not even try) + continue; // Apply early word-end pruning (If the best score can not be reached, do not even try) } - ///With pushing, ca. 80% of all label-lists are single-labels, so optimize for this case + /// With pushing, ca. 80% of all label-lists are single-labels, so optimize for this case if (exit >= 0) { // There is 1 label processOneWordEnd(*inst, hyp, exit, exitPenalty, relativePruning, bestWordEndPruning); @@ -3726,7 +3724,7 @@ void SearchSpace::findWordEndsInternal() { } void SearchSpace::findWordEnds() { - //std::cerr << "best hyp: " << bestScore() << std::endl; + // std::cerr << "best hyp: " << bestScore() << std::endl; if (earlyWordEndPruning_) { if (onTheFlyRescoring_) { findWordEndsInternal(); diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.hh b/src/Search/AdvancedTreeSearch/SearchSpace.hh index b22279485..4c7428803 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.hh +++ b/src/Search/AdvancedTreeSearch/SearchSpace.hh @@ -132,7 +132,6 @@ private: Bliss::LexiconRef lexicon_; }; - class SearchSpace : public Core::Component { public: /// Statistics: @@ -338,7 +337,9 @@ public: void rescale(Score offset, bool ignoreWordEnds = false); // Returns the info from trace manager about whether it needs cleanup - inline bool needCleanup() const { return trace_manager_.needCleanup();} + inline bool needCleanup() const { + return trace_manager_.needCleanup(); + } // Needs to be called once in a while, but not every timeframe, // deletes all traces that did not survive in stateHypotheses and rootStateHypotheses of activeTrees @@ -370,18 +371,18 @@ public: void setLookAhead(const std::deque>&); Search::SearchAlgorithm::RecognitionContext setContext(Search::SearchAlgorithm::RecognitionContext); - ///Returns the best prospect, eg. the score of the best state hypothesis including the look-ahead score + /// Returns the best prospect, eg. the score of the best state hypothesis including the look-ahead score Score bestProspect() const; ///@warning: Expensive, without caching StateHypothesesList::const_iterator bestProspectStateHypothesis() const; - ///Returns the best score (the look-ahead score is not included) + /// Returns the best score (the look-ahead score is not included) ///@warning: Expensive, but with caching Score bestScore() const; ///@warning: Expensive, without caching StateHypothesesList::const_iterator bestScoreStateHypothesis() const; Score quantileStateScore(Score min, Score max, u32 nHyp) const; - ///Returns the lowest word end score (without look-ahead) - ///Always valid after findWordEnds was called + /// Returns the lowest word end score (without look-ahead) + /// Always valid after findWordEnds was called Score minimumWordEndScore() const; Score quantileWordEndScore(Score min, Score max, u32 nHyp) const; diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.cc b/src/Search/AdvancedTreeSearch/TreeBuilder.cc index c42e53a09..2a941e32d 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.cc +++ b/src/Search/AdvancedTreeSearch/TreeBuilder.cc @@ -201,7 +201,8 @@ void MinimizedTreeBuilder::buildBody() { initialPhonemes_.insert(initial); if (isContextDependent(initial)) { coarticulatedInitial += 1; - } else { + } + else { uncoarticulatedInitial += 1; } } @@ -209,7 +210,8 @@ void MinimizedTreeBuilder::buildBody() { if (!finalPhonemes_.count(fin)) { if (isContextDependent(fin)) { coarticulatedFinal += 1; - } else { + } + else { uncoarticulatedFinal += 1; } finalPhonemes_.insert(fin); @@ -265,7 +267,8 @@ void MinimizedTreeBuilder::buildBody() { u32 exit; if (!isContextDependent(phones[pronLength - 1]) && useRootForCiExits) { exit = addExit(tail.second, Bliss::Phoneme::term, Bliss::Phoneme::term, 0, lemmaPron->id()); // Use the non-coarticulated root node - } else { + } + else { exit = addExit(tail.second, phones[pronLength - 1], *initialIt, 0, lemmaPron->id()); } if (pronLength == 1) { @@ -273,7 +276,8 @@ void MinimizedTreeBuilder::buildBody() { } } } - } else { + } + else { // Minimize the remaining phoneme, insert corresponding word-ends. for (Bliss::Pronunciation::LemmaIterator lemmaPron = lemmaProns.first; lemmaPron != lemmaProns.second; ++lemmaPron) { if (pronLength == 1) { @@ -285,7 +289,8 @@ void MinimizedTreeBuilder::buildBody() { exit.pronunciation = lemmaPron->id(); addSuccessor(createRoot(*finalIt, phones[0], 0), ID_FROM_LABEL(createExit(exit))); } - } else { + } + else { u32 exit = addExit(currentState.second, phones[pronLength - 2], phones[pronLength - 1], -1, lemmaPron->id()); if (pronLength == 2) { initialPhoneSuffix_[RootKey(phones[0], phones[1], 1)].insert(ID_FROM_LABEL(exit)); @@ -558,7 +563,8 @@ u32 MinimizedTreeBuilder::createExit(PersistentStateTree::Exit exit) { ExitHash::iterator exitHashIt = exitHash_.find(exit); if (exitHashIt != exitHash_.end()) { return exitHashIt->second; - } else { + } + else { // Exit does not exist yet, add it network_.exits.push_back(exit); u32 exitIndex = network_.exits.size() - 1; @@ -602,7 +608,8 @@ void MinimizedTreeBuilder::hmmFromAllophone(HMMSequence& ret, std::swap(left, right); if (boundary == Am::Allophone::isFinalPhone) { boundary = Am::Allophone::isInitialPhone; - } else if (boundary == Am::Allophone::isInitialPhone) { + } + else if (boundary == Am::Allophone::isInitialPhone) { boundary = Am::Allophone::isFinalPhone; } } @@ -664,13 +671,15 @@ std::pair MinimizedT u8 boundary = 0; if (phoneIndex != 0) { left = phones[phoneIndex - 1]; - } else { + } + else { boundary |= Am::Allophone::isInitialPhone; } if (phoneIndex != phones.size() - 1) { right = phones[phoneIndex + 1]; - } else { + } + else { boundary |= Am::Allophone::isFinalPhone; } @@ -806,7 +815,8 @@ std::vector MinimizedTreeBuilder::minimize(bool fo for (StateId node = 1; node < network_.structure.stateCount(); ++node) { determinizeMap[node] = node; } - } else { + } + else { // Determinize states: Join successor states with the same state-desc while (!active.empty()) { StateId state = active.front(); @@ -911,7 +921,8 @@ std::vector MinimizedTreeBuilder::minimize(bool fo for (StateId state = 1; state < oldNodeCount; ++state) { if (minimizeMap[state] == state) { minimizeExits(state, minimizeExitsMap); - } else { + } + else { network_.structure.clearOutputEdges(state); } } @@ -936,7 +947,8 @@ std::vector MinimizedTreeBuilder::minimize(bool fo if (orig == network_.rootState || network_.coarticulatedRootStates.count(orig)) { network_.rootTransitDescriptions.insert(*it); } - } else { + } + else { StateId mapped = minimizeMap[it->first]; verify(mapped); verify(network_.coarticulatedRootStates.count(mapped)); @@ -956,7 +968,8 @@ std::vector MinimizedTreeBuilder::minimize(bool fo for (StateId state = 1; state < determinizeMap.size(); ++state) { if (determinizeMap[state]) { determinizeMap[state] = minimizeMap[determinizeMap[state]]; - } else { + } + else { determinizeMap[state] = minimizeMap[state]; } } @@ -970,7 +983,8 @@ std::vector MinimizedTreeBuilder::minimize(bool fo *it = cleanupResult.nodeMap[*it]; kept += 1; verify(*it); - } else { + } + else { lost += 1; *it = 0; } @@ -979,7 +993,7 @@ std::vector MinimizedTreeBuilder::minimize(bool fo log() << "transformed states: " << kept << " lost: " << lost; // verify( allowLost || !lost ); - // update necessary hashs w.r.t. minimizeMap + // update necessary hashes w.r.t. minimizeMap predecessors_.swap(oldPredecessors); updateHashFromMap(minimizeMap, minimizeExitsMap); @@ -1008,7 +1022,8 @@ void MinimizedTreeBuilder::minimizeState(StateId state, std::vector& mi if (minimizeMap[*target] == Core::Type::max) { // std::cout << "detected recursion while minimization on " << *target << std::endl; successors.insert(*target); - }else { + } + else { successors.insert(minimizeMap[*target]); } } @@ -1019,7 +1034,8 @@ void MinimizedTreeBuilder::minimizeState(StateId state, std::vector& mi std::unordered_map::iterator it = predecessors_.find(pred); if (it != predecessors_.end()) { minimizeMap[state] = it->second; - } else { + } + else { minimizeMap[state] = state; predecessors_.insert(std::make_pair(pred, state)); for (std::set::iterator succIt = successors.begin(); succIt != successors.end(); ++succIt) @@ -1057,7 +1073,8 @@ void MinimizedTreeBuilder::minimizeExits(StateId state, const std::vector& ExitMap::iterator i = range.first; if (++i == range.second) { network_.structure.addOutputToNode(state, range.first->second); - } else { + } + else { // Join std::set newRootSuccessors; std::set left, right; @@ -1098,15 +1115,16 @@ void MinimizedTreeBuilder::mapSet(std::set& set, const std::vector::iterator it = oldSet.begin(); it != oldSet.end(); ++it) { if (*it >= minimizeMap.size()) { set.insert(*it); - } else if (!minimizeMap[*it]) { + } + else if (!minimizeMap[*it]) { verify(!force); - } else { + } + else { set.insert(minimizeMap[*it]); } } } - // update hash structures according to minimizeMap (invalid ones are removed) // should be ok for any number of minimize iterations void MinimizedTreeBuilder::updateHashFromMap(const std::vector& map, const std::vector& exitMap) { @@ -1170,10 +1188,12 @@ inline void MinimizedTreeBuilder::mapSuccessors(const std::set& success u32 eIdx = LABEL_FROM_ID(*sIt); if (exitMap.empty() || eIdx >= exitMap.size()) { tmpSet.insert(*sIt); - } else { + } + else { tmpSet.insert(ID_FROM_LABEL(exitMap[eIdx])); } - } else if (*sIt < map.size() && map[*sIt]) { + } + else if (*sIt < map.size() && map[*sIt]) { tmpSet.insert(map[*sIt]); } } diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.hh b/src/Search/AdvancedTreeSearch/TreeBuilder.hh index c99f9db66..3ddac517f 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.hh +++ b/src/Search/AdvancedTreeSearch/TreeBuilder.hh @@ -61,7 +61,7 @@ public: static const Core::ParameterBool paramAllowCrossWordSkips; static const Core::ParameterBool paramRepeatSilence; static const Core::ParameterInt paramMinimizeIterations; - typedef u32 StateId; + typedef u32 StateId; MinimizedTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); virtual ~MinimizedTreeBuilder() = default; From 7cc2e2fd42126a51cc195ed7accd394c25d9ff33 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 12 Dec 2024 14:30:49 +0100 Subject: [PATCH 03/24] Fix "historyLenght" name --- src/Search/AdvancedTreeSearch/SearchSpace.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index 4cbfac68e..712eec397 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -3108,7 +3108,7 @@ void SearchSpace::doStateStatistics() { int len = 0; if (h.isValid()) - len = backOffLm->historyLenght(h); + len = backOffLm->historyLength(h); if (mt.lookahead.get()) statesInTreesWithLookAhead += mt.states.size(); @@ -3240,10 +3240,10 @@ void SearchSpace::recombineWordEndsInternal(bool shallCreateLattice) { ReducedContextRecombinationMap wordEndHypothesisMap; for (in = out = wordEndHypotheses.begin(); in != wordEndHypotheses.end(); ++in) { - auto key = std::make_pair(recombinationLm_->reducedHistory(in->recombinationHistory, - reducedContextWordRecombinationLimit_), - in->transitState); - ReducedContextRecombinationMap::iterator i = wordEndHypothesisMap.find(key); + auto key = std::make_pair(recombinationLm_->reducedHistory(in->recombinationHistory, reducedContextWordRecombinationLimit_), + in->transitState); + + ReducedContextRecombinationMap::iterator i = wordEndHypothesisMap.find(key); if (i != wordEndHypothesisMap.end()) { WordEndHypothesis& a(*in); WordEndHypothesis& b(*(i->second)); @@ -3752,7 +3752,7 @@ Instance* SearchSpace::getBackOffInstance(Instance* instance) { Lm::History useHistory = instance->lookaheadHistory; - int length = lm->historyLenght(useHistory); + int length = lm->historyLength(useHistory); if (length == 0) return 0; @@ -3760,7 +3760,7 @@ Instance* SearchSpace::getBackOffInstance(Instance* instance) { // Create a back-off network for history-length length-1 Lm::History reduced = lm->reducedHistory(useHistory, length - 1); - verify(lm->historyLenght(reduced) == length - 1); + verify(lm->historyLength(reduced) == length - 1); verify(reduced.isValid()); From 8bda8250182df3347c054c7e651b7ff7190e9c2d Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 7 Jan 2025 14:38:45 +0100 Subject: [PATCH 04/24] replace vector trees_ by one tree_ and refactor related code --- .../AdvancedTreeSearch/PersistentStateTree.cc | 29 +-- .../AdvancedTreeSearch/PersistentStateTree.hh | 3 - src/Search/AdvancedTreeSearch/TreeBuilder.cc | 5 +- .../AdvancedTreeSearch/TreeStructure.cc | 188 ++++++++---------- .../AdvancedTreeSearch/TreeStructure.hh | 63 ++---- src/Search/AdvancedTreeSearch/TreeWalker.hh | 1 - 6 files changed, 111 insertions(+), 178 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc index 7a0dae10a..1dd113b1c 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc @@ -38,7 +38,6 @@ namespace Search { struct ConvertTree { const Search::StateTree* tree; HMMStateNetwork& subtrees; - TreeIndex masterTreeIndex; StateId rootSubTree; StateId ciRootNode; std::map exits; // Maps exits to label-indices @todo Make this a hash_map @@ -50,12 +49,12 @@ struct ConvertTree { std::map> rootTransitDescriptions; ConvertTree(const Search::StateTree* _tree, HMMStateNetwork& _subtrees) - : tree(_tree), subtrees(_subtrees), masterTreeIndex(subtrees.allocateTree()), lostNodeIndices(0) { + : tree(_tree), subtrees(_subtrees), lostNodeIndices(0) { } void convert() { for (u32 a = 0; a < tree->states_.size(); ++a) { - StateId created = subtrees.allocateTreeNode(masterTreeIndex); + StateId created = subtrees.allocateTreeNode(); verify(a + 1 == created); } @@ -167,8 +166,7 @@ struct ConvertTree { }; PersistentStateTree::PersistentStateTree(Core::Configuration config, Core::Ref acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory) - : masterTree(0), - rootState(0), + : rootState(0), ciRootState(0), archive_(paramCacheArchive(Core::Configuration(config, "search-network"))), acousticModel_(acousticModel), @@ -229,7 +227,6 @@ void PersistentStateTree::build() { convert.convert(); exits = convert.exitVector; - masterTree = convert.masterTreeIndex; rootState = convert.rootSubTree; ciRootState = convert.ciRootNode; coarticulatedRootStates = convert.coarticulatedRootNodes; @@ -293,7 +290,8 @@ MappedArchiveWriter& operator<<(MappedArchiveWriter& writer, const std::maplog() << "Loading persistent network format version " << formatVersion; u32 dependenciesChecksum = 0; - - in >> masterTree >> dependenciesChecksum; + u32 dummyIndex; // only needed for backwards compatibility, has no further effect + in >> dummyIndex >> dependenciesChecksum; if (dependenciesChecksum != dependencies_.getChecksum()) { Core::Application::us()->log() << "dependencies of the network image don't equal the required dependencies with checksum " << dependenciesChecksum; @@ -370,15 +368,13 @@ void PersistentStateTree::removeOutputs() { rootsList.push_back(*it); } - HMMStateNetwork::CleanupResult cleanupResult = structure.cleanup(rootsList, masterTree, false, true); + HMMStateNetwork::CleanupResult cleanupResult = structure.cleanup(rootsList, false, true); for (std::unordered_map::const_iterator it = cleanupResult.nodeMap.begin(); it != cleanupResult.nodeMap.end(); ++it) { if (it->first != it->second) std::cout << "mapped " << it->first << " to " << it->second << std::endl; verify(it->first == it->second); } - for (std::unordered_map::const_iterator it = cleanupResult.treeMap.begin(); it != cleanupResult.treeMap.end(); ++it) - verify(it->first == it->second); } HMMStateNetwork::CleanupResult PersistentStateTree::cleanup(bool cleanupExits) { @@ -433,7 +429,7 @@ HMMStateNetwork::CleanupResult PersistentStateTree::cleanup(bool cleanupExits) { ///@todo Go through the search tree, and collect the required coarticulated root nodes - HMMStateNetwork::CleanupResult cleanupResult = structure.cleanup(rootsList, masterTree); + HMMStateNetwork::CleanupResult cleanupResult = structure.cleanup(rootsList); Core::HashMap::const_iterator targetNodeIt; if (rootState) { @@ -482,13 +478,6 @@ HMMStateNetwork::CleanupResult PersistentStateTree::cleanup(bool cleanupExits) { exitIt->transitState = (*targetNodeIt).second; } - Core::HashMap::const_iterator targetTreeIt; - - targetTreeIt = cleanupResult.treeMap.find(masterTree); - verify(targetTreeIt != cleanupResult.treeMap.end()); - - masterTree = (*targetTreeIt).second; - // pushedWordEndNodes = cleanupResult.mapNodes( pushedWordEndNodes ); uncoarticulatedWordEndStates = cleanupResult.mapNodes(uncoarticulatedWordEndStates); // uncoarticulatedPushedWordEndNodes = cleanupResult.mapNodes( uncoarticulatedPushedWordEndNodes ); diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh index e83dd1073..2e1f5e47e 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh @@ -94,9 +94,6 @@ public: /** ----- state tree data: ------ */ - // Identity of the main search network - TreeIndex masterTree; - // Root node StateId rootState; diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.cc b/src/Search/AdvancedTreeSearch/TreeBuilder.cc index 2a941e32d..ba89b2036 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.cc +++ b/src/Search/AdvancedTreeSearch/TreeBuilder.cc @@ -115,7 +115,6 @@ MinimizedTreeBuilder::MinimizedTreeBuilder(Core::Configuration config, const Bli if (initialize) { verify(!network_.rootState); - network_.masterTree = network_.structure.allocateTree(); // Non-coarticulated root state network_.ciRootState = network_.rootState = createRoot(Bliss::Phoneme::term, Bliss::Phoneme::term, 0); @@ -554,7 +553,7 @@ AbstractTreeBuilder::StateId MinimizedTreeBuilder::createRoot(Bliss::Phoneme::Id } AbstractTreeBuilder::StateId MinimizedTreeBuilder::createState(StateTree::StateDesc desc) { - StateId ret = network_.structure.allocateTreeNode(network_.masterTree); + StateId ret = network_.structure.allocateTreeNode(); network_.structure.state(ret).stateDesc = desc; return ret; } @@ -835,7 +834,7 @@ std::vector MinimizedTreeBuilder::minimize(bool fo SuccessorHash::iterator it = items.first; if (++it != items.second) { - StateId newNode = network_.structure.allocateTreeNode(network_.masterTree); + StateId newNode = network_.structure.allocateTreeNode(); if (newNode >= determinizeMap.size()) { determinizeMap.resize(newNode + 1, 0); } diff --git a/src/Search/AdvancedTreeSearch/TreeStructure.cc b/src/Search/AdvancedTreeSearch/TreeStructure.cc index e3043d342..f23ed9bdd 100644 --- a/src/Search/AdvancedTreeSearch/TreeStructure.cc +++ b/src/Search/AdvancedTreeSearch/TreeStructure.cc @@ -19,7 +19,7 @@ namespace Search { HMMStateNetwork::HMMStateNetwork() : subTreeManager_(subTreeListBatches_, states_), edgeTargetManager_(edgeTargetBatches_, states_) { //The zero index is reserved as "invalid", so push one dummy item into all arrays - trees_.push_back(Tree()); + tree_ = Tree(); states_.push_back(HMMState()); subTreeListBatches_.push_back(0); edgeTargetBatches_.push_back(0); @@ -27,14 +27,8 @@ HMMStateNetwork::HMMStateNetwork() verify(sizeof(HMMState) % sizeof(u32) == 0); } -TreeIndex HMMStateNetwork::allocateTree() { - trees_.push_back(Tree()); - return TreeIndex(trees_.size() - 1); -} - -StateId HMMStateNetwork::allocateTreeNode(TreeIndex parent) { - verify(parent != EmptyTreeIndex); - StateId ret = subTreeManager_.appendOne(tree(parent).nodes, HMMState()); +StateId HMMStateNetwork::allocateTreeNode() { + StateId ret = subTreeManager_.appendOne(tree_.nodes, HMMState()); return ret; } @@ -113,17 +107,14 @@ void HMMStateNetwork::addOutputToEdge(SuccessorBatchId& list, u32 outputIndex) { addTargetToEdge(list, ID_FROM_LABEL(outputIndex)); } -u32 HMMStateNetwork::treeCount() const { - return trees_.size(); -} - u32 HMMStateNetwork::stateCount() const { return states_.size(); } bool HMMStateNetwork::write(Core::MappedArchiveWriter writer) { u32 version = DiskFormatVersionV2; - writer << version << subTreeListBatches_ << states_ << edgeTargetLists_ << edgeTargetBatches_ << trees_; + std::vector trees = {Tree(), tree_}; // need to write a vector of trees because HMMStateNetwork::read() expectes one + writer << version << subTreeListBatches_ << states_ << edgeTargetLists_ << edgeTargetBatches_ << trees; return writer.good(); } @@ -133,7 +124,9 @@ bool HMMStateNetwork::read(Core::MappedArchiveReader reader) { if (version == DiskFormatVersionV1) { std::vector states; - reader >> subTreeListBatches_ >> states >> edgeTargetLists_ >> edgeTargetBatches_ >> trees_; + std::vector trees; // need to read into a vector for backward compatibility + reader >> subTreeListBatches_ >> states >> edgeTargetLists_ >> edgeTargetBatches_ >> trees; + tree_ = trees[1]; if (!reader.good()) { return false; @@ -149,7 +142,9 @@ bool HMMStateNetwork::read(Core::MappedArchiveReader reader) { [](HMMStateV1 s){ return s.toHMMState(); }); } else if (version == DiskFormatVersionV2) { - reader >> subTreeListBatches_ >> states_ >> edgeTargetLists_ >> edgeTargetBatches_ >> trees_; + std::vector trees; // need to read into a vector for backward compatibility + reader >> subTreeListBatches_ >> states_ >> edgeTargetLists_ >> edgeTargetBatches_ >> trees; + tree_ = trees[1]; } else { return false; @@ -173,7 +168,7 @@ u32 HMMStateNetwork::countReachableEnds(std::vector& counts, StateId node) return counts[node]; } -HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::list startNodes, Search::TreeIndex masterTree, bool clearDeadEnds, bool onlyBatches) { +HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::list startNodes, bool clearDeadEnds, bool onlyBatches) { if (clearDeadEnds && !onlyBatches) { u32 deadEndNodes = 0; std::vector reachableEnds(states_.size(), Core::Type::max); @@ -212,23 +207,20 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::listlog() << "calculating reachable nodes and trees"; + ///Mark all reachable nodes + Core::Application::us()->log() << "calculating reachable nodes"; for (std::list::const_iterator it = startNodes.begin(); it != startNodes.end(); ++it) counter.visit(*it, 1); } { - std::vector newTrees; + Tree newTree; std::vector newSubTreeListBatches; std::vector newNodes; std::vector newEdgeTargetLists; @@ -236,7 +228,6 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::list newSubTreeManager(newSubTreeListBatches, newNodes); - newTrees.push_back(Tree()); newEdgeTargetLists.push_back(0); newSubTreeListBatches.push_back(0); newNodes.push_back(HMMState()); @@ -244,106 +235,91 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::list orderBehind(states_.size(), 0); std::vector follow(states_.size(), 0); - std::vector> orderedPerTree; - orderedPerTree.push_back(std::vector()); - - for (u32 tree = 1; tree < trees_.size(); ++tree) { - /// @todo Build a topology and order the nodes in a stable way based on that - //Build the order so that the second-order batches are continuous - Tools::BatchManager::Iterator it = subTreeManager_.getIterator(trees_[tree].nodes); - for (; it; ++it) { - StateId node = *it; - if (counter.visited.count(node) == 0) - continue; - // 2nd order - u32 previousSkipTarget = 0; + /// @todo Build a topology and order the nodes in a stable way based on that + //Build the order so that the second-order batches are continuous + Tools::BatchManager::Iterator it = subTreeManager_.getIterator(tree_.nodes); + for (; it; ++it) { + StateId node = *it; + if (counter.visited.count(node) == 0) + continue; + // 2nd order + u32 previousSkipTarget = 0; - // first order - u32 previousTarget = 0; + // first order + u32 previousTarget = 0; - for (HMMStateNetwork::SuccessorIterator targetIt = successors(node); targetIt; ++targetIt) { - if (targetIt.isLabel()) - break; + for (HMMStateNetwork::SuccessorIterator targetIt = successors(node); targetIt; ++targetIt) { + if (targetIt.isLabel()) + break; - StateId target = *targetIt; + StateId target = *targetIt; - if (!orderBehind[target]) - orderBehind[target] = previousTarget; - if (!follow[previousTarget]) - follow[previousTarget] = target; - previousTarget = target; - verify(target < states_.size()); + if (!orderBehind[target]) + orderBehind[target] = previousTarget; + if (!follow[previousTarget]) + follow[previousTarget] = target; + previousTarget = target; + verify(target < states_.size()); - for (HMMStateNetwork::SuccessorIterator skipTargetIt = successors(target); skipTargetIt; ++skipTargetIt) { - if (skipTargetIt.isLabel()) - break; - StateId skipTarget = *skipTargetIt; - orderBehind[skipTarget] = previousSkipTarget; - follow[previousSkipTarget] = skipTarget; - previousSkipTarget = skipTarget; - } + for (HMMStateNetwork::SuccessorIterator skipTargetIt = successors(target); skipTargetIt; ++skipTargetIt) { + if (skipTargetIt.isLabel()) + break; + StateId skipTarget = *skipTargetIt; + orderBehind[skipTarget] = previousSkipTarget; + follow[previousSkipTarget] = skipTarget; + previousSkipTarget = skipTarget; } } - std::vector ordered; - std::unordered_set had; + } + std::vector ordered; + std::unordered_set had; + + { + Tools::BatchManager::Iterator it = subTreeManager_.getIterator(tree_.nodes); + for (; it; ++it) { + StateId current = *it; - { - Tools::BatchManager::Iterator it = subTreeManager_.getIterator(trees_[tree].nodes); - for (; it; ++it) { - StateId current = *it; + if (counter.visited.count(current) == 0) + continue; - if (counter.visited.count(current) == 0) - continue; + if (onlyBatches) { + ordered.push_back(current); + } + else { + while (current) { + if (had.count(current)) + break; - if (onlyBatches) { ordered.push_back(current); - } - else { - while (current) { - if (had.count(current)) - break; - - ordered.push_back(current); - had.insert(current); - current = follow[current]; - } + had.insert(current); + current = follow[current]; } } } - - orderedPerTree.push_back(ordered); } - for (u32 tree = 1; tree < trees_.size(); ++tree) { - if (counter.visitedTrees.find(tree) != counter.visitedTrees.end()) { - //Build the order so that the second-order batches are continuous - const std::vector& ordered(orderedPerTree[tree]); - - //Transfer network into new list - ret.treeMap.insert(std::make_pair(tree, newTrees.size())); - newTrees.push_back(trees_[tree]); - newTrees.back().nodes = InvalidBatchId; - - //Transfer nodes into new batches - - for (u32 idx = 0; idx < ordered.size(); ++idx) { - StateId node = ordered[idx]; - if (counter.visited.find(node) != counter.visited.end()) { - verify(newNodes.size() > 0); - StateId newNode = newSubTreeManager.appendOne(newTrees.back().nodes, states_[node]); - ret.nodeMap.insert(std::make_pair(node, newNode)); - } + //Build the order so that the second-order batches are continuous + { + //Transfer network into new list + newTree = tree_; + newTree.nodes = InvalidBatchId; + + //Transfer nodes into new batches + for (u32 idx = 0; idx < ordered.size(); ++idx) { + StateId node = ordered[idx]; + if (counter.visited.find(node) != counter.visited.end()) { + verify(newNodes.size() > 0); + StateId newNode = newSubTreeManager.appendOne(newTree.nodes, states_[node]); + ret.nodeMap.insert(std::make_pair(node, newNode)); } - //No empty trees - verify(newTrees.back().nodes != InvalidBatchId); - } - else { - //This network is removed } + //No empty trees + verify(newTree.nodes != InvalidBatchId); } + Core::Application::us()->log() << "count of new nodes: " << newNodes.size(); verify(newNodes.size()); - trees_.swap(newTrees); + tree_ = newTree; states_.swap(newNodes); subTreeListBatches_.swap(newSubTreeListBatches); edgeTargetLists_.swap(newEdgeTargetLists); @@ -358,7 +334,6 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::list::iterator it; //Map the edge-targets of single batches SuccessorBatchId newBatch = InvalidBatchId; @@ -410,7 +385,7 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::listlog() << "re-calculating reachable nodes and trees"; + Core::Application::us()->log() << "re-calculating reachable nodes"; for (std::list::const_iterator it = startNodes.begin(); it != startNodes.end(); ++it) { std::unordered_map::const_iterator mapIt = ret.nodeMap.find(*it); verify(mapIt != ret.nodeMap.end()); @@ -418,13 +393,10 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::listlog() << "previous reachable nodes: " << counter.visited.size() << " new reachable nodes: " << counter2.visited.size() << " new total nodes: " << states_.size(); - Core::Application::us()->log() << "previous trees: " << counter.visitedTrees.size() << " new trees: " << counter2.visitedTrees.size(); Core::Application::us()->log() << "previous exits: " << counter.visitedFinalOutputs << " new exits: " << counter2.visitedFinalOutputs; verify(counter2.visited.size() == counter.visited.size()); - verify(counter2.visitedTrees.size() == counter.visitedTrees.size()); } return ret; diff --git a/src/Search/AdvancedTreeSearch/TreeStructure.hh b/src/Search/AdvancedTreeSearch/TreeStructure.hh index f470dde6e..d4d697dc7 100644 --- a/src/Search/AdvancedTreeSearch/TreeStructure.hh +++ b/src/Search/AdvancedTreeSearch/TreeStructure.hh @@ -42,9 +42,6 @@ enum { LabelMask = 1 << 27 }; -///Global index of a tree or subtree -typedef u32 TreeIndex; - ///Index of a state or label (see IS_LABEL, ID_FROM_LABEL, and LABEL_FROM_ID) typedef u32 StateId; @@ -147,12 +144,6 @@ struct Tree { class HMMStateNetwork { public: - enum { - //Index of the empty network - //The empty network has no node, and exactly one label that is to be activated directly - EmptyTreeIndex = 0 - }; - enum { DiskFormatVersionV1 = 1, DiskFormatVersionV2 = 2, @@ -237,12 +228,6 @@ public: ///****** STATE MANAGEMENT ****************************************************************************************** - ///Do not keep pointers/references to the returned tree, the address may change - inline Tree& tree(TreeIndex index) { - verify_(index > 0 && index < trees_.size()); - return trees_[index]; - } - ///Do not keep pointers to the returned state, the address may change when the network is manipulated inline_ HMMState& state(StateId state) { verify_(state > 0 && state < (int)states_.size()); @@ -255,43 +240,35 @@ public: return states_[state]; } - ///Allocates a new tree - TreeIndex allocateTree(); - - ///Allocates a new subtree, and adds it into the subtree list of the given parent. - ///As many subtrees for the same parent should be allocated in a row as possible, so batch-merging can happen - ///Returns a fully valid subtree(With initialized edge-list) - StateId allocateTreeNode(TreeIndex parent); + ///Allocates a new subtree, and adds it into the subtree list. + ///Returns a fully valid subtree (with initialized edge-list) + StateId allocateTreeNode(); ///Returns the count of nodes contained by the tree - inline u32 getNodeCount(TreeIndex parent); + inline u32 getNodeCount(); - ///Returns the @p number th node contained in the given parent tree - inline StateId getTreeNode(TreeIndex parent, u32 number); + ///Returns the @p number th node contained in the tree + inline StateId getTreeNode(u32 number); - ///Returns the number of nodes contained by the given parent tree - inline u32 getNodeNumber(TreeIndex parent, StateId node); + ///Returns the number of the @p node in the tree + inline u32 getNodeNumber(StateId node); ///Much faster version of getNodeNumber, that only works when the structure has been cleaned - inline u32 getNodeNumberCleanStructure(TreeIndex parent, StateId node) { - return node - subTreeListBatches_[tree(parent).nodes]; + inline u32 getNodeNumberCleanStructure(StateId node) { + return node - subTreeListBatches_[tree_.nodes]; } - ///Returns the total number of trees, which is the maximum upper bound for a valid TreeIndex - u32 treeCount() const; - ///Returns the total number of nodes, which is the maximum upper bound for a valid TreeNodeIndex u32 stateCount() const; struct CleanupResult { Core::HashMap nodeMap; - Core::HashMap treeMap; std::set mapNodes(const std::set& nodes) const; }; - ///Completely removes all trees and nodes that are not reachable from the given start-nodes, compressing the structure - CleanupResult cleanup(std::list startNodes, Search::TreeIndex masterTree, bool clearDeadEnds = true, bool onlyBatches = false); + ///Completely removes all nodes that are not reachable from the given start-nodes, compressing the structure + CleanupResult cleanup(std::list startNodes, bool clearDeadEnds = true, bool onlyBatches = false); ///****** EDGE MANAGEMENT ******************************************************************************************* @@ -317,7 +294,7 @@ public: void clearOutputEdges(StateId node); u32 getChecksum() const { - return states_.size() + edgeTargetBatches_.size() + edgeTargetLists_.size() + trees_.size() + subTreeListBatches_.size(); + return states_.size() + edgeTargetBatches_.size() + edgeTargetLists_.size() + subTreeListBatches_.size() + 2; // + 2 is needed for backwards compatibility } ///The change is applied when apply() is called @@ -455,7 +432,7 @@ private: std::vector edgeTargetBatches_; Tools::BatchManager edgeTargetManager_; - std::vector trees_; + Tree tree_; }; inline HMMStateNetwork::SuccessorIterator HMMStateNetwork::batchSuccessors(Search::SuccessorBatchId list) const { @@ -504,18 +481,18 @@ std::pair HMMStateNetwork::batchNodeRange(SuccessorBatchId batch) cons return std::make_pair((int)edgeTargetBatches_[batch], (int)edgeTargetBatches_[batch + 2]); } -u32 HMMStateNetwork::getNodeCount(Search::TreeIndex parent) { - return subTreeManager_.getIterator(tree(parent).nodes).countToEnd(); +u32 HMMStateNetwork::getNodeCount() { + return subTreeManager_.getIterator(tree_.nodes).countToEnd(); } -StateId HMMStateNetwork::getTreeNode(Search::TreeIndex parent, u32 number) { - Tools::BatchManager::Iterator it = subTreeManager_.getIterator(tree(parent).nodes); +StateId HMMStateNetwork::getTreeNode(u32 number) { + Tools::BatchManager::Iterator it = subTreeManager_.getIterator(tree_.nodes); it += number; return *it; } -u32 HMMStateNetwork::getNodeNumber(TreeIndex parent, StateId node) { - Tools::BatchManager::Iterator it = subTreeManager_.getIterator(tree(parent).nodes); +u32 HMMStateNetwork::getNodeNumber(StateId node) { + Tools::BatchManager::Iterator it = subTreeManager_.getIterator(tree_.nodes); return it.countUntil(node); } } // namespace Search diff --git a/src/Search/AdvancedTreeSearch/TreeWalker.hh b/src/Search/AdvancedTreeSearch/TreeWalker.hh index ad742b0b9..1a5f62450 100644 --- a/src/Search/AdvancedTreeSearch/TreeWalker.hh +++ b/src/Search/AdvancedTreeSearch/TreeWalker.hh @@ -84,7 +84,6 @@ struct CountSizeTreeWalkerBackend { } std::unordered_set visited; - std::unordered_set visitedTrees; u32 totalVisited; bool stopAtVisited; u32 visitedFinalOutputs; From ac690d0b4eaab68590b21b47896468bacf5f20b8 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Fri, 10 Jan 2025 10:39:33 +0100 Subject: [PATCH 05/24] Apply formatting --- .../AdvancedTreeSearch/PersistentStateTree.cc | 6 +- .../AdvancedTreeSearch/TreeStructure.cc | 51 ++++++++-------- .../AdvancedTreeSearch/TreeStructure.hh | 60 +++++++++---------- src/Search/AdvancedTreeSearch/TreeWalker.hh | 10 ++-- 4 files changed, 63 insertions(+), 64 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc index 1dd113b1c..0d5aa4a94 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc @@ -290,7 +290,7 @@ MappedArchiveWriter& operator<<(MappedArchiveWriter& writer, const std::maplog() << "Loading persistent network format version " << formatVersion; u32 dependenciesChecksum = 0; - u32 dummyIndex; // only needed for backwards compatibility, has no further effect - in >> dummyIndex >> dependenciesChecksum; + u32 dummyIndex; // only needed for backwards compatibility, has no further effect + in >> dummyIndex >> dependenciesChecksum; if (dependenciesChecksum != dependencies_.getChecksum()) { Core::Application::us()->log() << "dependencies of the network image don't equal the required dependencies with checksum " << dependenciesChecksum; diff --git a/src/Search/AdvancedTreeSearch/TreeStructure.cc b/src/Search/AdvancedTreeSearch/TreeStructure.cc index f23ed9bdd..50542ca16 100644 --- a/src/Search/AdvancedTreeSearch/TreeStructure.cc +++ b/src/Search/AdvancedTreeSearch/TreeStructure.cc @@ -18,7 +18,7 @@ namespace Search { HMMStateNetwork::HMMStateNetwork() : subTreeManager_(subTreeListBatches_, states_), edgeTargetManager_(edgeTargetBatches_, states_) { - //The zero index is reserved as "invalid", so push one dummy item into all arrays + // The zero index is reserved as "invalid", so push one dummy item into all arrays tree_ = Tree(); states_.push_back(HMMState()); subTreeListBatches_.push_back(0); @@ -97,7 +97,7 @@ void HMMStateNetwork::addNodeToEdge(SuccessorBatchId& list, StateId target) { void HMMStateNetwork::addTargetToEdge(SuccessorBatchId& batch, u32 target) { verify(target >= 0); - //Special case for only one item + // Special case for only one item edgeTargetManager_.appendToBatch(batch, target, target + 1); verify(batch != InvalidBatchId); @@ -112,8 +112,8 @@ u32 HMMStateNetwork::stateCount() const { } bool HMMStateNetwork::write(Core::MappedArchiveWriter writer) { - u32 version = DiskFormatVersionV2; - std::vector trees = {Tree(), tree_}; // need to write a vector of trees because HMMStateNetwork::read() expectes one + u32 version = DiskFormatVersionV2; + std::vector trees = {Tree(), tree_}; // need to write a vector of trees because HMMStateNetwork::read() expectes one writer << version << subTreeListBatches_ << states_ << edgeTargetLists_ << edgeTargetBatches_ << trees; return writer.good(); } @@ -124,7 +124,7 @@ bool HMMStateNetwork::read(Core::MappedArchiveReader reader) { if (version == DiskFormatVersionV1) { std::vector states; - std::vector trees; // need to read into a vector for backward compatibility + std::vector trees; // need to read into a vector for backward compatibility reader >> subTreeListBatches_ >> states >> edgeTargetLists_ >> edgeTargetBatches_ >> trees; tree_ = trees[1]; @@ -139,10 +139,10 @@ bool HMMStateNetwork::read(Core::MappedArchiveReader reader) { states.begin(), states.end(), std::back_inserter(states_), - [](HMMStateV1 s){ return s.toHMMState(); }); + [](HMMStateV1 s) { return s.toHMMState(); }); } else if (version == DiskFormatVersionV2) { - std::vector trees; // need to read into a vector for backward compatibility + std::vector trees; // need to read into a vector for backward compatibility reader >> subTreeListBatches_ >> states_ >> edgeTargetLists_ >> edgeTargetBatches_ >> trees; tree_ = trees[1]; } @@ -196,7 +196,7 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::listlog() << "cleared " << cleared << " dead-end nodes"; @@ -213,7 +213,7 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::listlog() << "calculating reachable nodes"; for (std::list::const_iterator it = startNodes.begin(); it != startNodes.end(); ++it) counter.visit(*it, 1); @@ -225,7 +225,7 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::list newNodes; std::vector newEdgeTargetLists; - //Must be created before adding the initial items, since it clears the lists + // Must be created before adding the initial items, since it clears the lists Tools::BatchManager newSubTreeManager(newSubTreeListBatches, newNodes); newEdgeTargetLists.push_back(0); @@ -236,7 +236,7 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::list follow(states_.size(), 0); /// @todo Build a topology and order the nodes in a stable way based on that - //Build the order so that the second-order batches are continuous + // Build the order so that the second-order batches are continuous Tools::BatchManager::Iterator it = subTreeManager_.getIterator(tree_.nodes); for (; it; ++it) { StateId node = *it; @@ -298,13 +298,13 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::list newEdgeTargetBatches; Tools::BatchManager newEdgeTargetManager(newEdgeTargetBatches, states_); - //Must be created before adding the initial items, since it clears the lists + // Must be created before adding the initial items, since it clears the lists newEdgeTargetBatches.push_back(0); - //Map the direct node members + // Map the direct node members for (u32 node = 1; node < states_.size(); ++node) { - - //Map the edge-targets of single batches + // Map the edge-targets of single batches SuccessorBatchId newBatch = InvalidBatchId; SuccessorBatchId oldBatch = states_[node].successors; for (Tools::BatchManager::Iterator it = edgeTargetManager_.getIterator(oldBatch); it; ++it) { if (IS_LABEL(*it)) { - //It's an label encoded as negative number + // It's an label encoded as negative number newEdgeTargetManager.appendToBatch(newBatch, *it, *it + 1); } else { - //It's a node + // It's a node verify(counter.visited.find(*it) != counter.visited.end()); std::unordered_map::const_iterator nodeMapIt = ret.nodeMap.find(*it); verify(nodeMapIt != ret.nodeMap.end()); @@ -363,11 +362,11 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::list::Iterator it = edgeTargetManager_.getIterator(oldBatch); it; ++it) { if (IS_LABEL(*it)) { - //It's an label encoded as negative number + // It's an label encoded as negative number newEdgeTargetManager.appendToBatch(edgeTargetLists_[batchNum], *it, *it + 1); } else { - //It's a node + // It's a node verify(counter.visited.find(*it) != counter.visited.end()); std::unordered_map::const_iterator nodeMapIt = ret.nodeMap.find(*it); verify(nodeMapIt != ret.nodeMap.end()); @@ -384,7 +383,7 @@ HMMStateNetwork::CleanupResult HMMStateNetwork::cleanup(std::listlog() << "re-calculating reachable nodes"; for (std::list::const_iterator it = startNodes.begin(); it != startNodes.end(); ++it) { std::unordered_map::const_iterator mapIt = ret.nodeMap.find(*it); diff --git a/src/Search/AdvancedTreeSearch/TreeStructure.hh b/src/Search/AdvancedTreeSearch/TreeStructure.hh index d4d697dc7..bf030e20f 100644 --- a/src/Search/AdvancedTreeSearch/TreeStructure.hh +++ b/src/Search/AdvancedTreeSearch/TreeStructure.hh @@ -42,7 +42,7 @@ enum { LabelMask = 1 << 27 }; -///Index of a state or label (see IS_LABEL, ID_FROM_LABEL, and LABEL_FROM_ID) +/// Index of a state or label (see IS_LABEL, ID_FROM_LABEL, and LABEL_FROM_ID) typedef u32 StateId; ///@todo Maybe this should be zero! @@ -82,10 +82,10 @@ struct HMMState { return *this; } - ///This must be initialized explicitly after creating the state + /// This must be initialized explicitly after creating the state StateTree::StateDesc stateDesc; - ///Batch of successor states, managed through a batch-manager in TreeStructure + /// Batch of successor states, managed through a batch-manager in TreeStructure SuccessorBatchId successors; /// Returns true if the label-edges batch represents only one single successor, which can be handled more efficiently @@ -122,9 +122,9 @@ struct HMMStateV1 { StateTree::StateDesc desc; HMMState result; - result.stateDesc.acousticModel = this->stateDesc.acousticModel; + result.stateDesc.acousticModel = this->stateDesc.acousticModel; result.stateDesc.transitionModelIndex = this->stateDesc.transitionModelIndex; - result.successors = this->successors; + result.successors = this->successors; if (result.stateDesc.acousticModel == StateDescV1::invalidAcousticModel) { result.stateDesc.acousticModel = StateTree::invalidAcousticModel; @@ -138,7 +138,7 @@ struct Tree { Tree() : nodes(InvalidBatchId) { } - ///All nodes contained by this tree. Managed as a batch by TreeStructure + /// All nodes contained by this tree. Managed as a batch by TreeStructure SubTreeListId nodes; }; @@ -228,54 +228,54 @@ public: ///****** STATE MANAGEMENT ****************************************************************************************** - ///Do not keep pointers to the returned state, the address may change when the network is manipulated + /// Do not keep pointers to the returned state, the address may change when the network is manipulated inline_ HMMState& state(StateId state) { verify_(state > 0 && state < (int)states_.size()); return states_[state]; } - ///Do not keep pointers to the returned state, the address may change when the network is manipulated + /// Do not keep pointers to the returned state, the address may change when the network is manipulated inline const HMMState& state(StateId state) const { verify_(state > 0 && state < (int)states_.size()); return states_[state]; } - ///Allocates a new subtree, and adds it into the subtree list. - ///Returns a fully valid subtree (with initialized edge-list) + /// Allocates a new subtree, and adds it into the subtree list. + /// Returns a fully valid subtree (with initialized edge-list) StateId allocateTreeNode(); - ///Returns the count of nodes contained by the tree + /// Returns the count of nodes contained by the tree inline u32 getNodeCount(); - ///Returns the @p number th node contained in the tree + /// Returns the @p number th node contained in the tree inline StateId getTreeNode(u32 number); - ///Returns the number of the @p node in the tree + /// Returns the number of the @p node in the tree inline u32 getNodeNumber(StateId node); - ///Much faster version of getNodeNumber, that only works when the structure has been cleaned + /// Much faster version of getNodeNumber, that only works when the structure has been cleaned inline u32 getNodeNumberCleanStructure(StateId node) { return node - subTreeListBatches_[tree_.nodes]; } - ///Returns the total number of nodes, which is the maximum upper bound for a valid TreeNodeIndex + /// Returns the total number of nodes, which is the maximum upper bound for a valid TreeNodeIndex u32 stateCount() const; struct CleanupResult { - Core::HashMap nodeMap; + Core::HashMap nodeMap; std::set mapNodes(const std::set& nodes) const; }; - ///Completely removes all nodes that are not reachable from the given start-nodes, compressing the structure + /// Completely removes all nodes that are not reachable from the given start-nodes, compressing the structure CleanupResult cleanup(std::list startNodes, bool clearDeadEnds = true, bool onlyBatches = false); ///****** EDGE MANAGEMENT ******************************************************************************************* - ///Adds the given target to the list of targets for the given edge. The referenced id will be changed. + /// Adds the given target to the list of targets for the given edge. The referenced id will be changed. void addNodeToEdge(SuccessorBatchId& list, StateId target); - ///Adds the given target to the list of targets for the given edge. The referenced id will be changed. + /// Adds the given target to the list of targets for the given edge. The referenced id will be changed. void addOutputToEdge(SuccessorBatchId& list, u32 outputIndex); void addTargetToNode(StateId node, StateId target) { @@ -290,14 +290,14 @@ public: void removeOutputFromNode(StateId node, u32 outputIndex); - ///Clears all connections behind the given node. The memory will be lost unless a cleanup is done afterwards. + /// Clears all connections behind the given node. The memory will be lost unless a cleanup is done afterwards. void clearOutputEdges(StateId node); u32 getChecksum() const { - return states_.size() + edgeTargetBatches_.size() + edgeTargetLists_.size() + subTreeListBatches_.size() + 2; // + 2 is needed for backwards compatibility + return states_.size() + edgeTargetBatches_.size() + edgeTargetLists_.size() + subTreeListBatches_.size() + 2; // + 2 is needed for backwards compatibility } - ///The change is applied when apply() is called + /// The change is applied when apply() is called class ChangePlan { public: void addSuccessor(StateId state) { @@ -396,15 +396,15 @@ public: return ret; } - ///Returns -1, -1 if this simple version does not work. Then "edgeTargets" has to be used. + /// Returns -1, -1 if this simple version does not work. Then "edgeTargets" has to be used. template inline std::pair batchSuccessorsSimple(SuccessorBatchId list) const; - ///Does not work with single-batches! Those must be checked before. + /// Does not work with single-batches! Those must be checked before. inline std::pair batchSuccessorsSimpleIgnoreLabels(SuccessorBatchId list) const; - ///Reads out the node-range associated to the given batch. Does not verify whether - ///the batch is a single-batch or has successor-batches, this has to be checked beforehand. + /// Reads out the node-range associated to the given batch. Does not verify whether + /// the batch is a single-batch or has successor-batches, this has to be checked beforehand. inline std::pair batchNodeRange(SuccessorBatchId batch) const; ///********************************************************************************************************************* @@ -420,15 +420,15 @@ public: private: void addTargetToEdge(SuccessorBatchId& batch, u32 target); u32 countReachableEnds(std::vector& counts, StateId node) const; - //This manager manages lists of sub-trees, one subtree-list for each network + // This manager manages lists of sub-trees, one subtree-list for each network std::vector subTreeListBatches_; std::vector states_; Tools::BatchManager subTreeManager_; - //Contains one SuccessorBatchId for each label of a subtree, as a linear list + // Contains one SuccessorBatchId for each label of a subtree, as a linear list std::vector edgeTargetLists_; - //This manager groups together edge successors, for edges coming from a common source(usually an label of a subtree) + // This manager groups together edge successors, for edges coming from a common source(usually an label of a subtree) std::vector edgeTargetBatches_; Tools::BatchManager edgeTargetManager_; @@ -470,7 +470,7 @@ inline std::pair HMMStateNetwork::batchSuccessorsSimple(SuccessorBatch const Search::StateId start = edgeTargetBatches_[batch]; if (not considerOutputs && IS_LABEL(start)) return std::pair(0, 0); - //Everything ok, this is a simple continous batch without a follower-batch + // Everything ok, this is a simple continous batch without a follower-batch return std::pair(start, edgeTargetBatches_[batch + 2]); } diff --git a/src/Search/AdvancedTreeSearch/TreeWalker.hh b/src/Search/AdvancedTreeSearch/TreeWalker.hh index 1a5f62450..a373d7c67 100644 --- a/src/Search/AdvancedTreeSearch/TreeWalker.hh +++ b/src/Search/AdvancedTreeSearch/TreeWalker.hh @@ -31,7 +31,7 @@ public: : tree(_tree) { } - ///Visits the nodes and all its followers, in the correct order + /// Visits the nodes and all its followers, in the correct order void visit(StateId node, Token token) { bool hadToken = token; @@ -83,10 +83,10 @@ struct CountSizeTreeWalkerBackend { ++visitedFinalOutputs; } - std::unordered_set visited; - u32 totalVisited; - bool stopAtVisited; - u32 visitedFinalOutputs; + std::unordered_set visited; + u32 totalVisited; + bool stopAtVisited; + u32 visitedFinalOutputs; }; typedef Search::SubTreeWalker CountSizeTreeWalker; From b947fe6d6778d6c8786fd89ba7cf69da14c42beb Mon Sep 17 00:00:00 2001 From: Larissa Date: Wed, 22 Jan 2025 12:29:50 +0100 Subject: [PATCH 06/24] Update comments and tree_ initialization --- .../AdvancedTreeSearch/PersistentStateTree.cc | 11 +++++++++-- src/Search/AdvancedTreeSearch/TreeStructure.cc | 15 ++++++++++----- src/Search/AdvancedTreeSearch/TreeStructure.hh | 10 +++++++--- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc index 0d5aa4a94..ed88b82f4 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc @@ -290,7 +290,10 @@ MappedArchiveWriter& operator<<(MappedArchiveWriter& writer, const std::maplog() << "Loading persistent network format version " << formatVersion; u32 dependenciesChecksum = 0; - u32 dummyIndex; // only needed for backwards compatibility, has no further effect + + // In the previous version, a master tree was used and the index was saved in the cache. + // For backward compatibility, read this into a dummy index. + // This index is not used further and has no effect on functionality. + u32 dummyIndex; in >> dummyIndex >> dependenciesChecksum; if (dependenciesChecksum != dependencies_.getChecksum()) { diff --git a/src/Search/AdvancedTreeSearch/TreeStructure.cc b/src/Search/AdvancedTreeSearch/TreeStructure.cc index 50542ca16..993692ec6 100644 --- a/src/Search/AdvancedTreeSearch/TreeStructure.cc +++ b/src/Search/AdvancedTreeSearch/TreeStructure.cc @@ -17,9 +17,8 @@ namespace Search { HMMStateNetwork::HMMStateNetwork() - : subTreeManager_(subTreeListBatches_, states_), edgeTargetManager_(edgeTargetBatches_, states_) { + : subTreeManager_(subTreeListBatches_, states_), edgeTargetManager_(edgeTargetBatches_, states_), tree_() { // The zero index is reserved as "invalid", so push one dummy item into all arrays - tree_ = Tree(); states_.push_back(HMMState()); subTreeListBatches_.push_back(0); edgeTargetBatches_.push_back(0); @@ -113,7 +112,9 @@ u32 HMMStateNetwork::stateCount() const { bool HMMStateNetwork::write(Core::MappedArchiveWriter writer) { u32 version = DiskFormatVersionV2; - std::vector trees = {Tree(), tree_}; // need to write a vector of trees because HMMStateNetwork::read() expectes one + // The previous version used a vector of trees, where index 0 represented an invalid tree and index 1 contained the actual master tree. + // Therefore, to mainain backward compatibility, the tree needs to be written into a similar vector structure, which is then saved to the cache. + std::vector trees = {Tree(), tree_}; writer << version << subTreeListBatches_ << states_ << edgeTargetLists_ << edgeTargetBatches_ << trees; return writer.good(); } @@ -123,8 +124,10 @@ bool HMMStateNetwork::read(Core::MappedArchiveReader reader) { reader >> version; if (version == DiskFormatVersionV1) { + // The previous version used a vector of trees, where index 0 represented an invalid tree and index 1 contained the actual master tree. + // Therefore, to maintain backward compatibility, the cache needs to be read into a vector again and the tree can then be retrieved from index 1. + std::vector trees; std::vector states; - std::vector trees; // need to read into a vector for backward compatibility reader >> subTreeListBatches_ >> states >> edgeTargetLists_ >> edgeTargetBatches_ >> trees; tree_ = trees[1]; @@ -142,7 +145,9 @@ bool HMMStateNetwork::read(Core::MappedArchiveReader reader) { [](HMMStateV1 s) { return s.toHMMState(); }); } else if (version == DiskFormatVersionV2) { - std::vector trees; // need to read into a vector for backward compatibility + // The previous version used a vector of trees, where index 0 represented an invalid tree and index 1 contained the actual master tree. + // Therefore, to maintain backward compatibility, the cache needs to be read into a vector again and the tree can then be retrieved from index 1. + std::vector trees; reader >> subTreeListBatches_ >> states_ >> edgeTargetLists_ >> edgeTargetBatches_ >> trees; tree_ = trees[1]; } diff --git a/src/Search/AdvancedTreeSearch/TreeStructure.hh b/src/Search/AdvancedTreeSearch/TreeStructure.hh index bf030e20f..2e2f658c6 100644 --- a/src/Search/AdvancedTreeSearch/TreeStructure.hh +++ b/src/Search/AdvancedTreeSearch/TreeStructure.hh @@ -240,8 +240,10 @@ public: return states_[state]; } - /// Allocates a new subtree, and adds it into the subtree list. - /// Returns a fully valid subtree (with initialized edge-list) + /// Creates a new node (HMMState) and appends it to the tree's list of nodes. + /// Specifically, new subtree is allocated and added to the subtree list by the subTreeManager_. + /// Returns the StateId of the newly created node (a fully valid subtree with initialized edge-list). + /// Note that the stateDesc must be set separately after this operation. StateId allocateTreeNode(); /// Returns the count of nodes contained by the tree @@ -294,7 +296,9 @@ public: void clearOutputEdges(StateId node); u32 getChecksum() const { - return states_.size() + edgeTargetBatches_.size() + edgeTargetLists_.size() + subTreeListBatches_.size() + 2; // + 2 is needed for backwards compatibility + // In the previous version, a vector of trees was always used with a fixed length of two (index 0 = invalid tree, index 1 = the actual master tree). + // This fixed length was included in the checksum calculation, therefore a hardcoded +2 is applied here to ensure backward compatibility + return states_.size() + edgeTargetBatches_.size() + edgeTargetLists_.size() + subTreeListBatches_.size() + 2; } /// The change is applied when apply() is called From f1691a32af9a3f430e9ecb00d5fabe0e4e47aa1c Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 28 Jan 2025 17:13:29 +0100 Subject: [PATCH 07/24] Add CtcTreeBuilder --- .../AdvancedTreeSearch/PersistentStateTree.cc | 17 +- .../AdvancedTreeSearch/PersistentStateTree.hh | 17 +- src/Search/AdvancedTreeSearch/SearchSpace.cc | 8 +- src/Search/AdvancedTreeSearch/TreeBuilder.cc | 239 ++++++++++++++++-- src/Search/AdvancedTreeSearch/TreeBuilder.hh | 42 ++- src/Search/StateTree.hh | 4 + 6 files changed, 289 insertions(+), 38 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc index ed88b82f4..05ddd42e7 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.cc @@ -422,6 +422,9 @@ HMMStateNetwork::CleanupResult PersistentStateTree::cleanup(bool cleanupExits) { std::set roots = coarticulatedRootStates; roots.insert(rootState); roots.insert(ciRootState); + for (StateId s : otherRootStates) { + roots.insert(s); + } // Also collect all transition-successors as coarticulated roots for (StateId node = 1; node < structure.stateCount(); ++node) { @@ -506,10 +509,13 @@ void PersistentStateTree::dumpDotGraph(std::string file, const std::vector& << "edge [fontname=\"Helvetica\"]" << std::endl; for (StateId node = 1; node < structure.stateCount(); ++node) { - int depth = 0; - if (!nodeDepths.empty()) - depth = nodeDepths[node]; - os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d\\nt=%d", node, node, depth, structure.state(node).stateDesc.acousticModel, structure.state(node).stateDesc.transitionModelIndex); + if (!nodeDepths.empty()) { + int depth = nodeDepths[node]; + os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d\\nt=%d", node, node, depth, structure.state(node).stateDesc.acousticModel, structure.state(node).stateDesc.transitionModelIndex); + } + else { + os << Core::form("n%d [label=\"%d\\nm=%d\\nt=%d", node, node, structure.state(node).stateDesc.acousticModel, structure.state(node).stateDesc.transitionModelIndex); + } for (HMMStateNetwork::SuccessorIterator target = structure.successors(node); target; ++target) if (target.isLabel() && exits[target.label()].pronunciation != Bliss::LemmaPronunciation::invalidId) @@ -518,7 +524,8 @@ void PersistentStateTree::dumpDotGraph(std::string file, const std::vector& << Core::form(" tr=%d", exits[target.label()].transitState); os << "\""; - if (node == rootState || node == ciRootState || uncoarticulatedWordEndStates.count(node)) + bool is_other_root = std::find(otherRootStates.begin(), otherRootStates.end(), node) != otherRootStates.end(); + if (node == rootState || node == ciRootState || uncoarticulatedWordEndStates.count(node) || is_other_root) os << ",shape=box"; os << "]" << std::endl; diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh index 2e1f5e47e..5c028f962 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh @@ -37,7 +37,7 @@ namespace Search { class HMMStateNetwork; class StateTree; -class PersistentStateTree { +class PersistentStateTree : public Core::ReferenceCounted { public: using TreeBuilderFactory = std::function(Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool)>; @@ -68,19 +68,23 @@ public: u32 getChecksum() const; /// Dump the search network as a dot graph into the given file - void dumpDotGraph(std::string file, const std::vector& nodeDepths); + void dumpDotGraph(std::string file, const std::vector& nodeDepths = {}); struct Exit { Bliss::LemmaPronunciation::Id pronunciation; StateId transitState; - bool operator==(const Exit& rhs) const { - return pronunciation == rhs.pronunciation && transitState == rhs.transitState; - } + + Exit() : pronunciation(Bliss::LemmaPronunciation::invalidId), transitState(invalidTreeNodeIndex) {} + struct Hash { u32 operator()(const Exit& exit) const { return MyStandardValueHash()(exit.pronunciation + MyStandardValueHash()(exit.transitState)); } }; + + bool operator==(const Exit& rhs) const { + return pronunciation == rhs.pronunciation && transitState == rhs.transitState; + } bool operator<(const Exit& rhs) const { return pronunciation < rhs.pronunciation || (pronunciation == rhs.pronunciation && transitState < rhs.transitState); } @@ -100,6 +104,9 @@ public: // Context-independent root node StateId ciRootState; + // Other root nodes (currently used for the wordBoundaryRoot in CtcTreeBuilder) + std::vector otherRootStates; + // The word-end exits std::vector exits; diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index 712eec397..0374cb613 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -499,9 +499,9 @@ void StaticSearchAutomaton::clearDepths() { int StaticSearchAutomaton::fillStateDepths(StateId state, int depth) { if (stateDepths[state] != Core::Type::min) { - if (stateDepths[state] != depth) { /// @todo Find out why this happens on some languages - std::cout << "conflicting state depths: " << stateDepths[state] << " vs " << depth << std::endl; - } + //if (stateDepths[state] != depth) { /// @todo Find out why this happens on some languages + // std::cout << "conflicting state depths: " << stateDepths[state] << " vs " << depth << std::endl; + //} if (depth > stateDepths[state]) { stateDepths[state] = Core::Type::min; // Re-fill successor depths } @@ -777,7 +777,7 @@ std::unique_ptr StaticSearchAutomaton::createTreeBuilder(Co return std::unique_ptr(new MinimizedTreeBuilder(config, *lexicon_, *acousticModel_, network, initialize)); } break; case TreeBuilderType::ctc: { - defect(); // TODO: add CTC implementation + return std::unique_ptr(new CtcTreeBuilder(config, *lexicon_, *acousticModel_, network, initialize)); } break; default: defect(); } diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.cc b/src/Search/AdvancedTreeSearch/TreeBuilder.cc index ba89b2036..b51cbf3f2 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.cc +++ b/src/Search/AdvancedTreeSearch/TreeBuilder.cc @@ -34,6 +34,26 @@ AbstractTreeBuilder::AbstractTreeBuilder(Core::Configuration config, network_(network) { } +StateId AbstractTreeBuilder::createState(StateTree::StateDesc desc) { + StateId ret = network_.structure.allocateTreeNode(); + network_.structure.state(ret).stateDesc = desc; + return ret; +} + +u32 AbstractTreeBuilder::createExit(PersistentStateTree::Exit exit) { + ExitHash::iterator exitHashIt = exitHash_.find(exit); + if (exitHashIt != exitHash_.end()) { + return exitHashIt->second; + } + else { + // Exit does not exist yet, add it + network_.exits.push_back(exit); + u32 exitIndex = network_.exits.size() - 1; + exitHash_.insert(std::make_pair(exit, exitIndex)); + return exitIndex; + } +} + // -------------------- MinimizedTreeBuilder -------------------- // TODO: Verify that pushed word-ends have the same transition penalty as the corresponding unpushed word-ends @@ -552,26 +572,6 @@ AbstractTreeBuilder::StateId MinimizedTreeBuilder::createRoot(Bliss::Phoneme::Id return ret; } -AbstractTreeBuilder::StateId MinimizedTreeBuilder::createState(StateTree::StateDesc desc) { - StateId ret = network_.structure.allocateTreeNode(); - network_.structure.state(ret).stateDesc = desc; - return ret; -} - -u32 MinimizedTreeBuilder::createExit(PersistentStateTree::Exit exit) { - ExitHash::iterator exitHashIt = exitHash_.find(exit); - if (exitHashIt != exitHash_.end()) { - return exitHashIt->second; - } - else { - // Exit does not exist yet, add it - network_.exits.push_back(exit); - u32 exitIndex = network_.exits.size() - 1; - exitHash_.insert(std::make_pair(exit, exitIndex)); - return exitIndex; - } -} - u32 MinimizedTreeBuilder::addExit(StateId predecessor, Bliss::Phoneme::Id leftPhoneme, Bliss::Phoneme::Id rightPhoneme, @@ -1197,3 +1197,202 @@ inline void MinimizedTreeBuilder::mapSuccessors(const std::set& success } } } + +// -------------------- CtcTreeBuilder -------------------- + +CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) + : AbstractTreeBuilder(config, lexicon, acousticModel, network) { + auto iters = lexicon.phonemeInventory()->phonemes(); + for (auto it = iters.first; it != iters.second; ++it) { + require(not(*it)->isContextDependent()); // Context dependent labels are not supported + } + + // Set the StateDesc for blank + blankAllophoneStateIndex_ = acousticModel_.blankAllophoneStateIndex(); + blankDesc_.acousticModel = acousticModel_.emissionIndex(blankAllophoneStateIndex_); + blankDesc_.transitionModelIndex = acousticModel_.stateTransitionIndex(blankAllophoneStateIndex_); + require_lt(blankDesc_.transitionModelIndex, Core::Type::max); + + if (initialize) { + verify(!network_.rootState); + network_.ciRootState = network_.rootState = createRoot(); + + // Create a special root for the word-boundary token if it exists in the lexicon + if (lexicon.specialLemma("word-boundary") != nullptr) { + wordBoundaryRoot_ = createRoot(); + network_.otherRootStates.push_back(wordBoundaryRoot_); + } + } +} + +std::unique_ptr CtcTreeBuilder::newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) { + return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network)); +} + +void CtcTreeBuilder::build() { + const Bliss::Lemma* wordBoundaryLemma = nullptr; + if (lexicon_.specialLemma("word-boundary") != nullptr) { + addWordBoundaryStates(); + wordBoundaryLemma = lexicon_.specialLemma("word-boundary"); + } + + auto blankLemma = lexicon_.specialLemma("blank"); + auto silenceLemma = lexicon_.specialLemma("silence"); + auto iters = lexicon_.lemmaPronunciations(); + + // Iterate over the lemmata and add them to the tree + for (auto it = iters.first; it != iters.second; ++it) { + if ((*it)->lemma() == wordBoundaryLemma) { + continue; + } + + StateId lastState = extendPronunciation(network_.rootState, (*it)->pronunciation()); + + if (wordBoundaryLemma != nullptr && (*it)->lemma() != blankLemma && (*it)->lemma() != silenceLemma) { + // If existing, the wordBoundaryRoot_ should be the transit state for all word ends except blank and silence + addExit(lastState, wordBoundaryRoot_, (*it)->id()); + } + else { + addExit(lastState, network_.rootState, (*it)->id()); + } + } +} + +StateId CtcTreeBuilder::createRoot() { + return createState(StateTree::StateDesc(Search::StateTree::invalidAcousticModel, Am::TransitionModel::entryM1)); +} + +u32 CtcTreeBuilder::addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron) { + PersistentStateTree::Exit exit; + exit.transitState = transitState; + exit.pronunciation = pron; + + u32 exitIndex = createExit(exit); + + // Check if the exit is already a successor + // This should only happen if the same lemma is contained multiple times in the lexicon + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(state); target; ++target) { + if (target.isLabel() && target.label() == exitIndex) { + return exitIndex; + } + } + + // The exit is not part of the successors yet, add it + network_.structure.addOutputToNode(state, ID_FROM_LABEL(exitIndex)); + return exitIndex; +} + +StateId CtcTreeBuilder::extendState(StateId predecessor, StateTree::StateDesc desc) { + // Check if the successor already exists + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { + if (!target.isLabel() && network_.structure.state(*target).stateDesc == desc) { + return *target; + } + } + + // No matching successor found, extend + StateId ret = createState(desc); + network_.structure.addTargetToNode(predecessor, ret); + return ret; +} + +void CtcTreeBuilder::addTransition(StateId predecessor, StateId successor) { + bool found = false; + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { + if(!target.isLabel() && network_.structure.state(*target).stateDesc == network_.structure.state(successor).stateDesc) { + // The node is already a successor of the predecessor, so the transition already exists + found = true; + } + } + if (!found) { + // The transition does not exists yet, add it + network_.structure.addTargetToNode(predecessor, successor); + } +} + +StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronunciation const* pron) { + StateId currentState = startState; + StateId prevNonBlankState = invalidTreeNodeIndex; + + for (u32 i = 0u; i < pron->length(); i++) { + Bliss::Phoneme::Id phoneme = (*pron)[i]; + + u32 boundary = 0u; + if (i == 0) { + boundary |= Am::Allophone::isInitialPhone; + } + if ((i + 1) == pron->length()) { + boundary |= Am::Allophone::isFinalPhone; + } + + Bliss::ContextPhonology::SemiContext history, future; + const Am::Allophone* allophone = acousticModel_.allophoneAlphabet()->allophone(Am::Allophone(Bliss::ContextPhonology::PhonemeInContext(phoneme, history, future), boundary)); + const Am::ClassicHmmTopology* hmmTopology = acousticModel_.hmmTopology(phoneme); + const bool allophone_is_blank = acousticModel_.allophoneStateAlphabet()->index(allophone, 0, false) == blankAllophoneStateIndex_; + + for (u32 phoneState = 0; phoneState < hmmTopology->nPhoneStates(); ++phoneState) { + Am::AllophoneState alloState = acousticModel_.allophoneStateAlphabet()->allophoneState(allophone, phoneState); + StateTree::StateDesc desc; + desc.acousticModel = acousticModel_.emissionIndex(alloState); // Decision tree look-up for CART id. + + for (u32 subState = 0; subState < hmmTopology->nSubStates(); ++subState) { + desc.transitionModelIndex = acousticModel_.stateTransitionIndex(alloState, subState); + verify(desc.transitionModelIndex < Core::Type::max); + + // Add new (non-blank) state + currentState = extendState(currentState, desc); + // Add loop for this state + addTransition(currentState, currentState); + // Add transition from previous non-blank state to this state, allowing to skip the blank state in-between these two + if (prevNonBlankState != invalidTreeNodeIndex) { + addTransition(prevNonBlankState, currentState); + } + prevNonBlankState = currentState; + + bool is_last_state_in_lemma = ((phoneState + 1) == hmmTopology->nPhoneStates()) and ((subState + 1) == hmmTopology->nSubStates()) and (boundary & Am::Allophone::isFinalPhone); + if (not allophone_is_blank and not is_last_state_in_lemma) { + // Add blank state after the newly created state + currentState = extendState(currentState, blankDesc_); + // Add loop for this blank state + addTransition(currentState, currentState); + } + } + } + } + + return currentState; +} + +void CtcTreeBuilder::addWordBoundaryStates() { + Bliss::Lemma const* wordBoundaryLemma = lexicon_.specialLemma("word-boundary"); + Bliss::LemmaPronunciation const* wordBoundaryPronLemma = nullptr; + StateId wordBoundaryEnd = 0; + + // Add the word-boundary to the tree, starting from the wordBoundaryRoot_ + // If the word-boundary has several pronunciation, only the first one is considered + auto prons = wordBoundaryLemma->pronunciations(); + wordBoundaryEnd = extendPronunciation(wordBoundaryRoot_, (prons.first)->pronunciation()); + wordBoundaryPronLemma = prons.first; + + require(wordBoundaryEnd != 0); + require(wordBoundaryPronLemma != nullptr); + + // The "normal" root is the transition state from the word-boundary token, such that a new word can be started afterwards + addExit(wordBoundaryEnd, network_.rootState, wordBoundaryPronLemma->id()); + + std::vector wordBoundaryLemmaStartStates; + for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(wordBoundaryRoot_); target; ++target) { + if (!target.isLabel()) { + wordBoundaryLemmaStartStates.push_back(*target); + } + } + // Add optional blank before the word-boundary lemma + StateId blankBefore = extendState(wordBoundaryRoot_, blankDesc_); + for (StateId wbs : wordBoundaryLemmaStartStates) { + network_.structure.addTargetToNode(blankBefore, wbs); + } + // Add loop for this blank state + addTransition(blankBefore, blankBefore); +} + + diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.hh b/src/Search/AdvancedTreeSearch/TreeBuilder.hh index 3ddac517f..706635520 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.hh +++ b/src/Search/AdvancedTreeSearch/TreeBuilder.hh @@ -46,9 +46,16 @@ public: virtual void build() = 0; protected: + typedef Core::HashMap ExitHash; + const Bliss::Lexicon& lexicon_; const Am::AcousticModel& acousticModel_; Search::PersistentStateTree& network_; + + ExitHash exitHash_; + + StateId createState(Search::StateTree::StateDesc desc); + u32 createExit(Search::PersistentStateTree::Exit exit); }; class MinimizedTreeBuilder : public AbstractTreeBuilder { @@ -162,7 +169,6 @@ protected: typedef std::set PhonemeIdSet; typedef Core::HashMap RootHash; typedef Core::HashMap SkipRootsHash; - typedef Core::HashMap ExitHash; typedef Core::HashMap, RootKey::Hash> CoarticulationJointHash; typedef Core::HashMap PredecessorsHash; @@ -186,7 +192,6 @@ protected: RootHash roots_; // Contains roots and joint-states SkipRootsHash skipRoots_; std::set skipRootSet_; - ExitHash exitHash_; CoarticulationJointHash initialPhoneSuffix_; CoarticulationJointHash initialFinalPhoneSuffix_; PredecessorsHash predecessors_; @@ -205,8 +210,6 @@ protected: StateId createSkipRoot(StateId baseRoot); StateId createRoot(Bliss::Phoneme::Id left, Bliss::Phoneme::Id right, int depth); - StateId createState(Search::StateTree::StateDesc desc); - u32 createExit(Search::PersistentStateTree::Exit exit); u32 addExit(StateId predecessor, Bliss::Phoneme::Id leftPhoneme, Bliss::Phoneme::Id rightPhoneme, @@ -244,4 +247,35 @@ protected: void mapSuccessors(const std::set&, std::set&, const std::vector&, const std::vector&); }; +class CtcTreeBuilder : public AbstractTreeBuilder { +public: + CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + virtual ~CtcTreeBuilder() = default; + + virtual std::unique_ptr newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + + // Build a new persistent state network. + virtual void build(); + +protected: + StateId wordBoundaryRoot_; + Search::StateTree::StateDesc blankDesc_; + Am::AllophoneStateIndex blankAllophoneStateIndex_; + + // Create a node with invalid AM and TM indices which serves as a root + StateId createRoot(); + // @param state is the last state of the word with pronunciation ID @param pron, add an exit leading to the root node @param transitState + // The exit is appended to the state's successors + u32 addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron); + + // Check if the node with @param desc is already a successor of the @param predecessor and add it if not + StateId extendState(StateId predecessor, Search::StateTree::StateDesc desc); + // Starting in @param startState (usually the root), include the lemma with pronunciation @param pron in the tree + StateId extendPronunciation(StateId startState, Bliss::Pronunciation const* pron); + // Add a transition between two already existing states, used to insert loops and skip-transitions + void addTransition(StateId predecessor, StateId successor); + // If the lexicon contains a word-boundary token, it is added starting from the wordBoundaryRoot_ + void addWordBoundaryStates(); +}; + #endif diff --git a/src/Search/StateTree.hh b/src/Search/StateTree.hh index 8755a015a..311056eb0 100644 --- a/src/Search/StateTree.hh +++ b/src/Search/StateTree.hh @@ -137,6 +137,10 @@ public: : acousticModel(0), transitionModelIndex(0) { } + StateDesc(ModelIndex ami, TransitionModelIndex tmi) + : acousticModel(ami), transitionModelIndex(tmi) { + } + std::string toString() const { std::ostringstream target; target << (u32)acousticModel << "_" << (u32)transitionModelIndex; From eafef9225eb6c360910287c17bdedaa1c4f7e4e4 Mon Sep 17 00:00:00 2001 From: Larissa Date: Mon, 10 Feb 2025 17:05:23 +0100 Subject: [PATCH 08/24] Uncouple treebuilding from AdvancedTreeSearch --- Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../makefiles/Modules.make | 1 + .../AdvancedTreeSearch/AcousticLookAhead.hh | 3 +- .../AdvancedTreeSearch/AdvancedTreeSearch.cc | 3 +- .../LanguageModelLookahead.hh | 6 ++-- .../AdvancedTreeSearch/LinearPrediction.hh | 2 +- src/Search/AdvancedTreeSearch/Makefile | 5 +-- .../AdvancedTreeSearch/PathRecombination.hh | 5 +-- .../PathRecombinationApproximation.hh | 2 +- src/Search/AdvancedTreeSearch/PrefixFilter.hh | 2 +- .../SearchNetworkTransformation.hh | 4 +-- src/Search/AdvancedTreeSearch/SearchSpace.cc | 32 ++++++------------- src/Search/AdvancedTreeSearch/SearchSpace.hh | 16 +++------- .../AdvancedTreeSearch/SearchSpaceHelpers.hh | 2 +- src/Search/AdvancedTreeSearch/Trace.hh | 2 +- src/Search/LanguageModelLookahead.hh | 2 +- src/Search/Makefile | 7 ++-- src/Search/TreeBuilder/Makefile | 28 ++++++++++++++++ .../PersistentStateTree.cc | 8 +++-- .../PersistentStateTree.hh | 5 +-- src/Search/{ => TreeBuilder}/StateTree.cc | 0 src/Search/{ => TreeBuilder}/StateTree.hh | 0 src/Search/{ => TreeBuilder}/StateTreeIo.cc | 0 src/Search/{ => TreeBuilder}/StateTreeIo.hh | 0 .../TreeBuilder.cc | 2 +- .../TreeBuilder.hh | 28 ++++++++++++++-- .../TreeStructure.cc | 0 .../TreeStructure.hh | 7 ++-- .../TreeWalker.hh | 0 src/Search/Wfst/Makefile | 1 + src/Search/Wfst/StateTree.hh | 2 +- src/Search/WordConditionedTreeSearch.cc | 2 +- src/Search/check.cc | 2 +- src/Speech/Makefile | 1 + src/Test/Makefile | 1 + src/Tools/Archiver/Makefile | 1 + src/Tools/NnTrainer/Makefile | 1 + 41 files changed, 116 insertions(+), 72 deletions(-) create mode 100644 src/Search/TreeBuilder/Makefile rename src/Search/{AdvancedTreeSearch => TreeBuilder}/PersistentStateTree.cc (99%) rename src/Search/{AdvancedTreeSearch => TreeBuilder}/PersistentStateTree.hh (96%) rename src/Search/{ => TreeBuilder}/StateTree.cc (100%) rename src/Search/{ => TreeBuilder}/StateTree.hh (100%) rename src/Search/{ => TreeBuilder}/StateTreeIo.cc (100%) rename src/Search/{ => TreeBuilder}/StateTreeIo.hh (100%) rename src/Search/{AdvancedTreeSearch => TreeBuilder}/TreeBuilder.cc (99%) rename src/Search/{AdvancedTreeSearch => TreeBuilder}/TreeBuilder.hh (91%) rename src/Search/{AdvancedTreeSearch => TreeBuilder}/TreeStructure.cc (100%) rename src/Search/{AdvancedTreeSearch => TreeBuilder}/TreeStructure.hh (99%) rename src/Search/{AdvancedTreeSearch => TreeBuilder}/TreeWalker.hh (100%) diff --git a/Modules.make b/Modules.make index 9f2a39dbd..487b6d11b 100644 --- a/Modules.make +++ b/Modules.make @@ -150,6 +150,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make index 765fd3849..a418b3345 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make @@ -145,6 +145,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make index 765fd3849..a418b3345 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make @@ -145,6 +145,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make index 7a5b1d1ee..06376620c 100644 --- a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make +++ b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make @@ -145,6 +145,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make index 56b50dbbe..7007af632 100644 --- a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make +++ b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make @@ -149,6 +149,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make index 7a45546db..94c8f8e74 100644 --- a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make +++ b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make @@ -150,6 +150,7 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) +LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/src/Search/AdvancedTreeSearch/AcousticLookAhead.hh b/src/Search/AdvancedTreeSearch/AcousticLookAhead.hh index 745ca963f..4c94d2314 100644 --- a/src/Search/AdvancedTreeSearch/AcousticLookAhead.hh +++ b/src/Search/AdvancedTreeSearch/AcousticLookAhead.hh @@ -17,9 +17,10 @@ #include #include +#include + #include "Helpers.hh" #include "SearchSpace.hh" -#include "TreeStructure.hh" // #include "LinearMiniHash.hh" namespace Search { diff --git a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc index fc81c0852..41d22c8de 100644 --- a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc +++ b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc @@ -21,8 +21,9 @@ #include #include #include -#include #include +#include + #include "SearchSpace.hh" #include "SearchSpaceStatistics.hh" diff --git a/src/Search/AdvancedTreeSearch/LanguageModelLookahead.hh b/src/Search/AdvancedTreeSearch/LanguageModelLookahead.hh index d5a871dca..0e490ca6d 100644 --- a/src/Search/AdvancedTreeSearch/LanguageModelLookahead.hh +++ b/src/Search/AdvancedTreeSearch/LanguageModelLookahead.hh @@ -26,11 +26,11 @@ #include #include #include -#include +#include +#include +#include #include "LinearPrediction.hh" -#include "PersistentStateTree.hh" -#include "TreeStructure.hh" // #define EXTENSIVE_SPARSE_COLLISION_STATS diff --git a/src/Search/AdvancedTreeSearch/LinearPrediction.hh b/src/Search/AdvancedTreeSearch/LinearPrediction.hh index ad96adc17..e53a680df 100644 --- a/src/Search/AdvancedTreeSearch/LinearPrediction.hh +++ b/src/Search/AdvancedTreeSearch/LinearPrediction.hh @@ -18,7 +18,7 @@ #include #include #include -#include "TreeStructure.hh" +#include namespace Search { class LinearPrediction { diff --git a/src/Search/AdvancedTreeSearch/Makefile b/src/Search/AdvancedTreeSearch/Makefile index 7a82b627b..24478f4ac 100644 --- a/src/Search/AdvancedTreeSearch/Makefile +++ b/src/Search/AdvancedTreeSearch/Makefile @@ -16,16 +16,13 @@ LIBSPRINTADVANCEDTREESEARCH_O = $(OBJDIR)/AcousticLookAhead.o \ $(OBJDIR)/LanguageModelLookahead.o \ $(OBJDIR)/PathRecombination.o \ $(OBJDIR)/PathRecombinationApproximation.o \ - $(OBJDIR)/PersistentStateTree.o \ $(OBJDIR)/PrefixFilter.o \ $(OBJDIR)/ScoreDependentStatistics.o \ $(OBJDIR)/SearchSpace.o \ $(OBJDIR)/SearchSpaceHelpers.o \ $(OBJDIR)/SearchSpaceStatistics.o \ $(OBJDIR)/SimpleThreadPool.o \ - $(OBJDIR)/Trace.o \ - $(OBJDIR)/TreeBuilder.o \ - $(OBJDIR)/TreeStructure.o + $(OBJDIR)/Trace.o ifeq ($(OS),darwin) CCFLAGS += -fexceptions diff --git a/src/Search/AdvancedTreeSearch/PathRecombination.hh b/src/Search/AdvancedTreeSearch/PathRecombination.hh index d94bfb8b0..0217b1979 100644 --- a/src/Search/AdvancedTreeSearch/PathRecombination.hh +++ b/src/Search/AdvancedTreeSearch/PathRecombination.hh @@ -16,9 +16,10 @@ #define PATHRECOMBINATION_HH #include +#include +#include + #include "Helpers.hh" -#include "PersistentStateTree.hh" -#include "TreeStructure.hh" namespace Search { class PathRecombination { diff --git a/src/Search/AdvancedTreeSearch/PathRecombinationApproximation.hh b/src/Search/AdvancedTreeSearch/PathRecombinationApproximation.hh index b3d5362bb..e2349477e 100644 --- a/src/Search/AdvancedTreeSearch/PathRecombinationApproximation.hh +++ b/src/Search/AdvancedTreeSearch/PathRecombinationApproximation.hh @@ -15,8 +15,8 @@ #ifndef SEARCH_PATHRECOMBINATIONAPPROXIMATION_HH #define SEARCH_PATHRECOMBINATIONAPPROXIMATION_HH +#include #include "PathRecombination.hh" -#include "PersistentStateTree.hh" namespace Search { class PathRecombinationApproximation { diff --git a/src/Search/AdvancedTreeSearch/PrefixFilter.hh b/src/Search/AdvancedTreeSearch/PrefixFilter.hh index 10a741a7a..8497bca11 100644 --- a/src/Search/AdvancedTreeSearch/PrefixFilter.hh +++ b/src/Search/AdvancedTreeSearch/PrefixFilter.hh @@ -18,7 +18,7 @@ #include #include -#include "PersistentStateTree.hh" +#include #include "SearchSpaceHelpers.hh" namespace Core { diff --git a/src/Search/AdvancedTreeSearch/SearchNetworkTransformation.hh b/src/Search/AdvancedTreeSearch/SearchNetworkTransformation.hh index cf78ec52c..3ff961d1e 100644 --- a/src/Search/AdvancedTreeSearch/SearchNetworkTransformation.hh +++ b/src/Search/AdvancedTreeSearch/SearchNetworkTransformation.hh @@ -18,8 +18,8 @@ /** * This file contains a collection of helper classes useful for transformations of the search network * */ -#include -#include "TreeStructure.hh" +#include +#include namespace AdvancedTreeSearch { struct StateWithSuccessors { diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index 0374cb613..b5d4238cc 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -23,13 +23,14 @@ #include #include +#include +#include + #include "AcousticLookAhead.hh" -#include "PersistentStateTree.hh" #include "PrefixFilter.hh" #include "Pruning.hh" #include "SearchNetworkTransformation.hh" #include "SearchSpaceStatistics.hh" -#include "TreeBuilder.hh" using namespace AdvancedTreeSearch; @@ -132,16 +133,16 @@ const Core::ParameterBool paramBuildMinimizedTreeFromScratch( true); const Core::Choice choiceTreeBuilderType( - "classic-hmm", static_cast(StaticSearchAutomaton::TreeBuilderType::classicHmm), - "minimized-hmm", static_cast(StaticSearchAutomaton::TreeBuilderType::minimizedHmm), - "ctc", static_cast(StaticSearchAutomaton::TreeBuilderType::ctc), + "classic-hmm", static_cast(TreeBuilderType::classicHmm), + "minimized-hmm", static_cast(TreeBuilderType::minimizedHmm), + "ctc", static_cast(TreeBuilderType::ctc), Core::Choice::endMark()); const Core::ParameterChoice paramTreeBuilderType( "tree-builder-type", &choiceTreeBuilderType, "which tree builder to use", - static_cast(StaticSearchAutomaton::TreeBuilderType::previousBehavior)); + static_cast(TreeBuilderType::previousBehavior)); const Core::ParameterBool paramConditionPredecessorWord( "condition-on-predecessor-word", @@ -386,7 +387,7 @@ StaticSearchAutomaton::StaticSearchAutomaton(Core::Configuration config, Core::R : Precursor(config), hmmLength(acousticModel->hmmTopologySet()->getDefault().nPhoneStates() * acousticModel->hmmTopologySet()->getDefault().nSubStates()), minimized(paramBuildMinimizedTreeFromScratch(config)), - network(config, acousticModel, lexicon, std::bind(&StaticSearchAutomaton::createTreeBuilder, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5)), + network(config, acousticModel, lexicon, std::bind(&createTreeBuilder, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)), prefixFilter(nullptr), treeBuilderType_(static_cast(paramTreeBuilderType(config))), acousticModel_(acousticModel), @@ -408,7 +409,7 @@ void StaticSearchAutomaton::buildNetwork() { if (!network.read(transformation)) { log() << "persistent network image could not be loaded, building it"; - std::unique_ptr builder = createTreeBuilder(config, *lexicon_, *acousticModel_, network); + std::unique_ptr builder = createTreeBuilder(treeBuilderType_, config, *lexicon_, *acousticModel_, network); if (not builder) { network.build(); network.cleanup(); @@ -768,21 +769,6 @@ void StaticSearchAutomaton::buildBatches() { network.removeOutputs(); } -std::unique_ptr StaticSearchAutomaton::createTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) { - switch (treeBuilderType_) { - case TreeBuilderType::classicHmm: { // Use StateTree.hh - return std::unique_ptr(nullptr); - } break; - case TreeBuilderType::minimizedHmm: { // Use TreeStructure.hh - return std::unique_ptr(new MinimizedTreeBuilder(config, *lexicon_, *acousticModel_, network, initialize)); - } break; - case TreeBuilderType::ctc: { - return std::unique_ptr(new CtcTreeBuilder(config, *lexicon_, *acousticModel_, network, initialize)); - } break; - default: defect(); - } -} - // ------------------------------- Search Space -------------------------------- SearchSpace::SearchSpace(const Core::Configuration& config, diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.hh b/src/Search/AdvancedTreeSearch/SearchSpace.hh index 4c7428803..ea52e8e39 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.hh +++ b/src/Search/AdvancedTreeSearch/SearchSpace.hh @@ -21,16 +21,17 @@ #include #include #include -#include +#include +#include +#include +#include #include "Helpers.hh" #include "LanguageModelLookahead.hh" -#include "PersistentStateTree.hh" #include "ScoreDependentStatistics.hh" #include "SearchSpaceHelpers.hh" #include "SimpleThreadPool.hh" #include "Trace.hh" -#include "TreeStructure.hh" struct EmissionSetCounter; namespace AdvancedTreeSearch { @@ -50,13 +51,6 @@ class StaticSearchAutomaton : public Core::Component { public: using Precursor = Core::Component; - enum class TreeBuilderType { - previousBehavior = 0, - classicHmm = 1, - minimizedHmm = 2, - ctc = 3, - }; - /// HMM length of a common phoneme const u32 hmmLength; bool minimized; @@ -125,8 +119,6 @@ public: protected: TreeBuilderType treeBuilderType_; - std::unique_ptr createTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); - private: Core::Ref acousticModel_; Bliss::LexiconRef lexicon_; diff --git a/src/Search/AdvancedTreeSearch/SearchSpaceHelpers.hh b/src/Search/AdvancedTreeSearch/SearchSpaceHelpers.hh index 38274a629..c1d4aaf88 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpaceHelpers.hh +++ b/src/Search/AdvancedTreeSearch/SearchSpaceHelpers.hh @@ -16,10 +16,10 @@ #define SEARCH_SPACE_HELPERS #include +#include #include "LanguageModelLookahead.hh" #include "TraceManager.hh" -#include "TreeStructure.hh" namespace Search { typedef s32 StateHypothesisIndex; diff --git a/src/Search/AdvancedTreeSearch/Trace.hh b/src/Search/AdvancedTreeSearch/Trace.hh index 73729a1fe..961da914f 100644 --- a/src/Search/AdvancedTreeSearch/Trace.hh +++ b/src/Search/AdvancedTreeSearch/Trace.hh @@ -17,8 +17,8 @@ #include #include -#include #include +#include #include "PathTrace.hh" namespace Search { diff --git a/src/Search/LanguageModelLookahead.hh b/src/Search/LanguageModelLookahead.hh index faa83bb10..29d3540ea 100644 --- a/src/Search/LanguageModelLookahead.hh +++ b/src/Search/LanguageModelLookahead.hh @@ -25,7 +25,7 @@ #include #include #include -#include "StateTree.hh" +#include namespace Search { diff --git a/src/Search/Makefile b/src/Search/Makefile index 11514593b..01e600170 100644 --- a/src/Search/Makefile +++ b/src/Search/Makefile @@ -15,8 +15,6 @@ LIBSPRINTSEARCH_O = \ $(OBJDIR)/LanguageModelLookahead.o \ $(OBJDIR)/Module.o \ $(OBJDIR)/Search.o \ - $(OBJDIR)/StateTree.o \ - $(OBJDIR)/StateTreeIo.o \ $(OBJDIR)/WordConditionedTreeSearch.o CHECK_O = $(OBJDIR)/check.o \ @@ -24,7 +22,6 @@ CHECK_O = $(OBJDIR)/check.o \ ../Bliss/libSprintBliss.$(a) \ ../Fsa/libSprintFsa.$(a) \ ../Core/libSprintCore.$(a) - ifdef MODULE_SEARCH_MBR LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskSearch.o @@ -32,6 +29,7 @@ LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskAStarSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskNBestListSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskSearchUtil.o endif +SUBDIRS += TreeBuilder ifdef MODULE_SEARCH_WFST SUBDIRS += Wfst endif @@ -59,6 +57,9 @@ libSprintSearch.$(a): $(LIBSPRINTSEARCH_O) check$(exe): $(CHECK_O) $(LD) $(CHECK_O) -o check$(exe) $(LDFLAGS) +TreeBuilder: + $(MAKE) -C $@ libSprintTreeBuilder.$(a) + Wfst: $(MAKE) -C $@ libSprintSearchWfst.$(a) diff --git a/src/Search/TreeBuilder/Makefile b/src/Search/TreeBuilder/Makefile new file mode 100644 index 000000000..09ce6fbba --- /dev/null +++ b/src/Search/TreeBuilder/Makefile @@ -0,0 +1,28 @@ +#!gmake + +TOPDIR = ../../.. + +include $(TOPDIR)/Makefile.cfg + +# ----------------------------------------------------------------------------- + +SUBDIRS = +TARGETS = libSprintTreeBuilder.$(a) + +LIBSPRINTTREEBUILDER_O = $(OBJDIR)/PersistentStateTree.o \ + $(OBJDIR)/StateTree.o \ + $(OBJDIR)/StateTreeIo.o \ + $(OBJDIR)/TreeBuilder.o \ + $(OBJDIR)/TreeStructure.o + +# ----------------------------------------------------------------------------- + +all: $(TARGETS) + +libSprintTreeBuilder.$(a): $(LIBSPRINTTREEBUILDER_O) + $(MAKELIB) $@ $^ + +include $(TOPDIR)/Rules.make + +sinclude $(LIBSPRINTTREEBUILDER_O:.o=.d) +include $(patsubst %.o,%.d,$(filter %.o,$(CHECK_O))) \ No newline at end of file diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc b/src/Search/TreeBuilder/PersistentStateTree.cc similarity index 99% rename from src/Search/AdvancedTreeSearch/PersistentStateTree.cc rename to src/Search/TreeBuilder/PersistentStateTree.cc index 05ddd42e7..8ee5d40e8 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.cc +++ b/src/Search/TreeBuilder/PersistentStateTree.cc @@ -19,11 +19,13 @@ #include #include -#include -#include "BatchManager.hh" -#include "Helpers.hh" +#include +#include + +#include "StateTree.hh" #include "TreeStructure.hh" + using namespace Search; using namespace Core; diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh b/src/Search/TreeBuilder/PersistentStateTree.hh similarity index 96% rename from src/Search/AdvancedTreeSearch/PersistentStateTree.hh rename to src/Search/TreeBuilder/PersistentStateTree.hh index 5c028f962..9f8b9ac78 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh +++ b/src/Search/TreeBuilder/PersistentStateTree.hh @@ -32,14 +32,15 @@ struct MyStandardValueHash { }; class AbstractTreeBuilder; +enum class TreeBuilderType; namespace Search { class HMMStateNetwork; class StateTree; -class PersistentStateTree : public Core::ReferenceCounted { +class PersistentStateTree { public: - using TreeBuilderFactory = std::function(Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool)>; + using TreeBuilderFactory = std::function(TreeBuilderType, Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool)>; ///@param lexicon This must be given if the resulting exits are supposed to be functional PersistentStateTree(Core::Configuration config, Core::Ref acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory); diff --git a/src/Search/StateTree.cc b/src/Search/TreeBuilder/StateTree.cc similarity index 100% rename from src/Search/StateTree.cc rename to src/Search/TreeBuilder/StateTree.cc diff --git a/src/Search/StateTree.hh b/src/Search/TreeBuilder/StateTree.hh similarity index 100% rename from src/Search/StateTree.hh rename to src/Search/TreeBuilder/StateTree.hh diff --git a/src/Search/StateTreeIo.cc b/src/Search/TreeBuilder/StateTreeIo.cc similarity index 100% rename from src/Search/StateTreeIo.cc rename to src/Search/TreeBuilder/StateTreeIo.cc diff --git a/src/Search/StateTreeIo.hh b/src/Search/TreeBuilder/StateTreeIo.hh similarity index 100% rename from src/Search/StateTreeIo.hh rename to src/Search/TreeBuilder/StateTreeIo.hh diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.cc b/src/Search/TreeBuilder/TreeBuilder.cc similarity index 99% rename from src/Search/AdvancedTreeSearch/TreeBuilder.cc rename to src/Search/TreeBuilder/TreeBuilder.cc index b51cbf3f2..d0efec14d 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.cc +++ b/src/Search/TreeBuilder/TreeBuilder.cc @@ -16,8 +16,8 @@ #include #include #include -#include #include +#include "StateTree.hh" #include "PersistentStateTree.hh" using namespace Search; diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.hh b/src/Search/TreeBuilder/TreeBuilder.hh similarity index 91% rename from src/Search/AdvancedTreeSearch/TreeBuilder.hh rename to src/Search/TreeBuilder/TreeBuilder.hh index 706635520..2b84ee4b4 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.hh +++ b/src/Search/TreeBuilder/TreeBuilder.hh @@ -16,9 +16,9 @@ #define TREEBUILDER_HH #include -#include -#include "Helpers.hh" -#include "LmCache.hh" +#include +#include +#include "StateTree.hh" #include "PersistentStateTree.hh" namespace Bliss { @@ -33,6 +33,13 @@ namespace Core { class Configuration; } +enum class TreeBuilderType { + previousBehavior = 0, + classicHmm = 1, + minimizedHmm = 2, + ctc = 3, +}; + class AbstractTreeBuilder : public Core::Component { public: typedef u32 StateId; @@ -278,4 +285,19 @@ protected: void addWordBoundaryStates(); }; +std::unique_ptr createTreeBuilder(TreeBuilderType treeBuilderType, Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true) { + switch (treeBuilderType) { + case TreeBuilderType::classicHmm: { // Use StateTree.hh + return std::unique_ptr(nullptr); + } break; + case TreeBuilderType::minimizedHmm: { // Use TreeStructure.hh + return std::unique_ptr(new MinimizedTreeBuilder(config, lexicon, acousticModel, network, initialize)); + } break; + case TreeBuilderType::ctc: { + return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize)); + } break; + default: defect(); + } +} + #endif diff --git a/src/Search/AdvancedTreeSearch/TreeStructure.cc b/src/Search/TreeBuilder/TreeStructure.cc similarity index 100% rename from src/Search/AdvancedTreeSearch/TreeStructure.cc rename to src/Search/TreeBuilder/TreeStructure.cc diff --git a/src/Search/AdvancedTreeSearch/TreeStructure.hh b/src/Search/TreeBuilder/TreeStructure.hh similarity index 99% rename from src/Search/AdvancedTreeSearch/TreeStructure.hh rename to src/Search/TreeBuilder/TreeStructure.hh index 2e2f658c6..127783bdf 100644 --- a/src/Search/AdvancedTreeSearch/TreeStructure.hh +++ b/src/Search/TreeBuilder/TreeStructure.hh @@ -15,10 +15,11 @@ #ifndef SEARCH_TREESTRUCTURE_HH #define SEARCH_TREESTRUCTURE_HH -#include -#include #include -#include "BatchManager.hh" + +#include +#include +#include "StateTree.hh" #define inline_ __attribute__((always_inline)) inline diff --git a/src/Search/AdvancedTreeSearch/TreeWalker.hh b/src/Search/TreeBuilder/TreeWalker.hh similarity index 100% rename from src/Search/AdvancedTreeSearch/TreeWalker.hh rename to src/Search/TreeBuilder/TreeWalker.hh diff --git a/src/Search/Wfst/Makefile b/src/Search/Wfst/Makefile index a2890549c..4acaf1a6f 100644 --- a/src/Search/Wfst/Makefile +++ b/src/Search/Wfst/Makefile @@ -61,6 +61,7 @@ CHECK_CTRANS_O = $(OBJDIR)/check_ctrans.o \ CHECK_O = $(OBJDIR)/check.o \ libSprintSearchWfst.$(a) \ ../libSprintSearch.$(a) \ + ../TreeBuilder/libSprintTreeBuilder.$(a) \ ../../Speech/libSprintSpeech.$(a) \ ../../Am/libSprintAm.$(a) \ ../../Mm/libSprintMm.$(a) \ diff --git a/src/Search/Wfst/StateTree.hh b/src/Search/Wfst/StateTree.hh index 4eae2c40b..d67d37c4e 100644 --- a/src/Search/Wfst/StateTree.hh +++ b/src/Search/Wfst/StateTree.hh @@ -18,7 +18,7 @@ #include #include #include -#include +#include namespace Search { namespace Wfst { diff --git a/src/Search/WordConditionedTreeSearch.cc b/src/Search/WordConditionedTreeSearch.cc index f37da4646..4cbce7b10 100644 --- a/src/Search/WordConditionedTreeSearch.cc +++ b/src/Search/WordConditionedTreeSearch.cc @@ -29,10 +29,10 @@ #include #include #include +#include #include #include "Histogram.hh" #include "LanguageModelLookahead.hh" -#include "StateTree.hh" #ifdef MODULE_LM_FSA #include diff --git a/src/Search/check.cc b/src/Search/check.cc index 58db27399..162c5de3c 100644 --- a/src/Search/check.cc +++ b/src/Search/check.cc @@ -14,7 +14,7 @@ */ #include #include -#include "StateTree.hh" +#include #ifdef MODULE_SEARCH_WFST #include #include diff --git a/src/Speech/Makefile b/src/Speech/Makefile index 60bee95d3..d6670c2d3 100644 --- a/src/Speech/Makefile +++ b/src/Speech/Makefile @@ -46,6 +46,7 @@ CHECK_O = $(OBJDIR)/check.o \ ../Mm/libSprintMm.$(a) \ ../Mc/libSprintMc.$(a) \ ../Search/libSprintSearch.$(a) \ + ../Search/TreeBuilder/libSprintTreeBuilder.$(a) \ ../Bliss/libSprintBliss.$(a) \ ../Flow/libSprintFlow.$(a) \ ../Fsa/libSprintFsa.$(a) \ diff --git a/src/Test/Makefile b/src/Test/Makefile index 7de6c4d37..7f7bafa4c 100644 --- a/src/Test/Makefile +++ b/src/Test/Makefile @@ -62,6 +62,7 @@ UNIT_TEST_O = $(OBJDIR)/UnitTester.o $(TEST_O) \ ../Core/libSprintCore.$(a)\ ../Speech/libSprintSpeech.$(a) \ ../Search/libSprintSearch.$(a) \ + ../Search/TreeBuilder/libSprintTreeBuilder.$(a) \ ../Lattice/libSprintLattice.$(a) \ ../Am/libSprintAm.$(a) \ ../Mm/libSprintMm.$(a) \ diff --git a/src/Tools/Archiver/Makefile b/src/Tools/Archiver/Makefile index f31539366..4520c120e 100644 --- a/src/Tools/Archiver/Makefile +++ b/src/Tools/Archiver/Makefile @@ -12,6 +12,7 @@ TARGETS = archiver$(exe) ARCHIVER_O = $(OBJDIR)/Archiver.o \ ../../Speech/libSprintSpeech.$(a) \ ../../Search/libSprintSearch.$(a) \ + ../../Search/TreeBuilder/libSprintTreeBuilder.$(a) \ ../../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) \ ../../Lattice/libSprintLattice.$(a) \ ../../Lm/libSprintLm.$(a) \ diff --git a/src/Tools/NnTrainer/Makefile b/src/Tools/NnTrainer/Makefile index 3b1be0d24..529d3ec8d 100644 --- a/src/Tools/NnTrainer/Makefile +++ b/src/Tools/NnTrainer/Makefile @@ -25,6 +25,7 @@ NN_TRAINER_O = $(OBJDIR)/NnTrainer.o \ ../../Mm/libSprintMm.$(a) \ ../../Nn/libSprintNn.$(a) \ ../../Search/libSprintSearch.$(a) \ + ../../Search/TreeBuilder/libSprintTreeBuilder.$(a) \ ../../Signal/libSprintSignal.$(a) \ ../../Speech/libSprintSpeech.$(a) From dea490c4fd671ab856f5a8d535123a299808fb19 Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 11 Feb 2025 11:38:53 +0100 Subject: [PATCH 09/24] replace std::vector by std::set --- src/Search/AdvancedTreeSearch/PersistentStateTree.hh | 2 +- src/Search/AdvancedTreeSearch/TreeBuilder.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh index 5c028f962..cd67123ba 100644 --- a/src/Search/AdvancedTreeSearch/PersistentStateTree.hh +++ b/src/Search/AdvancedTreeSearch/PersistentStateTree.hh @@ -105,7 +105,7 @@ public: StateId ciRootState; // Other root nodes (currently used for the wordBoundaryRoot in CtcTreeBuilder) - std::vector otherRootStates; + std::set otherRootStates; // The word-end exits std::vector exits; diff --git a/src/Search/AdvancedTreeSearch/TreeBuilder.cc b/src/Search/AdvancedTreeSearch/TreeBuilder.cc index b51cbf3f2..27a2a5158 100644 --- a/src/Search/AdvancedTreeSearch/TreeBuilder.cc +++ b/src/Search/AdvancedTreeSearch/TreeBuilder.cc @@ -1220,7 +1220,7 @@ CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& // Create a special root for the word-boundary token if it exists in the lexicon if (lexicon.specialLemma("word-boundary") != nullptr) { wordBoundaryRoot_ = createRoot(); - network_.otherRootStates.push_back(wordBoundaryRoot_); + network_.otherRootStates.insert(wordBoundaryRoot_); } } } From 64a6c500e28c92f8e85698121deabc0653c4f118 Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 18 Feb 2025 16:46:22 +0100 Subject: [PATCH 10/24] Make label loops optional for RNA topology --- src/Search/AdvancedTreeSearch/SearchSpace.cc | 1 + src/Search/TreeBuilder/TreeBuilder.cc | 15 ++++++++++----- src/Search/TreeBuilder/TreeBuilder.hh | 14 ++++++++++++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index b5d4238cc..83249d91b 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -136,6 +136,7 @@ const Core::Choice choiceTreeBuilderType( "classic-hmm", static_cast(TreeBuilderType::classicHmm), "minimized-hmm", static_cast(TreeBuilderType::minimizedHmm), "ctc", static_cast(TreeBuilderType::ctc), + "rna", static_cast(TreeBuilderType::rna), Core::Choice::endMark()); const Core::ParameterChoice paramTreeBuilderType( diff --git a/src/Search/TreeBuilder/TreeBuilder.cc b/src/Search/TreeBuilder/TreeBuilder.cc index 284b26f96..c9e94cc7d 100644 --- a/src/Search/TreeBuilder/TreeBuilder.cc +++ b/src/Search/TreeBuilder/TreeBuilder.cc @@ -1200,8 +1200,9 @@ inline void MinimizedTreeBuilder::mapSuccessors(const std::set& success // -------------------- CtcTreeBuilder -------------------- -CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) - : AbstractTreeBuilder(config, lexicon, acousticModel, network) { +CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize, bool labelLoops) + : AbstractTreeBuilder(config, lexicon, acousticModel, network), + labelLoops_(labelLoops) { auto iters = lexicon.phonemeInventory()->phonemes(); for (auto it = iters.first; it != iters.second; ++it) { require(not(*it)->isContextDependent()); // Context dependent labels are not supported @@ -1226,7 +1227,7 @@ CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& } std::unique_ptr CtcTreeBuilder::newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) { - return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network)); + return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize, labelLoops_)); } void CtcTreeBuilder::build() { @@ -1341,8 +1342,12 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia // Add new (non-blank) state currentState = extendState(currentState, desc); - // Add loop for this state - addTransition(currentState, currentState); + + if (labelLoops_) { + // Add loop for this state + addTransition(currentState, currentState); + } + // Add transition from previous non-blank state to this state, allowing to skip the blank state in-between these two if (prevNonBlankState != invalidTreeNodeIndex) { addTransition(prevNonBlankState, currentState); diff --git a/src/Search/TreeBuilder/TreeBuilder.hh b/src/Search/TreeBuilder/TreeBuilder.hh index 2b84ee4b4..88fdfdee7 100644 --- a/src/Search/TreeBuilder/TreeBuilder.hh +++ b/src/Search/TreeBuilder/TreeBuilder.hh @@ -38,6 +38,7 @@ enum class TreeBuilderType { classicHmm = 1, minimizedHmm = 2, ctc = 3, + rna = 4, }; class AbstractTreeBuilder : public Core::Component { @@ -254,9 +255,13 @@ protected: void mapSuccessors(const std::set&, std::set&, const std::vector&, const std::vector&); }; +// Tree builder for constructing search trees with either CTC or RNA topology. +// The topology depends on the 'labelLoops' parameter: +// CTC topology is used when 'labelLoops' is true, adding label self-loops. +// RNA (Transducer) topology is used when 'labelLoops' is false. class CtcTreeBuilder : public AbstractTreeBuilder { public: - CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true, bool labelLoops = true); virtual ~CtcTreeBuilder() = default; virtual std::unique_ptr newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); @@ -265,6 +270,8 @@ public: virtual void build(); protected: + bool labelLoops_; + StateId wordBoundaryRoot_; Search::StateTree::StateDesc blankDesc_; Am::AllophoneStateIndex blankAllophoneStateIndex_; @@ -294,7 +301,10 @@ std::unique_ptr createTreeBuilder(TreeBuilderType treeBuild return std::unique_ptr(new MinimizedTreeBuilder(config, lexicon, acousticModel, network, initialize)); } break; case TreeBuilderType::ctc: { - return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize)); + return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize, true)); + } break; + case TreeBuilderType::rna: { + return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize, false)); } break; default: defect(); } From 710505fac6c059736bfcb4d8abea375012bf68ec Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 18 Feb 2025 17:07:42 +0100 Subject: [PATCH 11/24] Formatting --- src/Search/TreeBuilder/TreeBuilder.cc | 16 +++++++--------- src/Search/TreeBuilder/TreeBuilder.hh | 18 +++++++++--------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/Search/TreeBuilder/TreeBuilder.cc b/src/Search/TreeBuilder/TreeBuilder.cc index c9e94cc7d..bada7daf9 100644 --- a/src/Search/TreeBuilder/TreeBuilder.cc +++ b/src/Search/TreeBuilder/TreeBuilder.cc @@ -17,8 +17,8 @@ #include #include #include -#include "StateTree.hh" #include "PersistentStateTree.hh" +#include "StateTree.hh" using namespace Search; @@ -1237,9 +1237,9 @@ void CtcTreeBuilder::build() { wordBoundaryLemma = lexicon_.specialLemma("word-boundary"); } - auto blankLemma = lexicon_.specialLemma("blank"); - auto silenceLemma = lexicon_.specialLemma("silence"); - auto iters = lexicon_.lemmaPronunciations(); + auto blankLemma = lexicon_.specialLemma("blank"); + auto silenceLemma = lexicon_.specialLemma("silence"); + auto iters = lexicon_.lemmaPronunciations(); // Iterate over the lemmata and add them to the tree for (auto it = iters.first; it != iters.second; ++it) { @@ -1300,7 +1300,7 @@ StateId CtcTreeBuilder::extendState(StateId predecessor, StateTree::StateDesc de void CtcTreeBuilder::addTransition(StateId predecessor, StateId successor) { bool found = false; for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { - if(!target.isLabel() && network_.structure.state(*target).stateDesc == network_.structure.state(successor).stateDesc) { + if (!target.isLabel() && network_.structure.state(*target).stateDesc == network_.structure.state(successor).stateDesc) { // The node is already a successor of the predecessor, so the transition already exists found = true; } @@ -1312,7 +1312,7 @@ void CtcTreeBuilder::addTransition(StateId predecessor, StateId successor) { } StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronunciation const* pron) { - StateId currentState = startState; + StateId currentState = startState; StateId prevNonBlankState = invalidTreeNodeIndex; for (u32 i = 0u; i < pron->length(); i++) { @@ -1375,7 +1375,7 @@ void CtcTreeBuilder::addWordBoundaryStates() { // Add the word-boundary to the tree, starting from the wordBoundaryRoot_ // If the word-boundary has several pronunciation, only the first one is considered - auto prons = wordBoundaryLemma->pronunciations(); + auto prons = wordBoundaryLemma->pronunciations(); wordBoundaryEnd = extendPronunciation(wordBoundaryRoot_, (prons.first)->pronunciation()); wordBoundaryPronLemma = prons.first; @@ -1399,5 +1399,3 @@ void CtcTreeBuilder::addWordBoundaryStates() { // Add loop for this blank state addTransition(blankBefore, blankBefore); } - - diff --git a/src/Search/TreeBuilder/TreeBuilder.hh b/src/Search/TreeBuilder/TreeBuilder.hh index 88fdfdee7..8f9cbad28 100644 --- a/src/Search/TreeBuilder/TreeBuilder.hh +++ b/src/Search/TreeBuilder/TreeBuilder.hh @@ -18,8 +18,8 @@ #include #include #include -#include "StateTree.hh" #include "PersistentStateTree.hh" +#include "StateTree.hh" namespace Bliss { class Lexicon; @@ -174,11 +174,11 @@ protected: const u32 hash; }; - typedef std::set PhonemeIdSet; - typedef Core::HashMap RootHash; - typedef Core::HashMap SkipRootsHash; - typedef Core::HashMap, RootKey::Hash> CoarticulationJointHash; - typedef Core::HashMap PredecessorsHash; + typedef std::set PhonemeIdSet; + typedef Core::HashMap RootHash; + typedef Core::HashMap SkipRootsHash; + typedef Core::HashMap, RootKey::Hash> CoarticulationJointHash; + typedef Core::HashMap PredecessorsHash; s32 minPhones_; bool addCiTransitions_; @@ -280,16 +280,16 @@ protected: StateId createRoot(); // @param state is the last state of the word with pronunciation ID @param pron, add an exit leading to the root node @param transitState // The exit is appended to the state's successors - u32 addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron); + u32 addExit(StateId state, StateId transitState, Bliss::LemmaPronunciation::Id pron); // Check if the node with @param desc is already a successor of the @param predecessor and add it if not StateId extendState(StateId predecessor, Search::StateTree::StateDesc desc); // Starting in @param startState (usually the root), include the lemma with pronunciation @param pron in the tree StateId extendPronunciation(StateId startState, Bliss::Pronunciation const* pron); // Add a transition between two already existing states, used to insert loops and skip-transitions - void addTransition(StateId predecessor, StateId successor); + void addTransition(StateId predecessor, StateId successor); // If the lexicon contains a word-boundary token, it is added starting from the wordBoundaryRoot_ - void addWordBoundaryStates(); + void addWordBoundaryStates(); }; std::unique_ptr createTreeBuilder(TreeBuilderType treeBuilderType, Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true) { From 4a88b3c91c53d831c834abae4fff7d673f1e6300 Mon Sep 17 00:00:00 2001 From: Larissa Date: Tue, 18 Feb 2025 18:23:24 +0100 Subject: [PATCH 12/24] Rename "labelLoops" to "enableLabelLoop" --- src/Search/TreeBuilder/TreeBuilder.cc | 4 ++-- src/Search/TreeBuilder/TreeBuilder.hh | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Search/TreeBuilder/TreeBuilder.cc b/src/Search/TreeBuilder/TreeBuilder.cc index ea101c52e..5c64efd71 100644 --- a/src/Search/TreeBuilder/TreeBuilder.cc +++ b/src/Search/TreeBuilder/TreeBuilder.cc @@ -1200,9 +1200,9 @@ inline void MinimizedTreeBuilder::mapSuccessors(const std::set& success // -------------------- CtcTreeBuilder -------------------- -CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize, bool labelLoops) +CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize, bool enableLabelLoop) : AbstractTreeBuilder(config, lexicon, acousticModel, network), - labelLoops_(labelLoops) { + labelLoops_(enableLabelLoop) { auto iters = lexicon.phonemeInventory()->phonemes(); for (auto it = iters.first; it != iters.second; ++it) { require(not(*it)->isContextDependent()); // Context dependent labels are not supported diff --git a/src/Search/TreeBuilder/TreeBuilder.hh b/src/Search/TreeBuilder/TreeBuilder.hh index 63c31661d..eb7f7ac11 100644 --- a/src/Search/TreeBuilder/TreeBuilder.hh +++ b/src/Search/TreeBuilder/TreeBuilder.hh @@ -256,12 +256,12 @@ protected: }; // Tree builder for constructing search trees with either CTC or RNA topology. -// The topology depends on the 'labelLoops' parameter: -// CTC topology is used when 'labelLoops' is true, adding label self-loops. -// RNA (Transducer) topology is used when 'labelLoops' is false. +// The topology depends on the 'enableLabelLoop' parameter: +// CTC topology is used when 'enableLabelLoop' is true, adding label self-loops. +// RNA (Transducer) topology is used when 'enableLabelLoop' is false. class CtcTreeBuilder : public AbstractTreeBuilder { public: - CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true, bool labelLoops = true); + CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true, bool enableLabelLoop = true); virtual ~CtcTreeBuilder() = default; virtual std::unique_ptr newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); From 9a5d21887bdc7e27fc83c2de14169e4f36d64319 Mon Sep 17 00:00:00 2001 From: Larissa Date: Sun, 23 Feb 2025 20:04:46 +0100 Subject: [PATCH 13/24] Move createTreeBuilder to Search::Module and remove subdir --- Modules.make | 1 - .../makefiles/Modules.make | 1 - .../makefiles/Modules.make | 1 - .../makefiles/Modules.make | 1 - .../makefiles/Modules.make | 1 - .../makefiles/Modules.make | 1 - .../AdvancedTreeSearch/AcousticLookAhead.hh | 2 +- .../AdvancedTreeSearch/AdvancedTreeSearch.cc | 3 +- .../LanguageModelLookahead.hh | 6 ++-- .../AdvancedTreeSearch/LinearPrediction.hh | 2 +- .../AdvancedTreeSearch/PathRecombination.hh | 4 +-- .../PathRecombinationApproximation.hh | 2 +- src/Search/AdvancedTreeSearch/PrefixFilter.hh | 2 +- .../SearchNetworkTransformation.hh | 4 +-- src/Search/AdvancedTreeSearch/SearchSpace.cc | 9 +++--- src/Search/AdvancedTreeSearch/SearchSpace.hh | 9 +++--- .../AdvancedTreeSearch/SearchSpaceHelpers.hh | 2 +- src/Search/AdvancedTreeSearch/Trace.hh | 2 +- src/Search/LanguageModelLookahead.hh | 2 +- src/Search/Makefile | 10 ++++--- src/Search/Module.cc | 16 +++++++++++ src/Search/Module.hh | 14 ++++++++-- .../{TreeBuilder => }/PersistentStateTree.cc | 3 +- .../{TreeBuilder => }/PersistentStateTree.hh | 4 +-- src/Search/{TreeBuilder => }/StateTree.cc | 0 src/Search/{TreeBuilder => }/StateTree.hh | 0 src/Search/{TreeBuilder => }/StateTreeIo.cc | 0 src/Search/{TreeBuilder => }/StateTreeIo.hh | 0 src/Search/{TreeBuilder => }/TreeBuilder.cc | 2 +- src/Search/{TreeBuilder => }/TreeBuilder.hh | 23 +-------------- src/Search/TreeBuilder/Makefile | 28 ------------------- src/Search/{TreeBuilder => }/TreeStructure.cc | 0 src/Search/{TreeBuilder => }/TreeStructure.hh | 2 +- src/Search/{TreeBuilder => }/TreeWalker.hh | 0 src/Search/Wfst/Makefile | 1 - src/Search/Wfst/StateTree.hh | 2 +- src/Search/WordConditionedTreeSearch.cc | 2 +- src/Search/check.cc | 2 +- src/Speech/Makefile | 1 - src/Test/Makefile | 1 - src/Tools/Archiver/Makefile | 1 - src/Tools/NnTrainer/Makefile | 1 - 42 files changed, 67 insertions(+), 101 deletions(-) rename src/Search/{TreeBuilder => }/PersistentStateTree.cc (99%) rename src/Search/{TreeBuilder => }/PersistentStateTree.hh (99%) rename src/Search/{TreeBuilder => }/StateTree.cc (100%) rename src/Search/{TreeBuilder => }/StateTree.hh (100%) rename src/Search/{TreeBuilder => }/StateTreeIo.cc (100%) rename src/Search/{TreeBuilder => }/StateTreeIo.hh (100%) rename src/Search/{TreeBuilder => }/TreeBuilder.cc (99%) rename src/Search/{TreeBuilder => }/TreeBuilder.hh (92%) delete mode 100644 src/Search/TreeBuilder/Makefile rename src/Search/{TreeBuilder => }/TreeStructure.cc (100%) rename src/Search/{TreeBuilder => }/TreeStructure.hh (99%) rename src/Search/{TreeBuilder => }/TreeWalker.hh (100%) diff --git a/Modules.make b/Modules.make index 578e9edfd..a9ee0ae7c 100644 --- a/Modules.make +++ b/Modules.make @@ -148,7 +148,6 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) -LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make index 03b60d6fe..f171381f7 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_arm_v1/makefiles/Modules.make @@ -143,7 +143,6 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) -LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make index 03b60d6fe..f171381f7 100644 --- a/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make +++ b/apptainer/2022-10-21_tensorflow-1.15_v1/makefiles/Modules.make @@ -143,7 +143,6 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) -LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make index 373ec61e6..2ea9bf106 100644 --- a/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make +++ b/apptainer/2023-05-08_tensorflow-2.8_v1/makefiles/Modules.make @@ -143,7 +143,6 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) -LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make index 3606464ea..af199b57a 100644 --- a/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make +++ b/apptainer/2023-08-09_tensorflow-2.8_onnx-1.15_v1/makefiles/Modules.make @@ -147,7 +147,6 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) -LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make index e23bf5002..bc36c260b 100644 --- a/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make +++ b/apptainer/2023-11-08_tensorflow-2.14_v1/makefiles/Modules.make @@ -148,7 +148,6 @@ endif # ****** Libraries ****** LIBS_SEARCH = src/Search/libSprintSearch.$(a) -LIBS_SEARCH += src/Search/TreeBuilder/libSprintTreeBuilder.$(a) ifdef MODULE_SEARCH_WFST LIBS_SEARCH += src/Search/Wfst/libSprintSearchWfst.$(a) LIBS_SEARCH += src/OpenFst/libSprintOpenFst.$(a) diff --git a/src/Search/AdvancedTreeSearch/AcousticLookAhead.hh b/src/Search/AdvancedTreeSearch/AcousticLookAhead.hh index 4c94d2314..4d7234a4f 100644 --- a/src/Search/AdvancedTreeSearch/AcousticLookAhead.hh +++ b/src/Search/AdvancedTreeSearch/AcousticLookAhead.hh @@ -17,7 +17,7 @@ #include #include -#include +#include #include "Helpers.hh" #include "SearchSpace.hh" diff --git a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc index 5d829e465..246ef90ea 100644 --- a/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc +++ b/src/Search/AdvancedTreeSearch/AdvancedTreeSearch.cc @@ -21,10 +21,9 @@ #include #include #include +#include #include -#include #include - #include "SearchSpace.hh" #include "SearchSpaceStatistics.hh" diff --git a/src/Search/AdvancedTreeSearch/LanguageModelLookahead.hh b/src/Search/AdvancedTreeSearch/LanguageModelLookahead.hh index 0e490ca6d..2538e513b 100644 --- a/src/Search/AdvancedTreeSearch/LanguageModelLookahead.hh +++ b/src/Search/AdvancedTreeSearch/LanguageModelLookahead.hh @@ -26,9 +26,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include "LinearPrediction.hh" diff --git a/src/Search/AdvancedTreeSearch/LinearPrediction.hh b/src/Search/AdvancedTreeSearch/LinearPrediction.hh index e53a680df..57329106f 100644 --- a/src/Search/AdvancedTreeSearch/LinearPrediction.hh +++ b/src/Search/AdvancedTreeSearch/LinearPrediction.hh @@ -18,7 +18,7 @@ #include #include #include -#include +#include namespace Search { class LinearPrediction { diff --git a/src/Search/AdvancedTreeSearch/PathRecombination.hh b/src/Search/AdvancedTreeSearch/PathRecombination.hh index 0217b1979..ed0848255 100644 --- a/src/Search/AdvancedTreeSearch/PathRecombination.hh +++ b/src/Search/AdvancedTreeSearch/PathRecombination.hh @@ -16,8 +16,8 @@ #define PATHRECOMBINATION_HH #include -#include -#include +#include +#include #include "Helpers.hh" diff --git a/src/Search/AdvancedTreeSearch/PathRecombinationApproximation.hh b/src/Search/AdvancedTreeSearch/PathRecombinationApproximation.hh index e2349477e..25f731b37 100644 --- a/src/Search/AdvancedTreeSearch/PathRecombinationApproximation.hh +++ b/src/Search/AdvancedTreeSearch/PathRecombinationApproximation.hh @@ -15,7 +15,7 @@ #ifndef SEARCH_PATHRECOMBINATIONAPPROXIMATION_HH #define SEARCH_PATHRECOMBINATIONAPPROXIMATION_HH -#include +#include #include "PathRecombination.hh" namespace Search { diff --git a/src/Search/AdvancedTreeSearch/PrefixFilter.hh b/src/Search/AdvancedTreeSearch/PrefixFilter.hh index 8497bca11..0a5dc31e5 100644 --- a/src/Search/AdvancedTreeSearch/PrefixFilter.hh +++ b/src/Search/AdvancedTreeSearch/PrefixFilter.hh @@ -18,7 +18,7 @@ #include #include -#include +#include #include "SearchSpaceHelpers.hh" namespace Core { diff --git a/src/Search/AdvancedTreeSearch/SearchNetworkTransformation.hh b/src/Search/AdvancedTreeSearch/SearchNetworkTransformation.hh index 3ff961d1e..91b763197 100644 --- a/src/Search/AdvancedTreeSearch/SearchNetworkTransformation.hh +++ b/src/Search/AdvancedTreeSearch/SearchNetworkTransformation.hh @@ -18,8 +18,8 @@ /** * This file contains a collection of helper classes useful for transformations of the search network * */ -#include -#include +#include +#include namespace AdvancedTreeSearch { struct StateWithSuccessors { diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index dd80c396b..0fdeb320e 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -22,10 +22,9 @@ #include #include #include +#include #include - -#include -#include +#include #include "AcousticLookAhead.hh" #include "PrefixFilter.hh" @@ -388,7 +387,7 @@ StaticSearchAutomaton::StaticSearchAutomaton(Core::Configuration config, Core::R : Precursor(config), hmmLength(acousticModel->hmmTopologySet()->getDefault().nPhoneStates() * acousticModel->hmmTopologySet()->getDefault().nSubStates()), minimized(paramBuildMinimizedTreeFromScratch(config)), - network(config, acousticModel, lexicon, std::bind(&createTreeBuilder, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)), + network(config, acousticModel, lexicon, std::bind(&Module_::createTreeBuilder, &Search::Module::instance(), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)), prefixFilter(nullptr), treeBuilderType_(static_cast(paramTreeBuilderType(config))), acousticModel_(acousticModel), @@ -410,7 +409,7 @@ void StaticSearchAutomaton::buildNetwork() { if (!network.read(transformation)) { log() << "persistent network image could not be loaded, building it"; - std::unique_ptr builder = createTreeBuilder(treeBuilderType_, config, *lexicon_, *acousticModel_, network); + std::unique_ptr builder = Search::Module::instance().createTreeBuilder(treeBuilderType_, config, *lexicon_, *acousticModel_, network); if (not builder) { network.build(); network.cleanup(); diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.hh b/src/Search/AdvancedTreeSearch/SearchSpace.hh index ea52e8e39..c487c68ad 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.hh +++ b/src/Search/AdvancedTreeSearch/SearchSpace.hh @@ -20,11 +20,12 @@ #include #include +#include #include -#include -#include -#include -#include +#include +#include +#include +#include #include "Helpers.hh" #include "LanguageModelLookahead.hh" diff --git a/src/Search/AdvancedTreeSearch/SearchSpaceHelpers.hh b/src/Search/AdvancedTreeSearch/SearchSpaceHelpers.hh index 846826cff..60dc4b0b0 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpaceHelpers.hh +++ b/src/Search/AdvancedTreeSearch/SearchSpaceHelpers.hh @@ -16,7 +16,7 @@ #define SEARCH_SPACE_HELPERS #include -#include +#include #include "LanguageModelLookahead.hh" #include "TraceManager.hh" diff --git a/src/Search/AdvancedTreeSearch/Trace.hh b/src/Search/AdvancedTreeSearch/Trace.hh index ecf794f25..b99b55a1d 100644 --- a/src/Search/AdvancedTreeSearch/Trace.hh +++ b/src/Search/AdvancedTreeSearch/Trace.hh @@ -17,8 +17,8 @@ #include #include +#include #include -#include #include #include "PathTrace.hh" diff --git a/src/Search/LanguageModelLookahead.hh b/src/Search/LanguageModelLookahead.hh index 29d3540ea..faa83bb10 100644 --- a/src/Search/LanguageModelLookahead.hh +++ b/src/Search/LanguageModelLookahead.hh @@ -25,7 +25,7 @@ #include #include #include -#include +#include "StateTree.hh" namespace Search { diff --git a/src/Search/Makefile b/src/Search/Makefile index cd955cc8d..bba7b5f3b 100644 --- a/src/Search/Makefile +++ b/src/Search/Makefile @@ -14,8 +14,13 @@ LIBSPRINTSEARCH_O = \ $(OBJDIR)/LatticeHandler.o \ $(OBJDIR)/LanguageModelLookahead.o \ $(OBJDIR)/Module.o \ + $(OBJDIR)/PersistentStateTree.o \ $(OBJDIR)/Search.o \ + $(OBJDIR)/StateTree.o \ + $(OBJDIR)/StateTreeIo.o \ $(OBJDIR)/Traceback.o \ + $(OBJDIR)/TreeBuilder.o \ + $(OBJDIR)/TreeStructure.o \ $(OBJDIR)/WordConditionedTreeSearch.o CHECK_O = $(OBJDIR)/check.o \ @@ -24,13 +29,13 @@ CHECK_O = $(OBJDIR)/check.o \ ../Fsa/libSprintFsa.$(a) \ ../Core/libSprintCore.$(a) + ifdef MODULE_SEARCH_MBR LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskAStarSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskNBestListSearch.o LIBSPRINTSEARCH_O += $(OBJDIR)/MinimumBayesRiskSearchUtil.o endif -SUBDIRS += TreeBuilder ifdef MODULE_SEARCH_WFST SUBDIRS += Wfst endif @@ -55,9 +60,6 @@ libSprintSearch.$(a): $(LIBSPRINTSEARCH_O) check$(exe): $(CHECK_O) $(LD) $(CHECK_O) -o check$(exe) $(LDFLAGS) -TreeBuilder: - $(MAKE) -C $@ libSprintTreeBuilder.$(a) - Wfst: $(MAKE) -C $@ libSprintSearchWfst.$(a) diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 666ce408d..456e237ea 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -16,6 +16,7 @@ #include #include #include +#include "TreeBuilder.hh" #ifdef MODULE_SEARCH_WFST #include #include @@ -32,6 +33,21 @@ using namespace Search; Module_::Module_() { } +std::unique_ptr Module_::createTreeBuilder(TreeBuilderType treeBuilderType, Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) const { + switch (treeBuilderType) { + case Search::TreeBuilderType::classicHmm: { // Use StateTree.hh + return std::unique_ptr(nullptr); + } break; + case Search::TreeBuilderType::minimizedHmm: { // Use TreeStructure.hh + return std::unique_ptr(new MinimizedTreeBuilder(config, lexicon, acousticModel, network, initialize)); + } break; + case Search::TreeBuilderType::ctc: { + return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize)); + } break; + default: defect(); + } +} + SearchAlgorithm* Module_::createRecognizer(SearchType type, const Core::Configuration& config) const { SearchAlgorithm* recognizer = 0; switch (type) { diff --git a/src/Search/Module.hh b/src/Search/Module.hh index 86be91222..d72d43999 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -18,11 +18,20 @@ #include #include +#include "TreeBuilder.hh" + namespace Search { class SearchAlgorithm; class LatticeHandler; +enum class TreeBuilderType { + previousBehavior = 0, + classicHmm = 1, + minimizedHmm = 2, + ctc = 3, +}; + enum SearchType { WordConditionedTreeSearchType, AdvancedTreeSearch, @@ -34,8 +43,9 @@ class Module_ { public: Module_(); - SearchAlgorithm* createRecognizer(SearchType type, const Core::Configuration& config) const; - LatticeHandler* createLatticeHandler(const Core::Configuration& c) const; + std::unique_ptr createTreeBuilder(TreeBuilderType treeBuilderType, Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true) const; + SearchAlgorithm* createRecognizer(SearchType type, const Core::Configuration& config) const; + LatticeHandler* createLatticeHandler(const Core::Configuration& c) const; }; typedef Core::SingletonHolder Module; diff --git a/src/Search/TreeBuilder/PersistentStateTree.cc b/src/Search/PersistentStateTree.cc similarity index 99% rename from src/Search/TreeBuilder/PersistentStateTree.cc rename to src/Search/PersistentStateTree.cc index 8ee5d40e8..10e2bd029 100644 --- a/src/Search/TreeBuilder/PersistentStateTree.cc +++ b/src/Search/PersistentStateTree.cc @@ -19,13 +19,12 @@ #include #include +#include #include #include -#include "StateTree.hh" #include "TreeStructure.hh" - using namespace Search; using namespace Core; diff --git a/src/Search/TreeBuilder/PersistentStateTree.hh b/src/Search/PersistentStateTree.hh similarity index 99% rename from src/Search/TreeBuilder/PersistentStateTree.hh rename to src/Search/PersistentStateTree.hh index 615af28cd..511ed8596 100644 --- a/src/Search/TreeBuilder/PersistentStateTree.hh +++ b/src/Search/PersistentStateTree.hh @@ -32,11 +32,11 @@ struct MyStandardValueHash { }; class AbstractTreeBuilder; -enum class TreeBuilderType; namespace Search { class HMMStateNetwork; class StateTree; +enum class TreeBuilderType; class PersistentStateTree : public Core::ReferenceCounted { public: @@ -148,4 +148,4 @@ private: bool read(Core::MappedArchiveReader reader); }; } // namespace Search -#endif // STATETREECOMPRESSION_H +#endif // PERSISTENT_STATE_TREE_H diff --git a/src/Search/TreeBuilder/StateTree.cc b/src/Search/StateTree.cc similarity index 100% rename from src/Search/TreeBuilder/StateTree.cc rename to src/Search/StateTree.cc diff --git a/src/Search/TreeBuilder/StateTree.hh b/src/Search/StateTree.hh similarity index 100% rename from src/Search/TreeBuilder/StateTree.hh rename to src/Search/StateTree.hh diff --git a/src/Search/TreeBuilder/StateTreeIo.cc b/src/Search/StateTreeIo.cc similarity index 100% rename from src/Search/TreeBuilder/StateTreeIo.cc rename to src/Search/StateTreeIo.cc diff --git a/src/Search/TreeBuilder/StateTreeIo.hh b/src/Search/StateTreeIo.hh similarity index 100% rename from src/Search/TreeBuilder/StateTreeIo.hh rename to src/Search/StateTreeIo.hh diff --git a/src/Search/TreeBuilder/TreeBuilder.cc b/src/Search/TreeBuilder.cc similarity index 99% rename from src/Search/TreeBuilder/TreeBuilder.cc rename to src/Search/TreeBuilder.cc index 5cd2ad9c1..e99f034d6 100644 --- a/src/Search/TreeBuilder/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -16,8 +16,8 @@ #include #include #include +#include #include -#include "StateTree.hh" #include "PersistentStateTree.hh" using namespace Search; diff --git a/src/Search/TreeBuilder/TreeBuilder.hh b/src/Search/TreeBuilder.hh similarity index 92% rename from src/Search/TreeBuilder/TreeBuilder.hh rename to src/Search/TreeBuilder.hh index 3dc841551..ee68f9aad 100644 --- a/src/Search/TreeBuilder/TreeBuilder.hh +++ b/src/Search/TreeBuilder.hh @@ -16,10 +16,10 @@ #define TREEBUILDER_HH #include +#include #include #include #include "PersistentStateTree.hh" -#include "StateTree.hh" namespace Bliss { class Lexicon; @@ -33,13 +33,6 @@ namespace Core { class Configuration; } -enum class TreeBuilderType { - previousBehavior = 0, - classicHmm = 1, - minimizedHmm = 2, - ctc = 3, -}; - class AbstractTreeBuilder : public Core::Component { public: typedef u32 StateId; @@ -292,18 +285,4 @@ protected: void addWordBoundaryStates(); }; -std::unique_ptr createTreeBuilder(TreeBuilderType treeBuilderType, Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true) { - switch (treeBuilderType) { - case TreeBuilderType::classicHmm: { // Use StateTree.hh - return std::unique_ptr(nullptr); - } break; - case TreeBuilderType::minimizedHmm: { // Use TreeStructure.hh - return std::unique_ptr(new MinimizedTreeBuilder(config, lexicon, acousticModel, network, initialize)); - } break; - case TreeBuilderType::ctc: { - return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize)); - } break; - default: defect(); - } -} #endif diff --git a/src/Search/TreeBuilder/Makefile b/src/Search/TreeBuilder/Makefile deleted file mode 100644 index 09ce6fbba..000000000 --- a/src/Search/TreeBuilder/Makefile +++ /dev/null @@ -1,28 +0,0 @@ -#!gmake - -TOPDIR = ../../.. - -include $(TOPDIR)/Makefile.cfg - -# ----------------------------------------------------------------------------- - -SUBDIRS = -TARGETS = libSprintTreeBuilder.$(a) - -LIBSPRINTTREEBUILDER_O = $(OBJDIR)/PersistentStateTree.o \ - $(OBJDIR)/StateTree.o \ - $(OBJDIR)/StateTreeIo.o \ - $(OBJDIR)/TreeBuilder.o \ - $(OBJDIR)/TreeStructure.o - -# ----------------------------------------------------------------------------- - -all: $(TARGETS) - -libSprintTreeBuilder.$(a): $(LIBSPRINTTREEBUILDER_O) - $(MAKELIB) $@ $^ - -include $(TOPDIR)/Rules.make - -sinclude $(LIBSPRINTTREEBUILDER_O:.o=.d) -include $(patsubst %.o,%.d,$(filter %.o,$(CHECK_O))) \ No newline at end of file diff --git a/src/Search/TreeBuilder/TreeStructure.cc b/src/Search/TreeStructure.cc similarity index 100% rename from src/Search/TreeBuilder/TreeStructure.cc rename to src/Search/TreeStructure.cc diff --git a/src/Search/TreeBuilder/TreeStructure.hh b/src/Search/TreeStructure.hh similarity index 99% rename from src/Search/TreeBuilder/TreeStructure.hh rename to src/Search/TreeStructure.hh index 127783bdf..4329bc637 100644 --- a/src/Search/TreeBuilder/TreeStructure.hh +++ b/src/Search/TreeStructure.hh @@ -18,8 +18,8 @@ #include #include +#include #include -#include "StateTree.hh" #define inline_ __attribute__((always_inline)) inline diff --git a/src/Search/TreeBuilder/TreeWalker.hh b/src/Search/TreeWalker.hh similarity index 100% rename from src/Search/TreeBuilder/TreeWalker.hh rename to src/Search/TreeWalker.hh diff --git a/src/Search/Wfst/Makefile b/src/Search/Wfst/Makefile index 4acaf1a6f..a2890549c 100644 --- a/src/Search/Wfst/Makefile +++ b/src/Search/Wfst/Makefile @@ -61,7 +61,6 @@ CHECK_CTRANS_O = $(OBJDIR)/check_ctrans.o \ CHECK_O = $(OBJDIR)/check.o \ libSprintSearchWfst.$(a) \ ../libSprintSearch.$(a) \ - ../TreeBuilder/libSprintTreeBuilder.$(a) \ ../../Speech/libSprintSpeech.$(a) \ ../../Am/libSprintAm.$(a) \ ../../Mm/libSprintMm.$(a) \ diff --git a/src/Search/Wfst/StateTree.hh b/src/Search/Wfst/StateTree.hh index d67d37c4e..4eae2c40b 100644 --- a/src/Search/Wfst/StateTree.hh +++ b/src/Search/Wfst/StateTree.hh @@ -18,7 +18,7 @@ #include #include #include -#include +#include namespace Search { namespace Wfst { diff --git a/src/Search/WordConditionedTreeSearch.cc b/src/Search/WordConditionedTreeSearch.cc index 4cbce7b10..f37da4646 100644 --- a/src/Search/WordConditionedTreeSearch.cc +++ b/src/Search/WordConditionedTreeSearch.cc @@ -29,10 +29,10 @@ #include #include #include -#include #include #include "Histogram.hh" #include "LanguageModelLookahead.hh" +#include "StateTree.hh" #ifdef MODULE_LM_FSA #include diff --git a/src/Search/check.cc b/src/Search/check.cc index 162c5de3c..58db27399 100644 --- a/src/Search/check.cc +++ b/src/Search/check.cc @@ -14,7 +14,7 @@ */ #include #include -#include +#include "StateTree.hh" #ifdef MODULE_SEARCH_WFST #include #include diff --git a/src/Speech/Makefile b/src/Speech/Makefile index ba4dc2752..4b5ab3528 100644 --- a/src/Speech/Makefile +++ b/src/Speech/Makefile @@ -46,7 +46,6 @@ CHECK_O = $(OBJDIR)/check.o \ ../Mm/libSprintMm.$(a) \ ../Mc/libSprintMc.$(a) \ ../Search/libSprintSearch.$(a) \ - ../Search/TreeBuilder/libSprintTreeBuilder.$(a) \ ../Bliss/libSprintBliss.$(a) \ ../Flow/libSprintFlow.$(a) \ ../Fsa/libSprintFsa.$(a) \ diff --git a/src/Test/Makefile b/src/Test/Makefile index f745a9d20..b9f3f2fcf 100644 --- a/src/Test/Makefile +++ b/src/Test/Makefile @@ -62,7 +62,6 @@ UNIT_TEST_O = $(OBJDIR)/UnitTester.o $(TEST_O) \ ../Core/libSprintCore.$(a)\ ../Speech/libSprintSpeech.$(a) \ ../Search/libSprintSearch.$(a) \ - ../Search/TreeBuilder/libSprintTreeBuilder.$(a) \ ../Lattice/libSprintLattice.$(a) \ ../Am/libSprintAm.$(a) \ ../Mm/libSprintMm.$(a) \ diff --git a/src/Tools/Archiver/Makefile b/src/Tools/Archiver/Makefile index 2c9746d4e..e9000e78a 100644 --- a/src/Tools/Archiver/Makefile +++ b/src/Tools/Archiver/Makefile @@ -12,7 +12,6 @@ TARGETS = archiver$(exe) ARCHIVER_O = $(OBJDIR)/Archiver.o \ ../../Speech/libSprintSpeech.$(a) \ ../../Search/libSprintSearch.$(a) \ - ../../Search/TreeBuilder/libSprintTreeBuilder.$(a) \ ../../Search/AdvancedTreeSearch/libSprintAdvancedTreeSearch.$(a) \ ../../Lattice/libSprintLattice.$(a) \ ../../Lm/libSprintLm.$(a) \ diff --git a/src/Tools/NnTrainer/Makefile b/src/Tools/NnTrainer/Makefile index c198c0bb8..fc71a4bd2 100644 --- a/src/Tools/NnTrainer/Makefile +++ b/src/Tools/NnTrainer/Makefile @@ -25,7 +25,6 @@ NN_TRAINER_O = $(OBJDIR)/NnTrainer.o \ ../../Mm/libSprintMm.$(a) \ ../../Nn/libSprintNn.$(a) \ ../../Search/libSprintSearch.$(a) \ - ../../Search/TreeBuilder/libSprintTreeBuilder.$(a) \ ../../Signal/libSprintSignal.$(a) \ ../../Speech/libSprintSpeech.$(a) From d8149f87a8c9e47bf1826194e6f58169361164c4 Mon Sep 17 00:00:00 2001 From: Larissa Date: Mon, 24 Feb 2025 15:04:01 +0100 Subject: [PATCH 14/24] Introduce RnaTreeBuilder subclass and allow-label-loop config-parameter --- src/Search/AdvancedTreeSearch/SearchSpace.cc | 14 ++++++++++++-- src/Search/AdvancedTreeSearch/SearchSpace.hh | 1 + src/Search/Module.cc | 7 +++++-- src/Search/Module.hh | 3 ++- src/Search/PersistentStateTree.hh | 2 +- src/Search/TreeBuilder.cc | 15 +++++++++++---- src/Search/TreeBuilder.hh | 10 +++++++++- 7 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index bb36df687..dc114b95b 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -145,6 +145,11 @@ const Core::ParameterChoice paramTreeBuilderType( "which tree builder to use", static_cast(TreeBuilderType::previousBehavior)); +const Core::ParameterBool paramLabelLoop( + "allow-label-loop", + "allow label loops in the search tree", + true); + const Core::ParameterBool paramConditionPredecessorWord( "condition-on-predecessor-word", "", @@ -388,7 +393,7 @@ StaticSearchAutomaton::StaticSearchAutomaton(Core::Configuration config, Core::R : Precursor(config), hmmLength(acousticModel->hmmTopologySet()->getDefault().nPhoneStates() * acousticModel->hmmTopologySet()->getDefault().nSubStates()), minimized(paramBuildMinimizedTreeFromScratch(config)), - network(config, acousticModel, lexicon, std::bind(&Module_::createTreeBuilder, &Search::Module::instance(), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6)), + network(config, acousticModel, lexicon, std::bind(&Module_::createTreeBuilder, &Search::Module::instance(), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6, std::placeholders::_7)), prefixFilter(nullptr), treeBuilderType_(static_cast(paramTreeBuilderType(config))), acousticModel_(acousticModel), @@ -396,6 +401,11 @@ StaticSearchAutomaton::StaticSearchAutomaton(Core::Configuration config, Core::R if (treeBuilderType_ == TreeBuilderType::previousBehavior) { treeBuilderType_ = minimized ? TreeBuilderType::minimizedHmm : TreeBuilderType::classicHmm; } + bool defaultUsed; + labelLoop_ = paramLabelLoop(config, true, &defaultUsed); + if (treeBuilderType_ == TreeBuilderType::rna) { + labelLoop_ = defaultUsed ? false : labelLoop_; // If allow-label-loop is not set in config, set it to false for RNA topology + } } StaticSearchAutomaton::~StaticSearchAutomaton() { @@ -410,7 +420,7 @@ void StaticSearchAutomaton::buildNetwork() { if (!network.read(transformation)) { log() << "persistent network image could not be loaded, building it"; - std::unique_ptr builder = Search::Module::instance().createTreeBuilder(treeBuilderType_, config, *lexicon_, *acousticModel_, network); + std::unique_ptr builder = Search::Module::instance().createTreeBuilder(treeBuilderType_, config, *lexicon_, *acousticModel_, network, labelLoop_); if (not builder) { network.build(); network.cleanup(); diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.hh b/src/Search/AdvancedTreeSearch/SearchSpace.hh index c487c68ad..5e85f9b43 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.hh +++ b/src/Search/AdvancedTreeSearch/SearchSpace.hh @@ -119,6 +119,7 @@ public: protected: TreeBuilderType treeBuilderType_; + bool labelLoop_; private: Core::Ref acousticModel_; diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 456e237ea..b42814ff6 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -33,7 +33,7 @@ using namespace Search; Module_::Module_() { } -std::unique_ptr Module_::createTreeBuilder(TreeBuilderType treeBuilderType, Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) const { +std::unique_ptr Module_::createTreeBuilder(TreeBuilderType treeBuilderType, Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool allowLabelLoop, bool initialize) const { switch (treeBuilderType) { case Search::TreeBuilderType::classicHmm: { // Use StateTree.hh return std::unique_ptr(nullptr); @@ -42,7 +42,10 @@ std::unique_ptr Module_::createTreeBuilder(TreeBuilderType return std::unique_ptr(new MinimizedTreeBuilder(config, lexicon, acousticModel, network, initialize)); } break; case Search::TreeBuilderType::ctc: { - return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize)); + return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, allowLabelLoop, initialize)); + } break; + case Search::TreeBuilderType::rna: { + return std::unique_ptr(new RnaTreeBuilder(config, lexicon, acousticModel, network, allowLabelLoop, initialize)); } break; default: defect(); } diff --git a/src/Search/Module.hh b/src/Search/Module.hh index d72d43999..00d0cdef9 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -30,6 +30,7 @@ enum class TreeBuilderType { classicHmm = 1, minimizedHmm = 2, ctc = 3, + rna = 4 }; enum SearchType { @@ -43,7 +44,7 @@ class Module_ { public: Module_(); - std::unique_ptr createTreeBuilder(TreeBuilderType treeBuilderType, Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true) const; + std::unique_ptr createTreeBuilder(TreeBuilderType treeBuilderType, Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool allowLabelLoop = true, bool initialize = true) const; SearchAlgorithm* createRecognizer(SearchType type, const Core::Configuration& config) const; LatticeHandler* createLatticeHandler(const Core::Configuration& c) const; }; diff --git a/src/Search/PersistentStateTree.hh b/src/Search/PersistentStateTree.hh index 511ed8596..1d4fdcbbe 100644 --- a/src/Search/PersistentStateTree.hh +++ b/src/Search/PersistentStateTree.hh @@ -40,7 +40,7 @@ enum class TreeBuilderType; class PersistentStateTree : public Core::ReferenceCounted { public: - using TreeBuilderFactory = std::function(TreeBuilderType, Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool)>; + using TreeBuilderFactory = std::function(TreeBuilderType, Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool, bool)>; ///@param lexicon This must be given if the resulting exits are supposed to be functional PersistentStateTree(Core::Configuration config, Core::Ref acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory); diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index 7dad9dbf7..e8fa55b83 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1201,9 +1201,9 @@ inline void MinimizedTreeBuilder::mapSuccessors(const std::set& success // -------------------- CtcTreeBuilder -------------------- -CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize, bool enableLabelLoop) +CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool allowLabelLoop, bool initialize) : AbstractTreeBuilder(config, lexicon, acousticModel, network), - labelLoops_(enableLabelLoop) { + labelLoop_(allowLabelLoop) { auto iters = lexicon.phonemeInventory()->phonemes(); for (auto it = iters.first; it != iters.second; ++it) { require(not(*it)->isContextDependent()); // Context dependent labels are not supported @@ -1228,7 +1228,7 @@ CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& } std::unique_ptr CtcTreeBuilder::newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) { - return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize, labelLoops_)); + return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, labelLoop_, initialize)); } void CtcTreeBuilder::build() { @@ -1347,7 +1347,7 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia // Add new (non-blank) state currentState = extendState(currentState, desc); - if (labelLoops_) { + if (labelLoop_) { // Add loop for this state addTransition(currentState, currentState); } @@ -1405,3 +1405,10 @@ void CtcTreeBuilder::addWordBoundaryStates() { // Add loop for this blank state addTransition(blankBefore, blankBefore); } + +// -------------------- RnaTreeBuilder -------------------- + +RnaTreeBuilder::RnaTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool allowLabelLoop, bool initialize) + : CtcTreeBuilder(config, lexicon, acousticModel, network, allowLabelLoop, initialize) {} + + diff --git a/src/Search/TreeBuilder.hh b/src/Search/TreeBuilder.hh index ee68f9aad..e2ca9fb8f 100644 --- a/src/Search/TreeBuilder.hh +++ b/src/Search/TreeBuilder.hh @@ -249,7 +249,7 @@ protected: class CtcTreeBuilder : public AbstractTreeBuilder { public: - CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool allowLabelLoop = true, bool initialize = true); virtual ~CtcTreeBuilder() = default; virtual std::unique_ptr newInstance(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); @@ -258,6 +258,8 @@ public: virtual void build(); protected: + bool labelLoop_; + StateId wordBoundaryRoot_; Search::StateTree::StateDesc blankDesc_; Am::AllophoneStateIndex blankAllophoneStateIndex_; @@ -285,4 +287,10 @@ protected: void addWordBoundaryStates(); }; +class RnaTreeBuilder : public CtcTreeBuilder { +public: + RnaTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool allowLabelLoop = false, bool initialize = true); + virtual ~RnaTreeBuilder() = default; +}; + #endif From 2e9c51bd0a39b7b0ed5f41b54dbb90ad1ecf2aec Mon Sep 17 00:00:00 2001 From: Larissa Date: Sat, 1 Mar 2025 17:29:52 +0100 Subject: [PATCH 15/24] Adjust AdvancedTreeSearch for CtcTreeBuilder support --- src/Search/AdvancedTreeSearch/SearchSpace.cc | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index f3520293d..fdb7719d0 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -478,7 +478,7 @@ void StaticSearchAutomaton::buildDepths(bool onlyFromRoot) { for (u32 a = 1; a < stateDepths.size(); ++a) { if (stateDepths[a] != Core::Type::min && stateDepths[a] != Core::Type::max) { for (HMMStateNetwork::SuccessorIterator it = network.structure.successors(a); it; ++it) { - if (!it.isLabel()) { + if (!it.isLabel() and *it != a) { verify(stateDepths[*it] > stateDepths[a]); } } @@ -515,7 +515,7 @@ int StaticSearchAutomaton::fillStateDepths(StateId state, int depth) { localDepth = 0; for (HMMStateNetwork::SuccessorIterator it = network.structure.successors(state); it; ++it) { - if (not it.isLabel()) { + if (not it.isLabel() and *it != state) { int d = fillStateDepths(*it, depth + 1); if (d > localDepth) { @@ -740,7 +740,9 @@ void StaticSearchAutomaton::buildBatches() { network.dumpDotGraph(paramDumpDotGraph(config), stateDepths); // Print some useful statistics about pushed and unpushed labels - verify(!network.unpushedCoarticulatedRootStates.empty()); + if (treeBuilderType_ != TreeBuilderType::ctc) { + verify(!network.unpushedCoarticulatedRootStates.empty()); + } u32 unpushedLabels = 0; u32 pushedLabels = 0; @@ -751,8 +753,10 @@ void StaticSearchAutomaton::buildBatches() { bool isUnpushed = network.unpushedCoarticulatedRootStates.count(transit) || transit == network.ciRootState || transit == network.rootState; if (isUnpushed) { ++unpushedLabels; - std::map>::iterator it = network.rootTransitDescriptions.find(transit); - verify(it != network.rootTransitDescriptions.end()); + if (treeBuilderType_ != TreeBuilderType::ctc) { + std::map>::iterator it = network.rootTransitDescriptions.find(transit); + verify(it != network.rootTransitDescriptions.end()); + } } else { ++pushedLabels; From 9deba08e28caa30d2970313f8b42aa1a0cb1ed09 Mon Sep 17 00:00:00 2001 From: Larissa Date: Wed, 19 Mar 2025 14:43:01 +0100 Subject: [PATCH 16/24] Move allow-label-loop parameter to CtcTreeBuilder and RnaTreeBuilder --- src/Search/Module.cc | 4 ++++ src/Search/Module.hh | 1 + src/Search/PersistentStateTree.hh | 3 +-- src/Search/TreeBuilder.cc | 28 +++++++++++++++++++++++++--- src/Search/TreeBuilder.hh | 12 ++++++++++++ 5 files changed, 43 insertions(+), 5 deletions(-) diff --git a/src/Search/Module.cc b/src/Search/Module.cc index 2fb7dc482..7dbaefdd5 100644 --- a/src/Search/Module.cc +++ b/src/Search/Module.cc @@ -37,6 +37,7 @@ const Core::Choice choiceTreeBuilderType( "classic-hmm", static_cast(TreeBuilderType::classicHmm), "minimized-hmm", static_cast(TreeBuilderType::minimizedHmm), "ctc", static_cast(TreeBuilderType::ctc), + "rna", static_cast(TreeBuilderType::rna), Core::Choice::endMark()); const Core::ParameterChoice paramTreeBuilderType( @@ -57,6 +58,9 @@ std::unique_ptr Module_::createTreeBuilder(Core::Configurat case TreeBuilderType::ctc: { return std::unique_ptr(new CtcTreeBuilder(config, lexicon, acousticModel, network, initialize)); } break; + case Search::TreeBuilderType::rna: { + return std::unique_ptr(new RnaTreeBuilder(config, lexicon, acousticModel, network, initialize)); + } break; default: defect(); } } diff --git a/src/Search/Module.hh b/src/Search/Module.hh index c50b00afe..88079a68c 100644 --- a/src/Search/Module.hh +++ b/src/Search/Module.hh @@ -30,6 +30,7 @@ enum TreeBuilderType { classicHmm, minimizedHmm, ctc, + rna, }; enum SearchType { diff --git a/src/Search/PersistentStateTree.hh b/src/Search/PersistentStateTree.hh index 1d4fdcbbe..a200ce9b2 100644 --- a/src/Search/PersistentStateTree.hh +++ b/src/Search/PersistentStateTree.hh @@ -36,11 +36,10 @@ class AbstractTreeBuilder; namespace Search { class HMMStateNetwork; class StateTree; -enum class TreeBuilderType; class PersistentStateTree : public Core::ReferenceCounted { public: - using TreeBuilderFactory = std::function(TreeBuilderType, Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool, bool)>; + using TreeBuilderFactory = std::function(Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool)>; ///@param lexicon This must be given if the resulting exits are supposed to be functional PersistentStateTree(Core::Configuration config, Core::Ref acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory); diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index d0474b358..4d6a66d6b 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1205,8 +1205,14 @@ inline void MinimizedTreeBuilder::mapSuccessors(const std::set& success // -------------------- CtcTreeBuilder -------------------- +const Core::ParameterBool CtcTreeBuilder::paramLabelLoop( + "allow-label-loop", + "allow label loops in the search tree", + true); + CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) - : AbstractTreeBuilder(config, lexicon, acousticModel, network) { + : AbstractTreeBuilder(config, lexicon, acousticModel, network), + labelLoop_(paramLabelLoop(config)) { auto iters = lexicon.phonemeInventory()->phonemes(); for (auto it = iters.first; it != iters.second; ++it) { require(not(*it)->isContextDependent()); // Context dependent labels are not supported @@ -1349,8 +1355,12 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia // Add new (non-blank) state currentState = extendState(currentState, desc); - // Add loop for this state - addTransition(currentState, currentState); + + if (labelLoop_) { + // Add loop for this state + addTransition(currentState, currentState); + } + // Add transition from previous non-blank state to this state, allowing to skip the blank state in-between these two if (prevNonBlankState != invalidTreeNodeIndex) { addTransition(prevNonBlankState, currentState); @@ -1404,3 +1414,15 @@ void CtcTreeBuilder::addWordBoundaryStates() { // Add loop for this blank state addTransition(blankBefore, blankBefore); } + +// -------------------- RnaTreeBuilder -------------------- + +const Core::ParameterBool RnaTreeBuilder::paramLabelLoop( + "allow-label-loop", + "allow label loops in the search tree", + false); + +RnaTreeBuilder::RnaTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) + : CtcTreeBuilder(config, lexicon, acousticModel, network, initialize) { + this->labelLoop_ = paramLabelLoop(config); +} diff --git a/src/Search/TreeBuilder.hh b/src/Search/TreeBuilder.hh index 64ce5bcd3..53bfb4c16 100644 --- a/src/Search/TreeBuilder.hh +++ b/src/Search/TreeBuilder.hh @@ -257,7 +257,11 @@ public: // Build a new persistent state network. virtual void build(); + static const Core::ParameterBool paramLabelLoop; + protected: + bool labelLoop_; + StateId wordBoundaryRoot_; Search::StateTree::StateDesc blankDesc_; Am::AllophoneStateIndex blankAllophoneStateIndex_; @@ -285,4 +289,12 @@ protected: void addWordBoundaryStates(); }; +class RnaTreeBuilder : public CtcTreeBuilder { +public: + RnaTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true); + virtual ~RnaTreeBuilder() = default; + + static const Core::ParameterBool paramLabelLoop; +}; + #endif From b23b180964d4df0550389bacdd2a8d68e686b760 Mon Sep 17 00:00:00 2001 From: Larissa Date: Wed, 19 Mar 2025 17:32:02 +0100 Subject: [PATCH 17/24] Adjust verify statements --- src/Search/AdvancedTreeSearch/SearchSpace.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index eee18ab1d..96b87e120 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -723,9 +723,7 @@ void StaticSearchAutomaton::buildBatches() { network.dumpDotGraph(paramDumpDotGraph(config), stateDepths); // Print some useful statistics about pushed and unpushed labels - if (treeBuilderType_ != TreeBuilderType::ctc) { - verify(!network.unpushedCoarticulatedRootStates.empty()); - } + verify(network.coarticulatedRootStates.empty() || !network.unpushedCoarticulatedRootStates.empty()); u32 unpushedLabels = 0; u32 pushedLabels = 0; @@ -736,7 +734,7 @@ void StaticSearchAutomaton::buildBatches() { bool isUnpushed = network.unpushedCoarticulatedRootStates.count(transit) || transit == network.ciRootState || transit == network.rootState; if (isUnpushed) { ++unpushedLabels; - if (treeBuilderType_ != TreeBuilderType::ctc) { + if (!network.rootTransitDescriptions.empty()) { std::map>::iterator it = network.rootTransitDescriptions.find(transit); verify(it != network.rootTransitDescriptions.end()); } From a6d6c3aa1b8ec2d388c6b5a82d5ec0451d27832e Mon Sep 17 00:00:00 2001 From: Larissa Date: Thu, 20 Mar 2025 13:20:18 +0100 Subject: [PATCH 18/24] Introduce parameter to (dis-)allow blank loops --- src/Search/TreeBuilder.cc | 15 ++++++++++++--- src/Search/TreeBuilder.hh | 2 ++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index 4d6a66d6b..2bb2a5a2a 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1210,9 +1210,15 @@ const Core::ParameterBool CtcTreeBuilder::paramLabelLoop( "allow label loops in the search tree", true); +const Core::ParameterBool CtcTreeBuilder::paramBlankLoop( + "allow-blank-loop", + "allow loops on the blank nodes in the search tree", + true); + CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) : AbstractTreeBuilder(config, lexicon, acousticModel, network), - labelLoop_(paramLabelLoop(config)) { + labelLoop_(paramLabelLoop(config)), + blankLoop_(paramBlankLoop(config)) { auto iters = lexicon.phonemeInventory()->phonemes(); for (auto it = iters.first; it != iters.second; ++it) { require(not(*it)->isContextDependent()); // Context dependent labels are not supported @@ -1371,8 +1377,11 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia if (not allophoneIsBlank and not isLastStateInLemma) { // Add blank state after the newly created state currentState = extendState(currentState, blankDesc_); - // Add loop for this blank state - addTransition(currentState, currentState); + + if (blankLoop_) { + // Add loop for this blank state + addTransition(currentState, currentState); + } } } } diff --git a/src/Search/TreeBuilder.hh b/src/Search/TreeBuilder.hh index 53bfb4c16..16a4445f1 100644 --- a/src/Search/TreeBuilder.hh +++ b/src/Search/TreeBuilder.hh @@ -258,9 +258,11 @@ public: virtual void build(); static const Core::ParameterBool paramLabelLoop; + static const Core::ParameterBool paramBlankLoop; protected: bool labelLoop_; + bool blankLoop_; StateId wordBoundaryRoot_; Search::StateTree::StateDesc blankDesc_; From b0cb5eee639f7d94562f33c8010e413ae206c3e4 Mon Sep 17 00:00:00 2001 From: Larissa Date: Thu, 20 Mar 2025 16:23:05 +0100 Subject: [PATCH 19/24] Add TODO concerning skip-transitions to AdvancedTreeSearch --- src/Search/AdvancedTreeSearch/SearchSpace.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index 96b87e120..8434ea33b 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -1273,6 +1273,8 @@ inline void SearchSpace::expandState(const Search::StateHypothesis& hyp) { if (loopScore < Core::Type::max) activateOrUpdateStateHypothesisLoop(hyp, loopScore); + // TODO: When using the CtcTreeBuilder or RnaTreeBuilder without label-loops, the tree includes skip-transitions + // over blank between two identical labels. These transitions should be explicitly ignored/disallowed here. // forward transition if ((state.successors & SingleSuccessorBatchMask) == SingleSuccessorBatchMask) { // The common case: Usually one hyp is connected to exactly one follower hyp From 7c98058ba51f40ee402172bbc0e77566c1c7f051 Mon Sep 17 00:00:00 2001 From: Larissa Date: Fri, 21 Mar 2025 19:46:50 +0100 Subject: [PATCH 20/24] Introduce parameter to force blank between two identical labels --- src/Search/AdvancedTreeSearch/SearchSpace.cc | 2 -- src/Search/TreeBuilder.cc | 26 +++++++++++++++++--- src/Search/TreeBuilder.hh | 3 +++ 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/Search/AdvancedTreeSearch/SearchSpace.cc b/src/Search/AdvancedTreeSearch/SearchSpace.cc index 8434ea33b..96b87e120 100644 --- a/src/Search/AdvancedTreeSearch/SearchSpace.cc +++ b/src/Search/AdvancedTreeSearch/SearchSpace.cc @@ -1273,8 +1273,6 @@ inline void SearchSpace::expandState(const Search::StateHypothesis& hyp) { if (loopScore < Core::Type::max) activateOrUpdateStateHypothesisLoop(hyp, loopScore); - // TODO: When using the CtcTreeBuilder or RnaTreeBuilder without label-loops, the tree includes skip-transitions - // over blank between two identical labels. These transitions should be explicitly ignored/disallowed here. // forward transition if ((state.successors & SingleSuccessorBatchMask) == SingleSuccessorBatchMask) { // The common case: Usually one hyp is connected to exactly one follower hyp diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index 2bb2a5a2a..236c4489a 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1215,10 +1215,16 @@ const Core::ParameterBool CtcTreeBuilder::paramBlankLoop( "allow loops on the blank nodes in the search tree", true); +const Core::ParameterBool CtcTreeBuilder::paramForceBlank( + "force-blank-between-identical-labels", + "require a blank label between two identical labels (only works if label-loops are disabled)", + true); + CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) : AbstractTreeBuilder(config, lexicon, acousticModel, network), labelLoop_(paramLabelLoop(config)), - blankLoop_(paramBlankLoop(config)) { + blankLoop_(paramBlankLoop(config)), + forceBlank_(paramForceBlank(config)) { auto iters = lexicon.phonemeInventory()->phonemes(); for (auto it = iters.first; it != iters.second; ++it) { require(not(*it)->isContextDependent()); // Context dependent labels are not supported @@ -1315,7 +1321,15 @@ StateId CtcTreeBuilder::extendState(StateId predecessor, StateTree::StateDesc de } void CtcTreeBuilder::addTransition(StateId predecessor, StateId successor) { - auto const& successorStateDesc = network_.structure.state(successor).stateDesc; + auto const& predecessorStateDesc = network_.structure.state(predecessor).stateDesc; + auto const& successorStateDesc = network_.structure.state(successor).stateDesc; + + if (forceBlank_) { + // Don't add a transition between two distinct states of equal description + if (predecessorStateDesc == successorStateDesc and predecessor != successor) { + return; + } + } for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { if (!target.isLabel() && network_.structure.state(*target).stateDesc == successorStateDesc) { @@ -1431,7 +1445,13 @@ const Core::ParameterBool RnaTreeBuilder::paramLabelLoop( "allow label loops in the search tree", false); +const Core::ParameterBool RnaTreeBuilder::paramForceBlank( + "force-blank-between-identical-labels", + "require a blank label between two identical labels (only works if label-loops are disabled)", + false); + RnaTreeBuilder::RnaTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize) : CtcTreeBuilder(config, lexicon, acousticModel, network, initialize) { - this->labelLoop_ = paramLabelLoop(config); + this->labelLoop_ = paramLabelLoop(config); + this->forceBlank_ = paramForceBlank(config); } diff --git a/src/Search/TreeBuilder.hh b/src/Search/TreeBuilder.hh index 16a4445f1..eb8f16d7a 100644 --- a/src/Search/TreeBuilder.hh +++ b/src/Search/TreeBuilder.hh @@ -259,10 +259,12 @@ public: static const Core::ParameterBool paramLabelLoop; static const Core::ParameterBool paramBlankLoop; + static const Core::ParameterBool paramForceBlank; protected: bool labelLoop_; bool blankLoop_; + bool forceBlank_; StateId wordBoundaryRoot_; Search::StateTree::StateDesc blankDesc_; @@ -297,6 +299,7 @@ public: virtual ~RnaTreeBuilder() = default; static const Core::ParameterBool paramLabelLoop; + static const Core::ParameterBool paramForceBlank; }; #endif From bfc535752ea493662442ba511f9cbd6d909983c7 Mon Sep 17 00:00:00 2001 From: Larissa Date: Fri, 21 Mar 2025 20:13:57 +0100 Subject: [PATCH 21/24] Also disable blank loop before word boundary token --- src/Search/TreeBuilder.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index 236c4489a..680d331a5 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1434,8 +1434,11 @@ void CtcTreeBuilder::addWordBoundaryStates() { for (StateId wbs : wordBoundaryLemmaStartStates) { network_.structure.addTargetToNode(blankBefore, wbs); } - // Add loop for this blank state - addTransition(blankBefore, blankBefore); + + if (blankLoop_) { + // Add loop for this blank state + addTransition(blankBefore, blankBefore); + } } // -------------------- RnaTreeBuilder -------------------- From 9180f823120f3e42098af447b68e870c778da6fb Mon Sep 17 00:00:00 2001 From: Larissa Date: Mon, 24 Mar 2025 11:16:05 +0100 Subject: [PATCH 22/24] Apply suggestions --- src/Search/TreeBuilder.cc | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index 680d331a5..f68099d5d 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1216,7 +1216,7 @@ const Core::ParameterBool CtcTreeBuilder::paramBlankLoop( true); const Core::ParameterBool CtcTreeBuilder::paramForceBlank( - "force-blank-between-identical-labels", + "force-blank-between-repeated-labels", "require a blank label between two identical labels (only works if label-loops are disabled)", true); @@ -1324,13 +1324,6 @@ void CtcTreeBuilder::addTransition(StateId predecessor, StateId successor) { auto const& predecessorStateDesc = network_.structure.state(predecessor).stateDesc; auto const& successorStateDesc = network_.structure.state(successor).stateDesc; - if (forceBlank_) { - // Don't add a transition between two distinct states of equal description - if (predecessorStateDesc == successorStateDesc and predecessor != successor) { - return; - } - } - for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) { if (!target.isLabel() && network_.structure.state(*target).stateDesc == successorStateDesc) { // The node is already a successor of the predecessor, so the transition already exists @@ -1381,8 +1374,9 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia addTransition(currentState, currentState); } - // Add transition from previous non-blank state to this state, allowing to skip the blank state in-between these two - if (prevNonBlankState != invalidTreeNodeIndex) { + if (prevNonBlankState != invalidTreeNodeIndex and not(forceBlank_ and network_.structure.state(prevNonBlankState).stateDesc == network_.structure.state(currentState).stateDesc and prevNonBlankState != currentState)) { + // Add transition from previous non-blank state to this state, allowing to skip the blank state in-between these two + // If we want to enforce blank between repeated labels, don't add a transition between two distinct states of equal description addTransition(prevNonBlankState, currentState); } prevNonBlankState = currentState; @@ -1449,7 +1443,7 @@ const Core::ParameterBool RnaTreeBuilder::paramLabelLoop( false); const Core::ParameterBool RnaTreeBuilder::paramForceBlank( - "force-blank-between-identical-labels", + "force-blank-between-repeated-labels", "require a blank label between two identical labels (only works if label-loops are disabled)", false); From 049f2561524bfe853d117cc43daccd529415dd87 Mon Sep 17 00:00:00 2001 From: Eugen Beck Date: Tue, 25 Mar 2025 10:00:57 +0000 Subject: [PATCH 23/24] Reword if statement --- src/Search/TreeBuilder.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index f68099d5d..cf9bee502 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1374,7 +1374,8 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia addTransition(currentState, currentState); } - if (prevNonBlankState != invalidTreeNodeIndex and not(forceBlank_ and network_.structure.state(prevNonBlankState).stateDesc == network_.structure.state(currentState).stateDesc and prevNonBlankState != currentState)) { + bool label_repetition = prevNonBlankState != currentState and network_.structure.state(prevNonBlankState).stateDesc == network_.structure.state(currentState).stateDesc; + if (prevNonBlankState != invalidTreeNodeIndex and not(label_repetition and forceBlank_)) { // Add transition from previous non-blank state to this state, allowing to skip the blank state in-between these two // If we want to enforce blank between repeated labels, don't add a transition between two distinct states of equal description addTransition(prevNonBlankState, currentState); From eb76d26cc7cf4c63cd9ad9de2e02f46e86c5f3ec Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Fri, 28 Mar 2025 11:37:17 -0400 Subject: [PATCH 24/24] bug fix --- src/Search/TreeBuilder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Search/TreeBuilder.cc b/src/Search/TreeBuilder.cc index cf9bee502..ebda597c2 100644 --- a/src/Search/TreeBuilder.cc +++ b/src/Search/TreeBuilder.cc @@ -1374,7 +1374,7 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia addTransition(currentState, currentState); } - bool label_repetition = prevNonBlankState != currentState and network_.structure.state(prevNonBlankState).stateDesc == network_.structure.state(currentState).stateDesc; + bool label_repetition = prevNonBlankState != currentState and prevNonBlankState != invalidTreeNodeIndex and network_.structure.state(prevNonBlankState).stateDesc == network_.structure.state(currentState).stateDesc; if (prevNonBlankState != invalidTreeNodeIndex and not(label_repetition and forceBlank_)) { // Add transition from previous non-blank state to this state, allowing to skip the blank state in-between these two // If we want to enforce blank between repeated labels, don't add a transition between two distinct states of equal description