Skip to content

Commit cf5884a

Browse files
committedJun 4, 2024·
adding latest from main
1 parent 701e21d commit cf5884a

20 files changed

+1329
-301
lines changed
 

‎common/datasets/tedlium2_v2/corpus.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import os
2+
from functools import lru_cache
3+
from typing import Dict, Optional, Any
4+
5+
from sisyphus import tk
6+
7+
from i6_core.audio.encoding import BlissChangeEncodingJob
8+
9+
from i6_core.meta import CorpusObject
10+
11+
from ..tedlium2.constants import DURATIONS
12+
from .download import download_data_dict
13+
14+
15+
@lru_cache()
16+
def get_bliss_corpus_dict(audio_format: str = "wav", output_prefix: str = "datasets") -> Dict[str, tk.Path]:
17+
"""
18+
creates a dictionary of all corpora in the TedLiumV2 dataset in the bliss xml format
19+
20+
:param audio_format: options: wav, ogg, flac, sph, nist. nist (NIST sphere format) and sph are the same.
21+
:param output_prefix:
22+
:return:
23+
"""
24+
assert audio_format in ["flac", "ogg", "wav", "sph", "nist"]
25+
26+
output_prefix = os.path.join(output_prefix, "Ted-Lium-2")
27+
28+
bliss_corpus_dict = download_data_dict(output_prefix=output_prefix).bliss_nist
29+
30+
audio_format_options = {
31+
"wav": {
32+
"output_format": "wav",
33+
"codec": "pcm_s16le",
34+
},
35+
"ogg": {"output_format": "ogg", "codec": "libvorbis"},
36+
"flac": {"output_format": "flac", "codec": "flac"},
37+
}
38+
39+
converted_bliss_corpus_dict = {}
40+
if audio_format not in ["sph", "nist"]:
41+
for corpus_name, sph_corpus in bliss_corpus_dict.items():
42+
bliss_change_encoding_job = BlissChangeEncodingJob(
43+
corpus_file=sph_corpus,
44+
sample_rate=16000,
45+
recover_duration=False,
46+
**audio_format_options[audio_format],
47+
)
48+
bliss_change_encoding_job.add_alias(
49+
os.path.join(
50+
output_prefix,
51+
"%s_conversion" % audio_format,
52+
corpus_name,
53+
)
54+
)
55+
converted_bliss_corpus_dict[corpus_name] = bliss_change_encoding_job.out_corpus
56+
else:
57+
converted_bliss_corpus_dict = bliss_corpus_dict
58+
59+
return converted_bliss_corpus_dict
60+
61+
62+
@lru_cache()
63+
def get_corpus_object_dict(audio_format: str = "flac", output_prefix: str = "datasets") -> Dict[str, CorpusObject]:
64+
"""
65+
creates a dict of all corpora in the TedLiumV2 dataset as a `meta.CorpusObject`
66+
67+
:param audio_format: options: wav, ogg, flac, sph, nist. nist (NIST sphere format) and sph are the same.
68+
:param output_prefix:
69+
:return:
70+
"""
71+
bliss_corpus_dict = get_bliss_corpus_dict(audio_format=audio_format, output_prefix=output_prefix)
72+
73+
corpus_object_dict = {}
74+
75+
for corpus_name, bliss_corpus in bliss_corpus_dict.items():
76+
corpus_object = CorpusObject()
77+
corpus_object.corpus_file = bliss_corpus
78+
corpus_object.audio_format = audio_format
79+
corpus_object.audio_dir = None
80+
corpus_object.duration = DURATIONS[corpus_name]
81+
82+
corpus_object_dict[corpus_name] = corpus_object
83+
84+
return corpus_object_dict
85+
86+
87+
@lru_cache()
88+
def get_stm_dict(output_prefix: str = "datasets") -> Dict[str, tk.Path]:
89+
"""
90+
fetches the STM files for TedLiumV2 dataset
91+
92+
:param output_prefix:
93+
:return:
94+
"""
95+
return download_data_dict(output_prefix=output_prefix).stm
96+
97+
98+
def get_ogg_zip_dict(
99+
subdir_prefix: str = "datasets",
100+
returnn_python_exe: Optional[tk.Path] = None,
101+
returnn_root: Optional[tk.Path] = None,
102+
bliss_to_ogg_job_rqmt: Optional[Dict[str, Any]] = None,
103+
extra_args: Optional[Dict[str, Dict[str, Any]]] = None,
104+
) -> Dict[str, tk.Path]:
105+
"""
106+
Get a dictionary containing the paths to the ogg_zip for each corpus part.
107+
108+
No outputs will be registered.
109+
110+
:param subdir_prefix: dir name prefix for aliases and outputs
111+
:param returnn_python_exe: path to returnn python executable
112+
:param returnn_root: python to returnn root
113+
:param bliss_to_ogg_job_rqmt: rqmt for bliss to ogg job
114+
:param extra_args: extra args for each dataset for bliss to ogg job
115+
:return: dictionary with ogg zip paths for each corpus (train, dev, test)
116+
"""
117+
from i6_core.returnn.oggzip import BlissToOggZipJob
118+
119+
ogg_zip_dict = {}
120+
bliss_corpus_dict = get_bliss_corpus_dict(audio_format="wav", output_prefix=subdir_prefix)
121+
if extra_args is None:
122+
extra_args = {}
123+
for name, bliss_corpus in bliss_corpus_dict.items():
124+
ogg_zip_job = BlissToOggZipJob(
125+
bliss_corpus,
126+
no_conversion=False, # cannot be used for corpus with multiple segments per recording
127+
returnn_python_exe=returnn_python_exe,
128+
returnn_root=returnn_root,
129+
**extra_args.get(name, {}),
130+
)
131+
if bliss_to_ogg_job_rqmt:
132+
ogg_zip_job.rqmt = bliss_to_ogg_job_rqmt
133+
ogg_zip_job.add_alias(os.path.join(subdir_prefix, "Ted-Lium-2", "%s_ogg_zip_job" % name))
134+
ogg_zip_dict[name] = ogg_zip_job.out_ogg_zip
135+
136+
return ogg_zip_dict
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
from dataclasses import dataclass
3+
from functools import lru_cache
4+
from typing import Any, Dict
5+
6+
from sisyphus import tk
7+
8+
from i6_core.datasets.tedlium2 import (
9+
DownloadTEDLIUM2CorpusJob,
10+
CreateTEDLIUM2BlissCorpusJobV2,
11+
)
12+
13+
14+
@dataclass(frozen=True)
15+
class TedLium2Data:
16+
"""Class for storing the TedLium2 data"""
17+
18+
data_dir: Dict[str, tk.Path]
19+
lm_dir: tk.Path
20+
vocab: tk.Path
21+
bliss_nist: Dict[str, tk.Path]
22+
stm: Dict[str, tk.Path]
23+
24+
25+
@lru_cache()
26+
def download_data_dict(output_prefix: str = "datasets") -> TedLium2Data:
27+
"""
28+
downloads the TedLiumV2 dataset and performs the initial data processing steps
29+
Uses the fixed job CreateTEDLIUM2BlissCorpusJobV2 from: https://github.com/rwth-i6/i6_core/pull/490
30+
31+
:param output_prefix:
32+
:return:
33+
"""
34+
download_tedlium2_job = DownloadTEDLIUM2CorpusJob()
35+
download_tedlium2_job.add_alias(os.path.join(output_prefix, "download", "raw_corpus_job"))
36+
37+
bliss_corpus_tedlium2_job = CreateTEDLIUM2BlissCorpusJobV2(download_tedlium2_job.out_corpus_folders)
38+
bliss_corpus_tedlium2_job.add_alias(os.path.join(output_prefix, "create_bliss", "bliss_corpus_job"))
39+
40+
tl2_data = TedLium2Data(
41+
data_dir=download_tedlium2_job.out_corpus_folders,
42+
lm_dir=download_tedlium2_job.out_lm_folder,
43+
vocab=download_tedlium2_job.out_vocab_dict,
44+
bliss_nist=bliss_corpus_tedlium2_job.out_corpus_files,
45+
stm=bliss_corpus_tedlium2_job.out_stm_files,
46+
)
47+
48+
return tl2_data

