Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 2 errors #155

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
4 changes: 2 additions & 2 deletions example/demo/collective/start_job_client.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ export PADDLE_POD_ID="not set"
BASEDIR=$(dirname $(readlink -f $0))
echo $BASEDIR

nohup python -u paddle_edl.demo.collective.job_client_demo \
python -m paddle_edl.demo.collective.job_client_demo \
--log_level 20 \
--package_sh ./resnet50/package.sh \
--pod_path ./resnet50_pod \
./train_pretrain.sh > job_client.log 2>&1 &
./train_pretrain.sh
4 changes: 2 additions & 2 deletions example/demo/collective/start_job_server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ echo "node_ips:${node_ips}"
BASEDIR=$(dirname $(readlink -f $0))
echo "${BASEDIR}"

nohup python -u paddle_edl.demo.collective.job_server_demo \
python -m paddle_edl.demo.collective.job_server_demo \
--node_ips ${node_ips} \
--pod_num_of_node 8 \
--time_interval_to_change 900 \
--gpu_num_of_node 8 > job_server.log 2>&1 &
--gpu_num_of_node 8
29 changes: 15 additions & 14 deletions python/edl/collective/distribute_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import multiprocessing
import sys
import threading
from edl.uitls import reader as edl_reader
from edl.utils import reader as edl_reader
from edl.utils import env as edl_env
from edl.utils import state as edl_state

