-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdump.py
48 lines (40 loc) · 1.35 KB
/
dump.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import sys
from glob import glob
import torch
import torch.nn as nn
from fire import Fire
from tqdm import tqdm
class Main:
def load_checkpoint(self, filepath):
if "src" not in sys.path:
sys.path.append("src")
if "aligner" not in sys.modules:
import aligner
else:
aligner = sys.modules["aligner"]
return aligner.Aligner.load_from_checkpoint(filepath)
def single(self, ckpt, dumpdir):
ckpt = glob(ckpt, recursive=True)
assert len(ckpt) == 1, ckpt
aligner = self.load_checkpoint(ckpt[0])
os.makedirs(dumpdir, exist_ok=True)
aligner.model.save_pretrained(dumpdir)
aligner.tokenizer.save_pretrained(dumpdir)
def linear(
self,
root_dir="/bigdata",
data="opus",
model="bert-base-multilingual-cased",
name="linear-orth0.01",
):
langs = "ar de es fr hi ru vi zh".split()
_dir = f"{root_dir}/checkpoints/alignment/{data}"
mapping = nn.ModuleDict()
for lang in tqdm(langs):
ckpt = f"{_dir}/en-{lang}-subset1/{model}-sim_linear/{name}/version_0/mapping.pth"
mapping[lang] = torch.load(ckpt)
os.makedirs(f"mapping/{name}", exist_ok=True)
torch.save(mapping, f"mapping/{name}/{model}.pth")
if __name__ == "__main__":
Fire(Main)