‎common/datasets/tedlium2_v2/export.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import os
2+
3+
from sisyphus import tk
4+
5+
from .corpus import get_bliss_corpus_dict, get_stm_dict
6+
from .lexicon import get_bliss_lexicon, get_g2p_augmented_bliss_lexicon
7+
from .textual_data import get_text_data_dict
8+
9+
TEDLIUM_PREFIX = "Ted-Lium-2"
10+
11+
12+
def _export_datasets(output_prefix: str = "datasets"):
13+
"""
14+
exports all datasets for TedLiumV2 with all available audio formats
15+
16+
:param output_prefix:
17+
:return:
18+
"""
19+
for audio_format in ["flac", "ogg", "wav", "nist", "sph"]:
20+
bliss_corpus_dict = get_bliss_corpus_dict(audio_format=audio_format, output_prefix=output_prefix)
21+
for name, bliss_corpus in bliss_corpus_dict.items():
22+
tk.register_output(
23+
os.path.join(
24+
output_prefix,
25+
TEDLIUM_PREFIX,
26+
"corpus",
27+
f"{name}-{audio_format}.xml.gz",
28+
),
29+
bliss_corpus,
30+
)
31+
32+
33+
def _export_stms(output_prefix: str = "datasets"):
34+
"""
35+
exports all STMs for TedLiumV2
36+
37+
:param output_prefix:
38+
:return:
39+
"""
40+
stm_dict = get_stm_dict(output_prefix=output_prefix)
41+
for name, stm_file in stm_dict.items():
42+
tk.register_output(
43+
os.path.join(
44+
output_prefix,
45+
TEDLIUM_PREFIX,
46+
"stm",
47+
f"{name}.txt",
48+
),
49+
stm_file,
50+
)
51+
52+
53+
def _export_text_data(output_prefix: str = "datasets"):
54+
"""
55+
exports all the textual data for TedLiumV2 dataset
56+
57+
:param output_prefix:
58+
:return:
59+
"""
60+
txt_data_dict = get_text_data_dict(output_prefix=output_prefix)
61+
for k, v in txt_data_dict.items():
62+
tk.register_output(os.path.join(output_prefix, TEDLIUM_PREFIX, "text_data", f"{k}.gz"), v)
63+
64+
65+
def _export_lexicon(output_prefix: str = "datasets"):
66+
"""
67+
exports the lexicon for TedLiumV2
68+
69+
:param output_prefix:
70+
:return:
71+
"""
72+
lexicon_output_prefix = os.path.join(output_prefix, TEDLIUM_PREFIX, "lexicon")
73+
74+
bliss_lexicon = get_bliss_lexicon(output_prefix=output_prefix)
75+
tk.register_output(os.path.join(lexicon_output_prefix, "tedlium2.lexicon.xml.gz"), bliss_lexicon)
76+
77+
g2p_bliss_lexicon = get_g2p_augmented_bliss_lexicon(
78+
add_unknown_phoneme_and_mapping=False, output_prefix=output_prefix
79+
)
80+
tk.register_output(
81+
os.path.join(lexicon_output_prefix, "tedlium2.lexicon_with_g2p.xml.gz"),
82+
g2p_bliss_lexicon,
83+
)
84+
85+
86+
def export_all(output_prefix: str = "datasets"):
87+
"""
88+
exports everything for TedLiumV2
89+
90+
:param output_prefix:
91+
:return:
92+
"""
93+
_export_datasets(output_prefix=output_prefix)
94+
_export_stms(output_prefix=output_prefix)
95+
_export_text_data(output_prefix=output_prefix)
96+
_export_lexicon(output_prefix=output_prefix)
+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import os
2+
from functools import lru_cache
3+
from sisyphus import tk
4+
5+
from i6_core.lexicon import LexiconFromTextFileJob
6+
from i6_core.lexicon.modification import WriteLexiconJob, MergeLexiconJob
7+
from i6_core.lib import lexicon
8+
from i6_experiments.common.helpers.g2p import G2PBasedOovAugmenter
9+
10+
from ..tedlium2.constants import SILENCE_PHONEME, UNKNOWN_PHONEME
11+
from .corpus import get_bliss_corpus_dict
12+
from .download import download_data_dict
13+
14+
15+
@lru_cache()
16+
def _get_special_lemma_lexicon(
17+
add_unknown_phoneme_and_mapping: bool = False,
18+
add_silence: bool = True,
19+
) -> lexicon.Lexicon:
20+
"""
21+
creates the special lemma used in RASR
22+
23+
:param add_unknown_phoneme_and_mapping: adds [unknown] as label with [UNK] as phoneme and <unk> as LM token
24+
:param add_silence: adds [silence] label with [SILENCE] phoneme,
25+
use False for CTC/RNN-T setups without silence modelling.
26+
:return:
27+
"""
28+
lex = lexicon.Lexicon()
29+
if add_silence:
30+
lex.add_lemma(
31+
lexicon.Lemma(
32+
orth=["[silence]", ""],
33+
phon=[SILENCE_PHONEME],
34+
synt=[],
35+
special="silence",
36+
eval=[[]],
37+
)
38+
)
39+
if add_unknown_phoneme_and_mapping:
40+
lex.add_lemma(
41+
lexicon.Lemma(
42+
orth=["[unknown]"],
43+
phon=[UNKNOWN_PHONEME],
44+
synt=["<unk>"],
45+
special="unknown",
46+
eval=[[]],
47+
)
48+
)
49+
else:
50+
lex.add_lemma(
51+
lexicon.Lemma(
52+
orth=["[unknown]"],
53+
synt=["<unk>"],
54+
special="unknown",
55+
eval=[[]],
56+
)
57+
)
58+
59+
lex.add_lemma(
60+
lexicon.Lemma(
61+
orth=["[sentence-begin]"],
62+
synt=["<s>"],
63+
special="sentence-begin",
64+
eval=[[]],
65+
)
66+
)
67+
lex.add_lemma(
68+
lexicon.Lemma(
69+
orth=["[sentence-end]"],
70+
synt=["</s>"],
71+
special="sentence-end",
72+
eval=[[]],
73+
)
74+
)
75+
if add_silence:
76+
lex.add_phoneme(SILENCE_PHONEME, variation="none")
77+
if add_unknown_phoneme_and_mapping:
78+
lex.add_phoneme(UNKNOWN_PHONEME, variation="none")
79+
80+
return lex
81+
82+
83+
@lru_cache()
84+
def _get_raw_bliss_lexicon(
85+
output_prefix: str,
86+
) -> tk.Path:
87+
"""
88+
downloads the vocabulary file from the TedLiumV2 dataset and creates a bliss lexicon
89+
90+
:param output_prefix:
91+
:return:
92+
"""
93+
vocab = download_data_dict(output_prefix=output_prefix).vocab
94+
95+
convert_lexicon_job = LexiconFromTextFileJob(
96+
text_file=vocab,
97+
compressed=True,
98+
)
99+
convert_lexicon_job.add_alias(os.path.join(output_prefix, "convert_text_to_bliss_lexicon_job"))
100+
101+
return convert_lexicon_job.out_bliss_lexicon
102+
103+
104+
@lru_cache()
105+
def get_bliss_lexicon(
106+
add_unknown_phoneme_and_mapping: bool = True,
107+
add_silence: bool = True,
108+
output_prefix: str = "datasets",
109+
) -> tk.Path:
110+
"""
111+
merges the lexicon with special RASR tokens with the lexicon created from the downloaded TedLiumV2 vocabulary
112+
113+
:param add_unknown_phoneme_and_mapping: add an unknown phoneme and mapping unknown phoneme:lemma
114+
:param add_silence: include silence lemma and phoneme
115+
:param output_prefix:
116+
:return:
117+
"""
118+
static_lexicon = _get_special_lemma_lexicon(add_unknown_phoneme_and_mapping, add_silence)
119+
static_lexicon_job = WriteLexiconJob(static_lexicon, sort_phonemes=True, sort_lemmata=False)
120+
static_lexicon_job.add_alias(os.path.join(output_prefix, "static_lexicon_job"))
121+
122+
raw_tedlium2_lexicon = _get_raw_bliss_lexicon(output_prefix=output_prefix)
123+
124+
merge_lexicon_job = MergeLexiconJob(
125+
bliss_lexica=[
126+
static_lexicon_job.out_bliss_lexicon,
127+
raw_tedlium2_lexicon,
128+
],
129+
sort_phonemes=True,
130+
sort_lemmata=True,
131+
compressed=True,
132+
)
133+
merge_lexicon_job.add_alias(os.path.join(output_prefix, "merge_lexicon_job"))
134+
135+
return merge_lexicon_job.out_bliss_lexicon
136+
137+
138+
@lru_cache()
139+
def get_g2p_augmented_bliss_lexicon(
140+
add_unknown_phoneme_and_mapping: bool = False,
141+
add_silence: bool = True,
142+
audio_format: str = "wav",
143+
output_prefix: str = "datasets",
144+
) -> tk.Path:
145+
"""
146+
augment the kernel lexicon with unknown words from the training corpus
147+
148+
:param add_unknown_phoneme_and_mapping: add an unknown phoneme and mapping unknown phoneme:lemma
149+
:param add_silence: include silence lemma and phoneme
150+
:param audio_format: options: wav, ogg, flac, sph, nist. nist (NIST sphere format) and sph are the same.
151+
:param output_prefix:
152+
:return:
153+
"""
154+
original_bliss_lexicon = get_bliss_lexicon(
155+
add_unknown_phoneme_and_mapping, add_silence=add_silence, output_prefix=output_prefix
156+
)
157+
corpus_name = "train"
158+
bliss_corpus = get_bliss_corpus_dict(audio_format=audio_format, output_prefix=output_prefix)[corpus_name]
159+
160+
g2p_augmenter = G2PBasedOovAugmenter(
161+
original_bliss_lexicon=original_bliss_lexicon,
162+
train_lexicon=original_bliss_lexicon,
163+
)
164+
augmented_bliss_lexicon = g2p_augmenter.get_g2p_augmented_bliss_lexicon(
165+
bliss_corpus=bliss_corpus,
166+
corpus_name=corpus_name,
167+
alias_path=os.path.join(output_prefix, "g2p"),
168+
casing="lower",
169+
)
170+
171+
return augmented_bliss_lexicon
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from functools import lru_cache
2+
from typing import Dict
3+
4+
from sisyphus import tk
5+
6+
from i6_core.corpus import CorpusToTxtJob
7+
from i6_core.text import ConcatenateJob
8+
9+
from i6_experiments.common.datasets.tedlium2.corpus_v2 import get_bliss_corpus_dict
10+
11+
from .download import download_data_dict
12+
13+
14+
@lru_cache()
15+
def get_text_data_dict(output_prefix: str = "datasets") -> Dict[str, tk.Path]:
16+
"""
17+
gather all the textual data provided within the TedLiumV2 dataset
18+
19+
:param output_prefix:
20+
:return:
21+
"""
22+
lm_dir = download_data_dict(output_prefix=output_prefix).lm_dir
23+
24+
text_corpora = [
25+
"commoncrawl-9pc",
26+
"europarl-v7-6pc",
27+
"giga-fren-4pc",
28+
"news-18pc",
29+
"news-commentary-v8-9pc",
30+
"yandex-1m-31pc",
31+
]
32+
33+
txt_dict = {name: lm_dir.join_right("%s.en.gz" % name) for name in text_corpora}
34+
txt_dict["audio-transcriptions"] = CorpusToTxtJob(
35+
get_bliss_corpus_dict(audio_format="wav", output_prefix="corpora")["train"]
36+
).out_txt
37+
txt_dict["background-data"] = ConcatenateJob(list(txt_dict.values())).out
38+
39+
return txt_dict

