Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 90 additions & 9 deletions tensorrt_llm/executor/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hmac
import os
import pickle # nosec B403
import threading
import time
import traceback
from queue import Queue
Expand Down Expand Up @@ -65,6 +66,13 @@ def __init__(self,
self.hmac_key = address[1] if address is not None else None
self.use_hmac_encryption = use_hmac_encryption

self._setup_lock = threading.Lock()

# Thread safety debugging
self._zmq_thread_id = None
self._zmq_debug_enabled = os.environ.get('TLLM_LLMAPI_ZMQ_DEBUG',
'0') != '0'

# Check HMAC key condition
if self.use_hmac_encryption and not self.is_server and self.hmac_key is None:
raise ValueError(
Expand Down Expand Up @@ -93,25 +101,52 @@ def __init__(self,
self.address = (self.address_endpoint, self.hmac_key)

def setup_lazily(self):
# Early return if setup is already done
if self._setup_done:
return
self._setup_done = True

if not self.is_server:
logger_debug(
f"Client [{self.name}] connecting to {self.address_endpoint} in {self.socket_type_str[self.socket_type]}\n",
"green")
self.socket.connect(self.address_endpoint)
with self._setup_lock:
if self._setup_done:
return
self._setup_done = True

if not self.is_server:
logger_debug(
f"Client [{self.name}] connecting to {self.address_endpoint} in {self.socket_type_str[self.socket_type]}\n",
"green")
self.socket.connect(self.address_endpoint)

self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)

def _check_thread_safety(self):
"""Check if the current thread is the same as the thread that first used the socket."""
if not self._zmq_debug_enabled:
return

current_thread_id = threading.get_ident()

if self._zmq_thread_id is None:
# First call - capture the thread ID
self._zmq_thread_id = current_thread_id
logger_debug(
f"ZMQ socket [{self.name}] initialized on thread {current_thread_id}",
"cyan")
elif self._zmq_thread_id != current_thread_id:
# Thread mismatch - raise error
raise RuntimeError(
f"ZMQ thread safety violation detected in [{self.name}]: "
f"Socket created on thread {self._zmq_thread_id}, "
f"but accessed from thread {current_thread_id}. "
f"ZMQ sockets are not thread-safe!")

def poll(self, timeout: int) -> bool:
"""
Parameters:
timeout (int): Timeout in seconds
"""
self.setup_lazily()
self._check_thread_safety()

events = dict(self.poller.poll(timeout=timeout * 1000))
if self.socket in events and events[self.socket] == zmq.POLLIN:
Expand All @@ -121,6 +156,7 @@ def poll(self, timeout: int) -> bool:

def put(self, obj: Any):
self.setup_lazily()
self._check_thread_safety()
with nvtx_range_debug("send", color="blue", category="IPC"):
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
# Need manual serialization for encryption or ROUTER multipart
Expand Down Expand Up @@ -148,6 +184,7 @@ def put_noblock(self,
assert retry >= 0 and retry <= 10, "Retry must be between 0 and 10, adjust the wait_time if needed"

self.setup_lazily()
self._check_thread_safety()
with nvtx_range_debug("send", color="blue", category="IPC"):

data = self._prepare_data(obj)
Expand All @@ -162,6 +199,7 @@ def put_noblock(self,

async def put_async(self, obj: Any):
self.setup_lazily()
self._check_thread_safety()
try:
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
# Need manual serialization for encryption or ROUTER multipart
Expand All @@ -182,6 +220,7 @@ async def put_async(self, obj: Any):

async def put_async_noblock(self, obj: Any):
self.setup_lazily()
self._check_thread_safety()
try:
if self.use_hmac_encryption:
data = pickle.dumps(obj) # nosec B301
Expand All @@ -196,14 +235,55 @@ async def put_async_noblock(self, obj: Any):

def get(self) -> Any:
self.setup_lazily()
self._check_thread_safety()
return self._recv_data()

async def get_async(self) -> Any:
self.setup_lazily()
self._check_thread_safety()
return await self._recv_data_async()

async def get_async_noblock(self, timeout: float = 0.5) -> Any:
return await asyncio.wait_for(self.get_async(), timeout)
"""Get data with timeout using polling to avoid message drops.

This method uses ZMQ's NOBLOCK flag with polling instead of asyncio.wait_for
to prevent cancelling recv operations which can cause message drops.

Args:
timeout: Timeout in seconds

Returns:
The received object

Raises:
asyncio.TimeoutError: If timeout is reached without receiving data
"""
self.setup_lazily()
self._check_thread_safety()

# Use polling loop instead of asyncio.wait_for to avoid cancelling recv
# which can cause message drops
deadline = asyncio.get_event_loop().time() + timeout
while True:
try:
# Try non-blocking receive
if self.socket_type == zmq.ROUTER:
identity, data = await self.socket.recv_multipart(
flags=zmq.NOBLOCK)
self._last_identity = identity
return self._parse_data(data)
else:
if self.use_hmac_encryption:
data = await self.socket.recv(flags=zmq.NOBLOCK)
return self._parse_data(data)
else:
return await self.socket.recv_pyobj(flags=zmq.NOBLOCK)
except zmq.Again:
# No message available yet
if asyncio.get_event_loop().time() >= deadline:
raise asyncio.TimeoutError()
# Short sleep to avoid busy-waiting
await asyncio.sleep(0.01)

def close(self):
if self.socket:
Expand Down Expand Up @@ -311,6 +391,7 @@ def notify_with_retry(self, message, max_retries=5, timeout=1):
raise ValueError(
"notify_with_retry is only supported for DEALER socket for now")

self._check_thread_safety()
retry_count = 0

while retry_count < max_retries:
Expand Down
Loading