-
Notifications
You must be signed in to change notification settings - Fork 56
/
main_fed.py
110 lines (85 loc) · 3.82 KB
/
main_fed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy
import pickle
import numpy as np
import pandas as pd
import torch
from utils.options import args_parser
from utils.train_utils import get_data, get_model
from models.Update import LocalUpdate
from models.test import test_img
import os
import pdb
if __name__ == '__main__':
# parse args
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format(
args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save)
if not os.path.exists(os.path.join(base_dir, 'fed')):
os.makedirs(os.path.join(base_dir, 'fed'), exist_ok=True)
dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
dict_save_path = os.path.join(base_dir, 'dict_users.pkl')
with open(dict_save_path, 'wb') as handle:
pickle.dump((dict_users_train, dict_users_test), handle)
# build model
net_glob = get_model(args)
net_glob.train()
# training
results_save_path = os.path.join(base_dir, 'fed/results.csv')
loss_train = []
net_best = None
best_loss = None
best_acc = None
best_epoch = None
lr = args.lr
results = []
for iter in range(args.epochs):
w_glob = None
loss_locals = []
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
print("Round {}, lr: {:.6f}, {}".format(iter, lr, idxs_users))
for idx in idxs_users:
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx])
net_local = copy.deepcopy(net_glob)
w_local, loss = local.train(net=net_local.to(args.device))
loss_locals.append(copy.deepcopy(loss))
if w_glob is None:
w_glob = copy.deepcopy(w_local)
else:
for k in w_glob.keys():
w_glob[k] += w_local[k]
lr *= args.lr_decay
# update global weights
for k in w_glob.keys():
w_glob[k] = torch.div(w_glob[k], m)
# copy weight to net_glob
net_glob.load_state_dict(w_glob)
# print loss
loss_avg = sum(loss_locals) / len(loss_locals)
loss_train.append(loss_avg)
if (iter + 1) % args.test_freq == 0:
net_glob.eval()
acc_test, loss_test = test_img(net_glob, dataset_test, args)
print('Round {:3d}, Average loss {:.3f}, Test loss {:.3f}, Test accuracy: {:.2f}'.format(
iter, loss_avg, loss_test, acc_test))
if best_acc is None or acc_test > best_acc:
net_best = copy.deepcopy(net_glob)
best_acc = acc_test
best_epoch = iter
# if (iter + 1) > args.start_saving:
# model_save_path = os.path.join(base_dir, 'fed/model_{}.pt'.format(iter + 1))
# torch.save(net_glob.state_dict(), model_save_path)
results.append(np.array([iter, loss_avg, loss_test, acc_test, best_acc]))
final_results = np.array(results)
final_results = pd.DataFrame(final_results, columns=['epoch', 'loss_avg', 'loss_test', 'acc_test', 'best_acc'])
final_results.to_csv(results_save_path, index=False)
if (iter + 1) % 50 == 0:
best_save_path = os.path.join(base_dir, 'fed/best_{}.pt'.format(iter + 1))
model_save_path = os.path.join(base_dir, 'fed/model_{}.pt'.format(iter + 1))
torch.save(net_best.state_dict(), best_save_path)
torch.save(net_glob.state_dict(), model_save_path)
print('Best model, iter: {}, acc: {}'.format(best_epoch, best_acc))