Skip to content

Commit caef1ac

Browse files
lunamidori5Midori AI Agent
andauthored
[FIX] Harden epoch RPC retry and snapshot resilience (#22)
Co-authored-by: Midori AI Agent <contact-us@midori-ai.xyz>
1 parent b367e8c commit caef1ac

2 files changed

Lines changed: 271 additions & 8 deletions

File tree

metahash/validator/epoch_validator.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from __future__ import annotations
33

44
import asyncio
5+
import random
6+
import time
57
from datetime import datetime
68
from typing import Optional, Tuple, Any
79

@@ -41,6 +43,44 @@ def __init__(self, *args, log_interval_blocks: int = 2, **kwargs):
4143
self.current_strategy_out: Any = None # can be dict (subnet bps) or list
4244

4345
# ----------------------- helpers ---------------------------------- #
46+
def _retry_rpc_call(self, fn, *, op_name: str):
47+
"""
48+
Retry RPC calls forever (until shutdown) with exponential backoff.
49+
This keeps the validator alive through transient RPC/node drift issues.
50+
"""
51+
attempt = 0
52+
delay_s = 1.0
53+
cap_s = 30.0
54+
last_error: Exception | None = None
55+
56+
while not self.should_exit:
57+
try:
58+
result = fn()
59+
if attempt > 0:
60+
bt.logging.info(
61+
f"[epoch] RPC recovered for {op_name} after {attempt} retries."
62+
)
63+
return result
64+
except Exception as err:
65+
last_error = err
66+
attempt += 1
67+
68+
if attempt == 1 or attempt % 5 == 0:
69+
bt.logging.warning(
70+
f"[epoch] RPC error during {op_name} (attempt {attempt}): {err}. "
71+
f"Retrying in ~{int(min(cap_s, delay_s))}s."
72+
)
73+
74+
sleep_s = min(cap_s, delay_s)
75+
sleep_s += random.uniform(0.0, min(1.0, sleep_s * 0.25))
76+
time.sleep(sleep_s)
77+
delay_s = min(cap_s, delay_s * 2.0)
78+
79+
raise RuntimeError(
80+
f"[epoch] Stopping RPC retries for {op_name} due to shutdown "
81+
f"(last_error={last_error!r})."
82+
)
83+
4484
def _discover_epoch_length(self) -> int:
4585
try:
4686
override = int(EPOCH_LENGTH_OVERRIDE or 0)
@@ -59,10 +99,19 @@ def _discover_epoch_length(self) -> int:
5999
return length
60100

61101
self._override_active = False
62-
tempo = self.subtensor.tempo(self.config.netuid) or 360
102+
tempo = self._retry_rpc_call(
103+
lambda: self.subtensor.tempo(self.config.netuid),
104+
op_name="tempo",
105+
) or 360
63106
try:
64-
head = self.subtensor.get_current_block()
65-
next_head = self.subtensor.get_next_epoch_start_block(self.config.netuid)
107+
head = self._retry_rpc_call(
108+
self.subtensor.get_current_block,
109+
op_name="get_current_block",
110+
)
111+
next_head = self._retry_rpc_call(
112+
lambda: self.subtensor.get_next_epoch_start_block(self.config.netuid),
113+
op_name="get_next_epoch_start_block",
114+
)
66115
if next_head is None:
67116
raise ValueError("RPC returned None")
68117
derived = next_head - (head - head % tempo)
@@ -77,7 +126,10 @@ def _discover_epoch_length(self) -> int:
77126
return length
78127

79128
def _epoch_snapshot(self) -> Tuple[int, int, int, int, int]:
80-
blk = self.subtensor.get_current_block()
129+
blk = self._retry_rpc_call(
130+
self.subtensor.get_current_block,
131+
op_name="get_current_block",
132+
)
81133
ep_l = self._epoch_len or self._discover_epoch_length()
82134
start = blk - (blk % ep_l)
83135
end = start + ep_l - 1
@@ -97,13 +149,19 @@ def _apply_epoch_state(self, blk: int, start: int, end: int, idx: int, ep_len: i
97149
)
98150

99151
async def _wait_for_next_head(self):
100-
head_block = self.subtensor.get_current_block()
152+
head_block = self._retry_rpc_call(
153+
self.subtensor.get_current_block,
154+
op_name="get_current_block",
155+
)
101156
ep_l = self._epoch_len or self._discover_epoch_length()
102157
target_head = head_block - (head_block % ep_l) + ep_l
103158
label = "override" if self._override_active else "chain"
104159

105160
while not self.should_exit:
106-
blk = self.subtensor.get_current_block()
161+
blk = self._retry_rpc_call(
162+
self.subtensor.get_current_block,
163+
op_name="get_current_block",
164+
)
107165
if blk >= target_head:
108166
return
109167
remain = max(0, target_head - blk)
@@ -157,7 +215,13 @@ def run(self): # noqa: D401
157215

158216
async def _loop():
159217
while not self.should_exit:
160-
blk, start, end, idx, ep_len = self._epoch_snapshot()
218+
try:
219+
blk, start, end, idx, ep_len = self._epoch_snapshot()
220+
except Exception as err:
221+
bt.logging.error(f"[epoch] snapshot failed before wait: {err}")
222+
bt.logging.debug("[epoch] snapshot traceback:", exc_info=True)
223+
await asyncio.sleep(max(1.0, float(BLOCKTIME)))
224+
continue
161225

162226
next_head = start + ep_len
163227
into = blk - start
@@ -207,7 +271,13 @@ async def _loop():
207271

208272
# head snapshot
209273
self._epoch_len = None
210-
blk2, start2, end2, idx2, ep_len2 = self._epoch_snapshot()
274+
try:
275+
blk2, start2, end2, idx2, ep_len2 = self._epoch_snapshot()
276+
except Exception as err:
277+
bt.logging.error(f"[epoch] snapshot failed at epoch boundary: {err}")
278+
bt.logging.debug("[epoch] snapshot traceback:", exc_info=True)
279+
await asyncio.sleep(max(1.0, float(BLOCKTIME)))
280+
continue
211281
self._apply_epoch_state(blk2, start2, end2, idx2, ep_len2)
212282

213283
try:

tests/test_epoch_validator.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
from __future__ import annotations
2+
3+
import importlib
4+
import sys
5+
import types
6+
7+
import pytest
8+
9+
10+
def _install_epoch_validator_test_stubs():
11+
if "bittensor" not in sys.modules:
12+
bittensor_mod = types.ModuleType("bittensor")
13+
14+
class _Logging:
15+
def info(self, *_args, **_kwargs):
16+
return None
17+
18+
def warning(self, *_args, **_kwargs):
19+
return None
20+
21+
def success(self, *_args, **_kwargs):
22+
return None
23+
24+
def error(self, *_args, **_kwargs):
25+
return None
26+
27+
def debug(self, *_args, **_kwargs):
28+
return None
29+
30+
bittensor_mod.logging = _Logging()
31+
bittensor_mod.BLOCKTIME = 12
32+
sys.modules["bittensor"] = bittensor_mod
33+
34+
if "metahash.base.validator" not in sys.modules:
35+
base_validator_mod = types.ModuleType("metahash.base.validator")
36+
37+
class _BaseValidatorNeuron:
38+
pass
39+
40+
base_validator_mod.BaseValidatorNeuron = _BaseValidatorNeuron
41+
sys.modules["metahash.base.validator"] = base_validator_mod
42+
43+
if "metahash.validator.strategy" not in sys.modules:
44+
strategy_mod = types.ModuleType("metahash.validator.strategy")
45+
46+
class _Strategy:
47+
def __init__(self, *args, **kwargs):
48+
return None
49+
50+
def compute_weights_bps(self, **_kwargs):
51+
return {}
52+
53+
strategy_mod.Strategy = _Strategy
54+
sys.modules["metahash.validator.strategy"] = strategy_mod
55+
56+
57+
_install_epoch_validator_test_stubs()
58+
sys.modules.pop("metahash.validator.epoch_validator", None)
59+
epoch_validator_mod = importlib.import_module("metahash.validator.epoch_validator")
60+
61+
62+
def _make_neuron():
63+
neuron = object.__new__(epoch_validator_mod.EpochValidatorNeuron)
64+
neuron.should_exit = False
65+
neuron.config = types.SimpleNamespace(netuid=73, no_epoch=True, force_epoch=False)
66+
neuron._epoch_len = None
67+
neuron._override_active = False
68+
neuron._bootstrapped = False
69+
neuron._running = False
70+
neuron.block = 0
71+
neuron.step = 0
72+
neuron.sync = lambda: None
73+
neuron.concurrent_forward = _noop_forward
74+
neuron.axon = None
75+
return neuron
76+
77+
78+
async def _noop_forward():
79+
return None
80+
81+
82+
def test_discover_epoch_length_retries_tempo_until_recovery(monkeypatch):
83+
neuron = _make_neuron()
84+
monkeypatch.setattr(epoch_validator_mod, "EPOCH_LENGTH_OVERRIDE", 0)
85+
monkeypatch.setattr(epoch_validator_mod.random, "uniform", lambda _a, _b: 0.0)
86+
monkeypatch.setattr(epoch_validator_mod.time, "sleep", lambda _s: None)
87+
88+
class _Subtensor:
89+
def __init__(self):
90+
self.tempo_calls = 0
91+
92+
def tempo(self, _netuid):
93+
self.tempo_calls += 1
94+
if self.tempo_calls < 3:
95+
raise RuntimeError("Header was not found in the database")
96+
return 360
97+
98+
def get_current_block(self):
99+
return 721
100+
101+
def get_next_epoch_start_block(self, _netuid):
102+
return 1081
103+
104+
st = _Subtensor()
105+
neuron.subtensor = st
106+
107+
epoch_len = neuron._discover_epoch_length()
108+
assert epoch_len == 361
109+
assert st.tempo_calls == 3
110+
111+
112+
def test_retry_rpc_call_uses_exponential_backoff_capped_at_30s(monkeypatch):
113+
neuron = _make_neuron()
114+
sleep_calls: list[float] = []
115+
116+
monkeypatch.setattr(epoch_validator_mod.random, "uniform", lambda _a, _b: 0.0)
117+
118+
def _fake_sleep(seconds: float):
119+
sleep_calls.append(seconds)
120+
if len(sleep_calls) >= 6:
121+
neuron.should_exit = True
122+
123+
monkeypatch.setattr(epoch_validator_mod.time, "sleep", _fake_sleep)
124+
125+
def _always_fail():
126+
raise RuntimeError("rpc down")
127+
128+
with pytest.raises(RuntimeError, match="Stopping RPC retries"):
129+
neuron._retry_rpc_call(_always_fail, op_name="tempo")
130+
131+
assert sleep_calls == [1.0, 2.0, 4.0, 8.0, 16.0, 30.0]
132+
133+
134+
def test_discover_epoch_length_falls_back_when_next_head_is_none(monkeypatch):
135+
neuron = _make_neuron()
136+
monkeypatch.setattr(epoch_validator_mod, "EPOCH_LENGTH_OVERRIDE", 0)
137+
monkeypatch.setattr(epoch_validator_mod.random, "uniform", lambda _a, _b: 0.0)
138+
monkeypatch.setattr(epoch_validator_mod.time, "sleep", lambda _s: None)
139+
140+
class _Subtensor:
141+
def tempo(self, _netuid):
142+
return 360
143+
144+
def get_current_block(self):
145+
return 720
146+
147+
def get_next_epoch_start_block(self, _netuid):
148+
return None
149+
150+
neuron.subtensor = _Subtensor()
151+
epoch_len = neuron._discover_epoch_length()
152+
assert epoch_len == 361
153+
154+
155+
def test_discover_epoch_length_uses_override_without_rpc(monkeypatch):
156+
neuron = _make_neuron()
157+
monkeypatch.setattr(epoch_validator_mod, "EPOCH_LENGTH_OVERRIDE", 7)
158+
159+
class _Subtensor:
160+
def tempo(self, _netuid): # pragma: no cover - should not be called
161+
raise AssertionError("tempo() should not be called when override is set")
162+
163+
def get_current_block(self): # pragma: no cover - should not be called
164+
raise AssertionError("get_current_block() should not be called when override is set")
165+
166+
def get_next_epoch_start_block(self, _netuid): # pragma: no cover - should not be called
167+
raise AssertionError("get_next_epoch_start_block() should not be called when override is set")
168+
169+
neuron.subtensor = _Subtensor()
170+
assert neuron._discover_epoch_length() == 7
171+
172+
173+
def test_run_continues_after_snapshot_error(monkeypatch):
174+
neuron = _make_neuron()
175+
calls = {"snap": 0, "sleep": 0}
176+
177+
def _snapshot():
178+
calls["snap"] += 1
179+
neuron.should_exit = True
180+
raise RuntimeError("boom")
181+
182+
async def _fake_async_sleep(_seconds: float):
183+
calls["sleep"] += 1
184+
return None
185+
186+
monkeypatch.setattr(epoch_validator_mod.asyncio, "sleep", _fake_async_sleep)
187+
neuron._epoch_snapshot = _snapshot
188+
189+
neuron.run()
190+
191+
assert calls["snap"] == 1
192+
assert calls["sleep"] == 1
193+
assert neuron._running is False

0 commit comments

Comments
 (0)