Skip to content

Commit

Permalink
Dynamic algo fix (#167)
Browse files Browse the repository at this point in the history
* little bit of typehinting

* dynamic algo and config bug fix and type hint in comm_utils
  • Loading branch information
tremblerz authored Mar 6, 2025
1 parent ded2ba5 commit 3e221d0
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,8 @@ def set_data_parameters(self, config: ConfigType) -> None:
self.classes_of_interest = classes
self.train_indices = train_indices
self.train_dset = train_dset
self.dloader = DataLoader(train_dset, batch_size=len(train_dset), shuffle=False)
self._test_loader = DataLoader(test_dset, batch_size=len(test_dset), shuffle=False)
self.dloader: DataLoader[Any] = DataLoader(train_dset, batch_size=len(train_dset), shuffle=False)
self._test_loader: DataLoader[Any] = DataLoader(test_dset, batch_size=len(test_dset), shuffle=False)
print("Using GIA data setup")
print(self.labels)
else:
Expand Down
23 changes: 10 additions & 13 deletions src/algos/fl_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_neighbor_model_wts(self) -> List[Dict[str, TorchModelType]]:
from all the neighbors because that's how
most dynamic topologies work.
"""
neighbor_models = self.comm_utils.all_gather(ignore_super_node=True)
neighbor_models: List[Dict[str, TorchModelType]] = self.comm_utils.all_gather(ignore_super_node=True)
return neighbor_models

def get_neighbor_similarity(self, others_wts: List[Dict[str, TorchModelType]]) -> List[float]:
Expand All @@ -95,11 +95,12 @@ def get_neighbor_similarity(self, others_wts: List[Dict[str, TorchModelType]]) -
raise ValueError("Similarity metric {} not implemented".format(self.similarity))
return similarity_wts

def sample_neighbours(self, k: int) -> List[int]:
def sample_neighbours(self, k: int, mode: str|None = None) -> List[int]:
"""
We perform neighbor sampling after
we have the similarity weights of all the neighbors.
"""
assert mode is None or mode == "pull", "Only pull mode is supported for dynamic topology"
if self.sampling == "closest":
return select_smallest_k(self.similarity_wts, k)
else:
Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(
self.topology = DynamicTopology(config, comm_utils, self)
self.topology.initialize()

def get_representation(self, **kwargs: Any) -> TorchModelType:
def get_representation(self, **kwargs: Any) -> Dict[str, int|Dict[str, Any]]:
"""
Returns the model weights as representation.
"""
Expand All @@ -172,24 +173,20 @@ def run_protocol(self) -> None:
epochs_per_round = self.config.get("epochs_per_round", 1)

for it in range(start_round, total_rounds):
self.round_init()

# Train locally and send the representation to the server
stats["train_loss"], stats["train_acc"], stats["train_time"] = self.local_train(
it, epochs_per_round
)
self.local_round_done()

# Collect the representations from all other nodes from the server
neighbors = self.topology.recv_and_agg(self.num_collaborators)
# TODO: Log the neighbors
stats["neighbors"] = neighbors

stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()
collabs = self.topology.recv_and_agg(self.num_collaborators)

# evaluate the model on the test data
# Inside FedStaticNode.run_protocol()
stats["test_loss"], stats["test_acc"] = self.local_test()
stats.update(self.get_memory_metrics())
self.log_metrics(stats=stats, iteration=it)
self.stats["neighbors"] = collabs
self.local_test()
self.round_finalize()



Expand Down
13 changes: 6 additions & 7 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
CIFAR10_DSET = "cifar10"
CIAR10_DPATH = "./datasets/imgs/cifar10/"

NUM_COLLABORATORS = 1
DUMP_DIR = "/tmp/"
NUM_COLLABORATORS = 3
DUMP_DIR = "/tmp/new_sonar/"

num_users = 3
num_users = 9
mpi_system_config: ConfigType = {
"exp_id": "",
"comm": {"type": "MPI"},
Expand Down Expand Up @@ -318,8 +318,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"exp_keys": [],
}

num_users = 4

dropout_dict: Any = {
"distribution_dict": { # leave dict empty to disable dropout
"method": "uniform", # "uniform", "normal"
Expand All @@ -346,9 +344,10 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"dpath": CIAR10_DPATH,
"seed": 2,
"device_ids": get_device_ids(num_users, gpu_ids),
"assign_based_on_host": True,
# "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[fed_dynamic_loss]), # type: ignore
"samples_per_user": 10000 // num_users, # distributed equally
"algos": get_algo_configs(num_users=num_users, algo_configs=[fed_dynamic_weights]), # type: ignore
"samples_per_user": 500, # distributed equally
"train_label_distribution": "non_iid",
"alpha_data": 0.1,
"test_label_distribution": "iid",
Expand Down
2 changes: 1 addition & 1 deletion src/utils/communication/comm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def receive(self, node_ids: List[int]) -> Any:
def broadcast(self, data: Any, tag: int = 0):
self.comm.broadcast(data)

def all_gather(self, tag: int = 0, ignore_super_node: bool = False):
def all_gather(self, tag: int = 0, ignore_super_node: bool = False) -> List[Dict[str, Any]]:
return self.comm.all_gather(ignore_super_node=ignore_super_node)

def send_quorum(self):
Expand Down

0 comments on commit 3e221d0

Please sign in to comment.