-
Notifications
You must be signed in to change notification settings - Fork 1
/
prog.py
63 lines (47 loc) · 1.4 KB
/
prog.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch as th
import pytorch_scripts as ps
import timeit, os.path
use_cuda = th.cuda.is_available()
device = th.device("cuda:0" if use_cuda else "cpu")
word = 'fruit'
targets = ps.get_target(word)
path = 'data/{0}.tsv'.format(word)
max_epochs = 100
num_workers = 6
dim = 2
batch_size = 5
eps = 1e-5
neg = 10
scale = 0.01
lr = 0.001
params = {
'batch_size': batch_size,
'shuffle': True,
'num_workers': num_workers
}
def train():
if os.path.isfile(path) == False:
ps.downloadNLTK()
ps.generate_synsets(targets, path)
ids, objects, relations = ps.get_data(path)
print('data: objects={0}, edges={1}'.format(len(objects), len(ids)))
data = ps.PoincareDataset(ids, objects, relations, neg)
model = ps.PoincareModule(len(objects), dim, scale, lr, eps)
loader = th.utils.data.DataLoader(data, **params)
for epoch in range(max_epochs):
epoch_loss = []
for inputs, targets in loader:
inputs, targets = inputs.to(device), targets.to(device)
preds = model(inputs)
loss = model.loss(preds, targets)
loss.backward()
model.optimize()
epoch_loss.append(loss.data.item())
print('epoch loss: {0}'.format(th.tensor(epoch_loss).mean()))
print('\nobjects:\n{0}'.format(objects))
print('\nids:\n{0}'.format(ids))
print('\nembedding vectors:\n{0}'.format(model.embeds.weight))
ps.plot_graph(objects, model.embeds, eps)
if __name__ == '__main__':
th.multiprocessing.freeze_support()
train()