Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

def normalize_weights(weights_dict: dict[int, float]) -> dict[int, float]:
"""
Normalize weights dictionary so that the sum equals 1.0.

Args:
weights_dict: Dictionary mapping UID to weight value
Returns:
Normalized weights dictionary with sum = 1.0
"""

total = sum(weights_dict.values())
if total <= 0:
return {}
return {uid: w / total for uid, w in weights_dict.items()}
66 changes: 47 additions & 19 deletions validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from cycle import get_miner_payloads
from model import multi_salience as sal_fn
from ledger import DataLog
from helpers import normalize_weights

LOG_DIR = os.path.join(os.getcwd(), "logs")
os.makedirs(LOG_DIR, exist_ok=True)
Expand Down Expand Up @@ -54,10 +55,10 @@
SAVE_INTERVAL = 480


def save_weights(weights_tensor: torch.Tensor, uids: list[int], block: int):
def save_weights(weights_tensor: torch.Tensor, hotkeys: list[str], block: int):
weights_data = {
"weights": weights_tensor,
"uids": uids,
"hotkeys": hotkeys,
"block": block,
}
with open(WEIGHTS_PATH, "wb") as f:
Expand Down Expand Up @@ -368,14 +369,12 @@ def calc_worker(dlog, block_snapshot, metagraph, cli_args):
if not final_weights:
weights_logger.warning("No weights to set after processing.")
return

total_weight = sum(final_weights.values())
if total_weight <= 0:
weights_logger.warning("Total calculated weight is zero or negative, skipping set.")
return
normalized_weights = normalize_weights(final_weights)

normalized_weights = {uid: w / total_weight for uid, w in final_weights.items()}

w = torch.tensor([normalized_weights.get(uid, 0.0) for uid in uids], dtype=torch.float32)

# Check for uniform weights (bug indicator)
Expand All @@ -396,7 +395,7 @@ def calc_worker(dlog, block_snapshot, metagraph, cli_args):
weights_logger.info(f"Normalized weights for block {block_snapshot}: {json.dumps(weights_to_log)}")
weights_logger.info(f"Final tensor sum: {final_w.sum().item()}")

save_weights(final_w, uids, block_snapshot)
save_weights(final_w, metagraph.hotkeys, block_snapshot)
weights_logger.info(f"Weights calculated and saved at block {block_snapshot} (max={final_w.max():.4f})")

weight_thread = threading.Thread(
Expand All @@ -412,22 +411,51 @@ def calc_worker(dlog, block_snapshot, metagraph, cli_args):
logging.info(f"No saved weights found at block {current_block}, skipping weight setting.")
else:
calc_block = weights_data.get("block", "unknown")
final_w = weights_data["weights"]
saved_uids = weights_data["uids"]
saved_weights = weights_data["weights"]
saved_hotkeys = weights_data["hotkeys"]
current_hotkeys = mg.hotkeys

weights_logger.info(f"Setting weights from saved array (calculated at block {calc_block})")

if list(saved_uids) != mg.uids.tolist():
weights_logger.warning("UID mismatch between saved weights and current metagraph, skipping.")
else:
sub.set_weights(
netuid=args.netuid,
wallet=wallet,
uids=mg.uids,
weights=final_w,
wait_for_inclusion=False,

if saved_hotkeys != current_hotkeys:
# UIDs changed, need to remap weights
weights_logger.warning("UID mismatch between saved weights and current metagraph.")
saved_weight_map = {
hotkey: saved_weights[idx].item()
for idx, hotkey in enumerate(saved_hotkeys)
}

remapped_weights = {
hotkey: saved_weight_map.get(hotkey, 0.0)
for hotkey in current_hotkeys
}

normalized_weights = normalize_weights(remapped_weights)

if not normalized_weights:
weights_logger.error("Normalization failed: all weights zero after remapping")
continue

# Reconstruct tensor aligned with current metagraph
final_weights = torch.tensor(
normalized_weights,
dtype=torch.float32
)
weights_logger.info(f"Weights set at block {current_block} (from block {calc_block}, max={final_w.max():.4f})")

weights_logger.info(f"Remapped and renormalized weights (Σ={final_weights.sum():.6f}, max={final_weights.max():.6f})")

else:
final_weights = saved_weights
weights_logger.info("UIDs match exactly, using saved weights as-is.")

sub.set_weights(
netuid=args.netuid,
wallet=wallet,
uids=mg.uids,
weights=final_weights,
wait_for_inclusion=False,
)
weights_logger.info(f"Weights set at block {current_block} (from block {calc_block}, max={final_weights.max():.4f})")

except KeyboardInterrupt:
stop_event.set()
Expand Down