|
29 | 29 | class LogLinearScales:
|
30 | 30 | label_posterior_scale: float
|
31 | 31 | transition_scale: float
|
| 32 | + context_label_posterior_scale: float = 1.0 |
32 | 33 | label_prior_scale: Optional[float] = None
|
33 | 34 |
|
34 | 35 | @classmethod
|
35 | 36 | def default(cls) -> "LogLinearScales":
|
36 |
| - return cls(label_posterior_scale=0.3, label_prior_scale=None, transition_scale=0.3) |
| 37 | + return cls(label_posterior_scale=0.3, transition_scale=0.3, label_prior_scale=None, context_label_posterior_scale=1.0) |
| 38 | + |
| 39 | +@dataclass(frozen=True, eq=True) |
| 40 | +class LossScales: |
| 41 | + center_scale:int = 1.0 |
| 42 | + right_scale: int = 1.0 |
| 43 | + left_scale: int = 1.0 |
| 44 | + |
| 45 | + def get_scale(self, label_name: str): |
| 46 | + if 'center' in label_name: |
| 47 | + return self.center_scale |
| 48 | + elif 'right' in label_name: |
| 49 | + return self.right_scale |
| 50 | + elif 'left' in label_name: |
| 51 | + return self.left_scale |
| 52 | + else: |
| 53 | + raise NotImplemented("Not recognized label name for output loss scale") |
| 54 | + |
37 | 55 |
|
38 | 56 |
|
39 | 57 | Layer = Dict[str, Any]
|
@@ -889,3 +907,183 @@ def add_fast_bw_layer_to_returnn_config(
|
889 | 907 | # ToDo: handel the import model part
|
890 | 908 |
|
891 | 909 | return returnn_config
|
| 910 | + |
| 911 | +def add_fast_bw_factored_layer_to_network( |
| 912 | + crp: rasr.CommonRasrParameters, |
| 913 | + network: Network, |
| 914 | + log_linear_scales: LogLinearScales, |
| 915 | + loss_scales: LossScales, |
| 916 | + label_info: LabelInfo, |
| 917 | + reference_layers: [str] = ["left-output", "center-output" "right-output"], |
| 918 | + label_prior_type: Optional[PriorType] = None, |
| 919 | + label_prior: Optional[returnn.CodeWrapper] = None, |
| 920 | + label_prior_estimation_axes: str = None, |
| 921 | + extra_rasr_config: Optional[rasr.RasrConfig] = None, |
| 922 | + extra_rasr_post_config: Optional[rasr.RasrConfig] = None, |
| 923 | +) -> Network: |
| 924 | + |
| 925 | + crp = correct_rasr_FSA_bug(crp) |
| 926 | + |
| 927 | + if label_prior_type is not None: |
| 928 | + assert log_linear_scales.label_prior_scale is not None, "If you plan to use the prior, please set the scale for it" |
| 929 | + if label_prior_type == PriorType.TRANSCRIPT: |
| 930 | + assert label_prior is not None, "You forgot to set the label prior file" |
| 931 | + |
| 932 | + inputs = [] |
| 933 | + for reference_layer in reference_layers: |
| 934 | + for attribute in ["loss", "loss_opts", "target"]: |
| 935 | + if reference_layer in network: |
| 936 | + network[reference_layer].pop(attribute, None) |
| 937 | + |
| 938 | + out_denot = reference_layer.split("-")[0] |
| 939 | + am_scale = log_linear_scales.label_posterior_scale if "center" in reference_layer else log_linear_scales.context_label_posterior_scale |
| 940 | + # prior calculation |
| 941 | + if label_prior_type is not None: |
| 942 | + prior_name = ("_").join(["label_prior", out_denot]) |
| 943 | + comb_name = ("_").join(["comb-prior", out_denot]) |
| 944 | + prior_eval_string = "(safe_log(source(1)) * prior_scale)" |
| 945 | + inputs.append(comb_name) |
| 946 | + if label_prior_type == PriorType.TRANSCRIPT: |
| 947 | + network[prior_name] = {"class": "constant", "dtype": "float32", "value": label_prior} |
| 948 | + elif label_prior_type == PriorType.AVERAGE: |
| 949 | + network[prior_name] = { |
| 950 | + "class": "accumulate_mean", |
| 951 | + "exp_average": 0.001, |
| 952 | + "from": reference_layer, |
| 953 | + "is_prob_distribution": True, |
| 954 | + } |
| 955 | + elif label_prior_type == PriorType.ONTHEFLY: |
| 956 | + assert label_prior_estimation_axes is not None, "You forgot to set one which axis you want to average the prior, eg. bt" |
| 957 | + network[prior_name] = { |
| 958 | + "class": "reduce", |
| 959 | + "mode": "mean", |
| 960 | + "from": reference_layer, |
| 961 | + "axis": label_prior_estimation_axes, |
| 962 | + } |
| 963 | + prior_eval_string = "tf.stop_gradient((safe_log(source(1)) * prior_scale))" |
| 964 | + else: |
| 965 | + raise NotImplementedError("Unknown PriorType") |
| 966 | + |
| 967 | + network[comb_name] = { |
| 968 | + "class": "combine", |
| 969 | + "kind": "eval", |
| 970 | + "eval": f"am_scale*(safe_log(source(0)) - {prior_eval_string})", |
| 971 | + "eval_locals": { |
| 972 | + "am_scale": am_scale, |
| 973 | + "prior_scale": log_linear_scales.label_prior_scale, |
| 974 | + }, |
| 975 | + "from": [reference_layer, prior_name], |
| 976 | + } |
| 977 | + |
| 978 | + else: |
| 979 | + comb_name = ("_").join(["multiply-scale", out_denot]) |
| 980 | + inputs.append(comb_name) |
| 981 | + network[comb_name] = { |
| 982 | + "class": "combine", |
| 983 | + "kind": "eval", |
| 984 | + "eval": "am_scale*(safe_log(source(0)))", |
| 985 | + "eval_locals": {"am_scale": am_scale}, |
| 986 | + "from": [reference_layer], |
| 987 | + } |
| 988 | + |
| 989 | + bw_out = ("_").join(["output-bw", out_denot]) |
| 990 | + network[bw_out] = { |
| 991 | + "class": "copy", |
| 992 | + "from": reference_layer, |
| 993 | + "loss": "via_layer", |
| 994 | + "loss_opts": { |
| 995 | + "align_layer": ("/").join(["fast_bw", out_denot]), |
| 996 | + "loss_wrt_to_act_in": "softmax", |
| 997 | + }, |
| 998 | + "loss_scale": loss_scales.get_scale(reference_layer), |
| 999 | + } |
| 1000 | + |
| 1001 | + network["fast_bw"] = { |
| 1002 | + "class": "fast_bw_factored", |
| 1003 | + "align_target": "hmm-monophone", |
| 1004 | + "hmm_opts": {"num_contexts": label_info.n_contexts}, |
| 1005 | + "from": inputs, |
| 1006 | + "tdp_scale": log_linear_scales.transition_scale, |
| 1007 | + "n_out": label_info.n_contexts*2 + label_info.get_n_state_classes() |
| 1008 | + } |
| 1009 | + |
| 1010 | + # Create additional Rasr config file for the automaton |
| 1011 | + mapping = { |
| 1012 | + "corpus": "neural-network-trainer.corpus", |
| 1013 | + "lexicon": ["neural-network-trainer.alignment-fsa-exporter.model-combination.lexicon"], |
| 1014 | + "acoustic_model": ["neural-network-trainer.alignment-fsa-exporter.model-combination.acoustic-model"], |
| 1015 | + } |
| 1016 | + config, post_config = rasr.build_config_from_mapping(crp, mapping) |
| 1017 | + post_config["*"].output_channel.file = "fastbw.log" |
| 1018 | + |
| 1019 | + # Define action |
| 1020 | + config.neural_network_trainer.action = "python-control" |
| 1021 | + # neural_network_trainer.alignment_fsa_exporter.allophone_state_graph_builder |
| 1022 | + config.neural_network_trainer.alignment_fsa_exporter.allophone_state_graph_builder.orthographic_parser.allow_for_silence_repetitions = ( |
| 1023 | + False |
| 1024 | + ) |
| 1025 | + config.neural_network_trainer.alignment_fsa_exporter.allophone_state_graph_builder.orthographic_parser.normalize_lemma_sequence_scores = ( |
| 1026 | + False |
| 1027 | + ) |
| 1028 | + # neural_network_trainer.alignment_fsa_exporter |
| 1029 | + config.neural_network_trainer.alignment_fsa_exporter.model_combination.acoustic_model.fix_allophone_context_at_word_boundaries = ( |
| 1030 | + True |
| 1031 | + ) |
| 1032 | + config.neural_network_trainer.alignment_fsa_exporter.model_combination.acoustic_model.transducer_builder_filter_out_invalid_allophones = ( |
| 1033 | + True |
| 1034 | + ) |
| 1035 | + |
| 1036 | + # additional config |
| 1037 | + config._update(extra_rasr_config) |
| 1038 | + post_config._update(extra_rasr_post_config) |
| 1039 | + |
| 1040 | + automaton_config = rasr.WriteRasrConfigJob(config, post_config).out_config |
| 1041 | + tk.register_output("train/bw.config", automaton_config) |
| 1042 | + |
| 1043 | + network["fast_bw"]["sprint_opts"] = { |
| 1044 | + "sprintExecPath": rasr.RasrCommand.select_exe(crp.nn_trainer_exe, "nn-trainer"), |
| 1045 | + "sprintConfigStr": DelayedFormat("--config={}", automaton_config), |
| 1046 | + "sprintControlConfig": {"verbose": True}, |
| 1047 | + "usePythonSegmentOrder": False, |
| 1048 | + "numInstances": 1, |
| 1049 | + } |
| 1050 | + |
| 1051 | + return network |
| 1052 | + |
| 1053 | + |
| 1054 | +def add_fast_bw_factored_layer_to_returnn_config( |
| 1055 | + crp: rasr.CommonRasrParameters, |
| 1056 | + returnn_config: returnn.ReturnnConfig, |
| 1057 | + log_linear_scales: LogLinearScales, |
| 1058 | + loss_scales: LossScales, |
| 1059 | + label_info: LabelInfo, |
| 1060 | + import_model: [tk.Path, str] = None, |
| 1061 | + reference_layers: [str] = ["left-output", "center-output", "right-output"], |
| 1062 | + label_prior_type: Optional[PriorType] = None, |
| 1063 | + label_prior: Optional[returnn.CodeWrapper] = None, |
| 1064 | + label_prior_estimation_axes: str = None, |
| 1065 | + extra_rasr_config: Optional[rasr.RasrConfig] = None, |
| 1066 | + extra_rasr_post_config: Optional[rasr.RasrConfig] = None, |
| 1067 | +) -> returnn.ReturnnConfig: |
| 1068 | + |
| 1069 | + returnn_config.config["network"] = add_fast_bw_factored_layer_to_network( |
| 1070 | + crp=crp, |
| 1071 | + network=returnn_config.config["network"], |
| 1072 | + log_linear_scales=log_linear_scales, |
| 1073 | + loss_scales=loss_scales, |
| 1074 | + label_info=label_info, |
| 1075 | + reference_layers=reference_layers, |
| 1076 | + label_prior_type=label_prior_type, |
| 1077 | + label_prior=label_prior, |
| 1078 | + label_prior_estimation_axes=label_prior_estimation_axes, |
| 1079 | + extra_rasr_config=extra_rasr_config, |
| 1080 | + extra_rasr_post_config=extra_rasr_post_config, |
| 1081 | + ) |
| 1082 | + |
| 1083 | + if "chunking" in returnn_config.config: |
| 1084 | + del returnn_config.config["chunking"] |
| 1085 | + if "pretrain" in returnn_config.config and import_model is not None: |
| 1086 | + del returnn_config.config["pretrain"] |
| 1087 | + |
| 1088 | + return returnn_config |
| 1089 | + |
0 commit comments