Skip to content

Commit 67ebdce

Browse files
committed
Update models
1 parent 35fa450 commit 67ebdce

13 files changed

+170
-85
lines changed

LaSE/LaSE/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def load_langid_model(cache_dir=None):
6969
"marathi": "mr",
7070
"spanish": "es",
7171
"scottish_gaelic": "gd",
72-
"nepali": "np",
72+
"nepali": "ne",
7373
"pashto": "ps",
7474
"persian": "fa",
7575
"pidgin": "pcm",

README.md

+68-55
Large diffs are not rendered by default.

figs/ar_tgt_lase.png

84 KB
Loading

figs/en_tgt_rouge2.png

89.7 KB
Loading

figs/hi_tgt_rouge2.png

89.2 KB
Loading

figs/ru_tgt_rouge2.png

86.7 KB
Loading

seq2seq/README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ We use a modified fork of [huggingface transformers](https://github.com/huggingf
33
## Setup
44

55
```bash
6-
$ git clone https://github.com/abhik1505040/crossum
6+
$ git clone https://github.com/csebuetnlp/crossum
77
$ cd crossum/seq2seq
88
$ conda create python==3.7.9 pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch -p ./env
99
$ conda activate ./env # or source activate ./env (for older versions of anaconda)
1010
$ bash setup.sh
1111
```
1212

13+
- **Note**: For newer NVIDIA GPUS such as ***A100*** or ***3090*** use `cudatoolkit=11.1`.
14+
1315
## Downloading data
1416

1517
This script downloads the metadata-stripped version of the dataset required for training.
@@ -30,7 +32,7 @@ Some sample commands for training on a 8 GPU node are given below.
3032
For multi-node usage with SLURM, refer to [job.sh]().
3133

3234
```bash
33-
bash trainer.sh --ngpus 8 --training_type m2m # trains the many-to-many model
35+
bash trainer.sh --ngpus 8 --training_type m2m --sampling multistage # trains the many-to-many model with multistage sampling
3436
bash trainer.sh --ngpus 8 --training_type m2o --pivot_lang arabic # trains the many-to-one model using arabic as the target language
3537
bash trainer.sh --ngpus 8 --training_type o2m --pivot_lang english # trains the one-to-many model using english as the source language
3638
```

seq2seq/download_data.sh

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
FILE="dataset.tar.bz2"
44

55
if [[ ! -d "dataset" ]]; then
6-
id="1ywYJEEaFnXIWW5xBwp0cNuPinDwQjCxe"
7-
curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${id}" > /dev/null
8-
curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${id}" -o ${FILE}
9-
rm ./cookie
6+
id="1bwURjAyQT6OkGRd_f9mwkWg9FABa_c6S"
7+
cert="https://docs.google.com/uc?export=download&id=${id}"
8+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate ${cert} -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${id}" -O ${FILE}
9+
rm -rf /tmp/cookies.txt
1010
tar -xvf ${FILE} && rm ${FILE}
1111
fi

seq2seq/evaluation_runner.sh

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
3+
ROOT_DATASET_DIR="dataset"
4+
ROOT_MODEL_DIR="output"
5+
RESULTS_DIR="evaluation_results"
6+
7+
for model_dir in $ROOT_MODEL_DIR/*/; do
8+
9+
suffix=$(basename $model_dir)
10+
read training_type pivot_lang rest <<< $(IFS="_"; echo $suffix)
11+
12+
if [[ "$training_type" = "m2o" ]]; then
13+
required_str="--required_tgt_lang ${pivot_lang}"
14+
elif [[ "$training_type" = "o2m" ]]; then
15+
required_str="--required_src_lang ${pivot_lang}"
16+
else
17+
required_str=" "
18+
fi
19+
20+
for data_type in "val" "test"; do
21+
python evaluator.py \
22+
--dataset_dir "${ROOT_DATASET_DIR}" \
23+
--output_dir "${RESULTS_DIR}/${suffix}" \
24+
--evaluation_type xlingual \
25+
--data_type ${data_type} \
26+
--xlingual_summarization_model_name_or_path $model_dir \
27+
$required_str
28+
done
29+
done

seq2seq/evaluator.py

+37-19
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def summarize_xlingual(
231231
tgt_lang,
232232
args
233233
):
234-
if os.path.isfile(os.path.join(output_dir, "test_generations.txt")):
234+
if os.path.isfile(os.path.join(output_dir, f"{args.data_type}_generations.txt")):
235235
return
236236

237237
script_path = os.path.abspath("pipeline.py")
@@ -246,9 +246,10 @@ def summarize_xlingual(
246246
f"--no_repeat_ngram_size {args.no_repeat_ngram_size}",
247247
f"--eval_beams {args.beam_size}",
248248
f"--tgt_lang {tgt_lang}",
249+
f"--rouge_lang {tgt_lang}",
249250
"--overwrite_output_dir",
250251
"--predict_with_generate",
251-
"--do_predict" if args.data_type == "test" else "--do_eval",
252+
"--do_predict",
252253
"--use_langid",
253254
"--seed 1234"
254255
]
@@ -272,8 +273,8 @@ def calculate_lase(
272273

273274

274275
def run(args):
275-
root_output_dir = os.path.join(args.output_dir, args.data_type, args.evaluation_type, "outputs")
276-
root_log_dir = os.path.join(args.output_dir, args.data_type, args.evaluation_type, "logs")
276+
root_output_dir = os.path.join(args.output_dir, args.data_type, "outputs")
277+
root_log_dir = os.path.join(args.output_dir, args.data_type, "logs")
277278

278279
os.makedirs(root_output_dir, exist_ok=True)
279280
os.makedirs(root_log_dir, exist_ok=True)
@@ -346,21 +347,38 @@ def evaluate(lase_key):
346347
pipeline_target_path
347348
)
348349

350+
# specially handly validation files
351+
# since output file is generated for
352+
# test files only
353+
if args.data_type == "val":
354+
shutil.copy(
355+
pipeline_source_path,
356+
os.path.join(dir_name, "test.source")
357+
)
358+
shutil.copy(
359+
pipeline_source_path,
360+
os.path.join(dir_name, "test.target")
361+
)
362+
349363
if args.evaluation_type == "xlingual":
350364
summarize_xlingual(dir_name, dir_name, tgt_lang, args)
365+
366+
if args.data_type == "val":
367+
shutil.move(
368+
os.path.join(dir_name, f"test_generations.txt"),
369+
os.path.join(dir_name, f"val_generations.txt")
370+
)
371+
372+
os.remove(os.path.join(dir_name, "test.source"))
373+
os.remove(os.path.join(dir_name, "test.target"))
374+
351375
pred_lines = read_lines(
352-
os.path.join(dir_name, "test_generations.txt")
376+
os.path.join(dir_name, f"{args.data_type}_generations.txt")
353377
)
354378
ref_lines = read_lines(pipeline_target_path)
355379

356-
if lase_key == "LaSE_in_lang":
357-
scores.update(
358-
calculate_rouge(pred_lines, ref_lines, rouge_lang=tgt_lang)
359-
)
360380

361-
lase_scores = calculate_lase(pred_lines, ref_lines, tgt_lang)
362-
scores[lase_key] = lase_scores["LaSE"]
363-
381+
364382
elif args.evaluation_type == "baseline":
365383
src_iso, tgt_iso = LANG2ISO.get(src_lang, None), LANG2ISO.get(tgt_lang, None)
366384
if (
@@ -386,13 +404,13 @@ def evaluate(lase_key):
386404
pred_lines = read_lines(translated_path)
387405
ref_lines = read_lines(pipeline_target_path)
388406

389-
if lase_key == "LaSE_in_lang":
390-
scores.update(
391-
calculate_rouge(pred_lines, ref_lines, rouge_lang=tgt_lang)
392-
)
407+
if lase_key == "LaSE_in_lang":
408+
scores.update(
409+
calculate_rouge(pred_lines, ref_lines, rouge_lang=tgt_lang)
410+
)
393411

394-
lase_scores = calculate_lase(pred_lines, ref_lines, tgt_lang)
395-
scores[lase_key] = lase_scores["LaSE"]
412+
lase_scores = calculate_lase(pred_lines, ref_lines, tgt_lang)
413+
scores[lase_key] = lase_scores["LaSE"]
396414

397415

398416
# first do crossum evaluation (in lang LaSE)
@@ -411,7 +429,7 @@ def evaluate(lase_key):
411429
gc.collect()
412430

413431
# aggregate results
414-
combined_results_path = os.path.join(args.output_dir, args.data_type, args.evaluation_type, "combined_results.log")
432+
combined_results_path = os.path.join(args.output_dir, args.data_type, "combined_results.log")
415433
logging.info("Writing the combined results to " + combined_results_path)
416434

417435
with open(combined_results_path, 'w') as outf:

seq2seq/pipeline.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,11 @@ def main():
602602
logger.info("*** Evaluate ***")
603603

604604
metrics = trainer.evaluate(
605-
metric_key_prefix="val", max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams
605+
metric_key_prefix="val",
606+
max_length=data_args.val_max_target_length,
607+
num_beams=data_args.eval_beams,
608+
length_penalty=data_args.length_penalty,
609+
no_repeat_ngram_size=data_args.no_repeat_ngram_size,
606610
)
607611
metrics["val_n_objs"] = data_args.n_val
608612
metrics["val_loss"] = round(metrics["val_loss"], 4)
@@ -618,7 +622,7 @@ def main():
618622
test_output = trainer.predict(
619623
test_dataset=test_dataset,
620624
metric_key_prefix="test",
621-
max_length=data_args.val_max_target_length,
625+
max_length=data_args.test_max_target_length,
622626
num_beams=data_args.eval_beams,
623627
length_penalty=data_args.length_penalty,
624628
no_repeat_ngram_size=data_args.no_repeat_ngram_size,

seq2seq/trainer.sh

+11-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ parser.add_argument('--training_type', type=str, choices=["m2m", "m2o", "o2m"],
1010
required=True, help='Training type (many-to-many/many-to-one/one-to-many)')
1111
parser.add_argument('--pivot_lang', type=str, default="english",
1212
help='Pivot language (Applicable for many-to-one and one-to-many)')
13+
parser.add_argument('--sampling', type=str, default="multistage", choices=["multistage", "unistage"],
14+
help='Sampling type (Applicable for many-to-many)')
1315
parser.add_argument('--exclude_native', action='store_true',
1416
default=False, help='Exclude the native-to-native filepairs during training')
1517
EOF
@@ -22,10 +24,17 @@ export ROOT_OUTPUT_DIR="${BASE_DIR}/output"
2224

2325
export PREFIX="${TRAINING_TYPE}_${PIVOT_LANG}"
2426
if [[ "$TRAINING_TYPE" = "m2m" ]]; then
25-
PREFIX="${TRAINING_TYPE}"
27+
PREFIX="${TRAINING_TYPE}_${SAMPLING}"
2628
OPTIONAL_ARGS=(
2729
"--multistage_upsampling_factors 0.5 0.75"
2830
)
31+
32+
if [[ "$SAMPLING" = "unistage" ]]; then
33+
OPTIONAL_ARGS=(
34+
"--upsampling_factor 0.25"
35+
)
36+
fi
37+
2938
else
3039
OPTIONAL_ARGS=(
3140
"--upsampling_factor 0.75"
@@ -40,7 +49,7 @@ fi
4049
export BASENAME="${PREFIX}_${SUFFIX}"
4150
export INPUT_DIR="${ROOT_INPUT_DIR}/${BASENAME}"
4251
export OUTPUT_DIR="${ROOT_OUTPUT_DIR}/${BASENAME}"
43-
export MIN_EXAMPLE_COUNT=32
52+
export MIN_EXAMPLE_COUNT=30
4453

4554
conda activate "${BASE_DIR}/env" || source activate "${BASE_DIR}/env"
4655

seq2seq/training_runner.sh

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
3+
bash trainer.sh --ngpus 8 --training_type m2o --pivot_lang english
4+
bash trainer.sh --ngpus 8 --training_type o2m --pivot_lang english
5+
bash trainer.sh --ngpus 8 --training_type m2o --pivot_lang hindi
6+
bash trainer.sh --ngpus 8 --training_type o2m --pivot_lang hindi
7+
bash trainer.sh --ngpus 8 --training_type m2o --pivot_lang russian
8+
bash trainer.sh --ngpus 8 --training_type o2m --pivot_lang russian
9+
bash trainer.sh --ngpus 8 --training_type m2o --pivot_lang arabic
10+
bash trainer.sh --ngpus 8 --training_type o2m --pivot_lang arabic

0 commit comments

Comments
 (0)