Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions scaler/client/agent/client_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from scaler.client.serializer.mixins import Serializer
from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.common import ObjectStorageAddress
from scaler.io.ymq.ymq import IOContext, IOSocketType
from scaler.protocol.python.message import (
ClientDisconnect,
ClientHeartbeatEcho,
Expand All @@ -31,16 +32,16 @@
from scaler.utility.event_loop import create_async_loop_routine
from scaler.utility.exceptions import ClientCancelledException, ClientQuitException, ClientShutdownException
from scaler.utility.identifiers import ClientID
from scaler.utility.zmq_config import ZMQConfig
from scaler.utility.ymq_config import YMQConfig


class ClientAgent(threading.Thread):
def __init__(
self,
identity: ClientID,
client_agent_address: ZMQConfig,
scheduler_address: ZMQConfig,
context: zmq.Context,
client_agent_address: YMQConfig,
scheduler_address: YMQConfig,
context: IOContext,
future_manager: ClientFutureManager,
stop_event: threading.Event,
timeout_seconds: int,
Expand All @@ -63,20 +64,17 @@ def __init__(
self._future_manager = future_manager

self._connector_internal = AsyncConnector(
context=zmq.asyncio.Context.shadow(self._context),
context=self._context,
name="client_agent_internal",
socket_type=zmq.PAIR,
bind_or_connect="bind",
address=self._client_agent_address,
callback=self.__on_receive_from_client,
identity=None,
)

self._connector_external = AsyncConnector(
context=zmq.asyncio.Context.shadow(self._context),
context=self._context,
name="client_agent_external",
socket_type=zmq.DEALER,
address=self._scheduler_address,
bind_or_connect="connect",
callback=self.__on_receive_from_scheduler,
identity=self._identity,
)
Expand All @@ -86,6 +84,9 @@ def __init__(
self._task_manager: Optional[ClientTaskManager] = None

def __initialize(self):
self._connector_internal.init_sync(bind_or_connect="bind", socket_type=IOSocketType.Binder)
self._connector_external.init_sync(bind_or_connect="connect", socket_type=IOSocketType.Connector)

self._disconnect_manager = ClientDisconnectManager()
self._heartbeat_manager = ClientHeartbeatManager(
death_timeout_seconds=self._timeout_seconds, storage_address_future=self._storage_address
Expand Down
23 changes: 12 additions & 11 deletions scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import logging
import threading
import uuid
from random import randint
from collections import Counter
from inspect import signature
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

import zmq

from scaler.client.agent.client_agent import ClientAgent
from scaler.client.agent.future_manager import ClientFutureManager
from scaler.client.future import ScalerFuture
Expand All @@ -19,14 +18,15 @@
from scaler.io.config import DEFAULT_CLIENT_TIMEOUT_SECONDS, DEFAULT_HEARTBEAT_INTERVAL_SECONDS
from scaler.io.sync_connector import SyncConnector
from scaler.io.sync_object_storage_connector import SyncObjectStorageConnector
from scaler.io.ymq.ymq import IOContext, IOSocketType
from scaler.protocol.python.message import ClientDisconnect, ClientShutdownResponse, GraphTask, Task
from scaler.utility.exceptions import ClientQuitException, MissingObjects
from scaler.utility.graph.optimization import cull_graph
from scaler.utility.graph.topological_sorter import TopologicalSorter
from scaler.utility.identifiers import ClientID, ObjectID, TaskID
from scaler.utility.metadata.profile_result import ProfileResult
from scaler.utility.metadata.task_flags import TaskFlags, retrieve_task_flags_from_task
from scaler.utility.zmq_config import ZMQConfig, ZMQType
from scaler.utility.ymq_config import YMQConfig
from scaler.worker.agent.processor.processor import Processor


Expand Down Expand Up @@ -88,22 +88,20 @@ def __initialize__(
self._stream_output = stream_output
self._identity = ClientID.generate_client_id()

self._client_agent_address = ZMQConfig(ZMQType.inproc, host=f"scaler_client_{uuid.uuid4().hex}")
self._scheduler_address = ZMQConfig.from_string(address)
self._client_agent_address = YMQConfig("127.0.0.1", randint(10000, 20000))
self._scheduler_address = YMQConfig.from_string(address)
self._timeout_seconds = timeout_seconds
self._heartbeat_interval_seconds = heartbeat_interval_seconds

self._stop_event = threading.Event()
self._context = zmq.Context()
self._connector_agent = SyncConnector(
context=self._context, socket_type=zmq.PAIR, address=self._client_agent_address, identity=self._identity
)
self._context = IOContext()

self._future_manager = ClientFutureManager(self._serializer)

self._agent = ClientAgent(
identity=self._identity,
client_agent_address=self._client_agent_address,
scheduler_address=ZMQConfig.from_string(address),
scheduler_address=self._scheduler_address,
context=self._context,
future_manager=self._future_manager,
stop_event=self._stop_event,
Expand All @@ -113,6 +111,10 @@ def __initialize__(
)
self._agent.start()

self._connector_agent = SyncConnector(
context=self._context, socket_type=IOSocketType.Connector, address=self._client_agent_address, identity=self._identity.extend("|agent")
)

logging.info(f"ScalerClient: connect to scheduler at {self._scheduler_address}")

# Blocks until the agent receives the object storage address
Expand Down Expand Up @@ -607,7 +609,6 @@ def __assert_client_not_stopped(self):

def __destroy(self):
self._agent.join()
self._context.destroy(linger=1)

@staticmethod
def __get_parent_task_priority() -> Optional[int]:
Expand Down
4 changes: 2 additions & 2 deletions scaler/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

from scaler.utility.logging.utility import setup_logger
from scaler.utility.object_storage_config import ObjectStorageConfig
from scaler.utility.zmq_config import ZMQConfig
from scaler.utility.ymq_config import YMQConfig
from scaler.worker.worker import Worker


class Cluster(multiprocessing.get_context("spawn").Process): # type: ignore[misc]

def __init__(
self,
address: ZMQConfig,
address: YMQConfig,
storage_address: Optional[ObjectStorageConfig],
worker_io_threads: int,
worker_names: List[str],
Expand Down
5 changes: 3 additions & 2 deletions scaler/cluster/combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from scaler.scheduler.allocate_policy.allocate_policy import AllocatePolicy
from scaler.utility.network_util import get_available_tcp_port
from scaler.utility.object_storage_config import ObjectStorageConfig
from scaler.utility.ymq_config import YMQConfig
from scaler.utility.zmq_config import ZMQConfig


Expand Down Expand Up @@ -59,9 +60,9 @@ def __init__(
logging_config_file: Optional[str] = None,
):
if address is None:
self._address = ZMQConfig.from_string(f"tcp://127.0.0.1:{get_available_tcp_port()}")
self._address = YMQConfig.from_string(f"tcp://127.0.0.1:{get_available_tcp_port()}")
else:
self._address = ZMQConfig.from_string(address)
self._address = YMQConfig.from_string(address)

if storage_address is None:
self._storage_address = ObjectStorageConfig(self._address.host, get_available_tcp_port())
Expand Down
3 changes: 2 additions & 1 deletion scaler/cluster/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from scaler.utility.event_loop import register_event_loop
from scaler.utility.logging.utility import setup_logger
from scaler.utility.object_storage_config import ObjectStorageConfig
from scaler.utility.ymq_config import YMQConfig
from scaler.utility.zmq_config import ZMQConfig


class SchedulerProcess(multiprocessing.get_context("spawn").Process): # type: ignore[misc]
def __init__(
self,
address: ZMQConfig,
address: YMQConfig,
storage_address: Optional[ObjectStorageConfig],
monitor_address: Optional[ZMQConfig],
io_threads: int,
Expand Down
6 changes: 3 additions & 3 deletions scaler/entry_points/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from scaler.utility.object_storage_config import ObjectStorageConfig
from scaler.utility.event_loop import EventLoopType, register_event_loop
from scaler.utility.zmq_config import ZMQConfig
from scaler.utility.ymq_config import YMQConfig


def get_args():
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_args():
),
)
parser.add_argument(
"--log-hub-address", "-la", default=None, type=ZMQConfig.from_string, help="address for Worker send logs"
"--log-hub-address", "-la", default=None, type=YMQConfig.from_string, help="address for Worker send logs"
)
parser.add_argument(
"--logging-paths",
Expand Down Expand Up @@ -128,7 +128,7 @@ def get_args():
help="specify the object storage server address, e.g. tcp://localhost:2346. If not specified, use the address "
"provided by the scheduler",
)
parser.add_argument("address", type=ZMQConfig.from_string, help="scheduler address to connect to")
parser.add_argument("address", type=YMQConfig.from_string, help="scheduler address to connect to")
return parser.parse_args()


Expand Down
8 changes: 5 additions & 3 deletions scaler/entry_points/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from scaler.utility.event_loop import EventLoopType
from scaler.utility.network_util import get_available_tcp_port
from scaler.utility.object_storage_config import ObjectStorageConfig
from scaler.utility.zmq_config import ZMQConfig
from scaler.utility.ymq_config import YMQConfig


def get_args():
Expand Down Expand Up @@ -119,14 +119,16 @@ def get_args():
parser.add_argument(
"--monitor-address",
"-ma",
type=ZMQConfig.from_string,
type=YMQConfig.from_string,
default=None,
help="specify monitoring address, if not specified, the monitoring address is scheduler address with port "
"number plus 2, e.g.: if scheduler address is tcp://localhost:2345, then monitoring address is "
"tcp://localhost:2347",
)
parser.add_argument(
"address", type=ZMQConfig.from_string, help="scheduler address to connect to, e.g.: `tcp://localhost:6378`"
"address",
type=YMQConfig.from_string,
help="scheduler address to connect to, e.g.: `tcp://localhost:6378`"
)
return parser.parse_args()

Expand Down
3 changes: 2 additions & 1 deletion scaler/entry_points/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Literal, Union

from scaler.io.sync_subscriber import SyncSubscriber
from scaler.io.sync_subscriber_zmq import SyncSubscriberZMQ
from scaler.protocol.python.message import StateScheduler
from scaler.protocol.python.mixins import Message
from scaler.utility.formatter import (
Expand Down Expand Up @@ -50,7 +51,7 @@ def poke(screen, args):
screen.nodelay(1)

try:
subscriber = SyncSubscriber(
subscriber = SyncSubscriberZMQ(
address=ZMQConfig.from_string(args.address),
callback=functools.partial(show_status, screen=screen),
topic=b"",
Expand Down
62 changes: 28 additions & 34 deletions scaler/io/async_binder.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import logging
import os
import uuid
from collections import defaultdict
from typing import Awaitable, Callable, Dict, List, Optional
from typing import Awaitable, Callable, Dict, Optional

import zmq.asyncio
from zmq import Frame
from scaler.io.ymq.ymq import IOContext, IOSocketType
from scaler.io.ymq import ymq

from scaler.io.utility import deserialize, serialize
from scaler.protocol.python.mixins import Message
from scaler.protocol.python.status import BinderStatus
from scaler.utility.mixins import Looper, Reporter
from scaler.utility.zmq_config import ZMQConfig
from scaler.utility.identifiers import ClientID, Identifier
from scaler.utility.ymq_config import YMQConfig


class AsyncBinder(Looper, Reporter):
def __init__(self, context: zmq.asyncio.Context, name: str, address: ZMQConfig, identity: Optional[bytes] = None):
def __init__(self, context: IOContext, name: str, address: YMQConfig, identity: Optional[Identifier] = None):
self._address = address

if identity is None:
identity = f"{os.getpid()}|{name}|{uuid.uuid4()}".encode()
identity = ClientID.generate_client_id(name)
self._identity = identity

self._context = context
self._socket = self._context.socket(zmq.ROUTER)
self.__set_socket_options()
self._socket.bind(self._address.to_address())

self._callback: Optional[Callable[[bytes, Message], Awaitable[None]]] = None

Expand All @@ -36,50 +32,48 @@ def __init__(self, context: zmq.asyncio.Context, name: str, address: ZMQConfig,
def identity(self):
return self._identity

async def init(self):
self._socket = await self._context.createIOSocket(self.identity.decode(), IOSocketType.Binder)
await self._socket.bind(self._address.to_address())

def init_sync(self):
self._socket = self._context.createIOSocket_sync(self.identity.decode(), IOSocketType.Binder)
self._socket.bind_sync(self._address.to_address())

def destroy(self):
self._context.destroy(linger=0)
pass

def register(self, callback: Callable[[bytes, Message], Awaitable[None]]):
self._callback = callback

async def routine(self):
frames: List[Frame] = await self._socket.recv_multipart(copy=False)
if not self.__is_valid_message(frames):
return
message: ymq.Message = await self._socket.recv()

source, payload = frames
message: Optional[Message] = deserialize(payload.bytes)
if message is None:
logging.error(f"received unknown message from {source.bytes!r}: {payload!r}")
# TODO: zero-copy
deseralized: Optional[Message] = deserialize(message.payload.data)
if deseralized is None:
logging.error(f"received unknown message from {message.address.data!r}: {message.payload.data!r}")
return

self.__count_received(message.__class__.__name__)
await self._callback(source.bytes, message)

if self._callback is None:
raise RuntimeError(f"{self.__get_prefix()}: no callback registered")

await self._callback(message.address.data, deseralized)

async def send(self, to: bytes, message: Message):
self.__count_sent(message.__class__.__name__)
await self._socket.send_multipart([to, serialize(message)], copy=False)
await self._socket.send(ymq.Message(to, serialize(message)))

def get_status(self) -> BinderStatus:
return BinderStatus.new_msg(received=self._received, sent=self._sent)

def __set_socket_options(self):
self._socket.setsockopt(zmq.IDENTITY, self._identity)
self._socket.setsockopt(zmq.SNDHWM, 0)
self._socket.setsockopt(zmq.RCVHWM, 0)

def __is_valid_message(self, frames: List[Frame]) -> bool:
if len(frames) < 2:
logging.error(f"{self.__get_prefix()} received unexpected frames {frames}")
return False

return True

def __count_received(self, message_type: str):
self._received[message_type] += 1

def __count_sent(self, message_type: str):
self._sent[message_type] += 1

def __get_prefix(self):
return f"{self.__class__.__name__}[{self._identity.decode()}]:"
return f"{self.__class__.__name__}[{self._identity}]:"
Loading
Loading