Skip to content

chore: minor refactoring and code clean-up #4526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions tensorrt_llm/executor/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import traceback
from queue import Queue
from typing import Any, Optional
from typing import Any, Callable, Optional

import zmq
import zmq.asyncio
Expand Down Expand Up @@ -127,6 +127,8 @@ def put(self, obj: Any):

async def put_async(self, obj: Any):
self.setup_lazily()
if self.socket is None:
raise RuntimeError("Socket is not initialized or has been closed")
try:
if self.use_hmac_encryption:
# Send pickled data with HMAC appended
Expand All @@ -135,17 +137,39 @@ async def put_async(self, obj: Any):
await self.socket.send(signed_data)
else:
# Send data without HMAC
await self.socket.send_pyobj(obj)
data = pickle.dumps(obj) # nosec B301
await self.socket.send(data)
except TypeError as e:
logger.error(f"Cannot pickle {obj}")
raise e
except zmq.ZMQError as e:
logger.error(f"ZMQ error while sending object: {e}")
raise e
except Exception as e:
logger.error(f"Error sending object: {e}")
logger.error(traceback.format_exc())
raise e

nvtx_mark("ipc.send", color="blue", category="IPC")

def put_noblock(self, obj: Any, on_fail: Callable[[], None] = None):
'''
Parameters:
obj: The object to send.
on_fail: A callable that will be called if the send fails.
'''
self.setup_lazily()
data = pickle.dumps(obj) # nosec B301
if self.use_hmac_encryption:
data = self._sign_data(data)
try:
self.socket.send(data, flags=zmq.NOBLOCK)
except zmq.ZMQError as e:
if on_fail is not None:
on_fail()
else:
raise e

def get(self) -> Any:
self.setup_lazily()

Expand Down Expand Up @@ -188,6 +212,36 @@ async def get_async(self) -> Any:
obj = await self.socket.recv_pyobj()
return obj

def get_with_poll(self,
on_fail: Callable[[], None] = None,
poll_timeout: float = 0.1) -> Any:
'''
Parameters:
on_fail: A callable that will be called each time polling times out.
poll_timeout: Timeout in seconds for each poll attempt.
'''
self.setup_lazily()
while True:
if self.poll(poll_timeout):
if self.use_hmac_encryption:
# Receive signed data with HMAC
signed_data = self.socket.recv()

# Split data and HMAC
data = signed_data[:-32]
actual_hmac = signed_data[-32:]

# Verify HMAC
if not self._verify_hmac(data, actual_hmac):
raise RuntimeError("HMAC verification failed")

return pickle.loads(data) # nosec B301
else:
# Receive data without HMAC
return self.socket.recv_pyobj()
elif on_fail is not None:
on_fail()

