Skip to content

Commit

Permalink
Dynamic topology and a few other minor fixes (#166)
Browse files Browse the repository at this point in the history
* push experiment generating code

* minor bug fix

* fix async experiment bug

* uncomment static algos

* enable two dynamic algorithms with the new APIs

* add sl primitives

* merge conflict prevention

* remove configs
  • Loading branch information
tremblerz authored Mar 5, 2025
1 parent 99c8b51 commit 20b8840
Show file tree
Hide file tree
Showing 25 changed files with 621 additions and 3,352 deletions.
41 changes: 22 additions & 19 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
get_dset_balanced_communities,
get_dset_communities,
)
from utils.types import ConfigType
from utils.types import ConfigType, TorchModelType
from utils.dropout_utils import NodeDropout

import torchvision.transforms as T # type: ignore
Expand Down Expand Up @@ -95,6 +95,7 @@ def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
self.set_constants()
self.config = config
self.comm_utils = comm_utils
self.node_id = self.comm_utils.get_rank()
self.comm_utils.register_node(self)
Expand Down Expand Up @@ -135,7 +136,7 @@ def set_constants(self) -> None:
self.round = 0
self.EMPTY_MODEL_TAG = "EMPTY_MODEL"

def setup_logging(self, config: Dict[str, ConfigType]) -> None:
def setup_logging(self, config: ConfigType) -> None:
"""
Sets up logging for the node by creating necessary directories and initializing logging utilities.
Expand All @@ -162,17 +163,11 @@ def setup_logging(self, config: Dict[str, ConfigType]) -> None:
print(f"Exiting to prevent accidental overwrite{reset_code}")
sys.exit(1)

# TODO: Check if the plot directory should be unique to each node
try:
self.plot_utils = PlotUtils(config)
except FileExistsError:
print(f"Plot directory for the node {self.node_id} already exists")

self.log_utils = LogUtils(config)
if self.node_id == 0:
self.log_utils.log_console("Config: {}".format(config))

def setup_cuda(self, config: Dict[str, ConfigType]) -> None:
def setup_cuda(self, config: ConfigType) -> None:
"""add docstring here"""
# Need a mapping from rank to device id
if (config.get("assign_based_on_host", False)):
Expand Down Expand Up @@ -232,7 +227,7 @@ def set_model_parameters(self, config: Dict[str, Any]) -> None:
else:
self.loss_fn = torch.nn.CrossEntropyLoss()

def set_shared_exp_parameters(self, config: Dict[str, ConfigType]) -> None:
def set_shared_exp_parameters(self, config: ConfigType) -> None:
self.num_collaborators: int = config["num_collaborators"] # type: ignore
if self.node_id != 0:
community_type, number_of_communities = config.get(
Expand All @@ -244,13 +239,13 @@ def set_shared_exp_parameters(self, config: Dict[str, ConfigType]) -> None:
else len(set(config["dset"].values()))
)
if community_type is not None and community_type == "dataset":
self.communities = get_dset_communities(config["num_users"], num_dset)
self.communities = get_dset_communities(config["num_users"], num_dset) # type: ignore
elif community_type is None or number_of_communities == 1:
all_users = list(range(1, config["num_users"] + 1))
all_users = list(range(1, config["num_users"] + 1)) # type: ignore
self.communities = {user: all_users for user in all_users}
elif community_type == "random":
self.communities = get_random_communities(
config["num_users"], number_of_communities
config["num_users"], number_of_communities # type: ignore
)
elif community_type == "balanced":
num_dset = (
Expand All @@ -271,21 +266,29 @@ def set_shared_exp_parameters(self, config: Dict[str, ConfigType]) -> None:
def local_round_done(self) -> None:
self.round += 1

def get_model_weights(self) -> Dict[str, Tensor]:
def get_model_weights(self, chop_model:bool=False) -> Dict[str, int|Dict[str, Any]]:
"""
Share the model weights
params:
@chop_model: bool, if True, the model will only send the client part of the model. Only being used by Split Learning
"""
message = {"sender": self.node_id, "round": self.round, "model": self.model.state_dict()}
if chop_model:
model, _ = self.model_utils.get_split_model(self.model, self.config["split_layer"])
model = model.state_dict()
else:
model = self.model.state_dict()
message: Dict[str, int|Dict[str, Any]] = {"sender": self.node_id, "round": self.round, "model": model}

if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'):
# also stream image and labels
message["images"] = self.images
message["labels"] = self.labels

# Move to CPU before sending
for key in message["model"].keys():
message["model"][key] = message["model"][key].to("cpu")

if isinstance(message["model"], dict):
for key in message["model"].keys():
message["model"][key] = message["model"][key].to("cpu")

return message

def get_local_rounds(self) -> int:
Expand Down Expand Up @@ -393,7 +396,7 @@ def get_and_set_working(self, round: Optional[int] = None) -> bool:
return is_working

def set_model_weights(
self, model_wts: OrderedDict[str, Tensor], keys_to_ignore: List[str] = []
self, model_wts: TorchModelType, keys_to_ignore: List[str] = []
) -> None:
"""
Set the model weights
Expand Down
190 changes: 0 additions & 190 deletions src/algos/fl_assigned.py

This file was deleted.

Loading

0 comments on commit 20b8840

Please sign in to comment.