Skip to content

Commit 52cd683

Browse files
Add centralized.py
1 parent d08faef commit 52cd683

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

centralized.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import time
2+
import numpy as np
3+
4+
from data_reader.data_reader import get_data
5+
from models.get_model import get_model
6+
from statistic.collect_stat import CollectStatistics
7+
from util.sampling import MinibatchSampling
8+
9+
# Configurations are in a separate config.py file
10+
from config import *
11+
12+
if use_min_loss:
13+
raise Exception('use_min_loss should be disabled in centralized case.')
14+
15+
model = get_model(model_name)
16+
if hasattr(model, 'create_graph'):
17+
model.create_graph(learning_rate=step_size)
18+
19+
if time_gen is not None:
20+
use_fixed_averaging_slots = True
21+
else:
22+
use_fixed_averaging_slots = False
23+
24+
if single_run:
25+
stat = CollectStatistics(results_file_name=single_run_results_file_path, is_single_run=True)
26+
else:
27+
stat = CollectStatistics(results_file_name=multi_run_results_file_path, is_single_run=False)
28+
29+
for sim in sim_runs:
30+
31+
if batch_size < total_data: # Read all data once when using stochastic gradient descent
32+
train_image, train_label, test_image, test_label, train_label_orig = get_data(dataset, total_data,
33+
dataset_file_path)
34+
sampler = MinibatchSampling(np.array(range(0, len(train_label))), batch_size, sim)
35+
else:
36+
sampler = None
37+
38+
if batch_size >= total_data: # Read data again for different sim. round when using deterministic gradient descent
39+
train_image, train_label, test_image, test_label, train_label_orig = get_data(dataset, total_data,
40+
dataset_file_path, sim_round=sim)
41+
train_indices = np.array(range(0, len(train_label)))
42+
43+
stat.init_stat_new_global_round()
44+
45+
dim_w = model.get_weight_dimension(train_image, train_label)
46+
w_init = model.get_init_weight(dim_w, rand_seed=sim)
47+
w = w_init
48+
49+
w_min_loss = None
50+
loss_min = np.inf
51+
52+
print('Start learning')
53+
54+
total_time = 0 # Actual total time, where use_fixed_averaging_slots has no effect
55+
total_time_recomputed = 0 # Recomputed total time using estimated time for each local and global update,
56+
# using predefined values when use_fixed_averaging_slots = true
57+
it_each_local = None
58+
59+
# Loop for multiple rounds of local iterations + global aggregation
60+
while True:
61+
time_total_all_start = time.time()
62+
w_prev = w
63+
64+
if batch_size < total_data:
65+
train_indices = sampler.get_next_batch()
66+
67+
grad = model.gradient(train_image, train_label, w, train_indices)
68+
69+
w = w - step_size * grad
70+
71+
if True in np.isnan(w):
72+
print('*** w_global is NaN, using previous value')
73+
w = w_prev # If current w_global contains NaN value, use previous w_global
74+
75+
if use_min_loss:
76+
loss_latest = model.loss(train_image, train_label, w, train_indices)
77+
print('*** Loss computed from data')
78+
else:
79+
if use_min_loss:
80+
try:
81+
# Note: This has to follow the gradient computation line above
82+
loss_latest = model.loss_from_prev_gradient_computation()
83+
print('*** Loss computed from previous gradient computation')
84+
except:
85+
# Will get an exception if the model does not support computing loss
86+
# from previous gradient computation
87+
loss_latest = model.loss(train_image, train_label, w, train_indices)
88+
print('*** Loss computed from data')
89+
90+
if use_min_loss:
91+
if (batch_size < total_data) and (w_min_loss is not None):
92+
# Recompute loss_min on w_min_loss so that the batch remains the same
93+
loss_min = model.loss(train_image, train_label, w_min_loss, train_indices)
94+
95+
if loss_latest < loss_min:
96+
loss_min = loss_latest
97+
w_min_loss = w
98+
99+
print("Loss of latest weight value: " + str(loss_latest))
100+
print("Minimum loss: " + str(loss_min))
101+
102+
# Calculate time
103+
time_total_all_end = time.time()
104+
time_total_all = time_total_all_end - time_total_all_start
105+
time_one_iteration_all = max(0.0, time_total_all)
106+
107+
print('Time for one local iteration:', time_one_iteration_all)
108+
109+
if use_fixed_averaging_slots:
110+
it_each_local = max(0.00000001, time_gen.get_local(1)[0])
111+
else:
112+
it_each_local = max(0.00000001, time_one_iteration_all)
113+
114+
# Compute number of iterations is current slot
115+
total_time_recomputed += it_each_local
116+
117+
# Compute time in current slot
118+
total_time += time_total_all
119+
120+
stat.collect_stat_end_local_round(None, np.nan, it_each_local, np.nan, None, model, train_image, train_label,
121+
test_image, test_label, w, total_time_recomputed)
122+
123+
# Check remaining resource budget, stop if exceeding resource budget
124+
if total_time_recomputed >= max_time:
125+
break
126+
127+
if use_min_loss:
128+
w_eval = w_min_loss
129+
else:
130+
w_eval = w
131+
132+
stat.collect_stat_end_global_round(sim, None, np.nan, total_time, model, train_image, train_label,
133+
test_image, test_label, w_eval, total_time_recomputed)

0 commit comments

Comments
 (0)