Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify AdvancedTreeSearch for CtcTreeBuilder support #102

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0cef183
Refactored TreeBuilder
larissakl Dec 10, 2024
ed7430d
Apply formatting
Dec 12, 2024
7cc2e2f
Fix "historyLenght" name
Dec 12, 2024
8bda825
replace vector trees_ by one tree_ and refactor related code
larissakl Jan 7, 2025
f550faa
Merge branch 'master' into treebuilder_refactor-tree-vector
larissakl Jan 9, 2025
ac690d0
Apply formatting
Jan 10, 2025
b947fe6
Update comments and tree_ initialization
larissakl Jan 22, 2025
f1691a3
Add CtcTreeBuilder
larissakl Jan 28, 2025
eafef92
Uncouple treebuilding from AdvancedTreeSearch
larissakl Feb 10, 2025
fd30f7d
Merge branch 'master' into ctc-treebuilder
curufinwe Feb 10, 2025
dea490c
replace std::vector by std::set
larissakl Feb 11, 2025
b229128
Merge branch 'ctc-treebuilder' into uncouple-treebuilder
larissakl Feb 11, 2025
64a6c50
Make label loops optional for RNA topology
larissakl Feb 18, 2025
710505f
Formatting
Feb 18, 2025
55f8f0e
Merge branch 'master' into uncouple-treebuilder
Feb 18, 2025
5125c84
Merge branch 'uncouple-treebuilder' into treebuilder-rna-topology
Feb 18, 2025
4a88b3c
Rename "labelLoops" to "enableLabelLoop"
larissakl Feb 18, 2025
9a5d218
Move createTreeBuilder to Search::Module and remove subdir
larissakl Feb 23, 2025
5990bc6
Merge branch 'uncouple-treebuilder' into treebuilder-rna-topology
larissakl Feb 24, 2025
d8149f8
Introduce RnaTreeBuilder subclass and allow-label-loop config-parameter
larissakl Feb 24, 2025
2e9c51b
Adjust AdvancedTreeSearch for CtcTreeBuilder support
larissakl Mar 1, 2025
552b5a1
Merge branch 'master' into treebuilder-rna-topology
larissakl Mar 19, 2025
9deba08
Move allow-label-loop parameter to CtcTreeBuilder and RnaTreeBuilder
larissakl Mar 19, 2025
8670b7e
Merge branch 'master' into advanced-tree-search-ctc
larissakl Mar 19, 2025
b23b180
Adjust verify statements
larissakl Mar 19, 2025
2c8669a
Merge branch 'treebuilder-rna-topology' into advanced-tree-search-ctc
larissakl Mar 20, 2025
a6d6c3a
Introduce parameter to (dis-)allow blank loops
larissakl Mar 20, 2025
b0cb5ee
Add TODO concerning skip-transitions to AdvancedTreeSearch
larissakl Mar 20, 2025
7c98058
Introduce parameter to force blank between two identical labels
larissakl Mar 21, 2025
bfc5357
Also disable blank loop before word boundary token
larissakl Mar 21, 2025
343a829
Merge branch 'master' into advanced-tree-search-ctc
larissakl Mar 21, 2025
9180f82
Apply suggestions
larissakl Mar 24, 2025
049f256
Reword if statement
curufinwe Mar 25, 2025
eb76d26
bug fix
Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/Search/AdvancedTreeSearch/SearchSpace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ void StaticSearchAutomaton::buildDepths(bool onlyFromRoot) {
for (u32 a = 1; a < stateDepths.size(); ++a) {
if (stateDepths[a] != Core::Type<int>::min && stateDepths[a] != Core::Type<int>::max) {
for (HMMStateNetwork::SuccessorIterator it = network.structure.successors(a); it; ++it) {
if (!it.isLabel()) {
if (!it.isLabel() and *it != a) {
verify(stateDepths[*it] > stateDepths[a]);
}
}
Expand Down Expand Up @@ -498,7 +498,7 @@ int StaticSearchAutomaton::fillStateDepths(StateId state, int depth) {
localDepth = 0;

for (HMMStateNetwork::SuccessorIterator it = network.structure.successors(state); it; ++it) {
if (not it.isLabel()) {
if (not it.isLabel() and *it != state) {
int d = fillStateDepths(*it, depth + 1);

if (d > localDepth) {
Expand Down Expand Up @@ -723,7 +723,7 @@ void StaticSearchAutomaton::buildBatches() {
network.dumpDotGraph(paramDumpDotGraph(config), stateDepths);

// Print some useful statistics about pushed and unpushed labels
verify(!network.unpushedCoarticulatedRootStates.empty());
verify(network.coarticulatedRootStates.empty() || !network.unpushedCoarticulatedRootStates.empty());

u32 unpushedLabels = 0;
u32 pushedLabels = 0;
Expand All @@ -734,8 +734,10 @@ void StaticSearchAutomaton::buildBatches() {
bool isUnpushed = network.unpushedCoarticulatedRootStates.count(transit) || transit == network.ciRootState || transit == network.rootState;
if (isUnpushed) {
++unpushedLabels;
std::map<StateId, std::pair<Bliss::Phoneme::Id, Bliss::Phoneme::Id>>::iterator it = network.rootTransitDescriptions.find(transit);
verify(it != network.rootTransitDescriptions.end());
if (!network.rootTransitDescriptions.empty()) {
std::map<StateId, std::pair<Bliss::Phoneme::Id, Bliss::Phoneme::Id>>::iterator it = network.rootTransitDescriptions.find(transit);
verify(it != network.rootTransitDescriptions.end());
}
}
else {
++pushedLabels;
Expand Down
45 changes: 36 additions & 9 deletions src/Search/TreeBuilder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1210,9 +1210,21 @@ const Core::ParameterBool CtcTreeBuilder::paramLabelLoop(
"allow label loops in the search tree",
true);

const Core::ParameterBool CtcTreeBuilder::paramBlankLoop(
"allow-blank-loop",
"allow loops on the blank nodes in the search tree",
true);

const Core::ParameterBool CtcTreeBuilder::paramForceBlank(
"force-blank-between-repeated-labels",
"require a blank label between two identical labels (only works if label-loops are disabled)",
true);

CtcTreeBuilder::CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize)
: AbstractTreeBuilder(config, lexicon, acousticModel, network),
labelLoop_(paramLabelLoop(config)) {
labelLoop_(paramLabelLoop(config)),
blankLoop_(paramBlankLoop(config)),
forceBlank_(paramForceBlank(config)) {
auto iters = lexicon.phonemeInventory()->phonemes();
for (auto it = iters.first; it != iters.second; ++it) {
require(not(*it)->isContextDependent()); // Context dependent labels are not supported
Expand Down Expand Up @@ -1309,7 +1321,8 @@ StateId CtcTreeBuilder::extendState(StateId predecessor, StateTree::StateDesc de
}

void CtcTreeBuilder::addTransition(StateId predecessor, StateId successor) {
auto const& successorStateDesc = network_.structure.state(successor).stateDesc;
auto const& predecessorStateDesc = network_.structure.state(predecessor).stateDesc;
auto const& successorStateDesc = network_.structure.state(successor).stateDesc;

for (HMMStateNetwork::SuccessorIterator target = network_.structure.successors(predecessor); target; ++target) {
if (!target.isLabel() && network_.structure.state(*target).stateDesc == successorStateDesc) {
Expand Down Expand Up @@ -1361,8 +1374,10 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia
addTransition(currentState, currentState);
}

// Add transition from previous non-blank state to this state, allowing to skip the blank state in-between these two
if (prevNonBlankState != invalidTreeNodeIndex) {
bool label_repetition = prevNonBlankState != currentState and prevNonBlankState != invalidTreeNodeIndex and network_.structure.state(prevNonBlankState).stateDesc == network_.structure.state(currentState).stateDesc;
if (prevNonBlankState != invalidTreeNodeIndex and not(label_repetition and forceBlank_)) {
// Add transition from previous non-blank state to this state, allowing to skip the blank state in-between these two
// If we want to enforce blank between repeated labels, don't add a transition between two distinct states of equal description
addTransition(prevNonBlankState, currentState);
}
prevNonBlankState = currentState;
Expand All @@ -1371,8 +1386,11 @@ StateId CtcTreeBuilder::extendPronunciation(StateId startState, Bliss::Pronuncia
if (not allophoneIsBlank and not isLastStateInLemma) {
// Add blank state after the newly created state
currentState = extendState(currentState, blankDesc_);
// Add loop for this blank state
addTransition(currentState, currentState);

if (blankLoop_) {
// Add loop for this blank state
addTransition(currentState, currentState);
}
}
}
}
Expand Down Expand Up @@ -1411,8 +1429,11 @@ void CtcTreeBuilder::addWordBoundaryStates() {
for (StateId wbs : wordBoundaryLemmaStartStates) {
network_.structure.addTargetToNode(blankBefore, wbs);
}
// Add loop for this blank state
addTransition(blankBefore, blankBefore);

if (blankLoop_) {
// Add loop for this blank state
addTransition(blankBefore, blankBefore);
}
}

// -------------------- RnaTreeBuilder --------------------
Expand All @@ -1422,7 +1443,13 @@ const Core::ParameterBool RnaTreeBuilder::paramLabelLoop(
"allow label loops in the search tree",
false);

const Core::ParameterBool RnaTreeBuilder::paramForceBlank(
"force-blank-between-repeated-labels",
"require a blank label between two identical labels (only works if label-loops are disabled)",
false);

RnaTreeBuilder::RnaTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize)
: CtcTreeBuilder(config, lexicon, acousticModel, network, initialize) {
this->labelLoop_ = paramLabelLoop(config);
this->labelLoop_ = paramLabelLoop(config);
this->forceBlank_ = paramForceBlank(config);
}
5 changes: 5 additions & 0 deletions src/Search/TreeBuilder.hh
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ protected:
class CtcTreeBuilder : public AbstractTreeBuilder {
public:
static const Core::ParameterBool paramLabelLoop;
static const Core::ParameterBool paramBlankLoop;
static const Core::ParameterBool paramForceBlank;

CtcTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true);
virtual ~CtcTreeBuilder() = default;
Expand All @@ -261,6 +263,8 @@ public:

protected:
bool labelLoop_;
bool blankLoop_;
bool forceBlank_;

StateId wordBoundaryRoot_;
Search::StateTree::StateDesc blankDesc_;
Expand Down Expand Up @@ -292,6 +296,7 @@ protected:
class RnaTreeBuilder : public CtcTreeBuilder {
public:
static const Core::ParameterBool paramLabelLoop;
static const Core::ParameterBool paramForceBlank;

RnaTreeBuilder(Core::Configuration config, const Bliss::Lexicon& lexicon, const Am::AcousticModel& acousticModel, Search::PersistentStateTree& network, bool initialize = true);
virtual ~RnaTreeBuilder() = default;
Expand Down