diff --git a/.gitignore b/.gitignore index 251a39d96..455b15b87 100755 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ third_party/libzmq* CIFAR10/output CIFAR10/*.txt ml/proto +ml/*.txt diff --git a/BUILD b/BUILD index 6728ab42f..0c30b604b 100755 --- a/BUILD +++ b/BUILD @@ -27,12 +27,13 @@ cc_library( cc_library( name = "src_files", srcs = glob(["src/**/*.cpp"]), - hdrs = glob(["src/**/*.h"]), + hdrs = glob(["src/**/*.h"], allow_empty = True), visibility = ["//visibility:public"], deps = [ ":include_files", "@third_party//:libzmq", "@third_party//:cppzmq", + "@boost//:program_options", ], defines = local_defines, ) @@ -62,9 +63,7 @@ cc_binary( name = "provider", srcs = ["main.cpp"], defines = ["PROVIDER=1"] + local_defines, - deps = [ - ":src_files", - ], + deps = [":src_files"], ) cc_binary( diff --git a/MODULE.bazel b/MODULE.bazel index bb719b7f3..032597a76 100755 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -13,5 +13,22 @@ local_repository( path = "third_party", ) -bazel_dep(name = "protobuf", version = "29.0-rc1") +bazel_dep(name = "protobuf", version = "29.0") bazel_dep(name = "googletest", version = "1.15.2") + +bazel_dep(name = "rules_boost", repo_name = "com_github_nelhage_rules_boost") +archive_override( + module_name = "rules_boost", + urls = ["https://github.com/nelhage/rules_boost/archive/refs/heads/master.tar.gz"], + strip_prefix = "rules_boost-master", + # It is recommended to edit the above URL and the below sha256 to point to a specific version of this repository. + # integrity = "sha256-...", +) +non_module_boost_repositories = use_extension("@com_github_nelhage_rules_boost//:boost/repositories.bzl", "non_module_dependencies") +use_repo( + non_module_boost_repositories, + "boost", +) +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "rules_proto", version = "7.0.2") +bazel_dep(name = "rules_cc", version = "0.0.16") diff --git a/README.md b/README.md index 90c7395ea..48ea439ee 100755 --- a/README.md +++ b/README.md @@ -7,6 +7,11 @@ Follow this to install bazel - https://bazel.build/install We can install ZeroMQ for cpp as follows. Run the following starting from the `CloudMesh/third_party/` folder. +#### Tailscale + +Tailscale is used to setup a P2P vpn network to connect the machines. Instructions to install can be found here: +https://tailscale.com/kb/1347/installation + #### Mac Dependencies May need to install the following when troubleshooting issues @@ -45,9 +50,15 @@ We install ZeroMQ for Python using `pip install pyzmq`. - BUILD file - Contains the build instructions for the targets. - MODULE.bazel file - Contains the module name and the dependencies. -## Compilation +## Compilation (Non Local - Multiple Machines) + +To compile **BOOTSTRAP**, **PROVIDER** and **REQUESTER**, ensure that the `BOOTSTRAP_HOST` env variable is set, which can be done using the following command: +``` +export BOOTSTRAP_HOST=___ +``` +The `BOOTSTRAP_PORT` env variable can also be set (unset is default to 8080). -To compile **BOOTSTRAP**, **PROVIDER** and **REQUESTER**, run the following commands: +Then run the following commands to compile the source code: ### MacOS ``` bazel build //... --experimental_google_legacy_api --config=macos @@ -76,23 +87,40 @@ To execute, run the following commands: ``` ./bazel-bin/bootstrap ``` -(8080 is reserved for bootstrap port so peers know where to connect) +(Uses port 8080) ### Provider ``` -./bazel-bin/provider [8080] -``` -(8080 is the default port, optional parameter) +./bazel-bin/provider -p +``` +Example: +``` +./bazel-bin/provider -p 8081 +``` +Program arguments can be viewed with +``` +./bazel-bin/provider -h +``` ### Requester ``` -./bazel-bin/requester [8080 [r | c]] +./bazel-bin/requester -w <# workers> -e <# epochs> -p -m +``` +Request to compute task example +``` +./bazel-bin/requester -w 3 -e 10 -p 8082 -m c ``` -`8080` is the default port, optional parameter\ -`r` is an optional parameter to request to receive the result of the computation (use same port as original request execution) -`c` is an optional parameter to request to provide the computation +Request to receive results example +``` +./bazel-bin/requester -p 8082 -m r +``` +The program execution for receiving results must use the same port as the execution of the compute request. +Program arguments can be viewed with +``` +./bazel-bin/requester -h +``` ### Resources diff --git a/include/Networking/client.h b/include/Networking/client.h index ff38a9531..d3236fd29 100644 --- a/include/Networking/client.h +++ b/include/Networking/client.h @@ -17,13 +17,13 @@ class Client { Sends all bytes from a buffer with resilience to partial sends with the option of num_retries (=-1 for blocking). */ - ssize_t send_all_bytes(const char* buffer, size_t length, int flags, + ssize_t sendAllBytes(const char* buffer, size_t length, int flags, int num_retries = 0); + void closeSocket(); public: Client(); ~Client(); - int setupConn(const char* HOST, const char* PORT, const char* CONNTYPE); int setupConn(const IpAddress& ipAddress, const char* CONNTYPE); int sendMsg(const std::string& data, int num_retries = 0); }; diff --git a/include/Networking/server.h b/include/Networking/server.h index 7a05e611e..dd51499e9 100644 --- a/include/Networking/server.h +++ b/include/Networking/server.h @@ -10,26 +10,28 @@ #include class Server { - const char* HOST; - const char* PORT; + IpAddress publicIp; const char* CONNTYPE; - IpAddress publicIP; - int server; // stores the current running server id - int activeConn; // stores the current active connection id + int server = -1; // stores the current running server id + int activeConn = -1; // stores the current active connection id /* Receives all bytes into a buffer with resilience to partial data sends. */ - ssize_t recv_all_bytes(char* buffer, size_t length, int flags, int num_retries = 0); + ssize_t recvAllBytes(char* buffer, size_t length, int flags, int num_retries = 0); + void closeSocket(); public: - Server(const char* host, const char* port, const char* type); + Server(const IpAddress& addr, const char* type); ~Server(); void setupServer(); // prepare server for connection bool acceptConn(); // blocking + bool acceptConn(IpAddress& clientAddr); // blocking int receiveFromConn(std::string& msg, int num_retries = 0); // process the active conn void replyToConn(std::string message); // reply to the active conn - void getFileFTP(std::string filename); // retrieve remote file + void + getFileIntoDirFTP(std::string filename, + std::string directory); // retrieve remote file into directory void closeConn(); // close the active conn }; diff --git a/include/Peers/bootstrap_node.h b/include/Peers/bootstrap_node.h index 938765eeb..8290b36c9 100644 --- a/include/Peers/bootstrap_node.h +++ b/include/Peers/bootstrap_node.h @@ -8,13 +8,12 @@ class BootstrapNode : public Peer { public: - BootstrapNode(const char*, std::string); + BootstrapNode(std::string); // ----------------- FIX LATER ----------------- BootstrapNode() {} ~BootstrapNode(); - static const char* getServerIpAddress(); - static const char* getServerPort(); + static IpAddress getServerIpAddr(); void registerPeer(const std::string& peerUuid, const IpAddress& peerIpAddr); AddressTable discoverPeers(const std::string& peerUuid, diff --git a/include/Peers/peer.h b/include/Peers/peer.h index 8648cd63a..1c9b7c568 100644 --- a/include/Peers/peer.h +++ b/include/Peers/peer.h @@ -9,8 +9,7 @@ class Peer { protected: - const char* host; - const char* port; + IpAddress publicIp; std::string uuid; Server* server; @@ -22,7 +21,8 @@ class Peer { public: Peer(); Peer(const std::string& uuid); - void setupServer(const char* host, const char* port); + void setPublicIp(const IpAddress& addr); + void setupServer(const IpAddress& addr); virtual ~Peer(); }; diff --git a/include/Peers/provider.h b/include/Peers/provider.h index bb1086365..02aea68eb 100644 --- a/include/Peers/provider.h +++ b/include/Peers/provider.h @@ -9,6 +9,10 @@ #include "../RequestResponse/task_response.h" #include "peer.h" +#include +#include +#include +#include #include #include #include @@ -18,7 +22,11 @@ class Provider : public Peer { bool isLocalBootstrap; bool isLeader; std::unique_ptr taskRequest; - std::unique_ptr taskResponse; + std::shared_ptr taskResponse; + + std::string currentAggregatedModelStateDict; + + std::thread* workloadThread; ZMQSender ml_zmq_sender; ZMQReceiver ml_zmq_receiver; @@ -27,7 +35,7 @@ class Provider : public Peer { ZMQReceiver aggregator_zmq_receiver; public: - Provider(const char* port, std::string uuid); + Provider(unsigned short port, std::string uuid); ~Provider() noexcept; void registerWithBootstrap(); @@ -35,10 +43,15 @@ class Provider : public Peer { void leaderHandleTaskRequest(const IpAddress& requesterIpAddr); void followerHandleTaskRequest(); void processData(); - void processWorkload(); // worker function to manipulate the TaskRequest - std::string + void + initializeWorkloadToML(); // worker function to manipulate the TaskRequest + void processWorkload(); // + + void ingestTrainingData(); // worker function to load training data into memory - TaskResponse aggregateResults(std::vector followerData); + + TaskResponse + aggregateResults(std::vector> followerData); }; #endif diff --git a/include/Peers/requester.h b/include/Peers/requester.h index 6dc37767c..a8e988474 100644 --- a/include/Peers/requester.h +++ b/include/Peers/requester.h @@ -14,7 +14,7 @@ class Requester : protected Peer { void divideTask(); public: - Requester(const char* port); + Requester(unsigned short port); ~Requester() noexcept; void sendDiscoveryRequest(unsigned int numProviders); void waitForDiscoveryResponse(); diff --git a/include/RequestResponse/discovery_response.h b/include/RequestResponse/discovery_response.h index 106fed2ce..439e971e0 100644 --- a/include/RequestResponse/discovery_response.h +++ b/include/RequestResponse/discovery_response.h @@ -6,12 +6,14 @@ #include class DiscoveryResponse : public Payload { + IpAddress callerAddr; AddressTable availablePeers; public: DiscoveryResponse(); - DiscoveryResponse(const AddressTable& availablePeers); + DiscoveryResponse(const IpAddress& callerAddr, const AddressTable& availablePeers); + IpAddress getCallerPublicIpAddress() const; AddressTable getAvailablePeers() const; google::protobuf::Message* serializeToProto() const override; void deserializeFromProto( diff --git a/include/RequestResponse/model_state_dict_params.h b/include/RequestResponse/model_state_dict_params.h new file mode 100644 index 000000000..58d4fbe6a --- /dev/null +++ b/include/RequestResponse/model_state_dict_params.h @@ -0,0 +1,30 @@ +#ifndef _MODEL_STATE_DICT_PARAMS_H +#define _MODEL_STATE_DICT_PARAMS_H + +#include +#include + +#include "../utility.h" +#include "payload.h" + +class ModelStateDictParams : public Payload { + // bytes representing the training data + std::string modelStateDict; + bool trainingIsComplete; + + public: + ModelStateDictParams(); + ModelStateDictParams(const std::string& modelStateDict); + + std::string getTrainingData() const; + void setTrainingData(const std::string& modelStateDict); + + bool getTrainingIsComplete() const; + void setTrainingIsComplete(bool trainingIsComplete); + + google::protobuf::Message* serializeToProto() const override; + void deserializeFromProto( + const google::protobuf::Message& protoMessage) override; +}; + +#endif diff --git a/include/RequestResponse/payload.h b/include/RequestResponse/payload.h index ef0cd3400..31c47b224 100644 --- a/include/RequestResponse/payload.h +++ b/include/RequestResponse/payload.h @@ -13,10 +13,12 @@ class Payload { UNKNOWN, ACKNOWLEDGEMENT, REGISTRATION, + REGISTRATION_RESPONSE, DISCOVERY_REQUEST, DISCOVERY_RESPONSE, TASK_REQUEST, - TASK_RESPONSE + TASK_RESPONSE, + MODEL_STATE_DICT_PARAMS, }; Payload(Type type); diff --git a/include/RequestResponse/registration_response.h b/include/RequestResponse/registration_response.h new file mode 100644 index 000000000..1e7a42d4e --- /dev/null +++ b/include/RequestResponse/registration_response.h @@ -0,0 +1,21 @@ +#ifndef _REGISTRATION_RESPONSE_ +#define _REGISTRATION_RESPONSE_ + +#include "payload.h" +#include "../utility.h" +#include + +class RegistrationResponse : public Payload { + IpAddress callerAddr; + + public: + RegistrationResponse(); + RegistrationResponse(const IpAddress& callerAddr); + + IpAddress getCallerPublicIpAddress() const; + google::protobuf::Message* serializeToProto() const override; + void deserializeFromProto( + const google::protobuf::Message& protoMessage) override; +}; + +#endif // _REGISTRATION_RESPONSE_ diff --git a/include/RequestResponse/task_request.h b/include/RequestResponse/task_request.h index f43df1a15..96d844c7c 100644 --- a/include/RequestResponse/task_request.h +++ b/include/RequestResponse/task_request.h @@ -11,6 +11,7 @@ class TaskRequest : public Payload { unsigned int numWorkers; std::string leaderUuid; + unsigned int numEpochs; AddressTable assignedWorkers; // A globPattern for the names of the training // data files necessitates the creation of @@ -24,13 +25,15 @@ class TaskRequest : public Payload { // Tagged "union" (to-do: c++17 supports variant class) enum TaskRequestType { NONE, GLOB_PATTERN, INDEX_FILENAME } taskRequestType; TaskRequest(); - TaskRequest(const unsigned int numWorkers, const std::string& data, + TaskRequest(const unsigned int numWorkers, const std::string& data, const unsigned int numEpochs, TaskRequestType type); void setLeaderUuid(const std::string& leaderUuid); void setAssignedWorkers(const AddressTable& assignedWorkers); void setGlobPattern(const std::string& pattern); void setTrainingDataIndexFilename(const std::string& filename); + void setNumEpochs(const unsigned int numEpochs); + // Write the index file(s) to SOURCE_DATA_DIR void writeToTrainingDataIndexFile( const std::vector& trainingDataFiles) const; @@ -39,9 +42,11 @@ class TaskRequest : public Payload { AddressTable getAssignedWorkers() const; std::string getGlobPattern() const; std::string getTrainingDataIndexFilename() const; + unsigned int getNumEpochs() const; - // Retrieves all data files referenced in this task request - std::vector getTrainingDataFiles() const; + // Retrieves all data files referenced in this task request. Uses a dir to + // target a directory. + std::vector getTrainingDataFiles(std::string dir) const; google::protobuf::Message* serializeToProto() const override; void deserializeFromProto( const google::protobuf::Message& protoMessage) override; diff --git a/include/RequestResponse/task_response.h b/include/RequestResponse/task_response.h index f32c911aa..eae55fa8b 100644 --- a/include/RequestResponse/task_response.h +++ b/include/RequestResponse/task_response.h @@ -10,14 +10,18 @@ class TaskResponse : public Payload { // bytes representing the training data std::string modelStateDict; + bool trainingIsComplete; public: TaskResponse(); - TaskResponse(const std::string& modelStateDict); + TaskResponse(const std::string& modelStateDict, bool trainingIsComplete); std::string getTrainingData() const; void setTrainingData(const std::string& modelStateDict); + bool getTrainingIsComplete() const; + void setTrainingIsComplete(bool trainingIsComplete); + google::protobuf::Message* serializeToProto() const override; void deserializeFromProto( const google::protobuf::Message& protoMessage) override; diff --git a/include/utility.h b/include/utility.h index 4c45fdfed..a1e34ddec 100644 --- a/include/utility.h +++ b/include/utility.h @@ -32,9 +32,15 @@ namespace fs = std::filesystem; #define MAX_PORT_TRIES 10 /* - * Defines the data location of training files. + * Defines the data location of training files and index files for the requestor. */ -const std::string DATA_DIR = "CIFAR10/train"; +const std::string SOURCE_DATA_DIR = "CIFAR10/train"; + +/* + * Defines the data location of training files being used locally by a provider. + */ +const std::string TARGET_DATA_DIR = "data/CIFAR10"; + struct IpAddress { std::string host; @@ -42,7 +48,8 @@ struct IpAddress { IpAddress() {} IpAddress(const std::string& host, const unsigned short port); - IpAddress(const char* host, const char* port); + + friend std::ostream& operator<<(std::ostream& os, const IpAddress& ip); }; utility::IpAddress* serializeIpAddressToProto(const IpAddress& ipAddress); @@ -62,8 +69,6 @@ static std::uniform_int_distribution<> dis2(8, 11); std::string generate_uuid_v4(); } // namespace uuid -std::string startNgrokForwarding(unsigned short port); - std::string vectorToString(std::vector v); int FTP_create_socket_client(int port, const char* addr); @@ -79,10 +84,17 @@ int FTP_accept_conn(int sock); fs::path resolveDataFile(const std::string filename); /* - * Verifies if a file is present in the data directory. Accepts - * a filename as input. + * Resolves the path of a file within a directory. + * Accepts a filename, directory and returns a relative path. + */ +fs::path resolveDataFileInDirectory(const std::string filename, + const std::string dir); + +/* + * Verifies if a file is present in a directory. Accepts + * a filename and directory as input. */ -bool isFileWithinDataDirectory(const std::string& filename); +bool isFileWithinDirectory(const std::string& filename, const std::string dir); /* * Generates a random port number that is available for use between MIN_PORT and diff --git a/main.cpp b/main.cpp index aef265396..4bb36783f 100644 --- a/main.cpp +++ b/main.cpp @@ -2,54 +2,117 @@ #include "include/Peers/provider.h" #include "include/Peers/requester.h" #include "include/utility.h" +#include +#include +#include #include #include +#define DEFAULT_PORT 8080 +#define DEFAULT_WORKERS 2 +#define DEFAULT_EPOCHS 10 + using namespace std; -// comment testing +namespace po = boost::program_options; int main(int argc, char* argv[]) { - const char *port = "8080"; + unsigned short port = DEFAULT_PORT; string uuid = uuid::generate_uuid_v4(); - if (argc >= 2) { - port = argv[1]; - } - #if defined(BOOTSTRAP) - cout << "Running as bootstrap node on port " << port << "." << endl; - BootstrapNode b = BootstrapNode(port, uuid); + unsigned int bootstrapPort = BootstrapNode::getServerIpAddr().port; + cout << "Running as bootstrap node on port " << bootstrapPort << "." + << endl; + BootstrapNode b = BootstrapNode(uuid); b.listen(); #elif defined(PROVIDER) + try { + // Define available program argument options + po::options_description desc("Allowed options"); + desc.add_options() + ("port,p", po::value()->default_value(DEFAULT_PORT), "set P2P server port") + ("help,h", "produce help message") + ; + + // Parse command-line arguments + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + //Handle options + if (vm.count("help")) { + cout << desc << endl; + return 0; + } + + port = vm["port"].as(); + } catch (const exception& e) { + cerr << "Error parsing program arguments: " << e.what() << endl; + return 1; + } + cout << "Running as provider on port " << port << "." << endl; Provider p = Provider(port, uuid); p.registerWithBootstrap(); p.listen(); #elif defined(REQUESTER) - cout << "Running as requester." << endl; - Requester r = Requester(port); - int numRequestedWorkers = 2; + // Define available program argument options + po::options_description desc("Allowed options"); + desc.add_options() + ("port,p", po::value()->default_value(DEFAULT_PORT), + "set P2P server port") + ("mode,m", po::value(), + "set mode to run requester: 'c' for compute task or 'r' for receive results") + ("workers,w", po::value()->default_value(DEFAULT_WORKERS), + "set number of workers to assign training task") + ("epochs,e", po::value()->default_value(DEFAULT_EPOCHS), + "set number of epochs to run training task") + ("help,h", "produce help message") + ; - string requestType = "c"; - if (argc >= 3) { - requestType = argv[2]; + // Parse command-line arguments + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + //Handle options + if (vm.count("help")) { + cout << desc << endl; + return 0; } - if (requestType == "c") { - TaskRequest request = - TaskRequest(numRequestedWorkers, ".*\\.jpg$", - TaskRequest::GLOB_PATTERN); + port = vm["port"].as(); + string mode = vm["mode"].as(); + + cout << "Running as requester on port " << port << "." << endl; + Requester r = Requester(port); + + if (mode == "c") { + unsigned int numRequestedWorkers = vm["workers"].as(); + unsigned int numEpochs = vm["epochs"].as(); + cout << "Creating training task for " << numRequestedWorkers + << " workers and " << numEpochs << " epochs" << endl; + TaskRequest request = TaskRequest(numRequestedWorkers, ".*\\.jpg$", + numEpochs, TaskRequest::GLOB_PATTERN); r.setTaskRequest(request); // sends the task request to the leader and provider peers r.sendTaskRequest(); cout << "Sent task request." << endl; - } else if (requestType == "r") { + } else if (mode == "r") { TaskResponse response = r.getResults(); - auto result = response.getTrainingData(); + std::string result = response.getTrainingData(); cout << "Received training results" << endl; - } + // output result.modelStateDict to a file as a binary + std::filesystem::create_directories("output"); + std::ofstream file("output/modelStateDict.bin", std::ios::binary); + + file.write(reinterpret_cast(result.data()), result.size()); + file.close(); + + std::cout << "Successfully wrote to output/modelStateDict" << endl; + } #else cout << "Please specify either --provider or --requester flag." << endl; #endif diff --git a/ml/aggregator.py b/ml/aggregator.py index 8bde15cfc..d6eaa2f97 100644 --- a/ml/aggregator.py +++ b/ml/aggregator.py @@ -10,7 +10,7 @@ from networks import SimpleCNN from dataloader import CIFAR10Dataset, get_data_loaders -from utils import train, val, test +from utils import train, val, test, network from proto import payload_pb2, utility_pb2 @@ -42,40 +42,51 @@ def main(): num_peers = int(input("Enter the number of peers: ")) context = zmq.Context() - responder = context.socket(zmq.REP) - responder.setsockopt(zmq.LINGER, 0) - responder.bind("tcp://*:" + str(port_rec)) + receiver = network.ZMQReciever(context, port_rec) + sender = network.ZMQSender(context, port_send) - sender = context.socket(zmq.REQ) - sender.setsockopt(zmq.LINGER, 0) - sender.connect("tcp://localhost:" + str(port_send)) + numCompletePeers = 0 + # completeStateDicts = [] + final_state_dict = None # recieve the models from fake_peer.py - print("Waiting for models...") - state_dicts = [] - for i in range(num_peers): - sd = responder.recv() - responder.send_string("ACK") - - agg_inp = payload_pb2.AggregatorInputData() - agg_inp.ParseFromString(sd) - agg_inp = pickle.loads(agg_inp.modelStateDict) - state_dicts.append(agg_inp) - - # average the models - print("Averaging models...") - avg_state_dict = nn_aggregator(state_dicts) - - # send the averaged model back to fake_peer.py - print("Sending averaged model...") - tr = payload_pb2.TaskResponse() - tr.modelStateDict = pickle.dumps(avg_state_dict) - sender.send(tr.SerializeToString()) - - print("Sent averaged model, waiting for acknowledgement") - - acknowledgement = sender.recv() - print("Acknowledgement received") + while True: + print("Waiting for models...") + # state_dicts = completeStateDicts.copy() + state_dicts = [] + for _ in range(num_peers - numCompletePeers): + payload = receiver.receive() + agg_inp = payload_pb2.ModelStateDictParams() + agg_inp.ParseFromString(payload) + agg_inp_model = pickle.loads(agg_inp.modelStateDict) + if agg_inp.trainingIsComplete: + numCompletePeers += 1 + # completeStateDicts.append(agg_inp_model) + state_dicts.append(agg_inp_model) + time.sleep(5) + + # average the models + print("Averaging models...") + avg_state_dict = nn_aggregator(state_dicts) + + if numCompletePeers > 0: + final_state_dict = avg_state_dict + break + + # send the averaged model back to fake_peer.py + print("Sending averaged model...") + tr = payload_pb2.TaskResponse() + tr.modelStateDict = pickle.dumps(avg_state_dict) + tr.trainingIsComplete = False + sender.send(tr.SerializeToString()) + + # send final response + pickled_weights = pickle.dumps(final_state_dict) + task_response = payload_pb2.TaskResponse() + task_response.modelStateDict = pickled_weights + task_response.trainingIsComplete = True + sender.send(task_response.SerializeToString()) + return diff --git a/ml/fake_peer.py b/ml/fake_peer.py index bf9a9e7ed..d0eed3e33 100644 --- a/ml/fake_peer.py +++ b/ml/fake_peer.py @@ -1,3 +1,4 @@ +# this is for building out the aggregator.py file. import os import torch import torch.nn as nn diff --git a/ml/test_main_simpleCNN.py b/ml/test_main_simpleCNN.py index 784277ff2..5f6289cdf 100644 --- a/ml/test_main_simpleCNN.py +++ b/ml/test_main_simpleCNN.py @@ -10,40 +10,38 @@ from networks import SimpleCNN from dataloader import CIFAR10Dataset, get_data_loaders -from utils import train, val, test +from utils import train, val, test, network from proto import payload_pb2, utility_pb2 def main(): - # Set up the context and responder socket + # Set up the context and receiver socket port_rec = int(input("Enter the ZMQ sender port number: ")) port_send = int(input("Enter the ZMQ receiver port number: ")) context = zmq.Context() - responder = context.socket(zmq.REP) - responder.setsockopt(zmq.LINGER, 0) - responder.bind("tcp://*:" + str(port_rec)) - - sender = context.socket(zmq.REQ) - sender.setsockopt(zmq.LINGER, 0) - sender.connect("tcp://localhost:" + str(port_send)) + receiver = network.ZMQReciever(context, port_rec) + sender = network.ZMQSender(context, port_send) # Hyperparameters (can use CLI) batch_size = 64 learning_rate = 0.001 - epochs = 1 # for now + # epochs = 2 # for now device = torch.device( "mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") # recieve a payload with the data_file_names - payload = responder.recv() - responder.send_string("ACK") - + payload = receiver.receive() training_payload = payload_pb2.TrainingData() training_payload.ParseFromString(payload) data_file_names = training_payload.training_data_index_filename + epochs = training_payload.numEpochs print("data_file_names: ", data_file_names) + # print hyperparameters + print("batch_size: ", batch_size) + print("learning_rate: ", learning_rate) + print("epochs: ", epochs) start_time = time.time() @@ -91,6 +89,23 @@ def main(): print("validating") val(model, device, val_loader, criterion, epoch, data_path) + # non compressed, non protobuf sending weights + pickled_weights = pickle.dumps(model.state_dict()) + task_response = payload_pb2.TaskResponse() + task_response.modelStateDict = pickled_weights + task_response.trainingIsComplete = False + sender.send(task_response.SerializeToString()) + + # receive the updated message + # TODO: fix this if we ignore the very last message + payload = receiver.receive() + agg_inp = payload_pb2.ModelStateDictParams() + agg_inp.ParseFromString(payload) + averaged_state_dict = pickle.loads(agg_inp.modelStateDict) + + # update current model with the averaged state dict + model.load_state_dict(averaged_state_dict) + # Test the model test(model, device, test_loader, criterion, data_path) @@ -103,16 +118,12 @@ def main(): print("Start Time: ", start_time) print("End Time: ", end_time) - # non compressed, non protobuf sending weights + # send final response pickled_weights = pickle.dumps(model.state_dict()) - task_response = payload_pb2.TaskResponse() task_response.modelStateDict = pickled_weights + task_response.trainingIsComplete = True sender.send(task_response.SerializeToString()) - print("Sent results, waiting for acknowledgement...") - - acknowledgement = sender.recv() - print("Acknowledgement received") return diff --git a/ml/testing_requester_model.py b/ml/testing_requester_model.py new file mode 100644 index 000000000..947280df9 --- /dev/null +++ b/ml/testing_requester_model.py @@ -0,0 +1,24 @@ +import torch +from networks import SimpleCNN +import pickle + + +def load_model(model_path='output/modelStateDict.bin'): + # Load the model directly from the binary file + model = SimpleCNN() + + # Read the binary file and unpickle it + with open(model_path, 'rb') as file: + unpickled_model = pickle.loads(file.read()) + + model.load_state_dict(unpickled_model) + return model + + +def main(): + model = load_model() + print(model) + + +if __name__ == '__main__': + main() diff --git a/ml/utils/__init__.py b/ml/utils/__init__.py index eb375e502..479c6204e 100644 --- a/ml/utils/__init__.py +++ b/ml/utils/__init__.py @@ -1,3 +1,3 @@ -from .train import train, train_distributed +from .train import train, train_multiple_aggregation from .test import test from .val import val diff --git a/ml/utils/network.py b/ml/utils/network.py new file mode 100644 index 000000000..1024ef14e --- /dev/null +++ b/ml/utils/network.py @@ -0,0 +1,32 @@ +import zmq + +# sender sends the message + + +class ZMQSender: + def __init__(self, context, port_send): + _sender = context.socket(zmq.REQ) + _sender.setsockopt(zmq.LINGER, 0) + _sender.connect("tcp://localhost:" + str(port_send)) + self.sender = _sender + # send message over zmq needs an acknowledgement as well + + def send(self, msg): + self.sender.send(msg) + ack = self.sender.recv() + return ack + + +# receiver receives the message +class ZMQReciever: + def __init__(self, context, port_rec): + _receiver = context.socket(zmq.REP) + _receiver.setsockopt(zmq.LINGER, 0) + _receiver.bind("tcp://*:" + str(port_rec)) + self.receiver = _receiver + + # receive message over zmq needs to send an acknowledgement as well + def receive(self): + msg = self.receiver.recv() + self.receiver.send_string("") + return msg diff --git a/ml/utils/train.py b/ml/utils/train.py index 0c1e63df1..a18021ed0 100644 --- a/ml/utils/train.py +++ b/ml/utils/train.py @@ -1,5 +1,7 @@ from tqdm import tqdm +import pickle +from proto import payload_pb2 def train(model, device, train_loader, optimizer, criterion, epoch): model.train() @@ -28,10 +30,13 @@ def train(model, device, train_loader, optimizer, criterion, epoch): print("[Epoch %d] loss: %.3f" % (epoch + 1, running_loss / len(train_loader))) -def train_distributed(model, device, train_loader, optimizer, criterion, epoch): +def train_multiple_aggregation(model, device, train_loader, optimizer, criterion, epoch, agg_cycle, receiver, sender): + """ + This function is equivalent to the train function above however the difference is that this + function will use ZMQSender and ZMQReceiver to send and receive model state dicts + """ model.train() running_loss = 0.0 - gradients = {} # Wrap your data loader with tqdm for a progress bar progress_bar = tqdm( enumerate(train_loader), @@ -47,15 +52,59 @@ def train_distributed(model, device, train_loader, optimizer, criterion, epoch): optimizer.step() running_loss += loss.item() - + # Update the progress bar with the latest loss information progress_bar.set_postfix( loss=running_loss / (batch_idx + 1), current_batch=batch_idx, refresh=True ) - for name, parameter in model.named_parameters(): - if parameter.grad is not None: - gradients[name] = parameter.grad.clone().to("cpu") + if batch_idx % agg_cycle == 0: + # non compressed, non protobuf sending weights + pickled_weights = pickle.dumps(model.state_dict()) + task_response = payload_pb2.TaskResponse() + task_response.modelStateDict = pickled_weights + sender.send(task_response.SerializeToString()) + + # recieve the updated message + payload = receiver.receive() + agg_inp = payload_pb2.ModelStateDictParams() + agg_inp.ParseFromString(payload) + averaged_state_dict = pickle.loads(agg_inp.modelStateDict) + + # update current model with the averaged state dict + model.load_state_dict(averaged_state_dict) + # At the end of the epoch, print the average loss print("[Epoch %d] loss: %.3f" % (epoch + 1, running_loss / len(train_loader))) - return gradients + +# def train_distributed(model, device, train_loader, optimizer, criterion, epoch): +# model.train() +# running_loss = 0.0 +# gradients = {} +# # Wrap your data loader with tqdm for a progress bar +# progress_bar = tqdm( +# enumerate(train_loader), +# total=len(train_loader), +# desc="Training Epoch {}".format(epoch + 1), +# ) +# for batch_idx, (data, target) in progress_bar: +# data, target = data.to(device), target.to(device) +# optimizer.zero_grad() +# output = model(data) +# loss = criterion(output, target) +# loss.backward() +# optimizer.step() + +# running_loss += loss.item() + +# progress_bar.set_postfix( +# loss=running_loss / (batch_idx + 1), current_batch=batch_idx, refresh=True +# ) + +# for name, parameter in model.named_parameters(): +# if parameter.grad is not None: +# gradients[name] = parameter.grad.clone().to("cpu") + +# print("[Epoch %d] loss: %.3f" % (epoch + 1, running_loss / len(train_loader))) + +# return gradients diff --git a/output/modelStateDict.bin b/output/modelStateDict.bin new file mode 100644 index 000000000..74248239b Binary files /dev/null and b/output/modelStateDict.bin differ diff --git a/proto/payload.proto b/proto/payload.proto index 13d2d8e05..47b2a56aa 100644 --- a/proto/payload.proto +++ b/proto/payload.proto @@ -7,10 +7,12 @@ package payload; enum PayloadType { ACKNOWLEDGEMENT = 0; REGISTRATION = 1; - DISCOVERY_REQUEST = 2; - DISCOVERY_RESPONSE = 3; - TASK_REQUEST = 4; - TASK_RESPONSE = 5; + REGISTRATION_RESPONSE = 2; + DISCOVERY_REQUEST = 3; + DISCOVERY_RESPONSE = 4; + TASK_REQUEST = 5; + TASK_RESPONSE = 6; + MODEL_STATE_DICT_PARAMS = 7; } message PayloadMessage { @@ -26,10 +28,15 @@ message PayloadMessage { TaskResponse taskResponse = 9; Acknowledgement acknowledgement = 10; Registration registration = 11; + RegistrationResponse registrationResponse = 12; + ModelStateDictParams modelStateDictParams = 13; } } -message DiscoveryResponse { utility.AddressTable availablePeers = 1; } +message DiscoveryResponse { + utility.IpAddress callerAddr = 1; + utility.AddressTable availablePeers = 2; +} message DiscoveryRequest { int32 peersRequested = 1; } @@ -37,18 +44,31 @@ message TaskRequest { int32 numWorkers = 1; string leaderUuid = 2; utility.AddressTable assignedWorkers = 3; + int32 numEpochs = 6; + oneof training_data_source { string glob_pattern = 4; string training_data_index_filename = 5; } } -message TrainingData { string training_data_index_filename = 1; } +message TrainingData { + string training_data_index_filename = 1; + int32 numEpochs = 2; +} -message AggregatorInputData { bytes modelStateDict = 1; } +message ModelStateDictParams { + bytes modelStateDict = 1; + bool trainingIsComplete = 2; +} -message TaskResponse { bytes modelStateDict = 1; } +message TaskResponse { + bytes modelStateDict = 1; + bool trainingIsComplete = 2; +} message Acknowledgement {} message Registration {} + +message RegistrationResponse { utility.IpAddress callerAddr = 1; } diff --git a/requirements.txt b/requirements.txt index 574a6a9c2..5da15d3c4 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,6 @@ certifi==2023.11.17 chardet==4.0.0 idna==2.10 -ngrok==0.12.0 -ngrok-api==0.10.0 -pyngrok==7.0.0 python-dotenv==1.0.0 PyYAML==6.0.1 requests==2.25.1 diff --git a/src/Networking/client.cpp b/src/Networking/client.cpp index 2608c3957..e61216780 100644 --- a/src/Networking/client.cpp +++ b/src/Networking/client.cpp @@ -6,8 +6,10 @@ Client::Client() : CONN{-1} {} Client::~Client() {} -int Client::setupConn(const char* HOST, const char* PORT, - const char* CONNTYPE) { +int Client::setupConn(const IpAddress& ipAddress, const char* CONNTYPE) { + const char* HOST = ipAddress.host.c_str(); + const char* PORT = to_string(ipAddress.port).c_str(); + addrinfo hints, *serverInfo; memset(&hints, 0, sizeof hints); @@ -28,8 +30,7 @@ int Client::setupConn(const char* HOST, const char* PORT, if (connect(CONN, addr->ai_addr, addr->ai_addrlen) == -1) { cerr << "Error connecting: " << strerror(errno) << endl; - close(CONN); - CONN = -1; + closeSocket(); continue; } @@ -46,12 +47,7 @@ int Client::setupConn(const char* HOST, const char* PORT, return 0; } -int Client::setupConn(const IpAddress& ipAddress, const char* CONNTYPE) { - return setupConn(ipAddress.host.c_str(), to_string(ipAddress.port).c_str(), - CONNTYPE); -} - -ssize_t Client::send_all_bytes(const char* buffer, size_t length, int flags, +ssize_t Client::sendAllBytes(const char* buffer, size_t length, int flags, int num_retries) { size_t total_sent = 0; int retries_used = 0; @@ -80,95 +76,130 @@ int Client::sendMsg(const string& data, int num_retries) { // Send message length first uint32_t data_size = htonl(data.size()); // Convert from host byte order to network byte order - if (send_all_bytes(reinterpret_cast(&data_size), sizeof(data_size), 0, num_retries) == -1) { + if (sendAllBytes(reinterpret_cast(&data_size), sizeof(data_size), 0, num_retries) == -1) { cerr << "Failed to send message length" << endl; return 1; } // Send message data - if (send_all_bytes(data.c_str(), data.size(), 0, num_retries) == -1) { + if (sendAllBytes(data.c_str(), data.size(), 0, num_retries) == -1) { cerr << "Failed to send message data" << endl; return 1; } cout << "Client successfully sent message" << endl; - char buffer[1024]; - ssize_t mLen = recv(CONN, buffer, sizeof(buffer), 0); - if (mLen < 0) { - cerr << "Error reading: " << strerror(errno) << endl; - close(CONN); - return 1; - } - buffer[mLen] = '\0'; - - /* - * If the request received contains the keyword "get", which is used to - * represent a file transfer request, the client will proceed to provide the - * file. - */ - // Using `std::istringstream` instead of strtok - std::istringstream iss(buffer); - std::string command, filename; - iss >> command >> filename; // Extract command and filename from message - - // process file descriptor - if (command == "get") { - char port[FTP_BUFFER_SIZE], buffer[FTP_BUFFER_SIZE], - char_num_blks[FTP_BUFFER_SIZE], char_num_last_blk[FTP_BUFFER_SIZE]; - int datasock, lSize, num_blks, num_last_blk, i; - FILE* fp; - cout << "FTP: Filename given is: " << filename << endl; - - if (!isFileWithinDataDirectory(filename)) { - cerr << "FTP: Requested file is not within the data directory" - << endl; - send(CONN, "0", FTP_BUFFER_SIZE, 0); + // Set socket timeout to 10 seconds + struct timeval tv; + tv.tv_sec = 10; + tv.tv_usec = 0; + setsockopt(CONN, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv); + + int dataSock, listenSock; + int listenPort = get_available_port(); + cout << "FTP: listen port is: " << listenPort << endl; + listenSock = FTP_create_socket_server( + listenPort); // creating socket for data connection + + while (true) { + char buffer[1024]; + ssize_t mLen = recv(CONN, buffer, sizeof(buffer), 0); + + if (mLen < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // Timeout occurred, no more messages + cout << "Timeout occurred, no more messages" << endl; + break; + } + cerr << "Error reading: " << strerror(errno) << endl; close(CONN); return 1; } - int data_port = get_available_port(); - if (data_port == -1) { - cerr << "FTP: No available ports" << endl; - send(CONN, "0", FTP_BUFFER_SIZE, 0); + if (mLen == 0) { + cout << "Connection closed by server" << endl; close(CONN); return 1; } - cout << "FTP: Data port is: " << data_port << endl; - sprintf(port, "%d", data_port); - datasock = FTP_create_socket_server( - data_port); // creating socket for data connection - send(CONN, port, FTP_BUFFER_SIZE, 0); // sending port no. to client - datasock = FTP_accept_conn(datasock); // accepting connnection by client - if ((fp = fopen(resolveDataFile(filename).c_str(), "r")) != NULL) { - // size of file - send(CONN, "nxt", FTP_BUFFER_SIZE, 0); - fseek(fp, 0, SEEK_END); - lSize = ftell(fp); - rewind(fp); - num_blks = lSize / FTP_BUFFER_SIZE; - num_last_blk = lSize % FTP_BUFFER_SIZE; - sprintf(char_num_blks, "%d", num_blks); - send(CONN, char_num_blks, FTP_BUFFER_SIZE, 0); - - for (i = 0; i < num_blks; i++) { - fread(buffer, sizeof(char), FTP_BUFFER_SIZE, fp); - send(datasock, buffer, FTP_BUFFER_SIZE, 0); + + buffer[mLen] = '\0'; + cout << "FTP Server Received: " << string(buffer, mLen) << endl; + + /* + * If the request received contains the keyword "get", which is used to + * represent a file transfer request, the client will proceed to provide the + * file. + */ + // Using `std::istringstream` instead of strtok + std::istringstream iss(buffer); + std::string command, filename; + iss >> command >> filename; // Extract command and filename from message + + // process file descriptor + if (command == "get") { + if (listenPort == -1) { + cerr << "FTP: No available ports" << endl; + send(CONN, "0", FTP_BUFFER_SIZE, 0); + close(CONN); + return 1; } - sprintf(char_num_last_blk, "%d", num_last_blk); - send(CONN, char_num_last_blk, FTP_BUFFER_SIZE, 0); - if (num_last_blk > 0) { - fread(buffer, sizeof(char), num_last_blk, fp); - send(datasock, buffer, FTP_BUFFER_SIZE, 0); + + char port[FTP_BUFFER_SIZE], buffer[FTP_BUFFER_SIZE], + char_num_blks[FTP_BUFFER_SIZE], char_num_last_blk[FTP_BUFFER_SIZE]; + int lSize, num_blks, num_last_blk, i; + FILE* fp; + cout << "FTP: Filename given is: " << filename << endl; + + if (!isFileWithinDirectory(filename, SOURCE_DATA_DIR)) { + cerr << "FTP: Requested file is not within the data directory" + << endl; + send(CONN, "0", FTP_BUFFER_SIZE, 0); + close(CONN); + return 1; + } + + sprintf(port, "%d", listenPort); + send(CONN, port, FTP_BUFFER_SIZE, 0); // sending port no. to client + dataSock = FTP_accept_conn(listenSock); // accepting connnection by client + + if ((fp = fopen(resolveDataFile(filename).c_str(), "r")) != NULL) { + // size of file + send(CONN, "nxt", FTP_BUFFER_SIZE, 0); + fseek(fp, 0, SEEK_END); + lSize = ftell(fp); + rewind(fp); + num_blks = lSize / FTP_BUFFER_SIZE; + num_last_blk = lSize % FTP_BUFFER_SIZE; + sprintf(char_num_blks, "%d", num_blks); + send(CONN, char_num_blks, FTP_BUFFER_SIZE, 0); + + for (i = 0; i < num_blks; i++) { + fread(buffer, sizeof(char), FTP_BUFFER_SIZE, fp); + send(dataSock, buffer, FTP_BUFFER_SIZE, 0); + } + sprintf(char_num_last_blk, "%d", num_last_blk); + send(CONN, char_num_last_blk, FTP_BUFFER_SIZE, 0); + if (num_last_blk > 0) { + fread(buffer, sizeof(char), num_last_blk, fp); + send(dataSock, buffer, FTP_BUFFER_SIZE, 0); + } + fclose(fp); + cout << "FTP: File upload done" << endl; + } else { + send(CONN, "0", FTP_BUFFER_SIZE, 0); } - fclose(fp); - cout << "FTP: File upload done" << endl; - } else { - send(CONN, "0", FTP_BUFFER_SIZE, 0); + close(dataSock); } } - close(CONN); + closeSocket(); + close(listenSock); return 0; } + +void Client::closeSocket() { + if (CONN != -1) { + close(CONN); + CONN = -1; + } +} diff --git a/src/Networking/ngrok_ip.py b/src/Networking/ngrok_ip.py deleted file mode 100755 index 8ccf6c341..000000000 --- a/src/Networking/ngrok_ip.py +++ /dev/null @@ -1,21 +0,0 @@ -# run ngrok http 8080 on command line -import ngrok -import sys -import os -from dotenv import load_dotenv - -load_dotenv() - -client = ngrok.Client(os.getenv('NGROK_TOKEN')) - -# check for command line arguments -port = 8080 -if len(sys.argv) == 2: - port = sys.argv[1] - -# get the open tunnels -while len(list(client.endpoints.list())) == 0: - pass - -for e in client.endpoints.list(): - print(e.public_url) \ No newline at end of file diff --git a/src/Networking/ngrok_restart.sh b/src/Networking/ngrok_restart.sh deleted file mode 100755 index acc7d787f..000000000 --- a/src/Networking/ngrok_restart.sh +++ /dev/null @@ -1,24 +0,0 @@ - #!/bin/bash - - # grabs the PID for the current running ngrok - ngrok_pid=$(pgrep ngrok) - - # kills ngrok if ngrok running - if [ $ngrok_pid ]; then - echo "Current ngrok PID = ${ngrok_pid}" - kill_ngrok_pid=$(kill -9 $ngrok_pid) - - # get exit status code for last command - check=$? - fi - - # Check if a port is provided as a command line argument - if [ "$#" -ne 1 ]; then - echo "Usage: $0 " - exit 1 - fi - - port=$1 - - # re-start ngrok - $(ngrok tcp $port &) diff --git a/src/Networking/server.cpp b/src/Networking/server.cpp index 4f7725c36..fa6779428 100644 --- a/src/Networking/server.cpp +++ b/src/Networking/server.cpp @@ -2,28 +2,16 @@ using namespace std; -Server::Server(const char* host, const char* port, const char* type) - : HOST{host}, PORT{port}, CONNTYPE{type}, server{-1} {} +Server::Server(const IpAddress& addr, const char* type) + : publicIp{addr}, CONNTYPE{type}, server{-1} {} void Server::setupServer() { -#if defined(NOLOCAL) - string response = startNgrokForwarding(stoi(PORT)); - // update public IP address - response = response.substr(6); - string ip = - response.substr(0, response.find(":")); // ignore "tcp://" prefix - unsigned short port = - static_cast(stoi(response.substr(ip.length() + 1))); - publicIP = IpAddress{ip, port}; -#elif defined(LOCAL) - publicIP = IpAddress{HOST, static_cast(stoi(PORT))}; -#else +#if !defined(NOLOCAL) && !defined(LOCAL) cerr << "Please specify either --local or --nolocal flag." << endl; exit(1); #endif - cout << "Initializing server on " << publicIP.host << ":" << publicIP.port - << endl; + cout << "Initializing server on " << publicIp << endl; server = socket(AF_INET, SOCK_STREAM, 0); if (server == -1) { @@ -33,19 +21,19 @@ void Server::setupServer() { sockaddr_in serverAddr; serverAddr.sin_family = AF_INET; - serverAddr.sin_port = htons(stoi(PORT)); + serverAddr.sin_port = htons(publicIp.port); serverAddr.sin_addr.s_addr = INADDR_ANY; if (::bind(server, (struct sockaddr*)&serverAddr, sizeof(serverAddr)) == -1) { cerr << "Error binding: " << strerror(errno) << endl; - close(server); + closeSocket(); exit(1); } if (listen(server, 5) == -1) { cerr << "Error listening: " << strerror(errno) << endl; - close(server); + closeSocket(); exit(1); } @@ -53,21 +41,31 @@ void Server::setupServer() { } bool Server::acceptConn() { - sockaddr_in clientAddr; - socklen_t clientAddrLen = sizeof(clientAddr); - activeConn = accept(server, (struct sockaddr*)&clientAddr, &clientAddrLen); + IpAddress addr; + return acceptConn(addr); +} + +bool Server::acceptConn(IpAddress& clientAddr) { + sockaddr_in addr; + socklen_t addrLen = sizeof(addr); + activeConn = accept(server, (struct sockaddr*)&addr, &addrLen); if (activeConn == -1) { cerr << "Error accepting: " << strerror(errno) << endl; - close(server); return false; } - cout << "Client connected" << endl; + // Get address of incoming connection + char addrBuffer[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &addr.sin_addr, addrBuffer, sizeof(addrBuffer)); + clientAddr.host = string(addrBuffer); + clientAddr.port = htons(addr.sin_port); + + cout << "Client connected from " << clientAddr << endl; return true; } -ssize_t Server::recv_all_bytes(char* buffer, size_t length, int flags, int num_retries) { +ssize_t Server::recvAllBytes(char* buffer, size_t length, int flags, int num_retries) { size_t total_received = 0; int retries_used = 0; while (total_received < length) { @@ -96,7 +94,7 @@ ssize_t Server::recv_all_bytes(char* buffer, size_t length, int flags, int num_r int Server::receiveFromConn(string& msg, int num_retries) { // Read message size uint32_t data_size; - if (recv_all_bytes(reinterpret_cast(&data_size), sizeof(data_size), 0, num_retries) == -1) { + if (recvAllBytes(reinterpret_cast(&data_size), sizeof(data_size), 0, num_retries) == -1) { cerr << "Failed to receive message length" << endl; return 1; } @@ -104,7 +102,7 @@ int Server::receiveFromConn(string& msg, int num_retries) { // Read message data string data(data_size, '\0'); - if (recv_all_bytes(data.data(), data_size, 0, num_retries) == -1) { + if (recvAllBytes(data.data(), data_size, 0, num_retries) == -1) { cerr << "Failed to receive message data" << endl; return 1; } @@ -119,7 +117,7 @@ void Server::replyToConn(string message) { send(activeConn, reply, strlen(reply), 0); } -void Server::getFileFTP(string filename) { +void Server::getFileIntoDirFTP(string filename, string directory) { std::string reply = "get " + filename; cout << "FTP: sending request \"" << reply << "\"" << endl; send(activeConn, reply.c_str(), strlen(reply.c_str()), 0); @@ -131,11 +129,14 @@ void Server::getFileFTP(string filename) { FILE* fp; recv(activeConn, port, FTP_BUFFER_SIZE, 0); data_port = atoi(port); - datasock = FTP_create_socket_client(data_port, PORT); + datasock = FTP_create_socket_client(data_port, to_string(publicIp.port).c_str()); recv(activeConn, msg, FTP_BUFFER_SIZE, 0); if (strcmp("nxt", msg) == 0) { - if ((fp = fopen(resolveDataFile(filename).c_str(), "w")) == NULL) - cout << "FTP: Error in creating file" << endl; + if ((fp = fopen( + resolveDataFileInDirectory(filename, TARGET_DATA_DIR) + .c_str(), + "w")) == NULL) + cout << "FTP: Error in creating file. errno: " << strerror(errno) << endl; else { recv(activeConn, char_num_blks, FTP_BUFFER_SIZE, 0); num_blks = atoi(char_num_blks); @@ -155,12 +156,24 @@ void Server::getFileFTP(string filename) { } else { cerr << "FTP: Error in opening file. Check filename" << endl; } + close(datasock); } -void Server::closeConn() { close(activeConn); } +void Server::closeConn() { + if (activeConn != -1) { + close(activeConn); + activeConn = -1; + } +} -Server::~Server() { +void Server::closeSocket() { if (server != -1) { + closeConn(); close(server); + server = -1; } } + +Server::~Server() { + closeSocket(); +} diff --git a/src/Peers/bootstrap_node.cpp b/src/Peers/bootstrap_node.cpp index 4e10a70f0..8994a7b9f 100644 --- a/src/Peers/bootstrap_node.cpp +++ b/src/Peers/bootstrap_node.cpp @@ -1,37 +1,40 @@ +#include + #include "../../include/Peers/bootstrap_node.h" #include "../../include/RequestResponse/discovery_request.h" #include "../../include/RequestResponse/discovery_response.h" #include "../../include/RequestResponse/message.h" #include "../../include/RequestResponse/registration.h" +#include "../../include/RequestResponse/registration_response.h" using namespace std; -BootstrapNode::BootstrapNode(const char* port, string uuid) : Peer(uuid) { - setupServer("127.0.0.1", port); +BootstrapNode::BootstrapNode(string uuid) : Peer(uuid) { + setupServer(getServerIpAddr()); } BootstrapNode::~BootstrapNode() {} -const char* BootstrapNode::getServerIpAddress() { -#if defined(NOLOCAL) - return "8.tcp.ngrok.io"; -#else - return "127.0.0.1"; -#endif -} - -const char* BootstrapNode::getServerPort() { +IpAddress BootstrapNode::getServerIpAddr() { #if defined(NOLOCAL) - return "12701"; + const char* host = getenv("BOOTSTRAP_HOST"); + const char* port = getenv("BOOTSTRAP_PORT"); + if (host == nullptr) { + cerr << "BOOTSTRAP_HOST not set in environment" << endl; + exit(1); + } + if (port == nullptr) { + port = "8080"; + } + return IpAddress(string(host), stoul(string(port))); #else - return "8080"; + return IpAddress("127.0.0.1", 8080); #endif } void BootstrapNode::registerPeer(const string& peerUuid, const IpAddress& peerIpAddr) { providerPeers[peerUuid] = peerIpAddr; - cout << "Registered peer " << peerUuid << " (" << peerIpAddr.host << ":" - << peerIpAddr.port << ")" << endl; + cout << "Registered peer " << peerUuid << " (" << peerIpAddr << ")" << endl; } AddressTable BootstrapNode::discoverPeers(const string& peerUuid, @@ -50,7 +53,9 @@ AddressTable BootstrapNode::discoverPeers(const string& peerUuid, void BootstrapNode::listen() { while (true) { cout << "Waiting for peer to connect..." << endl; - if (!server->acceptConn()) { + IpAddress senderClientIpAddr; + + if (!server->acceptConn(senderClientIpAddr)) { continue; } @@ -61,19 +66,30 @@ void BootstrapNode::listen() { continue; } - // process this request - string replyPrefix = "Bootstrap Node (" + uuid + ") - "; + // Deserialize request Message msg; msg.deserialize(serializedData); + + IpAddress senderServerIpAddr = senderClientIpAddr; + // Use server port specified in message (different from client port) + senderServerIpAddr.port = msg.getSenderIpAddr().port; + string senderUuid = msg.getSenderUuid(); - IpAddress senderIpAddr = msg.getSenderIpAddr(); shared_ptr payload = msg.getPayload(); + string replyPrefix = "Bootstrap Node (" + uuid + ") - "; switch (payload->getType()) { case Payload::Type::REGISTRATION: { server->replyToConn(replyPrefix + "received registration request"); - registerPeer(senderUuid, senderIpAddr); + cout << "received registration request from " << senderServerIpAddr.host << endl; + registerPeer(senderUuid, senderServerIpAddr); server->replyToConn("\nRegistration successful"); + // Create response + client->setupConn(senderServerIpAddr, "tcp"); + shared_ptr payload = + make_shared(senderServerIpAddr); + Message response(uuid, publicIp, payload);; + client->sendMsg(response.serialize()); break; } case Payload::Type::DISCOVERY_REQUEST: { @@ -86,11 +102,10 @@ void BootstrapNode::listen() { server->replyToConn("\nFound " + to_string(providers.size()) + " provider(s)"); // Create response - client->setupConn(senderIpAddr.host.c_str(), - to_string(senderIpAddr.port).c_str(), "tcp"); + client->setupConn(senderServerIpAddr, "tcp"); shared_ptr payload = - make_shared(providers); - Message response(uuid, IpAddress(host, port), payload); + make_shared(senderServerIpAddr, providers); + Message response(uuid, publicIp, payload); client->sendMsg(response.serialize()); break; } diff --git a/src/Peers/peer.cpp b/src/Peers/peer.cpp index b4fb1645e..68d535a21 100644 --- a/src/Peers/peer.cpp +++ b/src/Peers/peer.cpp @@ -4,7 +4,7 @@ using namespace std; -Peer::Peer() : host{nullptr}, port{nullptr}, server{nullptr}, client{nullptr} { +Peer::Peer() : server{nullptr}, client{nullptr} { // initialize client client = new Client(); @@ -13,22 +13,32 @@ Peer::Peer() : host{nullptr}, port{nullptr}, server{nullptr}, client{nullptr} { } Peer::Peer(const string& uuid) - : host{nullptr}, port{nullptr}, uuid{uuid}, server{nullptr}, - client{nullptr} { + : uuid{uuid}, server{nullptr}, client{nullptr} { // initialize client client = new Client(); } -void Peer::setupServer(const char* host, const char* port) { - this->host = host; - this->port = port; +void Peer::setPublicIp(const IpAddress& addr) { + publicIp = addr; +} + +void Peer::setupServer(const IpAddress& addr) { + setPublicIp(addr); const char* type = "tcp"; - // initialize server - server = new Server(host, port, type); + + // Tear down existing server + if (server != nullptr) { + delete server; + } + + // Initialize server + server = new Server(addr, type); server->setupServer(); } Peer::~Peer() { delete client; delete server; + client = nullptr; + server = nullptr; } diff --git a/src/Peers/provider.cpp b/src/Peers/provider.cpp index df0eaea08..dab9bca6d 100644 --- a/src/Peers/provider.cpp +++ b/src/Peers/provider.cpp @@ -3,19 +3,14 @@ #include "../../include/RequestResponse/acknowledgement.h" #include "../../include/RequestResponse/message.h" #include "../../include/RequestResponse/registration.h" +#include "../../include/RequestResponse/registration_response.h" #include "../../include/RequestResponse/task_request.h" #include "../../include/utility.h" #include "proto/payload.pb.h" -#include -#include -#include -#include -#include - using namespace std; -Provider::Provider(const char* port, string uuid) +Provider::Provider(unsigned short port, string uuid) : Peer(uuid), ml_zmq_sender(), ml_zmq_receiver(), aggregator_zmq_sender(), aggregator_zmq_receiver() { isBusy = false; @@ -26,24 +21,46 @@ Provider::Provider(const char* port, string uuid) cout << "Aggregator ZMQ: Sender: " << aggregator_zmq_sender.getAddress() << ", Receiver: " << aggregator_zmq_receiver.getAddress() << endl; - setupServer("127.0.0.1", port); + setupServer(IpAddress("127.0.0.1", port)); } Provider::~Provider() noexcept {} void Provider::registerWithBootstrap() { - const char* bootstrapHost = BootstrapNode::getServerIpAddress(); - const char* bootstrapPort = BootstrapNode::getServerPort(); - cout << "Connecting to bootstrap node at " << bootstrapHost << ":" - << bootstrapPort << endl; - if (client->setupConn(bootstrapHost, bootstrapPort, "tcp") == -1) { + IpAddress bootstrapIp = BootstrapNode::getServerIpAddr(); + cout << "Connecting to bootstrap node at " << bootstrapIp << endl; + if (client->setupConn(bootstrapIp, "tcp") == -1) { cerr << "Unable to connect to boostrap node" << endl; exit(1); } shared_ptr payload = make_shared(); - Message msg(uuid, IpAddress(host, port), payload); + Message msg(uuid, publicIp, payload); client->sendMsg(msg.serialize(), -1); + + // Get own public address from bootstrap node + while (!server->acceptConn()) + ; + + string registrationRespStr; + if (server->receiveFromConn(registrationRespStr) == 1) { + cerr << "Failed to receive registration response" << endl; + server->closeConn(); + exit(1); + } + + Message respMsg; + respMsg.deserialize(registrationRespStr); + shared_ptr respPayload = respMsg.getPayload(); + if (respPayload->getType() == Payload::Type::REGISTRATION_RESPONSE) { + shared_ptr rr = + static_pointer_cast(respPayload); + IpAddress publicIp = rr->getCallerPublicIpAddress(); + cout << "Public IP = " << publicIp << endl; + server->replyToConn("Obtained public ip address"); + setPublicIp(publicIp); + } + server->closeConn(); } void Provider::listen() { @@ -55,14 +72,15 @@ void Provider::listen() { // receive task request object from client string requesterData; - if(server->receiveFromConn(requesterData) == 1) { + if (server->receiveFromConn(requesterData) == 1) { server->closeConn(); continue; } Message requesterMsg; requesterMsg.deserialize(requesterData); - if (requesterMsg.getPayload()->getType() != Payload::Type::TASK_REQUEST) { + if (requesterMsg.getPayload()->getType() != + Payload::Type::TASK_REQUEST) { server->closeConn(); continue; } @@ -89,12 +107,18 @@ void Provider::listen() { // Download and parse training data index file from requester cout << "FTP: requesting index: " << taskRequest->getTrainingDataIndexFilename() << endl; + // Note: index files are downloaded to TARGET_DATA_DIR + server->getFileIntoDirFTP(taskRequest->getTrainingDataIndexFilename(), + TARGET_DATA_DIR); // bug here where we are saving the file to the same file - // server->getFileFTP(taskRequest->getTrainingDataIndexFilename()); - + // fix in PR. RN, this will not work if multiple machines. + ingestTrainingData(); server->closeConn(); + // initialize ML.py with metadata + workloadThread = new thread(&Provider::initializeWorkloadToML, this); + if (taskRequest->getLeaderUuid() == uuid) { leaderHandleTaskRequest(requesterIpAddr); } else { @@ -104,54 +128,102 @@ void Provider::listen() { } void Provider::leaderHandleTaskRequest(const IpAddress& requesterIpAddr) { - // Run processWorkload() in a separate thread - thread workloadThread(&Provider::processWorkload, this); - - vector followerData{}; - while (followerData.size() < taskRequest->getAssignedWorkers().size() - 1) { - cout << "\nWaiting for follower peer to connect..." << endl; - while (!server->acceptConn()); - - // get data from followers and aggregate - string followerMsgStr; - if (server->receiveFromConn(followerMsgStr) == 1) { - cerr << "Failed to receive data from follower" << endl; + TaskResponse aggregatedResults = TaskResponse(); + + for (unsigned int i = 0; i < taskRequest->getNumEpochs(); i++) { + vector> followerData{}; + while (followerData.size() < + taskRequest->getAssignedWorkers().size() - + 1 // TODO: this needs to be changed when we change to + // multiple aggr. per epoch. There is an edge case where + // different followers could have different aggr. cycles + // required. So we can't wait for all followers to convene. + ) { + + cout << "\nWaiting for follower peer to connect..." << endl; + while (!server->acceptConn()) + ; + + // get data from followers and aggregate + string followerMsgStr; + if (server->receiveFromConn(followerMsgStr) == 1) { + cerr << "Failed to receive data from follower" << endl; + server->closeConn(); + continue; + } + + Message followerMsg; + followerMsg.deserialize(followerMsgStr); + shared_ptr followerPayload = followerMsg.getPayload(); + + if (followerPayload->getType() != Payload::Type::TASK_RESPONSE) { + server->closeConn(); + continue; + } + + // append to followerData + shared_ptr taskResp = + static_pointer_cast(followerPayload); + followerData.push_back(taskResp); + + server->replyToConn("Received follower result."); server->closeConn(); - continue; } - Message followerMsg; - followerMsg.deserialize(followerMsgStr); - shared_ptr followerPayload = followerMsg.getPayload(); + cout << endl; - if (followerPayload->getType() != Payload::Type::TASK_RESPONSE) { - server->closeConn(); - continue; - } + workloadThread->join(); + delete workloadThread; - shared_ptr taskResp = - static_pointer_cast(followerPayload); + // Aggregate model parameters and send to aggregator script + aggregatedResults = aggregateResults(followerData); - // append to followerData - followerData.push_back(taskResp->getTrainingData()); - server->replyToConn("Received follower result."); - server->closeConn(); - } + // if final cycle, we send back to requester + if (aggregatedResults.getTrainingIsComplete()) { + break; + } - cout << endl; - workloadThread.join(); - TaskResponse aggregatedResults = aggregateResults(followerData); + // Send/Process aggregated results to follower peers + cout << "Sending aggregated results to followers..." << endl; + for (const auto& follower : taskRequest->getAssignedWorkers()) { + if (follower.first == uuid) { + // if the follower is the leader, skip sending + + // TODO: send the aggregated state dict to ML.py + // Create payload containing a ModelStateDictParams message + currentAggregatedModelStateDict = + aggregatedResults.getTrainingData(); + workloadThread = new thread(&Provider::processWorkload, this); + continue; + } + + IpAddress followerIp = follower.second; + while (client->setupConn(followerIp, "tcp") != 0) { + cout << "Failed to connect to follower to send aggregated " + "results, trying again in 5s" + << endl; + sleep(5); + } + + shared_ptr payload = + make_shared(aggregatedResults); + Message msg(uuid, publicIp, payload); + int code = client->sendMsg(msg.serialize(), -1); + cout << "Leader sent data to follower with code " << code << endl; + } + } // Send results back to requester // TODO: requester IP address could change shared_ptr aggregatePayload = make_shared(aggregatedResults); - Message aggregateResultMsg(uuid, IpAddress(host, port), aggregatePayload); + Message aggregateResultMsg(uuid, publicIp, aggregatePayload); // busy wait until connection is established while (client->setupConn(requesterIpAddr, "tcp") != 0) { constexpr int retry = 5; - cout << "Failed to connect to requester server, trying again in " << retry << "s" << endl; + cout << "Failed to connect to requester server, trying again in " + << retry << "s" << endl; sleep(retry); } cout << "Connected to requester server" << endl; @@ -162,7 +234,8 @@ void Provider::leaderHandleTaskRequest(const IpAddress& requesterIpAddr) { } cout << "Sent results to requester. Waiting for acknowledgement" << endl; - while (!server->acceptConn()); + while (!server->acceptConn()) + ; while (true) { // receive response from requester @@ -175,55 +248,122 @@ void Provider::leaderHandleTaskRequest(const IpAddress& requesterIpAddr) { cout << "Acknowledgement received from requester" << endl; break; } + + server->replyToConn("Received acknowledgement."); } server->closeConn(); } void Provider::followerHandleTaskRequest() { - processWorkload(); - cout << "Waiting for connection back to leader" << endl; - IpAddress leaderIp = - taskRequest->getAssignedWorkers()[taskRequest->getLeaderUuid()]; - // busy wait until connection is established with the leader - while (client->setupConn(leaderIp, "tcp") == -1) { - sleep(5); - } + workloadThread->join(); + delete workloadThread; + workloadThread = nullptr; + + for (unsigned int i = 0; i < taskRequest->getNumEpochs(); i++) { + // run one aggregation cycle of training + cout << "Waiting for connection back to leader" << endl; + IpAddress leaderIp = + taskRequest->getAssignedWorkers()[taskRequest->getLeaderUuid()]; + // busy wait until connection is established with the leader + while (client->setupConn(leaderIp, "tcp") == -1) { + sleep(5); + } + + // send results back to leader + shared_ptr payload = std::move(taskResponse); + Message msg(uuid, publicIp, payload); + int code = client->sendMsg(msg.serialize(), -1); + cout << "Follower sent data to leader with code " << code << endl; + + if (i != taskRequest->getNumEpochs() - 1) { + // TODO: change when we do multiple aggr. cycles in 1 epoch + // wait for leader to send model state dict param + while (!server->acceptConn()) + ; + // receive aggregated TaskResponse object + string leaderMsgStr; + if (server->receiveFromConn(leaderMsgStr) == 1) { + cerr << "Failed to receive aggregated TaskResponse object from " + "leader" + << endl; + server->closeConn(); + continue; + } + + Message leaderMsg; + leaderMsg.deserialize(leaderMsgStr); + shared_ptr leaderPayload = leaderMsg.getPayload(); + + if (leaderPayload->getType() != Payload::Type::TASK_RESPONSE) { + server->closeConn(); + cerr << "Received payload is not of type TASK_RESPONSE" << endl; + continue; + } + + shared_ptr taskResp = + static_pointer_cast(leaderPayload); + taskResponse = make_shared(std::move(*taskResp)); + server->replyToConn("Received leader result."); + server->closeConn(); + + currentAggregatedModelStateDict = taskResponse->getTrainingData(); - // send results back to leader - shared_ptr payload = std::move(taskResponse); - Message msg(uuid, IpAddress(host, port), payload); - int code = client->sendMsg(msg.serialize(), -1); - cout << "Follower sent data to leader with code " << code << endl; + // forward model state dict param to ml + processWorkload(); + } + } } -string Provider::ingestTrainingData() { +void Provider::ingestTrainingData() { + server->getFileIntoDirFTP(taskRequest->getTrainingDataIndexFilename(), TARGET_DATA_DIR); + string trainingDataIndexFile = taskRequest->getTrainingDataIndexFilename(); - vector requiredTrainingFiles = taskRequest->getTrainingDataFiles(); + vector requiredTrainingFiles = + taskRequest->getTrainingDataFiles(TARGET_DATA_DIR); cout << "Task requires " << requiredTrainingFiles.size() << " training files" << endl; // Ensure all required training data files are present for (const string& filename : requiredTrainingFiles) { - if (!isFileWithinDataDirectory(filename)) { + if (!isFileWithinDirectory(filename, TARGET_DATA_DIR)) { cout << "FTP: requesting " << filename << endl; - server->getFileFTP(filename); + server->getFileIntoDirFTP(filename, TARGET_DATA_DIR); } } cout << "All training files are now present" << endl; - - // Temporarily just send index file to python worker - return trainingDataIndexFile; } void Provider::processWorkload() { - string indexFile = ingestTrainingData(); + // send the current model state dict to ML.py + payload::ModelStateDictParams modelStateDictParamsProto; + modelStateDictParamsProto.set_modelstatedict( + currentAggregatedModelStateDict); + string serialized_modelStateDictParamsProto; + modelStateDictParamsProto.SerializeToString( + &serialized_modelStateDictParamsProto); + // TODO: send model state dict param to ML.py + ml_zmq_sender.send(serialized_modelStateDictParamsProto); - // send training data index file to the worker - cout << "Working on training data index file: " << indexFile << endl; + cout << "Waiting for processed data..." << endl; + auto rcvdData = ml_zmq_receiver.receive(); + cout << "Received processed data" << endl; + + // Parse received task response proto and populate TaskResponse object + payload::TaskResponse task_response_proto; + task_response_proto.ParseFromString(rcvdData); + taskResponse = + make_shared(task_response_proto.modelstatedict(), + task_response_proto.trainingiscomplete()); + cout << "Completed assigned workload" << endl; +} +void Provider::initializeWorkloadToML() { + // send training data index file to the worker cout << "Sending training data index file to worker..." << endl; payload::TrainingData training_data_proto; - training_data_proto.set_training_data_index_filename(indexFile); + training_data_proto.set_training_data_index_filename( + taskRequest->getTrainingDataIndexFilename()); + training_data_proto.set_numepochs(taskRequest->getNumEpochs()); string serialized_training_data; training_data_proto.SerializeToString(&serialized_training_data); ml_zmq_sender.send(serialized_training_data); @@ -235,27 +375,31 @@ void Provider::processWorkload() { // Parse received task response proto and populate TaskResponse object payload::TaskResponse task_response_proto; task_response_proto.ParseFromString(rcvdData); - taskResponse = make_unique(task_response_proto.modelstatedict()); + taskResponse = + make_shared(task_response_proto.modelstatedict(), + task_response_proto.trainingiscomplete()); cout << "Completed assigned workload" << endl; } -TaskResponse Provider::aggregateResults(vector followerData) { +TaskResponse +Provider::aggregateResults(vector> followerData) { cout << "Aggregating results..." << endl; - // append leader's data to followerData - string curr_data = taskResponse->getTrainingData(); - followerData.push_back(curr_data); + // append leader's data to back of followerData + followerData.push_back(taskResponse); // send each follower's data to the aggregator cout << "Sending data to aggregator..." << endl; - for (int i = 0; i < followerData.size(); i++) { + for (unsigned int i = 0; i < followerData.size(); i++) { cout << "Aggregator for loop: " << i << endl; - payload::AggregatorInputData* proto = - new payload::AggregatorInputData(); - proto->set_modelstatedict(followerData[i]); + payload::ModelStateDictParams* proto = + new payload::ModelStateDictParams(); + proto->set_modelstatedict(followerData[i]->getTrainingData()); + proto->set_trainingiscomplete(followerData[i]->getTrainingIsComplete()); string serialized; proto->SerializeToString(&serialized); aggregator_zmq_sender.send(serialized); + delete proto; } // receive aggregated data from the aggregator @@ -265,5 +409,8 @@ TaskResponse Provider::aggregateResults(vector followerData) { payload::TaskResponse aggTaskResponseProto; aggTaskResponseProto.ParseFromString(rcvdData); - return TaskResponse(aggTaskResponseProto.modelstatedict()); + TaskResponse aggTaskResponse = + TaskResponse(aggTaskResponseProto.modelstatedict(), + aggTaskResponseProto.trainingiscomplete()); + return aggTaskResponse; } diff --git a/src/Peers/requester.cpp b/src/Peers/requester.cpp index 8fade3e19..34af250b3 100644 --- a/src/Peers/requester.cpp +++ b/src/Peers/requester.cpp @@ -12,30 +12,29 @@ using namespace std; -Requester::Requester(const char* port) : Peer() { - setupServer("127.0.0.1", port); +Requester::Requester(unsigned short port) : Peer() { + setupServer(IpAddress("127.0.0.1", port)); } Requester::~Requester() noexcept {} void Requester::sendDiscoveryRequest(unsigned int numProviders) { - const char* bootstrapHost = BootstrapNode::getServerIpAddress(); - const char* bootstrapPort = BootstrapNode::getServerPort(); - cout << "Connecting to bootstrap node at " << bootstrapHost << ":" - << bootstrapPort << endl; - if (client->setupConn(bootstrapHost, bootstrapPort, "tcp") == -1) { + IpAddress bootstrapIp = BootstrapNode::getServerIpAddr(); + cout << "Connecting to bootstrap node at " << bootstrapIp << endl; + if (client->setupConn(bootstrapIp, "tcp") == -1) { cerr << "Unable to connect to boostrap node" << endl; exit(1); } shared_ptr payload = make_shared(numProviders); - Message msg(uuid, IpAddress(host, port), payload); + Message msg(uuid, publicIp, payload); client->sendMsg(msg.serialize(), -1); } void Requester::waitForDiscoveryResponse() { cout << "Waiting for discovery response..." << endl; - while (!server->acceptConn()); + while (!server->acceptConn()) + ; // receive response from bootstrap (or possibly another peer) string serializedData; @@ -54,6 +53,10 @@ void Requester::waitForDiscoveryResponse() { if (payload->getType() == Payload::Type::DISCOVERY_RESPONSE) { shared_ptr dr = static_pointer_cast(payload); + IpAddress publicIp = dr->getCallerPublicIpAddress(); + cout << "Public Ip = " << publicIp << endl; + setPublicIp(publicIp); + AddressTable availablePeers = dr->getAvailablePeers(); for (auto& it : availablePeers) { providerPeers[it.first] = it.second; @@ -75,7 +78,8 @@ void Requester::divideTask() { TaskRequest queuedTask = taskRequests.front(); // obtain list of training files - vector trainingFiles = queuedTask.getTrainingDataFiles(); + vector trainingFiles = + queuedTask.getTrainingDataFiles(SOURCE_DATA_DIR); cout << "Found " << trainingFiles.size() << " training files" << endl; // // shuffle training files for even distribution after division @@ -85,6 +89,7 @@ void Requester::divideTask() { // divide the vector into subvectors int numSubtasks = queuedTask.getNumWorkers(); + int numEpochs = queuedTask.getNumEpochs(); int subtaskSize = trainingFiles.size() / numSubtasks; int remainder = trainingFiles.size() % numSubtasks; @@ -109,9 +114,11 @@ void Requester::divideTask() { subtaskTrainingFiles.push_back(trainingFiles[i * subtaskSize + j]); } - // Build a path by combining the filename and DATA_DIR using path joins + // Build a path by combining the filename and SOURCE_DATA_DIR + // using path joins string filename = "subtaskIndex_" + std::to_string(i) + ".txt"; - TaskRequest subtaskRequest(1, filename, TaskRequest::INDEX_FILENAME); + TaskRequest subtaskRequest(1, filename, numEpochs, + TaskRequest::INDEX_FILENAME); subtaskRequest.writeToTrainingDataIndexFile(subtaskTrainingFiles); cout << "FTP: Created index file " << subtaskRequest.getTrainingDataIndexFilename() << endl; @@ -154,12 +161,11 @@ void Requester::sendTaskRequest() { // package and serialize the requests shared_ptr payload = make_shared(taskRequests[ctr]); - Message msg(uuid, IpAddress(host, port), payload); + Message msg(uuid, publicIp, payload); // set up the client - const char* host = worker.second.host.c_str(); - const char* port = to_string(worker.second.port).c_str(); - client->setupConn(host, port, "tcp"); + IpAddress workerIp = worker.second; + client->setupConn(workerIp, "tcp"); // send the request client->sendMsg(msg.serialize(), -1); @@ -173,7 +179,8 @@ TaskResponse Requester::getResults() { TaskResponse taskResult; // busy wait until connection is established cout << "Waiting for leader peer to connect" << endl; - while (!server->acceptConn()); + while (!server->acceptConn()) + ; // get data from workers and aggregate cout << "Waiting for leader peer to send results" << endl; @@ -192,9 +199,22 @@ TaskResponse Requester::getResults() { // send success acknowledgement to provider shared_ptr payload = make_shared(); - Message response(uuid, IpAddress(host, port), payload); - client->setupConn(leaderIpAddr, "tcp"); - client->sendMsg(response.serialize()); + Message response(uuid, publicIp, payload); + cout << "setting up connection" << endl; + cout << "leaderIpAddr: " << leaderIpAddr << endl; + + while (client->setupConn(leaderIpAddr, "tcp") != 0) { + constexpr int retry = 5; + cout << "Failed to connect to leader peer, trying again in " + << retry << "s" << endl; + sleep(retry); + } + + + if (client->sendMsg(response.serialize(), 5) == 1) { + cerr << "Failed to send aggregated result to leader peer" << endl; + exit(1); + } return taskResult; } diff --git a/src/RequestResponse/discovery_response.cpp b/src/RequestResponse/discovery_response.cpp index c3132e9d2..1d7335bed 100644 --- a/src/RequestResponse/discovery_response.cpp +++ b/src/RequestResponse/discovery_response.cpp @@ -3,8 +3,14 @@ using namespace std; DiscoveryResponse::DiscoveryResponse() : Payload(Type::DISCOVERY_RESPONSE) {} -DiscoveryResponse::DiscoveryResponse(const AddressTable& availablePeers) - : Payload(Type::DISCOVERY_RESPONSE), availablePeers{availablePeers} {} +DiscoveryResponse::DiscoveryResponse(const IpAddress& callerAddr, + const AddressTable& availablePeers) + : Payload(Type::DISCOVERY_RESPONSE), callerAddr{callerAddr}, + availablePeers{availablePeers} {} + +IpAddress DiscoveryResponse::getCallerPublicIpAddress() const { + return callerAddr; +} AddressTable DiscoveryResponse::getAvailablePeers() const { return availablePeers; @@ -12,6 +18,10 @@ AddressTable DiscoveryResponse::getAvailablePeers() const { google::protobuf::Message* DiscoveryResponse::serializeToProto() const { payload::DiscoveryResponse* proto = new payload::DiscoveryResponse(); + utility::IpAddress* protoIp = new utility::IpAddress(); + protoIp->set_ip(callerAddr.host); + protoIp->set_port(callerAddr.port); + proto->set_allocated_calleraddr(protoIp); utility::AddressTable* atProto = serializeAddressTable(availablePeers); proto->set_allocated_availablepeers(atProto); @@ -20,9 +30,13 @@ google::protobuf::Message* DiscoveryResponse::serializeToProto() const { void DiscoveryResponse::deserializeFromProto( const google::protobuf::Message& protoMessage) { - const payload::DiscoveryResponse& discoveryResponseProto = dynamic_cast(protoMessage); + + const utility::IpAddress& protoIp = discoveryResponseProto.calleraddr(); + callerAddr.host = protoIp.ip(); + callerAddr.port = protoIp.port(); + const utility::AddressTable& addressTableProto = discoveryResponseProto.availablepeers(); diff --git a/src/RequestResponse/message.cpp b/src/RequestResponse/message.cpp index 032a5e7b1..9b4253c7d 100644 --- a/src/RequestResponse/message.cpp +++ b/src/RequestResponse/message.cpp @@ -5,8 +5,10 @@ #include "../../include/RequestResponse/discovery_response.h" #include "../../include/RequestResponse/message.h" #include "../../include/RequestResponse/registration.h" +#include "../../include/RequestResponse/registration_response.h" #include "../../include/RequestResponse/task_request.h" #include "../../include/RequestResponse/task_response.h" +#include "../../include/RequestResponse/model_state_dict_params.h" #include "../../include/utility.h" using namespace std; @@ -38,6 +40,8 @@ void Message::initializePayload(const string& payloadTypeStr) { payload = make_shared(); } else if (payloadTypeStr == "REGISTRATION") { payload = make_shared(); + } else if (payloadTypeStr == "REGISTRATION_RESPONSE") { + payload = make_shared(); } else if (payloadTypeStr == "DISCOVERY_REQUEST") { payload = make_shared(); } else if (payloadTypeStr == "DISCOVERY_RESPONSE") { @@ -46,6 +50,8 @@ void Message::initializePayload(const string& payloadTypeStr) { payload = make_shared(); } else if (payloadTypeStr == "TASK_RESPONSE") { payload = make_shared(); + } else if (payloadTypeStr == "MODEL_STATE_DICT_PARAMS") { + payload = make_shared(); } else { cerr << "Unknown type " << payloadTypeStr << endl; } @@ -90,6 +96,16 @@ string Message::serialize() const { messageProto.set_allocated_registration(registrationProto); break; } + case Payload::Type::REGISTRATION_RESPONSE: { + messageProto.set_payloadtype(payload::PayloadType::REGISTRATION_RESPONSE); + shared_ptr registrationResponse = + getPayloadAs(); + payload::RegistrationResponse* rrProto = + static_cast( + registrationResponse->serializeToProto()); + messageProto.set_allocated_registrationresponse(rrProto); + break; + } case Payload::Type::DISCOVERY_REQUEST: { messageProto.set_payloadtype(payload::PayloadType::DISCOVERY_REQUEST); shared_ptr discoveryReq = @@ -124,6 +140,14 @@ string Message::serialize() const { messageProto.set_allocated_taskresponse(taskRespProto); break; } + case Payload::Type::MODEL_STATE_DICT_PARAMS: { + messageProto.set_payloadtype(payload::PayloadType::MODEL_STATE_DICT_PARAMS); + shared_ptr modelParams = getPayloadAs(); + auto* modelParamsProto = + static_cast(modelParams->serializeToProto()); + messageProto.set_allocated_modelstatedictparams(modelParamsProto); + break; + } default: cerr << "Unknown type" << endl; exit(EXIT_FAILURE); @@ -148,7 +172,7 @@ void Message::deserialize(const string& serializedData) { deserializeIpAddressFromProto(messageProto.senderipaddress()); cout << "Deserializing message with id: " << uuid << " from " << senderUuid - << " at " << senderIpAddr.host << ":" << senderIpAddr.port << endl; + << " at " << senderIpAddr << endl; string payloadType = payload::PayloadType_Name(messageProto.payloadtype()); initializePayload(payloadType); @@ -175,10 +199,19 @@ void Message::deserialize(const string& serializedData) { payload::PayloadType::REGISTRATION) { const payload::Registration& registration = messageProto.registration(); payload->deserializeFromProto(registration); + } else if (messageProto.payloadtype() == + payload::PayloadType::REGISTRATION_RESPONSE) { + const payload::RegistrationResponse& registrationResponse = + messageProto.registrationresponse(); + payload->deserializeFromProto(registrationResponse); } else if (messageProto.payloadtype() == payload::PayloadType::ACKNOWLEDGEMENT) { const payload::Acknowledgement& ack = messageProto.acknowledgement(); payload->deserializeFromProto(ack); + } else if (messageProto.payloadtype() == + payload::PayloadType::MODEL_STATE_DICT_PARAMS) { + const payload::ModelStateDictParams& modelParams = messageProto.modelstatedictparams(); + payload->deserializeFromProto(modelParams); } else { cerr << "Unknown or unsupported payload type for deserialization" << endl; diff --git a/src/RequestResponse/model_state_dict_params.cpp b/src/RequestResponse/model_state_dict_params.cpp new file mode 100644 index 000000000..f3f636790 --- /dev/null +++ b/src/RequestResponse/model_state_dict_params.cpp @@ -0,0 +1,40 @@ +#include "../../include/RequestResponse/model_state_dict_params.h" + +using namespace std; + +ModelStateDictParams::ModelStateDictParams() + : Payload(Type::MODEL_STATE_DICT_PARAMS) {} + +ModelStateDictParams::ModelStateDictParams(const string& modelStateDict) + : Payload(Type::MODEL_STATE_DICT_PARAMS), modelStateDict{modelStateDict} {} + +string ModelStateDictParams::getTrainingData() const { return modelStateDict; } + +void ModelStateDictParams::setTrainingData(const string& modelStateDict) { + this->modelStateDict = modelStateDict; +} + +bool ModelStateDictParams::getTrainingIsComplete() const { + return trainingIsComplete; +} + +void ModelStateDictParams::setTrainingIsComplete(bool trainingIsComplete) { + this->trainingIsComplete = trainingIsComplete; +} + +google::protobuf::Message* ModelStateDictParams::serializeToProto() const { + payload::ModelStateDictParams* proto = new payload::ModelStateDictParams(); + proto->set_modelstatedict(modelStateDict); + proto->set_trainingiscomplete(trainingIsComplete); + return proto; +} + +void ModelStateDictParams::deserializeFromProto( + const google::protobuf::Message& protoMessage) { + + const payload::ModelStateDictParams& proto = + dynamic_cast(protoMessage); + + modelStateDict = proto.modelstatedict(); + trainingIsComplete = proto.trainingiscomplete(); +} diff --git a/src/RequestResponse/registration_response.cpp b/src/RequestResponse/registration_response.cpp new file mode 100644 index 000000000..acb869308 --- /dev/null +++ b/src/RequestResponse/registration_response.cpp @@ -0,0 +1,30 @@ +#include "../../include/RequestResponse/registration_response.h" + +using namespace std; + +RegistrationResponse::RegistrationResponse() : Payload(Type::REGISTRATION_RESPONSE) {} +RegistrationResponse::RegistrationResponse(const IpAddress& callerAddr) + : Payload(Type::REGISTRATION_RESPONSE), callerAddr{callerAddr} {} + +IpAddress RegistrationResponse::getCallerPublicIpAddress() const { + return callerAddr; +} + +google::protobuf::Message* RegistrationResponse::serializeToProto() const { + payload::RegistrationResponse* proto = new payload::RegistrationResponse(); + utility::IpAddress* protoIp = new utility::IpAddress(); + protoIp->set_ip(callerAddr.host); + protoIp->set_port(callerAddr.port); + proto->set_allocated_calleraddr(protoIp); + + return proto; +} + +void RegistrationResponse::deserializeFromProto( + const google::protobuf::Message& protoMessage) { + const payload::RegistrationResponse& registrationResponseProto = + dynamic_cast(protoMessage); + const utility::IpAddress& protoIp = registrationResponseProto.calleraddr(); + callerAddr.host = protoIp.ip(); + callerAddr.port = protoIp.port(); +} diff --git a/src/RequestResponse/task_request.cpp b/src/RequestResponse/task_request.cpp index 3a5c59278..a90efa713 100644 --- a/src/RequestResponse/task_request.cpp +++ b/src/RequestResponse/task_request.cpp @@ -6,8 +6,8 @@ TaskRequest::TaskRequest() : Payload(Type::TASK_REQUEST), taskRequestType{NONE} {} TaskRequest::TaskRequest(const unsigned int numWorkers, const std::string& data, - TaskRequestType type) - : Payload(Type::TASK_REQUEST), numWorkers{numWorkers}, + const unsigned int numEpochs, TaskRequestType type) + : Payload(Type::TASK_REQUEST), numWorkers{numWorkers}, numEpochs{numEpochs}, taskRequestType{type} { if (type == GLOB_PATTERN) { globPattern = data; @@ -36,9 +36,16 @@ void TaskRequest::setTrainingDataIndexFilename(const std::string& filename) { globPattern.clear(); } +void TaskRequest::setNumEpochs(const unsigned int numEpochs) { + this->numEpochs = numEpochs; +} + void TaskRequest::writeToTrainingDataIndexFile( const vector& trainingDataFiles) const { - fs::path indexFilePath = resolveDataFile(trainingDataIndexFilename); + fs::path indexFilePath = resolveDataFileInDirectory(trainingDataIndexFilename, SOURCE_DATA_DIR); + + // Create directory if it doesn't exist + fs::create_directories(indexFilePath.parent_path()); ofstream indexFile(indexFilePath, std::ios::out | std::ios::trunc); if (!indexFile) { @@ -89,14 +96,15 @@ std::string TaskRequest::getTrainingDataIndexFilename() const { return (taskRequestType == INDEX_FILENAME) ? trainingDataIndexFilename : ""; } -vector TaskRequest::getTrainingDataFiles() const { +vector TaskRequest::getTrainingDataFiles(string dir) const { vector trainingDataFiles; if (taskRequestType == GLOB_PATTERN) { regex pattern = convertToRegexPattern(globPattern); - trainingDataFiles = getMatchingDataFiles(pattern, DATA_DIR); + trainingDataFiles = getMatchingDataFiles(pattern, dir); } else if (taskRequestType == INDEX_FILENAME) { - ifstream indexFile(resolveDataFile(trainingDataIndexFilename)); + ifstream indexFile( + resolveDataFileInDirectory(trainingDataIndexFilename, dir)); if (indexFile.is_open()) { string line; while (getline(indexFile, line)) { @@ -109,10 +117,13 @@ vector TaskRequest::getTrainingDataFiles() const { return trainingDataFiles; } +unsigned int TaskRequest::getNumEpochs() const { return numEpochs; } + google::protobuf::Message* TaskRequest::serializeToProto() const { payload::TaskRequest* proto = new payload::TaskRequest(); proto->set_numworkers(numWorkers); proto->set_leaderuuid(leaderUuid); + proto->set_numepochs(numEpochs); if (taskRequestType == GLOB_PATTERN) { proto->set_glob_pattern(globPattern); @@ -134,6 +145,7 @@ void TaskRequest::deserializeFromProto( numWorkers = proto.numworkers(); leaderUuid = proto.leaderuuid(); + numEpochs = proto.numepochs(); if (proto.has_glob_pattern()) { setGlobPattern(proto.glob_pattern()); diff --git a/src/RequestResponse/task_response.cpp b/src/RequestResponse/task_response.cpp index f47714853..d8a14d7b2 100644 --- a/src/RequestResponse/task_response.cpp +++ b/src/RequestResponse/task_response.cpp @@ -4,8 +4,10 @@ using namespace std; TaskResponse::TaskResponse() : Payload(Type::TASK_RESPONSE) {} -TaskResponse::TaskResponse(const string& modelStateDict) - : Payload(Type::TASK_RESPONSE), modelStateDict{modelStateDict} {} +TaskResponse::TaskResponse(const string& modelStateDict, + bool trainingIsComplete) + : Payload(Type::TASK_RESPONSE), modelStateDict{modelStateDict}, + trainingIsComplete{trainingIsComplete} {} string TaskResponse::getTrainingData() const { return modelStateDict; } @@ -13,9 +15,16 @@ void TaskResponse::setTrainingData(const string& modelStateDict) { this->modelStateDict = modelStateDict; } +bool TaskResponse::getTrainingIsComplete() const { return trainingIsComplete; } + +void TaskResponse::setTrainingIsComplete(bool trainingIsComplete) { + this->trainingIsComplete = trainingIsComplete; +} + google::protobuf::Message* TaskResponse::serializeToProto() const { payload::TaskResponse* proto = new payload::TaskResponse(); proto->set_modelstatedict(modelStateDict); + proto->set_trainingiscomplete(trainingIsComplete); return proto; } @@ -26,4 +35,5 @@ void TaskResponse::deserializeFromProto( dynamic_cast(protoMessage); modelStateDict = proto.modelstatedict(); + trainingIsComplete = proto.trainingiscomplete(); } diff --git a/src/utility.cpp b/src/utility.cpp index b6d999e06..17823f097 100644 --- a/src/utility.cpp +++ b/src/utility.cpp @@ -5,12 +5,6 @@ using namespace std; IpAddress::IpAddress(const string& host, const unsigned short port) : host(host), port(port) {} -IpAddress::IpAddress(const char* host, const char* port) { - this->host = host; - string portStr(port); - this->port = stoi(portStr); -} - // converts an IpAddress object to a utility::IpAddress proto object utility::IpAddress* serializeIpAddressToProto(const IpAddress& ipAddress) { utility::IpAddress* j = new utility::IpAddress(); @@ -29,6 +23,11 @@ IpAddress deserializeIpAddressFromProto(const utility::IpAddress& proto) { return ip; } +// overloaded output operator for IpAddress +std::ostream& operator<<(std::ostream& os, const IpAddress& ip) { + return os << ip.host << ":" << ip.port; +} + // converts a AddressTable object to a utility::AddressTable proto object utility::AddressTable* serializeAddressTable(const AddressTable& addressTable) { utility::AddressTable* at = new utility::AddressTable(); @@ -81,44 +80,6 @@ string uuid::generate_uuid_v4() { return ss.str(); } -string startNgrokForwarding(unsigned short port) { - const string command = - "python3 ./src/Networking/ngrok_ip.py " + to_string(port); - - // start ngrok - string ngrok_restart = - "./src/Networking/ngrok_restart.sh " + to_string(port) + " &"; - system(ngrok_restart.c_str()); - - // Open a pipe to capture the output - FILE* pipe = popen(command.c_str(), "r"); - if (!pipe) { - cerr << "Error opening pipe." << endl; - return ""; - } - - // read - char buffer[256]; - string result = ""; - // while (fgets(buffer, 256, pipe) != nullptr) { - // result += buffer; - // } - - // // Close the pipe - // pclose(pipe); - // return result; - while (!feof(pipe)) { - if (fgets(buffer, 256, pipe) != nullptr) - result += buffer; - } - - // Close the pipe - pclose(pipe); - return result; -} - -void close_ngrok_forwarding() { system("pkill ngrok"); } - string vectorToString(vector vec) { stringstream ss; for (unsigned int i = 0; i < vec.size(); i++) { @@ -208,22 +169,31 @@ int FTP_accept_conn(int sock) { } fs::path resolveDataFile(const std::string filename) { - std::string resolvedFilename = DATA_DIR + "/" + filename; + return resolveDataFileInDirectory(filename, SOURCE_DATA_DIR); +} + +fs::path resolveDataFileInDirectory(const std::string filename, + const std::string dir) { + + fs::path dirPath(dir); + if (!fs::exists(dirPath)) { + fs::create_directories(dirPath); + } + + std::string resolvedFilename = dir + "/" + filename; return fs::path(resolvedFilename); } -bool isFileWithinDataDirectory(const std::string& filename) { - std::regex cloudmeshDataPattern(".*cloudmesh/" + std::string(DATA_DIR) + - ".*", +bool isFileWithinDirectory(const std::string& filename, const std::string dir) { + std::regex cloudmeshDataPattern(".*cloudmesh/" + dir + ".*", std::regex_constants::icase); try { - fs::path requestedPath = resolveDataFile(filename); + fs::path requestedPath = resolveDataFileInDirectory(filename, dir); std::string canonicalPathStr = fs::canonical(fs::absolute(requestedPath)).string(); return std::regex_search(canonicalPathStr, cloudmeshDataPattern); } catch (const std::exception& e) { - std::cerr << "Caught Error: " << e.what() << std::endl; return false; // Invalid path format or not within the data directory } } @@ -251,6 +221,9 @@ int get_available_port() { close(sock); } + + std::cerr << "Unable to find an available port after " << MAX_PORT_TRIES + << " tries" << std::endl; return -1; }