from edl.utils import data_server
from edl.utils import data_server_pb2
from edl.utils import edl_process
from edl.utils import process as edl_process
from edl.utils import data_server_client
from edl.utils import etcd_db
from edl.utils.log_utils import logger
Expand Down Expand Up @@ -60,9 +60,9 @@ def __init__(
self._data_queue = out_queue

def _get_file_list(self, timeout=60):
client = data_server_client.DataServerClient()
client = data_server_client.Client()
return client.get_file_list(
leader_endpoint=self._leader_endpoint,
reader_leader_endpoint=self._leader_endpoint,
reader_name=self._reader_name,
pod_id=self._pod_id,
file_list=self._file_list,
Expand Down Expand Up @@ -150,10 +150,11 @@ def __init__(
self._t_generater = threading.Thread(target=self.generate)
self._t_accesser = threading.Thread(target=self.access)

self._client = data_server_client.DataServerClient()
self._client = data_server_client.Client()
self._lock = threading.Lock()

def start(self):
self._client.connect(self._reader_leader_endpoint)
self._client._connect(self._reader_leader_endpoint)
self._t_reporter.start()
self._t_generater.start()
self._t_accesser.start()
Expand All @@ -177,7 +178,7 @@ def _report(self, report_size=10):

self._client.report_batch_data_meta(
reader_leader_endpoint=self._reader_leader_endpoint,
reader_name=self._name,
reader_name=self._reader_name,
pod_id=self._trainer_env.pod_id,
dataserver_endpoint=self._data_server.endpoint,
batch_data_ids=batch_data_ids,
Expand All @@ -188,23 +189,23 @@ def _report(self, report_size=10):
while not self._stop.set() and len(batch_data_ids) > 0:
self._client.report_batch_data_meta(
reader_leader_endpoint=self._reader_leader_endpoint,
reader_name=self._name,
reader_name=self._reader_name,
pod_id=self._trainer_env.pod_id,
dataserver_endpoint=self._data_server.endpoint,
batch_data_ids=batch_data_ids,
)

self._client.reach_data_end(
reader_leader_endpoint=self._reader_leader_endpoint,
reader_name=self._name,
reader_name=self._reader_name,
pod_id=self._trainer_env.pod_id,
)

def _access(self):
while not self._stop.set():
res = self._client.get_balanced_batch_data(
res = self._client.get_batch_data_meta(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be the wrong function

reader_leader_endpoint=self._reader_leader_endpoint,
reader_name=self._name,
reader_name=self._reader_name,
pod_id=self._trainer_env.pod_id,
)

Expand All @@ -219,7 +220,7 @@ def _get_batch_data(self, req):
Read BatchData from local or remote by BatchDataRequest
"""
if self._trainer_env.pod_id != req.producer_pod_id:
return (req, self._client.get_batch_data(req))
return (req, self._client.get_batch_data(self._reader_leader_endpoint, req))

return (req, self.get_local_batch_data(req))

Expand Down Expand Up @@ -322,7 +323,7 @@ def __init__(self, file_list, file_splitter_cls, batch_size, cache_capcity=100):
self._trainer_env.endpoints, self._trainer_env.job_id
)
# reader meta
self._reader_leader = edl_reader.load_from_ectd(
self._reader_leader = edl_reader.load_from_etcd(
self._etcd, self._trainer_env.pod_leader_id, timeout=60
)

Expand All @@ -342,7 +343,7 @@ def stop(self):
self._accesser.join()
self._accesser = None

def __exit__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.stop()

def _check_accesser(self):
Expand Down
3 changes: 0 additions & 3 deletions python/edl/discovery/consistent_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ def get_node(self, key):

def get_node_nodes(self, key):
# return node, nodes, version
if len(self._nodes) == 0:
return None, self._nodes, self._version

node = self.get_node(key)
return node, self._nodes, self._version

Expand Down
2 changes: 1 addition & 1 deletion python/edl/distill/balance_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def call_back(add_servers, rm_servers):

self._db.refresh(service_name, self._discovery_server) # before watch, refresh
# NOTE. start from revision + 1, that is after get_service
watch_id = self._db.watch_service( # noqa: F841
self._db.watch_service( # noqa: F841
service_name, call_back, start_revision=revision + 1
)

Expand Down
4 changes: 2 additions & 2 deletions python/edl/utils/cluster_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def is_stopped(self):
with self._lock:
return self._t_register is None

def __exit__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.stop()

def _generate_cluster_from_resource(self, resource_pods):
Expand Down Expand Up @@ -203,7 +203,7 @@ def _generate_cluster_once(self):
len(inited) > 0
and current_cluster.get_pods_nranks() < self._job_env.max_nodes
):
train_status = edl_train_status.load_from_etcd(self._etcd, timeout=30)
train_status = edl_train_status.load_from_etcd(self._etcd, self._pod_id, timeout=30)
if (
train_status == edl_train_status.TrainStatus.INITIAL
or train_status == edl_train_status.TrainStatus.RUNNING
Expand Down
2 changes: 1 addition & 1 deletion python/edl/utils/cluster_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,5 @@ def is_stopped(self):
with self._lock:
return self._t_watcher is None

def __exit__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
9 changes: 3 additions & 6 deletions python/edl/utils/data_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,9 @@ def get_size(self):

def pop(self, num):
a = []
while len(self._queue) > 0:
if (num > 0 and len(a) < num) or num <= 0:
batch_data_id = self._queue.popleft()
a.append(batch_data_id)
else:
break
while len(self._queue) > 0 and ((num > 0 and len(a) < num) or num <= 0):
batch_data_id = self._queue.popleft()
a.append(batch_data_id)

logger.debug(
"batch_data_ids:{}, queue:{}".format(
Expand Down
2 changes: 1 addition & 1 deletion python/edl/utils/data_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def report_batch_data_meta(

@error_utils.handle_errors_until_timeout
def reach_data_end(self, reader_leader_endpoint, reader_name, pod_id, timeout=60):
conn = self.connect(reader_leader_endpoint, timeout=30)
conn = self._connect(reader_leader_endpoint, timeout=30)

req = data_server_pb2.ReachDataEndRequest()
req.reader_name = reader_name
Expand Down
8 changes: 8 additions & 0 deletions python/edl/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ class EdlNotLeaderError(EdlException):
pass


class EdlNotFoundLeader(EdlException):
pass


class EdlAccessDataError(EdlException):
pass


def deserialize(pb_status):
thismodule = sys.modules[__name__]
try:
Expand Down
2 changes: 1 addition & 1 deletion python/edl/utils/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _launch(self):

time.sleep(3)

def __exit__(self):
def __exit__(self, exc_type, exc_value, traceback):
if self._leader_register is not None:
self._leader_register.stop()

Expand Down
2 changes: 1 addition & 1 deletion python/edl/utils/leader_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def stop(self):
self._generate_cluster.stop()
logger.info("pod:{} leader_register stopped".format(self._pod_id))

def __exit__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.stop()

def is_stopped(self):
Expand Down
2 changes: 1 addition & 1 deletion python/edl/utils/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ def is_stopped(self):
with self._lock:
return self._worker is None

def __exit__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
19 changes: 8 additions & 11 deletions python/edl/utils/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@ def to_json(self):
}
return json.dumps(d)

def from_json(self, s):
@classmethod
def from_json(cls, s):
d = json.loads(s)
self._name = d["name"]
self._pod_id = d["pod_id"]
self._endpoint = d["endpoint"]
return cls(d["name"], d["pod_id"], d["endpoint"])

def __str_(self):
return self._to_json()
return self.to_json()


@error_utils.handle_errors_until_timeout
Expand All @@ -61,8 +60,7 @@ def load_from_etcd(self, etcd, reader_name, pod_id, timeout=60):
"path:{}".format(etcd.get_full_path(path, pod_id))
)

meta = ReaderMeta()
meta.from_json(value)
meta = ReaderMeta.from_json(value)
logger.debug("get reader:{}".format(meta))
return meta

Expand All @@ -77,12 +75,11 @@ def check_dist_readers(etcd):

readers = {}
for s in servers:
r = ReaderMeta()
r.from_json(s.value)
r = ReaderMeta.from_json(s.value)

readers[r.key] = r
readers[r._pod_id] = r

cluster = edl_cluster.get_cluster(etcd)
cluster = edl_cluster.wait_to_load_from_etcd(etcd)
if cluster is None:
raise exceptions.EdlTableError(
"table:{} has no readers".format(constants.ETCD_CLUSTER)
Expand Down
2 changes: 1 addition & 1 deletion python/edl/utils/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ def is_stopped(self):
with self._lock:
return self._t_register is None

def __exit__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
6 changes: 3 additions & 3 deletions python/edl/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_current_epoch_attr(self):
return self.get_epoch_attr(self._epoch_no)

def update_current_epoch_attr(self, epoch_attr):
return self._update_epoch_attr(self._epoch_no, epoch_attr)
return self.update_epoch_attr(self._epoch_no, epoch_attr)


class State(json_serializable.Serializable):
Expand Down Expand Up @@ -160,11 +160,11 @@ def global_step_no(self):

@property
def total_batch_size(self):
return self._defaults["total_batch_size"]
return self._default["total_batch_size"]

@total_batch_size.setter
def total_batch_size(self, size):
self._defaults["total_batch_size"] = size
self._default["total_batch_size"] = size


@error_utils.handle_errors_until_timeout
Expand Down