-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_distillation.py
63 lines (57 loc) · 2.18 KB
/
run_distillation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from diffusion.distillation import *
from diffusion.diffusion import *
from data.utils import *
from diffusion.eegwave import *
from torch.utils.data import random_split
import numpy as np
import wandb
import json
from pathlib import Path
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_checkpoint(savepath, device):
checkpoint = torch.load(savepath)
epoch = checkpoint['epoch']
config = checkpoint['config']
function_approximator = EEGWave(
checkpoint['n_class'],
checkpoint['n_subject'],
checkpoint['N'],
checkpoint['n'],
checkpoint['C'],
checkpoint['E'],
checkpoint['K']
)
model = Diffusion(function_approximator, checkpoint['T'])
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
return epoch, config, model
print("Initialize")
with open("diffusion/distillation_conf.json",'r') as fconf:
conf = json.load(fconf)
wandb.init(project="amal_diffusion",entity="amal_2223",config=conf)
random_seed = np.random.choice(9999)
config = wandb.config
config.SEED = random_seed
torch.manual_seed(config.SEED)
np.random.seed(config.SEED)
torch.backends.cudnn.deterministic = True
print("Loading data")
if config.DATA == 'VEPESS':
train_ds = VepessDataset(config.N_SUBJECTS,True,partition='train')
val_ds = VepessDataset(config.N_SUBJECTS,True,partition='val')
test_ds = VepessDataset(config.N_SUBJECTS,True,partition='test')
SIGNAL_LENGTH = 512
else:
train_ds = BCICIV2aDataset(config.N_SUBJECTS,True,partition='train')
val_ds = BCICIV2aDataset(config.N_SUBJECTS,True,partition='val')
test_ds = BCICIV2aDataset(config.N_SUBJECTS,True,partition='test')
SIGNAL_LENGTH = 448
teacher_path = Path(f"{os.path.dirname(os.path.abspath(__file__))}/diffusion/checkpoints/diffusion_{config.TEACHER}.pch")
_, _, model = load_checkpoint(teacher_path ,device)
optimizer = torch.optim.Adam(model.parameters(), config.LEARNING_RATE)
wandb.watch(model, log="all")
train_dl = DataLoader(train_ds,batch_size=config.TRAIN_BATCH_SIZE,shuffle=True)
val_dl = DataLoader(val_ds,batch_size=config.EVAL_BATCH_SIZE,shuffle=False)
test_dl = DataLoader(test_ds,batch_size=config.EVAL_BATCH_SIZE,shuffle=False)
print("Distilling")
distill(model, device, train_dl, val_dl, optimizer, config, wandb)