diff --git a/helpers.py b/helpers.py new file mode 100644 index 0000000..2233115 --- /dev/null +++ b/helpers.py @@ -0,0 +1,15 @@ + +def normalize_weights(weights_dict: dict[int, float]) -> dict[int, float]: + """ + Normalize weights dictionary so that the sum equals 1.0. + + Args: + weights_dict: Dictionary mapping UID to weight value + Returns: + Normalized weights dictionary with sum = 1.0 + """ + + total = sum(weights_dict.values()) + if total <= 0: + return {} + return {uid: w / total for uid, w in weights_dict.items()} \ No newline at end of file diff --git a/validator.py b/validator.py index 530648d..35adaec 100644 --- a/validator.py +++ b/validator.py @@ -24,6 +24,7 @@ from cycle import get_miner_payloads from model import multi_salience as sal_fn from ledger import DataLog +from helpers import normalize_weights LOG_DIR = os.path.join(os.getcwd(), "logs") os.makedirs(LOG_DIR, exist_ok=True) @@ -54,10 +55,10 @@ SAVE_INTERVAL = 480 -def save_weights(weights_tensor: torch.Tensor, uids: list[int], block: int): +def save_weights(weights_tensor: torch.Tensor, hotkeys: list[str], block: int): weights_data = { "weights": weights_tensor, - "uids": uids, + "hotkeys": hotkeys, "block": block, } with open(WEIGHTS_PATH, "wb") as f: @@ -368,14 +369,12 @@ def calc_worker(dlog, block_snapshot, metagraph, cli_args): if not final_weights: weights_logger.warning("No weights to set after processing.") return - total_weight = sum(final_weights.values()) if total_weight <= 0: weights_logger.warning("Total calculated weight is zero or negative, skipping set.") return + normalized_weights = normalize_weights(final_weights) - normalized_weights = {uid: w / total_weight for uid, w in final_weights.items()} - w = torch.tensor([normalized_weights.get(uid, 0.0) for uid in uids], dtype=torch.float32) # Check for uniform weights (bug indicator) @@ -396,7 +395,7 @@ def calc_worker(dlog, block_snapshot, metagraph, cli_args): weights_logger.info(f"Normalized weights for block {block_snapshot}: {json.dumps(weights_to_log)}") weights_logger.info(f"Final tensor sum: {final_w.sum().item()}") - save_weights(final_w, uids, block_snapshot) + save_weights(final_w, metagraph.hotkeys, block_snapshot) weights_logger.info(f"Weights calculated and saved at block {block_snapshot} (max={final_w.max():.4f})") weight_thread = threading.Thread( @@ -412,22 +411,51 @@ def calc_worker(dlog, block_snapshot, metagraph, cli_args): logging.info(f"No saved weights found at block {current_block}, skipping weight setting.") else: calc_block = weights_data.get("block", "unknown") - final_w = weights_data["weights"] - saved_uids = weights_data["uids"] + saved_weights = weights_data["weights"] + saved_hotkeys = weights_data["hotkeys"] + current_hotkeys = mg.hotkeys weights_logger.info(f"Setting weights from saved array (calculated at block {calc_block})") - - if list(saved_uids) != mg.uids.tolist(): - weights_logger.warning("UID mismatch between saved weights and current metagraph, skipping.") - else: - sub.set_weights( - netuid=args.netuid, - wallet=wallet, - uids=mg.uids, - weights=final_w, - wait_for_inclusion=False, + + if saved_hotkeys != current_hotkeys: + # UIDs changed, need to remap weights + weights_logger.warning("UID mismatch between saved weights and current metagraph.") + saved_weight_map = { + hotkey: saved_weights[idx].item() + for idx, hotkey in enumerate(saved_hotkeys) + } + + remapped_weights = { + hotkey: saved_weight_map.get(hotkey, 0.0) + for hotkey in current_hotkeys + } + + normalized_weights = normalize_weights(remapped_weights) + + if not normalized_weights: + weights_logger.error("Normalization failed: all weights zero after remapping") + continue + + # Reconstruct tensor aligned with current metagraph + final_weights = torch.tensor( + normalized_weights, + dtype=torch.float32 ) - weights_logger.info(f"Weights set at block {current_block} (from block {calc_block}, max={final_w.max():.4f})") + + weights_logger.info(f"Remapped and renormalized weights (Σ={final_weights.sum():.6f}, max={final_weights.max():.6f})") + + else: + final_weights = saved_weights + weights_logger.info("UIDs match exactly, using saved weights as-is.") + + sub.set_weights( + netuid=args.netuid, + wallet=wallet, + uids=mg.uids, + weights=final_weights, + wait_for_inclusion=False, + ) + weights_logger.info(f"Weights set at block {current_block} (from block {calc_block}, max={final_weights.max():.4f})") except KeyboardInterrupt: stop_event.set()