-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathload_balancer.py
148 lines (128 loc) · 6.72 KB
/
load_balancer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import heapq
import random
import threading
from contextlib import contextmanager
from typing import Dict, List, Tuple
from hivemind import RemoteExpert, TimedStorage, PeerID
from hivemind.dht import DHT
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertPrefix, ExpertUID, ExpertInfo
from hivemind.utils.performance_ema import PerformanceEMA
from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
logger = get_logger(__name__)
class LoadBalancer:
def __init__(self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
**kwargs):
self.dht, self.key = dht, key
self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
self.experts = TimedStorage[ExpertUID, PeerID]()
self.blacklist = TimedStorage[ExpertUID, type(None)]()
self.throughputs: Dict[ExpertUID, PerformanceEMA] = {}
self.queue: List[Tuple[float, float, ExpertUID]] = []
self.uid_to_queue: Dict[ExpertUID, Tuple[float, float, ExpertUID]] = {}
self.lock = threading.Lock()
self.is_alive = threading.Event()
self.is_alive.set()
self.update_trigger, self.update_finished = threading.Event(), threading.Event()
self.update_period, self.last_update = update_period, get_dht_time()
self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
self.update_thread.start()
self._p2p = RemoteExpertWorker.run_coroutine(self.dht.replicate_p2p())
def update_experts_in_background(self):
while self.is_alive.is_set():
time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time())
try:
self.update_trigger.wait(timeout=time_to_next_update)
# update triggered by main thread
except TimeoutError:
pass # update triggered by refresh_period
self.update_trigger.clear()
response = self.dht.get(self.key, latest=True)
if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict):
for index, expert_info in response.value.items():
try:
(expert_uid, peer_id), expiration_time = expert_info
maybe_banned = self.blacklist.get(expert_uid)
if maybe_banned is None or expiration_time > maybe_banned.expiration_time:
self._add_expert(expert_uid, peer_id, expiration_time)
else:
logger.debug(f"Not adding expert {expert_uid} (blacklisted).")
except Exception as e:
logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
else:
logger.warning(f"Could not refresh experts, dht info key contains {response}, "
f"will retry in {time_to_next_update}s")
if len(self.queue) == 0:
logger.warning("Update routine finished, but still no experts available.")
self.last_update = get_dht_time()
self.update_finished.set()
def _trigger_updating_experts(self):
self.update_finished.clear()
self.update_trigger.set()
self.update_finished.wait()
@property
def n_active_experts(self) -> int:
if len(self.uid_to_queue) == 0:
# Maybe it did not do the first update yet
self._trigger_updating_experts()
return len(self.uid_to_queue)
def _add_expert(self, uid: ExpertUID, peer_id: PeerID, expiration_time: DHTExpiration):
with self.lock:
self.experts.store(uid, peer_id, expiration_time)
if uid not in self.uid_to_queue:
logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
heap_entry = (base_load, random.random(), uid)
heapq.heappush(self.queue, heap_entry)
self.uid_to_queue[uid] = heap_entry
else:
logger.debug(f"Refreshing existing module: {uid}, new expiration time = {expiration_time:.3f}.")
def _ban_expert(self, uid: ExpertUID):
with self.lock:
maybe_expert = self.experts.get(uid)
expiration_time = maybe_expert.expiration_time if maybe_expert else get_dht_time()
self.blacklist.store(uid, None, expiration_time)
self.uid_to_queue.pop(uid, None)
self.throughputs.pop(uid, None)
del self.experts[uid]
logger.debug(f"Banned expert {uid} with expiration time = {expiration_time:.2f}.")
@contextmanager
def use_another_expert(self, task_size: float, max_tries: int = 3) -> RemoteExpert:
n_tries = 0
while True:
if len(self.queue) == 0:
if n_tries == max_tries:
raise NoModulesFound('No modules found in the network')
n_tries += 1
self._trigger_updating_experts()
continue
with self.lock:
current_runtime, _, uid = heap_entry = heapq.heappop(self.queue)
maybe_peer_id = self.experts.get(uid)
if maybe_peer_id is None:
# remove expired expert from queue
self.uid_to_queue.pop(uid, None)
self.throughputs.pop(uid, None)
if self.uid_to_queue.get(uid) != heap_entry:
continue # skip uids that are banned or expired
if self.throughputs[uid].num_updates != 0:
expected_time_taken = task_size / self.throughputs[uid].samples_per_second
else:
expected_time_taken = self.initial_throughput * task_size
new_heap_entry = (current_runtime + expected_time_taken, random.random(), uid)
heapq.heappush(self.queue, new_heap_entry)
self.uid_to_queue[uid] = new_heap_entry
break
try:
with self.throughputs[uid].update_threadsafe(task_size):
logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
yield RemoteExpert(ExpertInfo(uid, PeerID.from_base58(maybe_peer_id.value)), self._p2p)
except BaseException:
self._ban_expert(uid)
raise
def shutdown(self):
self.is_alive.clear()
self._trigger_updating_experts()
class NoModulesFound(RuntimeError):
pass