‎common/datasets/tedlium2_v2/vocab.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from i6_experiments.common.helpers.text_labels.subword_nmt_bpe import (
2+
get_returnn_subword_nmt,
3+
get_bpe_settings,
4+
BPESettings,
5+
)
6+
from .corpus import get_bliss_corpus_dict
7+
8+
9+
def get_subword_nmt_bpe(bpe_size: int, unk_label: str = "<unk>", subdir_prefix: str = "") -> BPESettings:
10+
"""
11+
Get the BPE tokens via the Returnn subword-nmt for a Tedlium2 setup.
12+
13+
:param bpe_size: the number of BPE merge operations. This is NOT the resulting vocab size!
14+
:param unk_label: unknown label symbol
15+
:param subdir_prefix: dir name prefix for aliases and outputs
16+
"""
17+
subword_nmt_repo = get_returnn_subword_nmt(output_prefix=subdir_prefix)
18+
train_corpus = get_bliss_corpus_dict()["train"]
19+
bpe_settings = get_bpe_settings(
20+
train_corpus,
21+
bpe_size=bpe_size,
22+
unk_label=unk_label,
23+
output_prefix=subdir_prefix,
24+
subword_nmt_repo_path=subword_nmt_repo,
25+
)
26+
return bpe_settings
27+
28+
29+
def get_subword_nmt_bpe_v2(bpe_size: int, unk_label: str = "<unk>", subdir_prefix: str = "") -> BPESettings:
30+
"""
31+
Get the BPE tokens via the Returnn subword-nmt for a Tedlium2 setup.
32+
33+
V2: Uses subword-nmt version corrected for Apptainer related bug, adds hash overwrite for repo
34+
35+
:param bpe_size: the number of BPE merge operations. This is NOT the resulting vocab size!
36+
:param unk_label: unknown label symbol
37+
:param subdir_prefix: dir name prefix for aliases and outputs
38+
"""
39+
subword_nmt_repo = get_returnn_subword_nmt(
40+
commit_hash="5015a45e28a958f800ef1c50e7880c0c9ef414cf", output_prefix=subdir_prefix
41+
)
42+
subword_nmt_repo.hash_overwrite = "I6_SUBWORD_NMT_V2"
43+
train_corpus = get_bliss_corpus_dict()["train"]
44+
bpe_settings = get_bpe_settings(
45+
train_corpus,
46+
bpe_size=bpe_size,
47+
unk_label=unk_label,
48+
output_prefix=subdir_prefix,
49+
subword_nmt_repo_path=subword_nmt_repo,
50+
)
51+
return bpe_settings

‎users/raissi/experiments/librispeech/configs/LFR_factored/baseline/config.py

+9
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@
7979
out_joint_diphone="output/output_batch_major",
8080
)
8181

82+
CONF_FH_TRIPHONE_FS_DECODING_TENSOR_CONFIG_V2 = dataclasses.replace(
83+
DecodingTensorMap.default(),
84+
in_encoder_output="conformer_12_output/add",
85+
out_encoder_output="encoder__output/output_batch_major",
86+
out_right_context="right__output/output_batch_major",
87+
out_left_context="left__output/output_batch_major",
88+
out_center_state="center__output/output_batch_major",
89+
out_joint_diphone="output/output_batch_major",
90+
)
8291