def close(self):
if self.socket:
self.socket.close()
Expand Down
49 changes: 19 additions & 30 deletions tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
print_colored, print_colored_debug)
from .executor import GenerationExecutor
from .ipc import FusedIpcQueue, IpcQueue
from .ipc import IpcQueue
from .postproc_worker import PostprocWorkerConfig
from .request import CancellingRequest, GenerationRequest
from .result import GenerationResult, IterationResult
Expand Down Expand Up @@ -114,22 +114,16 @@ def _setup_queues(self) -> WorkerCommIpcAddrs:
name="proxy_request_queue")
self.request_error_queue = IpcQueue(is_server=True,
name="proxy_request_error_queue")
# TODO[chunweiy]: Unify IpcQueue and FusedIpcQueue
# Use PULL mode when enable_postprocess_parallel as there are
# multiple senders from multiple processes.
self.result_queue = FusedIpcQueue(
self.result_queue = IpcQueue(
is_server=True,
fuse_message=False,
socket_type=zmq.PULL
if self.enable_postprocess_parallel else zmq.PAIR,
name="proxy_result_queue")
self.mp_stats_queue = FusedIpcQueue(is_server=True,
fuse_message=False,
name="proxy_stats_queue")
self.kv_cache_events_queue = FusedIpcQueue(
is_server=True,
fuse_message=False,
name="proxy_kv_cache_events_queue")
self.mp_stats_queue = IpcQueue(is_server=True, name="proxy_stats_queue")
self.kv_cache_events_queue = IpcQueue(
is_server=True, name="proxy_kv_cache_events_queue")
return WorkerCommIpcAddrs(
request_queue_addr=self.request_queue.address,
request_error_queue_addr=self.request_error_queue.address,
Expand All @@ -149,11 +143,11 @@ def abort_request(self, request_id: int) -> None:
# send back a finished result.
self.request_queue.put(CancellingRequest(request_id))

def dispatch_result_task(self) -> bool:
def dispatch_result_task(self) -> None:
# TODO[chunweiy]: convert the dispatch_result_task to async, that should
# benefit from zmq.asyncio.Context
if (res := self.result_queue.get()) is None:
return False # shutdown the thread
raise ManagedThread.StopEvent()

async_queues = []
event_loop = None
Expand Down Expand Up @@ -181,17 +175,14 @@ def process_res(res):
for i in res:
global_tracer().log_instant("IPC.get")
if i is None:
return False
raise ManagedThread.StopEvent()
process_res(i)

if async_queues:
_SyncQueue.notify_many(event_loop, async_queues)

return True # success

def _iteration_result_task(self, queue: Union[FusedIpcQueue,
IntraProcessQueue],
result_singleton: IterationResult) -> bool:
def _iteration_result_task(self, queue: Union[IpcQueue, IntraProcessQueue],
result_singleton: IterationResult) -> None:
# iteration result is not urgent, so we can sleep a bit
time.sleep(0.2)

Expand All @@ -200,11 +191,11 @@ def _iteration_result_task(self, queue: Union[FusedIpcQueue,
except:
logger.debug(
"proxy.py: Error in _iteration_result_task: queue.get()")
return False
raise ManagedThread.StopEvent()

if data is None:
logger.debug("proxy.py: _iteration_result_task: data is None")
return False # shutdown the thread
raise ManagedThread.StopEvent()

data = data if isinstance(data, list) else [data]
queue = result_singleton.queue
Expand All @@ -217,7 +208,7 @@ def _iteration_result_task(self, queue: Union[FusedIpcQueue,
for d in data:
if d is None:
logger.debug("proxy.py: _iteration_result_task: d is None")
return False
raise ManagedThread.StopEvent()

if isinstance(queue, _SyncQueue):
queue.put_nowait(d)
Expand All @@ -237,15 +228,13 @@ def _iteration_result_task(self, queue: Union[FusedIpcQueue,
logger.debug(f"proxy.py: Error in _iteration_result_task: {e}")
raise e

return True # success

def dispatch_stats_task(self) -> bool:
return self._iteration_result_task(self.mp_stats_queue,
self._iter_stats_result)
def dispatch_stats_task(self) -> None:
self._iteration_result_task(self.mp_stats_queue,
self._iter_stats_result)

def dispatch_kv_cache_events_task(self) -> bool:
return self._iteration_result_task(self.kv_cache_events_queue,
self._iter_kv_events_result)
def dispatch_kv_cache_events_task(self) -> None:
self._iteration_result_task(self.kv_cache_events_queue,
self._iter_kv_events_result)

def _start_dispatch_threads(self):
if self.dispatch_result_thread is None:
Expand Down
77 changes: 17 additions & 60 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@
from ..runtime.model_runner import _engine_config_to_model_config
from ..sampling_params import BatchedLogitsProcessor, SamplingParams
from .executor import GenerationExecutor, IterationResultQueue
from .ipc import FusedIpcQueue, IpcQueue
from .ipc import IpcQueue
from .postproc_worker import (PostprocParams, PostprocWorker,
PostprocWorkerConfig, postproc_worker_main)
from .request import (CancellingRequest, GenerationRequest, LoRARequest,
PromptAdapterRequest)
from .result import (GenerationResult, IterationResult, LogProbsResult,
ResponseWrapper, compute_logprobs)
from .utils import (PERIODICAL_RESP_IN_AWAIT, ErrorResponse, IntraProcessQueue,
RequestError, WorkerCommIpcAddrs, has_event_loop,
is_llm_response)
from .utils import (ErrorResponse, IntraProcessQueue, RequestError,
WorkerCommIpcAddrs, has_event_loop, is_llm_response)

__all__ = [
"ExecutorBindingsWorker",
Expand Down Expand Up @@ -190,7 +189,7 @@ def _create_iteration_result_queue(self,
it_result_queue.aqueue = None

def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue,
queue: Union[Queue, FusedIpcQueue,
queue: Union[Queue, IpcQueue,
IntraProcessQueue]):
assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized."
it_result_queue.is_initialized = True
Expand Down Expand Up @@ -239,7 +238,7 @@ def _create_error_response(self, response: tllm.Response) -> ErrorResponse:
def _iteration_result_task(self, it_result_queue: IterationResultQueue,
engine_get_result_api: Callable,
result_singleton: IterationResult,
result_serializer: Callable):
result_serializer: Callable) -> None:
time.sleep(0.2)
async_queues = []
queue = result_singleton.queue if self._is_llm_executor and result_singleton else it_result_queue.queue
Expand Down Expand Up @@ -269,8 +268,6 @@ def _iteration_result_task(self, it_result_queue: IterationResultQueue,
logger.error(f"worker.py: Error in _iteration_result_task: {e}")
raise e

return True # success

def dispatch_stats_task(self) -> bool:

# Define a Callable to join iteration and request stats
Expand Down Expand Up @@ -301,7 +298,7 @@ def get_stats():
self._iter_stats_result,
stats_serializer)

def dispatch_kv_cache_events_task(self) -> bool:
def dispatch_kv_cache_events_task(self) -> None:
if isinstance(self.engine, tllm.Executor):
# Check if the engine has a kv cache event manager
# If not, return an empty list for the events which will cause the thread to exit early.
Expand Down Expand Up @@ -618,32 +615,27 @@ def worker_main(
request_error_queue = IpcQueue(worker_queues.request_error_queue_addr,
is_server=False,
name="worker_request_error_queue")
mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr,
is_server=False,
fuse_message=True,
name="worker_stats_queue")
kv_cache_events_queue = FusedIpcQueue(
mp_stats_queue = IpcQueue(worker_queues.stats_queue_addr,
is_server=False,
name="worker_stats_queue")
kv_cache_events_queue = IpcQueue(
worker_queues.kv_cache_events_queue_addr,
is_server=False,
fuse_message=False,
name="worker_kv_cache_events_queue")

if postproc_worker_config.enabled:
# IPC queues for sending inputs to the postprocess parallel
# processes, each one is a PAIR zmq socket
result_queues = [
FusedIpcQueue(is_server=True,
fuse_message=PERIODICAL_RESP_IN_AWAIT,
name=f"postprocess_{i}_feedin_queue")
IpcQueue(is_server=True, name=f"postprocess_{i}_feedin_queue")
for i in range(postproc_worker_config.num_postprocess_workers)
]
else:
# IPC queue for sending results back to the proxy, and let the
# Proxy process to handle the postprocess
result_queue = FusedIpcQueue(worker_queues.result_queue_addr,
is_server=False,
fuse_message=PERIODICAL_RESP_IN_AWAIT,
name="worker_result_queue")
result_queue = IpcQueue(worker_queues.result_queue_addr,
is_server=False,
name="worker_result_queue")

def notify_proxy_threads_to_quit():
# Signal the dispatcher thread in the proxy to quit
Expand Down Expand Up @@ -758,8 +750,7 @@ class AwaitResponseHelper:
class HandlerKind(enum.Enum):
unknown = 0
single_process_worker = 1
ipc_periodically = 2
ipc_batched = 3
ipc_batched = 2

def __init__(self, worker: "ExecutorBindingsWorker"):
# TODO: make worker weakref
Expand All @@ -783,10 +774,7 @@ def responses_handler(self, responses: List[tllm.Response]):
# The ExecutorBindingProxy is used
print_colored_debug(f"creating await_response helper for IPC\n",
color="yellow")
if PERIODICAL_RESP_IN_AWAIT:
self.handler_kind = HandlerKind.ipc_periodically
else:
self.handler_kind = HandlerKind.ipc_batched
self.handler_kind = HandlerKind.ipc_batched
else:
raise NotImplementedError

Expand All @@ -795,12 +783,10 @@ def responses_handler(self, responses: List[tllm.Response]):
return self.handle_for_worker(responses)
case HandlerKind.ipc_batched:
return self.handle_for_ipc_batched(responses)
case HandlerKind.ipc_periodically:
return self.handle_for_ipc_periodically(responses)
case _:
raise NotImplementedError

def __call__(self) -> bool:
def __call__(self) -> None:
''' This method should be called by a ManagedThread. '''
responses = self.worker.engine.await_responses(
timeout=datetime.timedelta(milliseconds=100))
Expand All @@ -814,7 +800,6 @@ def __call__(self) -> bool:
color="red",
category="Worker"):
self.responses_handler(responses)
return True

def handle_for_worker(self, responses: List[tllm.Response]) -> None:
''' Return the responses to asyncio.event_loop. '''
Expand Down Expand Up @@ -847,34 +832,6 @@ def handle_for_worker(self, responses: List[tllm.Response]) -> None:
if async_queues:
_SyncQueue.notify_many(event_loop, async_queues)

def handle_for_ipc_periodically(self,
responses: List[tllm.Response]) -> None:
''' Return the responses to Proxy via IPC. This will put Rsp to a Queue
in a FusedIpcQueue, and a background thread will batch them and invoke
IPC periodically. '''

with nvtx_range_debug(f"handle_for_ipc_periodically-{len(responses)}",
color="red",
category="Worker"):

for response in responses:

if self.worker._has_background_error():
response = self.worker._create_error_response(response)
elif response.has_error():
response = ErrorResponse(response.client_id,
response.error_msg,
response.request_id)
else:
logprobs_result = _get_logprobs(
self.worker, response, self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)

# TODO: To verify the performance of using ZMQ instead of SharedMemory
# to send the logits tensor back to the Proxy process.
_send_rsp(self.worker, response)

def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
''' Perform the IPC in batch explicitly. '''
postproc_batches = [
Expand Down
Loading