diff --git a/stanza/models/lemma_classifier/train_many.py b/stanza/models/lemma_classifier/train_many.py index acfb1f6100..1960bf30da 100644 --- a/stanza/models/lemma_classifier/train_many.py +++ b/stanza/models/lemma_classifier/train_many.py @@ -74,6 +74,12 @@ def train_n_models(num_models: int, base_path: str, args): args.save_name = new_save_name train_lstm_main(predefined_args=args) + if args.change_param == "attn_model": + for i in range(num_models): + new_save_name = os.path.join(base_path, f"attn_model_{args.num_heads}_heads_{i}.pt") + args.save_name = new_save_name + train_lstm_main(predefined_args=args) + def train_n_tfmrs(num_models: int, base_path: str, args): if args.multi_train_type == "tfmr":