diff --git a/example/demo/collective/start_job_client.sh b/example/demo/collective/start_job_client.sh index 7cf0c2e2..0a96a2d3 100755 --- a/example/demo/collective/start_job_client.sh +++ b/example/demo/collective/start_job_client.sh @@ -30,8 +30,8 @@ export PADDLE_POD_ID="not set" BASEDIR=$(dirname $(readlink -f $0)) echo $BASEDIR -nohup python -u paddle_edl.demo.collective.job_client_demo \ +python -m paddle_edl.demo.collective.job_client_demo \ --log_level 20 \ --package_sh ./resnet50/package.sh \ --pod_path ./resnet50_pod \ - ./train_pretrain.sh > job_client.log 2>&1 & + ./train_pretrain.sh diff --git a/example/demo/collective/start_job_server.sh b/example/demo/collective/start_job_server.sh index ebbdb8fd..933ace00 100755 --- a/example/demo/collective/start_job_server.sh +++ b/example/demo/collective/start_job_server.sh @@ -23,8 +23,8 @@ echo "node_ips:${node_ips}" BASEDIR=$(dirname $(readlink -f $0)) echo "${BASEDIR}" -nohup python -u paddle_edl.demo.collective.job_server_demo \ +python -m paddle_edl.demo.collective.job_server_demo \ --node_ips ${node_ips} \ --pod_num_of_node 8 \ --time_interval_to_change 900 \ - --gpu_num_of_node 8 > job_server.log 2>&1 & + --gpu_num_of_node 8 diff --git a/python/edl/collective/distribute_reader.py b/python/edl/collective/distribute_reader.py index a0a748cf..fe85a882 100644 --- a/python/edl/collective/distribute_reader.py +++ b/python/edl/collective/distribute_reader.py @@ -16,13 +16,13 @@ import multiprocessing import sys import threading -from edl.uitls import reader as edl_reader +from edl.utils import reader as edl_reader from edl.utils import env as edl_env from edl.utils import state as edl_state from edl.utils import data_server from edl.utils import data_server_pb2 -from edl.utils import edl_process +from edl.utils import process as edl_process from edl.utils import data_server_client from edl.utils import etcd_db from edl.utils.log_utils import logger @@ -60,9 +60,9 @@ def __init__( self._data_queue = out_queue def _get_file_list(self, timeout=60): - client = data_server_client.DataServerClient() + client = data_server_client.Client() return client.get_file_list( - leader_endpoint=self._leader_endpoint, + reader_leader_endpoint=self._leader_endpoint, reader_name=self._reader_name, pod_id=self._pod_id, file_list=self._file_list, @@ -150,10 +150,11 @@ def __init__( self._t_generater = threading.Thread(target=self.generate) self._t_accesser = threading.Thread(target=self.access) - self._client = data_server_client.DataServerClient() + self._client = data_server_client.Client() + self._lock = threading.Lock() def start(self): - self._client.connect(self._reader_leader_endpoint) + self._client._connect(self._reader_leader_endpoint) self._t_reporter.start() self._t_generater.start() self._t_accesser.start() @@ -177,7 +178,7 @@ def _report(self, report_size=10): self._client.report_batch_data_meta( reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, + reader_name=self._reader_name, pod_id=self._trainer_env.pod_id, dataserver_endpoint=self._data_server.endpoint, batch_data_ids=batch_data_ids, @@ -188,7 +189,7 @@ def _report(self, report_size=10): while not self._stop.set() and len(batch_data_ids) > 0: self._client.report_batch_data_meta( reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, + reader_name=self._reader_name, pod_id=self._trainer_env.pod_id, dataserver_endpoint=self._data_server.endpoint, batch_data_ids=batch_data_ids, @@ -196,15 +197,15 @@ def _report(self, report_size=10): self._client.reach_data_end( reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, + reader_name=self._reader_name, pod_id=self._trainer_env.pod_id, ) def _access(self): while not self._stop.set(): - res = self._client.get_balanced_batch_data( + res = self._client.get_batch_data_meta( reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, + reader_name=self._reader_name, pod_id=self._trainer_env.pod_id, ) @@ -219,7 +220,7 @@ def _get_batch_data(self, req): Read BatchData from local or remote by BatchDataRequest """ if self._trainer_env.pod_id != req.producer_pod_id: - return (req, self._client.get_batch_data(req)) + return (req, self._client.get_batch_data(self._reader_leader_endpoint, req)) return (req, self.get_local_batch_data(req)) @@ -322,7 +323,7 @@ def __init__(self, file_list, file_splitter_cls, batch_size, cache_capcity=100): self._trainer_env.endpoints, self._trainer_env.job_id ) # reader meta - self._reader_leader = edl_reader.load_from_ectd( + self._reader_leader = edl_reader.load_from_etcd( self._etcd, self._trainer_env.pod_leader_id, timeout=60 ) @@ -342,7 +343,7 @@ def stop(self): self._accesser.join() self._accesser = None - def __exit__(self): + def __exit__(self, exc_type, exc_value, traceback): self.stop() def _check_accesser(self): diff --git a/python/edl/discovery/consistent_hash.py b/python/edl/discovery/consistent_hash.py index edd0a3ef..1d1eff2e 100644 --- a/python/edl/discovery/consistent_hash.py +++ b/python/edl/discovery/consistent_hash.py @@ -89,9 +89,6 @@ def get_node(self, key): def get_node_nodes(self, key): # return node, nodes, version - if len(self._nodes) == 0: - return None, self._nodes, self._version - node = self.get_node(key) return node, self._nodes, self._version diff --git a/python/edl/distill/balance_table.py b/python/edl/distill/balance_table.py index 834cddf0..3ebb16f6 100644 --- a/python/edl/distill/balance_table.py +++ b/python/edl/distill/balance_table.py @@ -434,7 +434,7 @@ def call_back(add_servers, rm_servers): self._db.refresh(service_name, self._discovery_server) # before watch, refresh # NOTE. start from revision + 1, that is after get_service - watch_id = self._db.watch_service( # noqa: F841 + self._db.watch_service( # noqa: F841 service_name, call_back, start_revision=revision + 1 ) diff --git a/python/edl/utils/cluster_generator.py b/python/edl/utils/cluster_generator.py index 03dfa5fe..21fc5a00 100644 --- a/python/edl/utils/cluster_generator.py +++ b/python/edl/utils/cluster_generator.py @@ -89,7 +89,7 @@ def is_stopped(self): with self._lock: return self._t_register is None - def __exit__(self): + def __exit__(self, exc_type, exc_value, traceback): self.stop() def _generate_cluster_from_resource(self, resource_pods): @@ -203,7 +203,7 @@ def _generate_cluster_once(self): len(inited) > 0 and current_cluster.get_pods_nranks() < self._job_env.max_nodes ): - train_status = edl_train_status.load_from_etcd(self._etcd, timeout=30) + train_status = edl_train_status.load_from_etcd(self._etcd, self._pod_id, timeout=30) if ( train_status == edl_train_status.TrainStatus.INITIAL or train_status == edl_train_status.TrainStatus.RUNNING diff --git a/python/edl/utils/cluster_watcher.py b/python/edl/utils/cluster_watcher.py index 58223637..bb9bd3a1 100644 --- a/python/edl/utils/cluster_watcher.py +++ b/python/edl/utils/cluster_watcher.py @@ -116,5 +116,5 @@ def is_stopped(self): with self._lock: return self._t_watcher is None - def __exit__(self): + def __exit__(self, exc_type, exc_value, traceback): self.stop() diff --git a/python/edl/utils/data_server.py b/python/edl/utils/data_server.py index 1a37ff8a..380bdbf6 100644 --- a/python/edl/utils/data_server.py +++ b/python/edl/utils/data_server.py @@ -62,12 +62,9 @@ def get_size(self): def pop(self, num): a = [] - while len(self._queue) > 0: - if (num > 0 and len(a) < num) or num <= 0: - batch_data_id = self._queue.popleft() - a.append(batch_data_id) - else: - break + while len(self._queue) > 0 and ((num > 0 and len(a) < num) or num <= 0): + batch_data_id = self._queue.popleft() + a.append(batch_data_id) logger.debug( "batch_data_ids:{}, queue:{}".format( diff --git a/python/edl/utils/data_server_client.py b/python/edl/utils/data_server_client.py index a05a25af..928d09c3 100644 --- a/python/edl/utils/data_server_client.py +++ b/python/edl/utils/data_server_client.py @@ -97,7 +97,7 @@ def report_batch_data_meta( @error_utils.handle_errors_until_timeout def reach_data_end(self, reader_leader_endpoint, reader_name, pod_id, timeout=60): - conn = self.connect(reader_leader_endpoint, timeout=30) + conn = self._connect(reader_leader_endpoint, timeout=30) req = data_server_pb2.ReachDataEndRequest() req.reader_name = reader_name diff --git a/python/edl/utils/exceptions.py b/python/edl/utils/exceptions.py index 981c516e..88854627 100644 --- a/python/edl/utils/exceptions.py +++ b/python/edl/utils/exceptions.py @@ -89,6 +89,14 @@ class EdlNotLeaderError(EdlException): pass +class EdlNotFoundLeader(EdlException): + pass + + +class EdlAccessDataError(EdlException): + pass + + def deserialize(pb_status): thismodule = sys.modules[__name__] try: diff --git a/python/edl/utils/launcher.py b/python/edl/utils/launcher.py index 01f3c361..28a26ddb 100644 --- a/python/edl/utils/launcher.py +++ b/python/edl/utils/launcher.py @@ -245,7 +245,7 @@ def _launch(self): time.sleep(3) - def __exit__(self): + def __exit__(self, exc_type, exc_value, traceback): if self._leader_register is not None: self._leader_register.stop() diff --git a/python/edl/utils/leader_pod.py b/python/edl/utils/leader_pod.py index 88740b6f..68102b14 100644 --- a/python/edl/utils/leader_pod.py +++ b/python/edl/utils/leader_pod.py @@ -130,7 +130,7 @@ def stop(self): self._generate_cluster.stop() logger.info("pod:{} leader_register stopped".format(self._pod_id)) - def __exit__(self): + def __exit__(self, exc_type, exc_value, traceback): self.stop() def is_stopped(self): diff --git a/python/edl/utils/process.py b/python/edl/utils/process.py index 45981324..19f0a547 100644 --- a/python/edl/utils/process.py +++ b/python/edl/utils/process.py @@ -47,5 +47,5 @@ def is_stopped(self): with self._lock: return self._worker is None - def __exit__(self): + def __exit__(self, exc_type, exc_value, traceback): self.stop() diff --git a/python/edl/utils/reader.py b/python/edl/utils/reader.py index 3ee04b6d..b5fbc93e 100644 --- a/python/edl/utils/reader.py +++ b/python/edl/utils/reader.py @@ -33,14 +33,13 @@ def to_json(self): } return json.dumps(d) - def from_json(self, s): + @classmethod + def from_json(cls, s): d = json.loads(s) - self._name = d["name"] - self._pod_id = d["pod_id"] - self._endpoint = d["endpoint"] + return cls(d["name"], d["pod_id"], d["endpoint"]) def __str_(self): - return self._to_json() + return self.to_json() @error_utils.handle_errors_until_timeout @@ -61,8 +60,7 @@ def load_from_etcd(self, etcd, reader_name, pod_id, timeout=60): "path:{}".format(etcd.get_full_path(path, pod_id)) ) - meta = ReaderMeta() - meta.from_json(value) + meta = ReaderMeta.from_json(value) logger.debug("get reader:{}".format(meta)) return meta @@ -77,12 +75,11 @@ def check_dist_readers(etcd): readers = {} for s in servers: - r = ReaderMeta() - r.from_json(s.value) + r = ReaderMeta.from_json(s.value) - readers[r.key] = r + readers[r._pod_id] = r - cluster = edl_cluster.get_cluster(etcd) + cluster = edl_cluster.wait_to_load_from_etcd(etcd) if cluster is None: raise exceptions.EdlTableError( "table:{} has no readers".format(constants.ETCD_CLUSTER) diff --git a/python/edl/utils/register.py b/python/edl/utils/register.py index a11c3e47..42e6e030 100644 --- a/python/edl/utils/register.py +++ b/python/edl/utils/register.py @@ -82,5 +82,5 @@ def is_stopped(self): with self._lock: return self._t_register is None - def __exit__(self): + def __exit__(self, exc_type, exc_value, traceback): self.stop() diff --git a/python/edl/utils/state.py b/python/edl/utils/state.py index bb23002c..f9139ebd 100644 --- a/python/edl/utils/state.py +++ b/python/edl/utils/state.py @@ -108,7 +108,7 @@ def get_current_epoch_attr(self): return self.get_epoch_attr(self._epoch_no) def update_current_epoch_attr(self, epoch_attr): - return self._update_epoch_attr(self._epoch_no, epoch_attr) + return self.update_epoch_attr(self._epoch_no, epoch_attr) class State(json_serializable.Serializable): @@ -160,11 +160,11 @@ def global_step_no(self): @property def total_batch_size(self): - return self._defaults["total_batch_size"] + return self._default["total_batch_size"] @total_batch_size.setter def total_batch_size(self, size): - self._defaults["total_batch_size"] = size + self._default["total_batch_size"] = size @error_utils.handle_errors_until_timeout