-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathload_checkpoint_to_validate.py
72 lines (43 loc) · 1.72 KB
/
load_checkpoint_to_validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import adversarialNeuralCryptography as ad
LOAD_PATH = "./adversarial_neural_cryptography_model_and_optimizer"
def random_generate_ptext_and_key(ptext_size, key_size):
"""
generate a ptext and a key for validate
"""
ptext = torch.randint(0, 2, (1, ptext_size), dtype=torch.float).to(ad.DEVICE) * 2 - 1
key = torch.randint(0, 2, (1, key_size), dtype=torch.float).to(ad.DEVICE) * 2 - 1
return ptext, key
def model_load_checkpoint():
"""
alice, bob, eve load checkpoint
:return: a tuple: (alice, bob, eve)
"""
checkpoint = torch.load(LOAD_PATH)
alice = ad.Model(ad.PTEXT_SIZE, ad.KEY_SIZE)
bob = ad.Model(ad.PTEXT_SIZE, ad.KEY_SIZE)
eve = ad.Model(ad.PTEXT_SIZE)
alice.load_state_dict(checkpoint['Alice_state_dict'])
bob.load_state_dict(checkpoint['Bob_state_dict'])
eve.load_state_dict(checkpoint['Eve_state_dict'])
alice.to(ad.DEVICE)
bob.to(ad.DEVICE)
eve.to(ad.DEVICE)
return alice, bob, eve
def validate():
"""
generate a ptext and key and compare them to the output of the model
:return:
"""
ptext, key = random_generate_ptext_and_key(ad.PTEXT_SIZE, ad.KEY_SIZE)
alice, bob, eve = model_load_checkpoint()
ctext = alice(torch.cat((ptext, key), 1).float())
predict_ptext_bob = bob(torch.cat((ctext, key), 1).float())
predict_ptext_eve = eve(ctext)
# for better print
ptext = ptext.cpu().detach().numpy()
predict_ptext_bob = predict_ptext_bob.cpu().detach().numpy()
predict_ptext_eve = predict_ptext_eve.cpu().detach().numpy()
print('Real ptext:\n{}\n\nptext bob:\n{}\n\nptext eve:\n{}'.format(ptext, predict_ptext_bob, predict_ptext_eve))
if __name__ == '__main__':
validate()