diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py index f178596f4..003938727 100644 --- a/tensorrt_llm/executor/ipc.py +++ b/tensorrt_llm/executor/ipc.py @@ -5,7 +5,7 @@ import time import traceback from queue import Queue -from typing import Any, Optional +from typing import Any, Callable, Optional import zmq import zmq.asyncio @@ -127,6 +127,8 @@ def put(self, obj: Any): async def put_async(self, obj: Any): self.setup_lazily() + if self.socket is None: + raise RuntimeError("Socket is not initialized or has been closed") try: if self.use_hmac_encryption: # Send pickled data with HMAC appended @@ -135,10 +137,14 @@ async def put_async(self, obj: Any): await self.socket.send(signed_data) else: # Send data without HMAC - await self.socket.send_pyobj(obj) + data = pickle.dumps(obj) # nosec B301 + await self.socket.send(data) except TypeError as e: logger.error(f"Cannot pickle {obj}") raise e + except zmq.ZMQError as e: + logger.error(f"ZMQ error while sending object: {e}") + raise e except Exception as e: logger.error(f"Error sending object: {e}") logger.error(traceback.format_exc()) @@ -146,6 +152,24 @@ async def put_async(self, obj: Any): nvtx_mark("ipc.send", color="blue", category="IPC") + def put_noblock(self, obj: Any, on_fail: Callable[[], None] = None): + ''' + Parameters: + obj: The object to send. + on_fail: A callable that will be called if the send fails. + ''' + self.setup_lazily() + data = pickle.dumps(obj) # nosec B301 + if self.use_hmac_encryption: + data = self._sign_data(data) + try: + self.socket.send(data, flags=zmq.NOBLOCK) + except zmq.ZMQError as e: + if on_fail is not None: + on_fail() + else: + raise e + def get(self) -> Any: self.setup_lazily() @@ -188,6 +212,36 @@ async def get_async(self) -> Any: obj = await self.socket.recv_pyobj() return obj + def get_with_poll(self, + on_fail: Callable[[], None] = None, + poll_timeout: float = 0.1) -> Any: + ''' + Parameters: + on_fail: A callable that will be called each time polling times out. + poll_timeout: Timeout in seconds for each poll attempt. + ''' + self.setup_lazily() + while True: + if self.poll(poll_timeout): + if self.use_hmac_encryption: + # Receive signed data with HMAC + signed_data = self.socket.recv() + + # Split data and HMAC + data = signed_data[:-32] + actual_hmac = signed_data[-32:] + + # Verify HMAC + if not self._verify_hmac(data, actual_hmac): + raise RuntimeError("HMAC verification failed") + + return pickle.loads(data) # nosec B301 + else: + # Receive data without HMAC + return self.socket.recv_pyobj() + elif on_fail is not None: + on_fail() + def close(self): if self.socket: self.socket.close() diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 64607bada..1a4a8fe1c 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -18,7 +18,7 @@ from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, print_colored, print_colored_debug) from .executor import GenerationExecutor -from .ipc import FusedIpcQueue, IpcQueue +from .ipc import IpcQueue from .postproc_worker import PostprocWorkerConfig from .request import CancellingRequest, GenerationRequest from .result import GenerationResult, IterationResult @@ -114,22 +114,16 @@ def _setup_queues(self) -> WorkerCommIpcAddrs: name="proxy_request_queue") self.request_error_queue = IpcQueue(is_server=True, name="proxy_request_error_queue") - # TODO[chunweiy]: Unify IpcQueue and FusedIpcQueue # Use PULL mode when enable_postprocess_parallel as there are # multiple senders from multiple processes. - self.result_queue = FusedIpcQueue( + self.result_queue = IpcQueue( is_server=True, - fuse_message=False, socket_type=zmq.PULL if self.enable_postprocess_parallel else zmq.PAIR, name="proxy_result_queue") - self.mp_stats_queue = FusedIpcQueue(is_server=True, - fuse_message=False, - name="proxy_stats_queue") - self.kv_cache_events_queue = FusedIpcQueue( - is_server=True, - fuse_message=False, - name="proxy_kv_cache_events_queue") + self.mp_stats_queue = IpcQueue(is_server=True, name="proxy_stats_queue") + self.kv_cache_events_queue = IpcQueue( + is_server=True, name="proxy_kv_cache_events_queue") return WorkerCommIpcAddrs( request_queue_addr=self.request_queue.address, request_error_queue_addr=self.request_error_queue.address, @@ -149,11 +143,11 @@ def abort_request(self, request_id: int) -> None: # send back a finished result. self.request_queue.put(CancellingRequest(request_id)) - def dispatch_result_task(self) -> bool: + def dispatch_result_task(self) -> None: # TODO[chunweiy]: convert the dispatch_result_task to async, that should # benefit from zmq.asyncio.Context if (res := self.result_queue.get()) is None: - return False # shutdown the thread + raise ManagedThread.StopEvent() async_queues = [] event_loop = None @@ -181,17 +175,14 @@ def process_res(res): for i in res: global_tracer().log_instant("IPC.get") if i is None: - return False + raise ManagedThread.StopEvent() process_res(i) if async_queues: _SyncQueue.notify_many(event_loop, async_queues) - return True # success - - def _iteration_result_task(self, queue: Union[FusedIpcQueue, - IntraProcessQueue], - result_singleton: IterationResult) -> bool: + def _iteration_result_task(self, queue: Union[IpcQueue, IntraProcessQueue], + result_singleton: IterationResult) -> None: # iteration result is not urgent, so we can sleep a bit time.sleep(0.2) @@ -200,11 +191,11 @@ def _iteration_result_task(self, queue: Union[FusedIpcQueue, except: logger.debug( "proxy.py: Error in _iteration_result_task: queue.get()") - return False + raise ManagedThread.StopEvent() if data is None: logger.debug("proxy.py: _iteration_result_task: data is None") - return False # shutdown the thread + raise ManagedThread.StopEvent() data = data if isinstance(data, list) else [data] queue = result_singleton.queue @@ -217,7 +208,7 @@ def _iteration_result_task(self, queue: Union[FusedIpcQueue, for d in data: if d is None: logger.debug("proxy.py: _iteration_result_task: d is None") - return False + raise ManagedThread.StopEvent() if isinstance(queue, _SyncQueue): queue.put_nowait(d) @@ -237,15 +228,13 @@ def _iteration_result_task(self, queue: Union[FusedIpcQueue, logger.debug(f"proxy.py: Error in _iteration_result_task: {e}") raise e - return True # success - - def dispatch_stats_task(self) -> bool: - return self._iteration_result_task(self.mp_stats_queue, - self._iter_stats_result) + def dispatch_stats_task(self) -> None: + self._iteration_result_task(self.mp_stats_queue, + self._iter_stats_result) - def dispatch_kv_cache_events_task(self) -> bool: - return self._iteration_result_task(self.kv_cache_events_queue, - self._iter_kv_events_result) + def dispatch_kv_cache_events_task(self) -> None: + self._iteration_result_task(self.kv_cache_events_queue, + self._iter_kv_events_result) def _start_dispatch_threads(self): if self.dispatch_result_thread is None: diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index fc7742d04..79fbc500c 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -30,16 +30,15 @@ from ..runtime.model_runner import _engine_config_to_model_config from ..sampling_params import BatchedLogitsProcessor, SamplingParams from .executor import GenerationExecutor, IterationResultQueue -from .ipc import FusedIpcQueue, IpcQueue +from .ipc import IpcQueue from .postproc_worker import (PostprocParams, PostprocWorker, PostprocWorkerConfig, postproc_worker_main) from .request import (CancellingRequest, GenerationRequest, LoRARequest, PromptAdapterRequest) from .result import (GenerationResult, IterationResult, LogProbsResult, ResponseWrapper, compute_logprobs) -from .utils import (PERIODICAL_RESP_IN_AWAIT, ErrorResponse, IntraProcessQueue, - RequestError, WorkerCommIpcAddrs, has_event_loop, - is_llm_response) +from .utils import (ErrorResponse, IntraProcessQueue, RequestError, + WorkerCommIpcAddrs, has_event_loop, is_llm_response) __all__ = [ "ExecutorBindingsWorker", @@ -190,7 +189,7 @@ def _create_iteration_result_queue(self, it_result_queue.aqueue = None def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue, - queue: Union[Queue, FusedIpcQueue, + queue: Union[Queue, IpcQueue, IntraProcessQueue]): assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized." it_result_queue.is_initialized = True @@ -239,7 +238,7 @@ def _create_error_response(self, response: tllm.Response) -> ErrorResponse: def _iteration_result_task(self, it_result_queue: IterationResultQueue, engine_get_result_api: Callable, result_singleton: IterationResult, - result_serializer: Callable): + result_serializer: Callable) -> None: time.sleep(0.2) async_queues = [] queue = result_singleton.queue if self._is_llm_executor and result_singleton else it_result_queue.queue @@ -269,8 +268,6 @@ def _iteration_result_task(self, it_result_queue: IterationResultQueue, logger.error(f"worker.py: Error in _iteration_result_task: {e}") raise e - return True # success - def dispatch_stats_task(self) -> bool: # Define a Callable to join iteration and request stats @@ -301,7 +298,7 @@ def get_stats(): self._iter_stats_result, stats_serializer) - def dispatch_kv_cache_events_task(self) -> bool: + def dispatch_kv_cache_events_task(self) -> None: if isinstance(self.engine, tllm.Executor): # Check if the engine has a kv cache event manager # If not, return an empty list for the events which will cause the thread to exit early. @@ -618,32 +615,27 @@ def worker_main( request_error_queue = IpcQueue(worker_queues.request_error_queue_addr, is_server=False, name="worker_request_error_queue") - mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr, - is_server=False, - fuse_message=True, - name="worker_stats_queue") - kv_cache_events_queue = FusedIpcQueue( + mp_stats_queue = IpcQueue(worker_queues.stats_queue_addr, + is_server=False, + name="worker_stats_queue") + kv_cache_events_queue = IpcQueue( worker_queues.kv_cache_events_queue_addr, is_server=False, - fuse_message=False, name="worker_kv_cache_events_queue") if postproc_worker_config.enabled: # IPC queues for sending inputs to the postprocess parallel # processes, each one is a PAIR zmq socket result_queues = [ - FusedIpcQueue(is_server=True, - fuse_message=PERIODICAL_RESP_IN_AWAIT, - name=f"postprocess_{i}_feedin_queue") + IpcQueue(is_server=True, name=f"postprocess_{i}_feedin_queue") for i in range(postproc_worker_config.num_postprocess_workers) ] else: # IPC queue for sending results back to the proxy, and let the # Proxy process to handle the postprocess - result_queue = FusedIpcQueue(worker_queues.result_queue_addr, - is_server=False, - fuse_message=PERIODICAL_RESP_IN_AWAIT, - name="worker_result_queue") + result_queue = IpcQueue(worker_queues.result_queue_addr, + is_server=False, + name="worker_result_queue") def notify_proxy_threads_to_quit(): # Signal the dispatcher thread in the proxy to quit @@ -758,8 +750,7 @@ class AwaitResponseHelper: class HandlerKind(enum.Enum): unknown = 0 single_process_worker = 1 - ipc_periodically = 2 - ipc_batched = 3 + ipc_batched = 2 def __init__(self, worker: "ExecutorBindingsWorker"): # TODO: make worker weakref @@ -783,10 +774,7 @@ def responses_handler(self, responses: List[tllm.Response]): # The ExecutorBindingProxy is used print_colored_debug(f"creating await_response helper for IPC\n", color="yellow") - if PERIODICAL_RESP_IN_AWAIT: - self.handler_kind = HandlerKind.ipc_periodically - else: - self.handler_kind = HandlerKind.ipc_batched + self.handler_kind = HandlerKind.ipc_batched else: raise NotImplementedError @@ -795,12 +783,10 @@ def responses_handler(self, responses: List[tllm.Response]): return self.handle_for_worker(responses) case HandlerKind.ipc_batched: return self.handle_for_ipc_batched(responses) - case HandlerKind.ipc_periodically: - return self.handle_for_ipc_periodically(responses) case _: raise NotImplementedError - def __call__(self) -> bool: + def __call__(self) -> None: ''' This method should be called by a ManagedThread. ''' responses = self.worker.engine.await_responses( timeout=datetime.timedelta(milliseconds=100)) @@ -814,7 +800,6 @@ def __call__(self) -> bool: color="red", category="Worker"): self.responses_handler(responses) - return True def handle_for_worker(self, responses: List[tllm.Response]) -> None: ''' Return the responses to asyncio.event_loop. ''' @@ -847,34 +832,6 @@ def handle_for_worker(self, responses: List[tllm.Response]) -> None: if async_queues: _SyncQueue.notify_many(event_loop, async_queues) - def handle_for_ipc_periodically(self, - responses: List[tllm.Response]) -> None: - ''' Return the responses to Proxy via IPC. This will put Rsp to a Queue - in a FusedIpcQueue, and a background thread will batch them and invoke - IPC periodically. ''' - - with nvtx_range_debug(f"handle_for_ipc_periodically-{len(responses)}", - color="red", - category="Worker"): - - for response in responses: - - if self.worker._has_background_error(): - response = self.worker._create_error_response(response) - elif response.has_error(): - response = ErrorResponse(response.client_id, - response.error_msg, - response.request_id) - else: - logprobs_result = _get_logprobs( - self.worker, response, self.worker._is_pytorch_backend) - if logprobs_result: - response = ResponseWrapper(response, logprobs_result) - - # TODO: To verify the performance of using ZMQ instead of SharedMemory - # to send the logits tensor back to the Proxy process. - _send_rsp(self.worker, response) - def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None: ''' Perform the IPC in batch explicitly. ''' postproc_batches = [ diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py index 39cfd2d57..b349475d9 100644 --- a/tensorrt_llm/llmapi/utils.py +++ b/tensorrt_llm/llmapi/utils.py @@ -239,6 +239,9 @@ class ManagedThread(threading.Thread): **kwargs: The arguments to pass to the task """ + class StopEvent(Exception): + pass + def __init__(self, task: Callable[..., bool], error_queue: Queue, @@ -264,8 +267,10 @@ def run(self): break try: - if not task(**self.kwargs): - break + task(**self.kwargs) + except ManagedThread.StopEvent: + print_colored_debug(f"Thread {self.name} stopped.\n", "green") + break except Exception as e: logger.error( f"Error in thread {self.name}: {e}\n{traceback.format_exc()}"