8392
BLSTM_FH_DECODING_TENSOR_CONFIG = dataclasses.replace(
8493
CONF_FH_DECODING_TENSOR_CONFIG,

‎users/raissi/setups/common/BASE_factored_hybrid_system.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def _set_native_lstm_path(self, search_numpy_blas=True, blas_lib=None):
530530
self.native_lstm2_path = compile_native_op_job.out_op
531531

532532
def set_local_flf_tool_for_decoding(self, path):
533-
self.csp["base"].flf_tool_exe = path
533+
self.crp["base"].flf_tool_exe = path
534534

535535
# --------------------- Init procedure -----------------
536536
def set_initial_nn_args(self, initial_nn_args):

‎users/raissi/setups/common/TF_factored_hybrid_system.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import i6_experiments.users.raissi.setups.common.helpers.train as train_helpers
4848
import i6_experiments.users.raissi.setups.common.helpers.decode as decode_helpers
4949

50+
from i6_experiments.users.raissi.setups.common.helpers.priors.factored_estimation import get_triphone_priors
51+
from i6_experiments.users.raissi.setups.common.helpers.priors.util import PartitionDataSetup
5052

5153
# user based modules
5254
from i6_experiments.users.raissi.setups.common.data.backend import BackendInfo
@@ -74,7 +76,7 @@
7476

7577
from i6_experiments.users.raissi.setups.common.data.backend import Backend, BackendInfo
7678

77-
79+
from i6_experiments.users.raissi.setups.common.decoder.BASE_factored_hybrid_search import DecodingTensorMap
7880
from i6_experiments.users.raissi.setups.common.decoder.config import (
7981
PriorInfo,
8082
PriorConfig,
@@ -160,9 +162,6 @@ def get_model_checkpoint(self, model_job, epoch):
160162
def get_model_path(self, model_job, epoch):
161163
return model_job.out_checkpoints[epoch].ckpt_path
162164

163-
def set_local_flf_tool_for_decoding(self, path=None):
164-
self.csp["base"].flf_tool_exe = path
165-
166165
# -------------------------------------------- Training --------------------------------------------------------
167166

168167
# -------------encoder architectures -------------------------------
@@ -279,7 +278,7 @@ def get_conformer_network_zhou_variant(
279278
network["classes_"]["from"] = "slice_classes"
280279

281280
else:
282-
network=encoder_net
281+
network = encoder_net
283282

284283
return network
285284

@@ -736,9 +735,38 @@ def set_diphone_priors_returnn_rasr(
736735

737736
self.experiments[key]["priors"] = p_info
738737

739-
740-
def set_triphone_priors_factored(self):
738+
def set_triphone_priors_factored(
739+
self,
740+
key: str,
741+
epoch: int,
742+
tensor_map: DecodingTensorMap,
743+
partition_data_setup: PartitionDataSetup = None,
744+
model_path: tk.Path = None,
745+
):
741746
self.create_hdf()
747+
if self.experiments[key]["graph"].get("inference", None) is None:
748+
self.set_graph_for_experiment(key)
749+
if partition_data_setup is None:
750+
partition_data_setup = PartitionDataSetup()
751+
752+
if model_path is None:
753+
model_path = DelayedFormat(self.get_model_path(model_job=self.experiments[key]["train_job"], epoch=epoch))
754+
triphone_priors = get_triphone_priors(
755+
name=f"{self.experiments[key]['name']}/e{epoch}",
756+
graph_path=self.experiments[key]["graph"]["inference"],
757+
model_path=model_path,
758+
data_paths=self.hdfs[self.train_key],
759+
tensor_map=tensor_map,
760+
partition_data_setup=partition_data_setup,
761+
label_info=self.label_info,
762+
)
763+
764+
p_info = PriorInfo(
765+
center_state_prior=PriorConfig(file=triphone_priors[1], scale=0.0),
766+
left_context_prior=PriorConfig(file=triphone_priors[2], scale=0.0),
767+
right_context_prior=PriorConfig(file=triphone_priors[0], scale=0.0),
768+
)
769+
self.experiments[key]["priors"] = p_info
742770

743771
def set_triphone_priors_returnn_rasr(
744772
self,

‎users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -671,8 +671,11 @@ def recognize(
671671
if search_parameters.tdp_scale is not None:
672672
if name_override is None:
673673
name += f"-tdpScale-{search_parameters.tdp_scale}"
674-
name += f"-spTdp-{format_tdp(search_parameters.tdp_speech)}"
675674
name += f"-silTdp-{format_tdp(search_parameters.tdp_silence)}"
675+
if search_parameters.tdp_nonword is not None:
676+
name += f"-nwTdp-{format_tdp(search_parameters.tdp_nonword)}"
677+
name += f"-spTdp-{format_tdp(search_parameters.tdp_speech)}"
678+
676679

677680
if self.feature_scorer_type.is_factored():
678681
if search_parameters.transition_scales is not None:
@@ -758,6 +761,12 @@ def recognize(
758761
adv_search_extra_config = (
759762
copy.deepcopy(adv_search_extra_config) if adv_search_extra_config is not None else rasr.RasrConfig()
760763
)
764+
765+
if search_parameters.word_recombination_limit is not None:
766+
adv_search_extra_config.flf_lattice_tool.network.recognizer.recognizer.reduce_context_word_recombination = True
767+
adv_search_extra_config.flf_lattice_tool.network.recognizer.recognizer.reduce_context_word_recombination_limit = search_parameters.word_recombination_limit
768+
name += f"recombLim{search_parameters.word_recombination_limit}"
769+
761770
if search_parameters.altas is not None:
762771
adv_search_extra_config.flf_lattice_tool.network.recognizer.recognizer.acoustic_lookahead_temporal_approximation_scale = (
763772
search_parameters.altas
@@ -907,7 +916,7 @@ def recognize(
907916
if add_sis_alias_and_output:
908917
tk.register_output(f"{pre_path}/{name}.wer", scorer.out_report_dir)
909918

910-
if opt_lm_am and search_parameters.altas is None:
919+
if opt_lm_am and (search_parameters.altas is None or search_parameters.altas < 3.0):
911920
assert search_parameters.beam >= 15.0
912921
if pron_scale is not None:
913922
if isinstance(pron_scale, DelayedBase) and pron_scale.is_set():
@@ -1311,14 +1320,16 @@ def push_delayed_tuple(
13111320
best_priors = best_overall_n.out_argmin[0]
13121321
best_tdp_scale = best_overall_n.out_argmin[1]
13131322
best_tdp_sil = best_overall_n.out_argmin[2]
1314-
best_tdp_sp = best_overall_n.out_argmin[3]
1323+
best_tdp_nw = best_overall_n.out_argmin[3]
1324+
best_tdp_sp = best_overall_n.out_argmin[4]
13151325
if use_pron:
1316-
best_pron = best_overall_n.out_argmin[4]
1326+
best_pron = best_overall_n.out_argmin[5]
13171327

13181328
base_cfg = dataclasses.replace(
13191329
search_parameters,
13201330
tdp_scale=best_tdp_scale,
13211331
tdp_silence=push_delayed_tuple(best_tdp_sil),
1332+
tdp_nonword=push_delayed_tuple(best_tdp_nw),
13221333
tdp_speech=push_delayed_tuple(best_tdp_sp),
13231334
pron_scale=best_pron,
13241335
)
@@ -1327,6 +1338,7 @@ def push_delayed_tuple(
13271338
search_parameters,
13281339
tdp_scale=best_tdp_scale,
13291340
tdp_silence=push_delayed_tuple(best_tdp_sil),
1341+
tdp_nonword=push_delayed_tuple(best_tdp_nw),
13301342
tdp_speech=push_delayed_tuple(best_tdp_sp),
13311343
)
13321344

‎users/raissi/setups/common/decoder/config.py

+6
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class SearchParameters:
157157
altas: Optional[float] = None
158158
lm_lookahead_scale: Optional[float] = None
159159
lm_lookahead_history_limit: Int = 1
160+
word_recombination_limit: Optional[Int] = None
160161
posterior_scales: Optional[PosteriorScales] = None
161162
silence_penalties: Optional[Tuple[Float, Float]] = None # loop, fwd
162163
state_dependent_tdps: Optional[Union[str, tk.Path]] = None
@@ -189,6 +190,11 @@ def with_lm_lookahead_scale(self, scale: Float) -> "SearchParameters":
189190
def with_lm_lookahead_history_limit(self, history_limit: Int) -> "SearchParameters":
190191
return dataclasses.replace(self, lm_lookahead_history_limit=history_limit)
191192

193+
def with_word_recombination_limit(self, word_recombination_limit: Int) -> "SearchParameters":
194+
return dataclasses.replace(self, word_recombination_limit=word_recombination_limit)
195+
196+
197+
192198
def with_prior_scale(
193199
self,
194200
center: Optional[Float] = None,

‎users/raissi/setups/common/helpers/network/augment.py

+199-1
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,29 @@
2929
class LogLinearScales:
3030
label_posterior_scale: float
3131
transition_scale: float
32+
context_label_posterior_scale: float = 1.0
3233
label_prior_scale: Optional[float] = None
3334

3435
@classmethod
3536
def default(cls) -> "LogLinearScales":
36-
return cls(label_posterior_scale=0.3, label_prior_scale=None, transition_scale=0.3)
37+
return cls(label_posterior_scale=0.3, transition_scale=0.3, label_prior_scale=None, context_label_posterior_scale=1.0)
38+
39+
@dataclass(frozen=True, eq=True)
40+
class LossScales:
41+
center_scale:int = 1.0
42+
right_scale: int = 1.0
43+
left_scale: int = 1.0
44+
45+
def get_scale(self, label_name: str):
46+
if 'center' in label_name:
47+
return self.center_scale
48+
elif 'right' in label_name:
49+
return self.right_scale
50+
elif 'left' in label_name:
51+
return self.left_scale
52+
else:
53+
raise NotImplemented("Not recognized label name for output loss scale")
54+
3755

3856

3957
Layer = Dict[str, Any]
@@ -889,3 +907,183 @@ def add_fast_bw_layer_to_returnn_config(
889907
# ToDo: handel the import model part
890908

891909
return returnn_config
910+
911+
def add_fast_bw_factored_layer_to_network(
912+
crp: rasr.CommonRasrParameters,
913+
network: Network,
914+
log_linear_scales: LogLinearScales,
915+
loss_scales: LossScales,
916+
label_info: LabelInfo,
917+
reference_layers: [str] = ["left-output", "center-output" "right-output"],
918+
label_prior_type: Optional[PriorType] = None,
919+
label_prior: Optional[returnn.CodeWrapper] = None,
920+
label_prior_estimation_axes: str = None,
921+
extra_rasr_config: Optional[rasr.RasrConfig] = None,
922+
extra_rasr_post_config: Optional[rasr.RasrConfig] = None,
923+
) -> Network:
924+
925+
crp = correct_rasr_FSA_bug(crp)
926+
927+
if label_prior_type is not None:
928+
assert log_linear_scales.label_prior_scale is not None, "If you plan to use the prior, please set the scale for it"
929+
if label_prior_type == PriorType.TRANSCRIPT:
930+
assert label_prior is not None, "You forgot to set the label prior file"
931+
932+
inputs = []
933+
for reference_layer in reference_layers:
934+
for attribute in ["loss", "loss_opts", "target"]:
935+
if reference_layer in network:
936+
network[reference_layer].pop(attribute, None)
937+
938+
out_denot = reference_layer.split("-")[0]
939+
am_scale = log_linear_scales.label_posterior_scale if "center" in reference_layer else log_linear_scales.context_label_posterior_scale
940+
# prior calculation
941+
if label_prior_type is not None:
942+
prior_name = ("_").join(["label_prior", out_denot])
943+
comb_name = ("_").join(["comb-prior", out_denot])
944+
prior_eval_string = "(safe_log(source(1)) * prior_scale)"
945+
inputs.append(comb_name)
946+
if label_prior_type == PriorType.TRANSCRIPT:
947+
network[prior_name] = {"class": "constant", "dtype": "float32", "value": label_prior}
948+
elif label_prior_type == PriorType.AVERAGE:
949+
network[prior_name] = {
950+
"class": "accumulate_mean",
951+
"exp_average": 0.001,
952+
"from": reference_layer,
953+
"is_prob_distribution": True,
954+
}
955+
elif label_prior_type == PriorType.ONTHEFLY:
956+
assert label_prior_estimation_axes is not None, "You forgot to set one which axis you want to average the prior, eg. bt"
957+
network[prior_name] = {
958+
"class": "reduce",
959+
"mode": "mean",
960+
"from": reference_layer,
961+
"axis": label_prior_estimation_axes,
962+
}
963+
prior_eval_string = "tf.stop_gradient((safe_log(source(1)) * prior_scale))"
964+
else:
965+
raise NotImplementedError("Unknown PriorType")
966+
967+
network[comb_name] = {
968+
"class": "combine",
969+
"kind": "eval",
970+
"eval": f"am_scale*(safe_log(source(0)) - {prior_eval_string})",
971+
"eval_locals": {
972+
"am_scale": am_scale,
973+
"prior_scale": log_linear_scales.label_prior_scale,
974+
},
975+
"from": [reference_layer, prior_name],
976+
}
977+
978+
else:
979+
comb_name = ("_").join(["multiply-scale", out_denot])
980+
inputs.append(comb_name)
981+
network[comb_name] = {
982+
"class": "combine",
983+
"kind": "eval",
984+
"eval": "am_scale*(safe_log(source(0)))",
985+
"eval_locals": {"am_scale": am_scale},
986+
"from": [reference_layer],
987+
}
988+
989+
bw_out = ("_").join(["output-bw", out_denot])
990+
network[bw_out] = {
991+
"class": "copy",
992+
"from": reference_layer,
993+
"loss": "via_layer",
994+
"loss_opts": {
995+
"align_layer": ("/").join(["fast_bw", out_denot]),
996+
"loss_wrt_to_act_in": "softmax",
997+
},
998+
"loss_scale": loss_scales.get_scale(reference_layer),
999+
}
1000+
1001+
network["fast_bw"] = {
1002+
"class": "fast_bw_factored",
1003+
"align_target": "hmm-monophone",
1004+
"hmm_opts": {"num_contexts": label_info.n_contexts},
1005+
"from": inputs,
1006+
"tdp_scale": log_linear_scales.transition_scale,
1007+
"n_out": label_info.n_contexts*2 + label_info.get_n_state_classes()
1008+
}
1009+
1010+
# Create additional Rasr config file for the automaton
1011+
mapping = {
1012+
"corpus": "neural-network-trainer.corpus",
1013+
"lexicon": ["neural-network-trainer.alignment-fsa-exporter.model-combination.lexicon"],
1014+
"acoustic_model": ["neural-network-trainer.alignment-fsa-exporter.model-combination.acoustic-model"],
1015+
}
1016+
config, post_config = rasr.build_config_from_mapping(crp, mapping)
1017+
post_config["*"].output_channel.file = "fastbw.log"
1018+
1019+
# Define action
1020+
config.neural_network_trainer.action = "python-control"
1021+
# neural_network_trainer.alignment_fsa_exporter.allophone_state_graph_builder
1022+
config.neural_network_trainer.alignment_fsa_exporter.allophone_state_graph_builder.orthographic_parser.allow_for_silence_repetitions = (
1023+
False
1024+
)
1025+
config.neural_network_trainer.alignment_fsa_exporter.allophone_state_graph_builder.orthographic_parser.normalize_lemma_sequence_scores = (
1026+
False
1027+
)
1028+
# neural_network_trainer.alignment_fsa_exporter
1029+
config.neural_network_trainer.alignment_fsa_exporter.model_combination.acoustic_model.fix_allophone_context_at_word_boundaries = (
1030+
True
1031+
)
1032+
config.neural_network_trainer.alignment_fsa_exporter.model_combination.acoustic_model.transducer_builder_filter_out_invalid_allophones = (
1033+
True
1034+
)
1035+
1036+
# additional config
1037+
config._update(extra_rasr_config)
1038+
post_config._update(extra_rasr_post_config)
1039+
1040+
automaton_config = rasr.WriteRasrConfigJob(config, post_config).out_config
1041+
tk.register_output("train/bw.config", automaton_config)
1042+
1043+
network["fast_bw"]["sprint_opts"] = {
1044+
"sprintExecPath": rasr.RasrCommand.select_exe(crp.nn_trainer_exe, "nn-trainer"),
1045+
"sprintConfigStr": DelayedFormat("--config={}", automaton_config),
1046+
"sprintControlConfig": {"verbose": True},
1047+
"usePythonSegmentOrder": False,
1048+
"numInstances": 1,
1049+
}
1050+
1051+
return network
1052+
1053+
1054+
def add_fast_bw_factored_layer_to_returnn_config(
1055+
crp: rasr.CommonRasrParameters,
1056+
returnn_config: returnn.ReturnnConfig,
1057+
log_linear_scales: LogLinearScales,
1058+
loss_scales: LossScales,
1059+
label_info: LabelInfo,
1060+
import_model: [tk.Path, str] = None,
1061+
reference_layers: [str] = ["left-output", "center-output", "right-output"],
1062+
label_prior_type: Optional[PriorType] = None,
1063+
label_prior: Optional[returnn.CodeWrapper] = None,
1064+
label_prior_estimation_axes: str = None,
1065+
extra_rasr_config: Optional[rasr.RasrConfig] = None,
1066+
extra_rasr_post_config: Optional[rasr.RasrConfig] = None,
1067+
) -> returnn.ReturnnConfig:
1068+
1069+
returnn_config.config["network"] = add_fast_bw_factored_layer_to_network(
1070+
crp=crp,
1071+
network=returnn_config.config["network"],
1072+
log_linear_scales=log_linear_scales,
1073+
loss_scales=loss_scales,
1074+
label_info=label_info,
1075+
reference_layers=reference_layers,
1076+
label_prior_type=label_prior_type,
1077+
label_prior=label_prior,
1078+
label_prior_estimation_axes=label_prior_estimation_axes,
1079+
extra_rasr_config=extra_rasr_config,
1080+
extra_rasr_post_config=extra_rasr_post_config,
1081+
)
1082+
1083+
if "chunking" in returnn_config.config:
1084+
del returnn_config.config["chunking"]
1085+
if "pretrain" in returnn_config.config and import_model is not None:
1086+
del returnn_config.config["pretrain"]
1087+
1088+
return returnn_config
1089+

‎users/raissi/setups/common/helpers/priors/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
from .flat import CreateFlatPriorsJob
88
from .smoothen import smoothen_priors, SmoothenPriorsJob
99
from .scale import scale_priors, ScalePriorsJob
10-
from .transcription import get_mono_transcription_priors
10+
from .transcription import get_prior_from_transcription
1111
from .tri_join import JoinRightContextPriorsJob, ReshapeCenterStatePriorsJob

‎users/raissi/setups/common/helpers/priors/estimate_povey_like_prior_fh.py

+236-136
Large diffs are not rendered by default.

‎users/raissi/setups/common/helpers/priors/factored_estimation.py

+144-105
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,157 @@
1-
2-
def get_diphone_priors(graphPath, model, dataPaths, datasetIndices,
3-
nStateClasses=141, nContexts=47, gpu=1, time=20, isSilMapped=True, name=None, nBatch=10000, tf_library=None, tm=None):
1+
import numpy as np
2+
from typing import List
3+
4+
from sisyphus import *
5+
6+
7+
from i6_experiments.users.raissi.setups.common.data.factored_label import LabelInfo
8+
from i6_experiments.users.raissi.setups.common.decoder.BASE_factored_hybrid_search import DecodingTensorMap
9+
from i6_experiments.users.raissi.setups.common.helpers.priors.estimate_povey_like_prior_fh import (
10+
EstimateFactoredTriphonePriorsJob,
11+
CombineMeansForTriphoneForward,
12+
DumpXmlForTriphoneForwardJob,
13+
)
14+
15+
from i6_experiments.users.raissi.setups.common.helpers.priors.util import PartitionDataSetup
16+
17+
Path = setup_path(__package__)
18+
RANDOM_SEED = 42
19+
20+
21+
def get_triphone_priors(
22+
name: str,
23+
graph_path: Path,
24+
model_path: Path,
25+
data_paths: List[Path],
26+
label_info: LabelInfo,
27+
tensor_map: DecodingTensorMap,
28+
partition_data_setup: PartitionDataSetup,
29+
tf_library=None,
30+
n_batch=10000,
31+
cpu: int = 2,
32+
gpu: int = 1,
33+
time: int = 1,
34+
):
35+
36+
triphone_files = []
37+
diphone_files = []
38+
context_files = []
39+
num_segments = []
40+
41+
np.random.seed(RANDOM_SEED)
42+
for i in np.random.choice(range(len(data_paths)//partition_data_setup.data_offset), partition_data_setup.n_data_indices, replace=False):
43+
start_ind = i * partition_data_setup.data_offset
44+
end_ind = (i + 1) * partition_data_setup.data_offset
45+
for j in range(partition_data_setup.n_segment_indices):
46+
start_ind_seg = j * partition_data_setup.segment_offset
47+
end_ind_seg = (j + 1) * partition_data_setup.segment_offset
48+
# if end_ind_seg > 1248: end_ind_seg = 1248
49+
data_indices = list(range(start_ind, end_ind))
50+
estimateJob = EstimateFactoredTriphonePriorsJob(
51+
graph_path=graph_path,
52+
model_path=model_path,
53+
tensor_map=tensor_map,
54+
data_paths=data_paths,
55+
data_indices=data_indices,
56+
start_ind_segment=start_ind_seg,
57+
end_ind_segment=end_ind_seg,
58+
label_info=label_info,
59+
tf_library_path=tf_library,
60+
n_batch=n_batch,
61+
cpu=cpu,
62+
gpu=gpu,
63+
time=time,
64+
)
65+
if name is not None:
66+
estimateJob.add_alias(f"priors/{name}-{data_indices}_{start_ind_seg}")
67+
triphone_files.extend(estimateJob.triphone_files)
68+
diphone_files.extend(estimateJob.diphone_files)
69+
context_files.extend(estimateJob.context_files)
70+
num_segments.extend(estimateJob.num_segments)
71+
72+
comb_jobs = []
73+
for spliter in range(0, len(triphone_files), partition_data_setup.split_step):
74+
start = spliter
75+
end = min(spliter + partition_data_setup.split_step, len(triphone_files))
76+
comb_jobs.append(
77+
CombineMeansForTriphoneForward(
78+
triphone_files=triphone_files[start:end],
79+
diphone_files=diphone_files[start:end],
80+
context_files=context_files[start:end],
81+
num_segment_files=num_segments[start:end],
82+
label_info=label_info,
83+
)
84+
)
85+
86+
comb_triphone_files = [c.triphone_files_out for c in comb_jobs]
87+
comb_diphone_files = [c.diphone_files_out for c in comb_jobs]
88+
comb_context_files = [c.context_files_out for c in comb_jobs]
89+
comb_num_segs = [c.num_segments_out for c in comb_jobs]
90+
xmlJob = DumpXmlForTriphoneForwardJob(
91+
triphone_files=comb_triphone_files,
92+
diphone_files=comb_diphone_files,
93+
context_files=comb_context_files,
94+
num_segment_files=comb_num_segs,
95+
label_info=label_info
96+
)
97+
98+
prior_files_triphone = [xmlJob.triphone_xml, xmlJob.diphone_xml, xmlJob.context_xml]
99+
xml_name = f"priors/{name}"
100+
tk.register_output(xml_name, prior_files_triphone[0])
101+
102+
return prior_files_triphone
103+
104+
105+
# needs refactoring
106+
def get_diphone_priors(
107+
graph_path,
108+
model_path,
109+
data_paths,
110+
data_indices,
111+
nStateClasses=141,
112+
nContexts=47,
113+
gpu=1,
114+
time=20,
115+
isSilMapped=True,
116+
name=None,
117+
n_batch=10000,
118+
tf_library=None,
119+
tensor_map=None,
120+
):
4121

5122
if tf_library is None:
6123
tf_library = libraryPath
7-
if tm is None:
8-
tm = defaultTfMap
9-
10-
estimateJob = EstimateSprintDiphoneAndContextPriors(graphPath,
11-
model,
12-
dataPaths,
13-
datasetIndices,
14-
tf_library,
15-
nContexts=nContexts,
16-
nStateClasses=nStateClasses,
17-
gpu=gpu,
18-
time=time,
19-
tensorMap=tm,
20-
nBatch=nBatch ,)
124+
if tensor_map is None:
125+
tensor_map = defaultTfMap
126+
127+
estimateJob = EstimateSprintDiphoneAndContextPriors(
128+
graph_path,
129+
model_path,
130+
data_paths,
131+
data_indices,
132+
tf_library,
133+
nContexts=nContexts,
134+
nStateClasses=nStateClasses,
135+
gpu=gpu,
136+
time=time,
137+
tensorMap=tensor_map,
138+
n_batch=n_batch,
139+
)
21140
if name is not None:
22141
estimateJob.add_alias(f"priors/{name}")
23142

24-
xmlJob = DumpXmlSprintForDiphone(estimateJob.diphoneFiles,
25-
estimateJob.contextFiles,
26-
estimateJob.numSegments,
27-
nContexts=nContexts,
28-
nStateClasses=nStateClasses,
29-
adjustSilence=isSilMapped)
143+
xmlJob = DumpXmlSprintForDiphone(
144+
estimateJob.diphone_files,
145+
estimateJob.context_files,
146+
estimateJob.num_segments,
147+
nContexts=nContexts,
148+
nStateClasses=nStateClasses,
149+
adjustSilence=isSilMapped,
150+
)
30151

31152
priorFiles = [xmlJob.diphoneXml, xmlJob.contextXml]
32153

33154
xmlName = f"priors/{name}"
34155
tk.register_output(xmlName, priorFiles[0])
35156

36157
return priorFiles
37-
38-
39-
40-
def get_triphone_priors(graphPath, model, dataPaths, nStateClasses=282, nContexts=47, nPhones=47, nStates=3,
41-
cpu=2, gpu=1, time=1, nBatch=18000, dNum=3, sNum=20, step=200, dataOffset=10, segmentOffset=10,
42-
name=None, tf_library=None, tm=None, isMulti=False):
43-
if tf_library is None:
44-
tf_library = libraryPath
45-
if tm is None:
46-
tm = defaultTfMap
47-
48-
triphoneFiles = []
49-
diphoneFiles = []
50-
contextFiles = []
51-
numSegments = []
52-
53-
54-
for i in range(2, dNum + 2):
55-
startInd = i * dataOffset
56-
endInd = (i + 1) * dataOffset
57-
for j in range(sNum):
58-
startSegInd = j * segmentOffset
59-
endSegInd = (j + 1) * segmentOffset
60-
if endSegInd > 1248: endSegInd = 1248
61-
62-
datasetIndices = list(range(startInd, endInd))
63-
estimateJob = EstimateSprintTriphonePriorsForward(graphPath,
64-
model,
65-
dataPaths,
66-
datasetIndices,
67-
startSegInd, endSegInd,
68-
tf_library,
69-
nContexts=nContexts,
70-
nStateClasses=nStateClasses,
71-
nStates=nStates,
72-
nPhones=nPhones,
73-
nBatch=nBatch,
74-
cpu=cpu,
75-
gpu=gpu,
76-
time=time,
77-
tensorMap=tm,
78-
isMultiEncoder=isMulti)
79-
if name is not None:
80-
estimateJob.add_alias(f"priors/{name}-startind{startSegInd}")
81-
triphoneFiles.extend(estimateJob.triphoneFiles)
82-
diphoneFiles.extend(estimateJob.diphoneFiles)
83-
contextFiles.extend(estimateJob.contextFiles)
84-
numSegments.extend(estimateJob.numSegments)
85-
86-
87-
88-
comJobs = []
89-
for spliter in range(0, len(triphoneFiles), step):
90-
start = spliter
91-
end = spliter + step
92-
if end > len(triphoneFiles):
93-
end = triphoneFiles
94-
comJobs.append(CombineMeansForTriphoneForward(triphoneFiles[start:end],
95-
diphoneFiles[start:end],
96-
contextFiles[start:end],
97-
numSegments[start:end],
98-
nContexts=nContexts,
99-
nStates=nStateClasses,
100-
))
101-
102-
combTriphoneFiles = [c.triphoneFilesOut for c in comJobs]
103-
combDiphoneFiles = [c.diphoneFilesOut for c in comJobs]
104-
combContextFiles = [c.contextFilesOut for c in comJobs]
105-
combNumSegs = [c.numSegmentsOut for c in comJobs]
106-
xmlJob = DumpXmlForTriphoneForward(combTriphoneFiles,
107-
combDiphoneFiles,
108-
combContextFiles,
109-
combNumSegs,
110-
nContexts=nContexts,
111-
nStates=nStateClasses)
112-
113-
priorFilesTriphone = [xmlJob.triphoneXml, xmlJob.diphoneXml, xmlJob.contextXml]
114-
xmlName = f"priors/{name}"
115-
tk.register_output(xmlName, priorFilesTriphone[0])
116-
117-
118-
return priorFilesTriphone
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,82 @@
1-
__all__ = ["get_mono_transcription_priors"]
1+
from sisyphus import *
2+
from sisyphus.tools import try_get
23

3-
import numpy as np
4-
from typing import Iterator, List
5-
import pickle
4+
import os
65

7-
from sisyphus import Job, Task
6+
from i6_core.corpus.transform import ApplyLexiconToCorpusJob
7+
from i6_core.lexicon.allophones import DumpStateTyingJob
8+
from i6_core.lexicon.modification import AddEowPhonemesToLexiconJob
89

9-
from i6_experiments.users.raissi.setups.common.decoder.config import PriorInfo, PriorConfig
10-
from i6_experiments.users.raissi.setups.common.helpers.priors.util import write_prior_xml
1110

11+
from i6_experiments.users.mann.experimental.statistics import AllophoneCounts
12+
from i6_experiments.users.mann.setups.prior import PriorFromTranscriptionCounts
1213

13-
pickles = {
14-
(
15-
1,
16-
False,
17-
): "/work/asr4/raissi/setups/librispeech/960-ls/dependencies/priors/daniel/monostate/monostate.pickle",
18-
(
19-
1,
20-
True,
21-
): "/work/asr4/raissi/setups/librispeech/960-ls/dependencies/priors/daniel/monostate/monostate.we.pickle",
22-
(
23-
3,
24-
False,
25-
): "/work/asr4/raissi/setups/librispeech/960-ls/dependencies/priors/daniel/threepartite/threepartite.pickle",
26-
(
27-
3,
28-
True,
29-
): "/work/asr4/raissi/setups/librispeech/960-ls/dependencies/priors/daniel/threepartite/threepartite.we.pickle",
30-
}
3114

15+
def output(name, value):
16+
opath = os.path.join(fname, name)
17+
if isinstance(value, dict):
18+
tk.register_report(opath, DescValueReport(value))
19+
return
20+
tk.register_report(opath, SimpleValueReport(value))
3221

33-
class LoadTranscriptionPriorsJob(Job):
34-
def __init__(self, n: int, eow: bool):
35-
assert n in [1, 3]
22+
from sisyphus.delayed_ops import DelayedBase
3623

37-
self.n = n
38-
self.eow = eow
24+
class DelayedGetDefault(DelayedBase):
25+
def __init__(self, a, b, default=None):
26+
super().__init__(a, b)
27+
self.default = default
3928

40-
self.out_priors = self.output_path("priors.xml")
29+
def get(self):
30+
try:
31+
return try_get(self.a)[try_get(self.b)]
32+
except KeyError:
33+
return self.default
4134

42-
def tasks(self) -> Iterator[Task]:
43-
yield Task("run", mini_task=True)
4435

45-
def run(self):
46-
file = pickles[(self.n, self.eow)]
36+
def get_prior_from_transcription(
37+
crp,
38+
total_frames,
39+
average_phoneme_frames,
40+
epsilon=1e-12,
41+
lemma_end_probability=0.0,
4742

48-
with open(file, "rb") as f:
49-
priors: List[float] = pickle.load(f)
43+
):
5044

51-
write_prior_xml(log_priors=np.log(priors), path=self.out_priors)
45+
lexicon_w_we = AddEowPhonemesToLexiconJob(
46+
crp.lexicon_config.file,
47+
boundary_marker=" #", # the prepended space is important
48+
)
5249

50+
corpus = crp.corpus_config.file
51+
if not isinstance(crp.corpus_config.file, tk.Path):
52+
corpus = tk.Path(crp.corpus_config.file)
5353

54-
def get_mono_transcription_priors(states_per_phone: int, with_word_end: bool) -> PriorInfo:
55-
load_j = LoadTranscriptionPriorsJob(states_per_phone, with_word_end)
56-
return PriorInfo(center_state_prior=PriorConfig(file=load_j.out_priors, scale=0.0))
54+
55+
transcribe_job = ApplyLexiconToCorpusJob(
56+
corpus,
57+
lexicon_w_we.out_lexicon,
58+
)
59+
60+
count_phonemes = AllophoneCounts(
61+
transcribe_job.out_corpus,
62+
lemma_end_probability=lemma_end_probability,
63+
)
64+
65+
state_tying_file = DumpStateTyingJob(crp).out_state_tying
66+
67+
68+
69+
prior_job = PriorFromTranscriptionCounts(
70+
allophone_counts=count_phonemes.counts,
71+
total_count=count_phonemes.total,
72+
state_tying=state_tying_file,
73+
average_phoneme_frames=average_phoneme_frames,
74+
num_frames=total_frames,
75+
eps=epsilon,
76+
)
77+
78+
return {
79+
"txt": prior_job.out_prior_txt_file,
80+
"xml": prior_job.out_prior_xml_file,
81+
"png": prior_job.out_prior_png_file
82+
}

‎users/raissi/setups/common/helpers/priors/util.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,21 @@
22

33
from dataclasses import dataclass
44
import numpy as np
5-
from typing import List, Tuple, Union
5+
from typing import List, Tuple, Union
66
import xml.etree.ElementTree as ET
77

88
from sisyphus import Path
99

1010

11+
@dataclass(frozen=True, eq=True)
12+
class PartitionDataSetup:
13+
n_segment_indices: int = 20
14+
n_data_indices: int = 3
15+
segment_offset: int = 10
16+
data_offset: int = 10
17+
split_step: int = 200
18+
19+
1120
@dataclass(frozen=True, eq=True)
1221
class ParsedPriors:
1322
priors_log: List[float]
@@ -81,4 +90,4 @@ def get_batch_from_segments(segments: List, batchSize=10000):
8190
yield segments[index * batchSize : (index + 1) * batchSize]
8291
index += 1
8392
except IndexError:
84-
index = 0
93+
index = 0

‎users/raissi/setups/common/util/tdp.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Union, Tuple
44

55
from sisyphus import tk
6-
from sisyphus.delayed_ops import DelayedBase
6+
from sisyphus.delayed_ops import DelayedBase, DelayedGetItem
77

88
from i6_experiments.common.setups.rasr.config.am_config import Tdp
99
from i6_experiments.users.raissi.setups.common.data.typings import TDP
@@ -14,6 +14,8 @@ def to_tdp(tdp_tuple: Tuple[TDP, TDP, TDP, TDP]) -> Tdp:
1414

1515

1616
def format_tdp_val(val) -> str:
17+
if isinstance(val, DelayedGetItem):
18+
val = val.get()
1719
return "inf" if val == "infinity" else f"{val}"
1820

1921

‎users/raissi/setups/librispeech/decoder/LBS_factored_hybrid_search.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
lm_gc_simple_hash=lm_gc_simple_hash,
9898
gpu=gpu,
9999
)
100+
self.trafo_lm_config = self.get_eugen_trafo_with_quant_and_compress_config()
100101

101102
def get_ls_kazuki_lstm_lm_config(
102103
self,
@@ -115,7 +116,7 @@ def get_ls_kazuki_lstm_lm_config(
115116
state_manager="lstm",
116117
).get()
117118

118-
def get_eugen_trafo_config(
119+
def get_eugen_trafo_with_quant_and_compress_config(
119120
self,
120121
min_batch_size: int = 0,
121122
opt_batch_size: int = 64,
@@ -229,6 +230,62 @@ def get_eugen_trafo_config(
229230

230231
return trafo_config
231232

233+
def get_eugen_trafo_config(
234+
self,
235+
min_batch_size: int = 0,
236+
opt_batch_size: int = 64,
237+
max_batch_size: int = 64,
238+
scale: Optional[float] = None,
239+
) -> rasr.RasrConfig:
240+
# assert self.library_path is not None
241+
242+
243+
trafo_config = rasr.RasrConfig()
244+
245+
trafo_config.min_batch_size = min_batch_size
246+
trafo_config.opt_batch_size = opt_batch_size
247+
trafo_config.max_batch_size = max_batch_size
248+
trafo_config.allow_reduced_history = True
249+
if scale is not None:
250+
trafo_config.scale = scale
251+
trafo_config.type = "tfrnn"
252+
trafo_config.vocab_file = tk.Path("/work/asr3/raissi/shared_workspaces/gunz/dependencies/ls-eugen-trafo-lm/vocabulary", cached=True)
253+
trafo_config.transform_output_negate = True
254+
trafo_config.vocab_unknown_word = "<UNK>"
255+
256+
trafo_config.input_map.info_0.param_name = "word"
257+
trafo_config.input_map.info_0.tensor_name = "extern_data/placeholders/delayed/delayed"
258+
trafo_config.input_map.info_0.seq_length_tensor_name = "extern_data/placeholders/delayed/delayed_dim0_size"
259+
260+
trafo_config.input_map.info_1.param_name = "state-lengths"
261+
trafo_config.input_map.info_1.tensor_name = "output/rec/dec_0_self_att_att/state_lengths"
262+
263+
trafo_config.loader.type = "meta"
264+
trafo_config.loader.meta_graph_file = (
265+
"/work/asr4/raissi/setups/librispeech/960-ls/dependencies/trafo-lm_eugen/integrated_fixup_graph_no_cp_no_quant.meta"
266+
)
267+
model_path = "/work/asr3/raissi/shared_workspaces/gunz/dependencies/ls-eugen-trafo-lm/epoch.030"
268+
trafo_config.loader.saved_model_file = rasr.StringWrapper(model_path, f"{model_path}.index")
269+
trafo_config.loader.required_libraries = self.library_path
270+
271+
trafo_config.output_map.info_0.param_name = "softmax"
272+
trafo_config.output_map.info_0.tensor_name = "output/rec/decoder/add"
273+
274+
trafo_config.output_map.info_1.param_name = "weights"
275+
trafo_config.output_map.info_1.tensor_name = "output/rec/output/W/read"
276+
277+
trafo_config.output_map.info_2.param_name = "bias"
278+
trafo_config.output_map.info_2.tensor_name = "output/rec/output/b/read"
279+
280+
281+
trafo_config.state_manager.cache_prefix = True
282+
trafo_config.state_manager.min_batch_size = min_batch_size
283+
trafo_config.state_manager.min_common_prefix_length = 0
284+
trafo_config.state_manager.type = "transformer"
285+
trafo_config.softmax_adapter.type = "blas-nce"
286+
287+
return trafo_config
288+
232289
def recognize_ls_trafo_lm(
233290
self,
234291
*,
@@ -265,7 +322,7 @@ def recognize_ls_trafo_lm(
265322
is_nn_lm=True,
266323
keep_value=keep_value,
267324
label_info=label_info,
268-
lm_config=self.get_eugen_trafo_config(),
325+
lm_config=self.trafo_lm_config,
269326
name_override=name_override,
270327
name_prefix=name_prefix,
271328
num_encoder_output=num_encoder_output,

‎users/raissi/utils/default_tools.py

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def get_rasr_binary_path(rasr_path):
9393
hash_overwrite="CONFORMER_RETURNN_Len_FIX",
9494
)
9595
RETURNN_ROOT_TORCH = tk.Path("/work/tools/users/raissi/returnn_versions/torch", hash_overwrite="TORCH_RETURNN_ROOT")
96+
RETURNN_ROOT_BW_FACTORED = tk.Path("/work/tools/users/raissi/returnn_versions/bw-factored", hash_overwrite="BW_RETURNN_ROOT")
9697

9798
SCTK_BINARY_PATH = compile_sctk(branch="v2.4.12") # use last published version
9899
SCTK_BINARY_PATH.hash_overwrite = "DEFAULT_SCTK_BINARY_PATH"

0 commit comments

Comments
 (0)
Please sign in to comment.