From 50e268686829491626d5a3d1511b7e4192b72645 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Mon, 29 Apr 2024 16:27:23 -0600 Subject: [PATCH] Merge with `mlc-ai/main` (`d3d264d4b05d73e9757375013b842254f052c6ed`, April 29th 2024) (#265) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Serving][Grammar] BNF grammar simplifier and matcher (#1801) * [Serving] LogProbs support (#1832) This PR introduces the logprobs support with OpenAI API compatibility. It enhances the sampler with a function to get the top-probability tokens (supporting 5 tokens at most as of now). To make it easy to pass logprob results back from serving engine to frontend, we choose to pass logprob results in JSON string with OpenAI API spec. Unit tests are added to ensure the correctness of logprobs. And the logprobs support also work with speculative decoding. * [Serving] Support Mixtral in MLC Serve (#1840) This PR supports Mixtral in MLC serve. The main thing is only introducing the Mistral conversation template to Python registry so that MLC Serve can use. Besides that, this PR updates the KV cache capacity analysis to make it more accurate in terms of usage calculation, while being conservative since there is a known issue regarding batch-prefill embedding taking which may lead to OOM. We will reset the follow up on the issue with a fix in the future and then enable the estimation to use more GPU vRAM. * [Fix] Fix `u_char` for Windows build (#1848) Prior to this PR, `u_char` was used while it is not a standard type in C++, which causes Windows build failure. This PR fixes it by using `unsigned char`. * Auto updated submodule references * [Fix] Add phi lm head name to is_final_fc, add q4f16_ft to CI (#1849) [Fix] Add phi lm head name to is_final_fc * [Build] Replace mod_transform_before_build with IRModule pass (#1852) Instead of a python function that returns an updated `IRModule`, the new `optimize_mod_pipeline` function returns a `tvm.ir.transform.Pass` which can be applied to an `IRModule`. * [SLM] Add support for InternLM architecture (#1835) * Create __init__.py * Add files via upload * Update model.py * Update model_preset.py * Update conv_templates.cc * Update internlm_loader.py * Update internlm_quantization.py * fix name of notes * Update model.py * Migration * fix pylint issue * fix pylint issue * fix pylint error * Update internlm_loader.py * Update __init__.py * Update __init__.py * Delete python/mlc_chat/model/internlm/__init__.py * Add files via upload * [Bugfix] Handle model names with multiple path components (#1851) Prior to this commit, a model name with multiple path components (e.g. `dist/models/group_name/model_name`) would have duplicated path components (e.g. `dist/group_name/artifact_path/group_name/libname.so`). This commit resolves the duplication. * [KVCache] Add max num threads awareness to KVCache kernels (#1822) * [KVCache] Add max num threads to KVCache kernels, fix WebGPU * Read max_num_threads_per_block when available * Change merge state in place kernel * Make attention decode aware of max num threads, not just webgpu Co-authored-by: Egor Churaev * Change util function name --------- Co-authored-by: Egor Churaev * [KVCache] Migrate Baichuan model to PagedKVCache (#1854) * [Python] Lazy import of transformers for tiktoken conversion (#1860) This PR moves the import of transformers into the function body of tiktoken tokenizer conversion, so we do not have a force dependency on transformers. * [SLM] RWKV5 World Support (#1787) This PR adds RWKV5 support with RNNState, a similar interface as PagedAttention. Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> * [Serving] Register the ChatML conversation template (#1862) Following #1854 , this pr registers the ChatML conversation template. * [Utils][Transform] Added SetEntryFuncs transform (#1855) Sets the entry functions for a module. This utility is intended for cases where only module contains several externally-exposed functions, and only one is desired for use. (e.g. Separating out a `transform_params` function from an `IRModule` that also contains inference functions.) This commit only updates the external visibility, after which `relax.transform.DeadCodeElimination()` can be applied. * [Build] Update transform_params_for_each_rank to IRModule pass (#1856) This allows it to be used as part of a optimization pipeline specified as a `tvm.ir.transform.Sequential`. * [Serving][Grammar] Integrate JSON grammar into the generation pipeline (#1867) This PR is the 3rd part of the grammar-guided generation. This intregrates the grammar framework into the generation process, and supports JSON output for now. The API this PR provides is compatible with the OpenAI api. ### APIs #### Python API ``` @dataclass class ResponseFormat: type: Literal["text", "json_object"] = "text" json_schema: Optional[str] = None @dataclass class GenerationConfig: response_format: ResponseFormat = ResponseFormat(type="text") ``` #### Rest API ``` response_format: { "type": "text" } # text generation, by default response_format: { "type": "json_object" } # json generation response_format: { "type": "json_object", json_schema="..."} # json generation with schema ``` JSON generation with schema is not supported yet, but has been planned to be realized in the future. ### Performance #### Without JSON ``` Single token prefill latency: 891.2234 ms/tok Single token decode latency: 31.3399 ms/tok Prefill token throughput: 4693.3077 tok/s Decode token throughput: 226.4406 tok/s Overall token throughput: 470.3180 tok/s ``` #### With JSON ``` Single token prefill latency: 219.2287 ms/tok Single token decode latency: 29.1399 ms/tok Prefill token throughput: 7392.1555 tok/s Decode token throughput: 179.2296 tok/s Overall token throughput: 1052.1996 tok/s ``` We observed a slight decrease in performance under JSON mode. This will be further optimized in the future. * [Serving] Support "n" for parallel generation (#1868) This PR brings field `n` to generation config and thereby supports parallel generation. This parallel generation effectively leverages the "fork" functionality of paged KV cache. This PR supports specifying the number of parallel generation `n` in stardard OpenAI ChatCompletion API. This is the last feature towards the OpenAI API feature completeness. * [CI] Add retry to scm checkout (#1869) Sometimes scm checkout can timeout, this PR add retry to that * [Attn] Use float32 accumulation in attention kernel (#1870) Prior to this PR, the TIR attention kernels does not cast matmul operands to fp32 before multiplying. For models like Phi-2 which may have large Q/K/V data (at the level of a few hundreds), the fp16 multiplication exceeds the range of fp16, and lead to attention result being NAN sometimes. This PR fixes this issue. * [Utils] Allow ReorderTransformFunc to be used without param manager (#1857) Prior to this commit, the `ReorderTransformFunc` required several components of the `ParamManager` to use. The functionality it provides, reordering dataflow blocks to minimize the liveset, is useful outside of the context of the `ParamManager`. This commit makes the following changes, allowing it to be used independently of the `ParamManager`. - Generate the `pidx2binname` dictionary outside of `ReorderTransformFunc` - Allow parameters to be separate `func.params`, rather than a single bundled tuple parameter. * [SLM] Migrate Phi-2 to paged KV Cache #1871 (#1872) This PR migrates Phi-2 for Paged KV cache Attention as a part of Model definition migration according to #1749 . Co-authored-by: Shrey Gupta * [Fix] Fix the use of "call_inplace_packed" and "call_pure_packed" (#1874) The use of `call_inplace_packed` and `call_pure_packed` in the old flow is outdated due to signature changes. This PR fixes the issue. * [Fix] Add the missing BundleModelParams pass (#1875) PR #1852 missed to apply the BundleModelParams pass and thus made the compiled models not runnable through ChatModule (#1864). This PR fixes the issue. * [Docs] Update Android APK download link (#1876) As pointed out by #1830, this PR fixes the Android app download link in docs. * Fix MLC-LLM website link weight convert not accessible (#1877) Fix website link not accessible * [Serving][Grammar] Support termination state in GrammarStateMatcher (#1884) * [Serving] Make RequestState as a standalone object class (#1878) This PR adopts suggestions from the support of OpenAI API parallel generation `n` in #1868. The main update in this PR is to make the RequestState as a standalone object class, which was a typedef from `std::vector` before. This PR also fixes a bug in prefill that will cause engine failure when `n` is large. * [SLM] Update StableLM model and migrate it to paged KV Cache (#1882) * [KVCache] Qwen 1.0 Model PagedKV Support (#1887) Support Qwen1.0 Paged KV Cache * [Serving] Estimate KV cache memory usage with metadata (#1888) Prior to this PR, the serving engine memory usage estimation reads model config for fields such as `num_key_value_heads`, `num_hidden_layers`, etc.. However, since not every model share the same set of config names (#1854), the estimation fails for models that do not have this set of config field names. This PR makes the following changes. First, it attaches these field values into the model's metadata, in which way we unify the field names for different models effectively. Then, when estimating the memory usage, we read these fields from the metadata, rather than model config, so we are safe for the name inconsistency. * [KVCache] Migrate bigcode arch to PagedKVCache (#1891) Compilation and runtime smooth. I will open follow-up PRs to enable starcoder2 support in the same model definition file * [Serving] Add Phi-2 conv template to mlc serve (#1890) This PR adds the phi-2 model template to MLC serve. For testing 1. Start server ```python -m mlc_chat.serve.server --model ./dist/phi-2-q4f16_1-MLC/ --model-lib-path ./dist/phi-2-q4f16_1-MLC/phi-2-q4f16_1-cuda.so --device auto --max-batch-size 2 --enable-tracing --host 127.0.0.1 --port 8000 --max-total-seq-length 8000``` 2. Send request ```python test_server_rest_api.py``` ```python # test_server_rest_api.py import requests import json model = "./dist/phi-2-q4f16_1-MLC/" port = 8000 payload = { "model": f"{model}", "messages": [{"role": "user", "content": "Tell me about Machine Learning in 200 words."}], "stream": False, } r = requests.post(f"http://127.0.0.1:{port}/v1/chat/completions", json=payload) if r.status_code != 200: print(r.json()) else: print(r.json()["choices"][0]["message"]["content"]) ``` * [Attn] Fix attention kernel for head dim not divisble by 32 (#1889) Prior to this PR, our TIR prefill attention kernel assumes the head dim to be a multiple of 32. As reported by #1826, this assumption does not always hold. This PR fixes this issue so that models with different head dim can also compile. * [Python] Enable "thrust" for CUDA by default (#1866) This PR enables thrust for CUDA targets so that we can dispatch some operators (e.g., cumsum) to thrust. * [Serving] Fix loading presharded weights (#1894) * [Serving] Address embedding lookup OOM issue (#1899) This PR addresses the OOM issue that may be caused by embedding lookup when the batch size of a prefill action is large. Prior to this PR, a large embedding tensor will be created for each sequence in the prefilled batch, thus may take unexpectedly large memory when the batch size is large. * [Model] Remove redundant `batch_forward` and move broadcast (#1900) This PR contains four changes: 1. It removes the duplicate `batch_forward` defined in model definitions. This function was widely used prior to our migration to PagedKVCache, since before migration the attention codepath of single sequence forward and batch forward differ. But since our migration, the codepaths are unified into one, and therefore we can safely remove most `batch_forward` functions. 2. It moves `op.ccl_broadcast_from_worker0` from model main forward (which will be called at the beginning of prefill/decode) to embedding. This change has two benefits. Firstly, the token ids taken by `embed` was not broadcasted across workers, and it is possible for workers other than 0 to have illegal token ids which is not in the range of vocab size, and moving the broadcasting to `embed` perfectly address this issue. Secondly, broadcasting token ids in `embed` is more lightweight than broadcasting embeddings in `prefill`/`decode`, since the tensor size of token ids is much smaller. 3. It adds `max_batch_size` to the config class of models, so that they are potentially compatible with batching and MLC serve. 4. It removes the `k_cache` and `v_cache` effects from the models that have switched to PagedKVCache support. Randomly picked a few models (as below) to run the engine test, and all of them are passed: * phi-2 with tp=2, * RedPajama with tp=2, * stablelm with tp=2 (since stablelm does not support TP right now). * [KVCache]Migrate Qwen2 model to PagedKVCache (#1903) * [CI] Skip not supported quantization in model compilation test (#1904) This PR updates the model compilation test so that it will now skip a quantization when the model does not support. * [Serving] Add missing header for `std::iota` (#1905) The header `` was missed, which may have caused build failure on Windows. This PR adds the header. * [Serving] Fix Model TokenEmbed function with TP (#1906) This PR fixes a severe bug introduced by #1899. Since #1899, we no longer copy the embedding back from worker 0 when using tensor parallelism. However, we did not synchronize with the worker 0. This will cause the following issue: in batch prefill, we will continuously call TokenEmbed for multiple times. Each time, we will copy the token ids to the `token_ids` NDArray on worker 0. If we do not synchronize with worker 0, then it is possible that the local token ids have been updated for multiple times, before the first `CopyToWorker0` really starts to execute on the worker 0 side. As a result, at the time of executing the token ids copy to worker 0, the local token ids might be wrong (by "wrong", say we are executing the copying of seq 0's token ids, then the actual local token ids array might have already been seq 3's token ids). As a result, the issue will cause the batch prefill behave completely wrong. This PR adds a synchronization with worker 0 explicitly. * [SLM] Add support for Orion architecture. (#1883) This is a PR for supporting [OrionStarAI/Orion-14B-Chat](https://huggingface.co/OrionStarAI/Orion-14B-Chat). * [Model] Eliminate the reshape in embedding func (#1908) Prior to this PR, there is a trailing reshape kernel at the end of the embedding func. The reshape is not necessarily needed to be as a kernel, which consumes extra time during execution. This PR eliminates the reshape in the embedding function by updating the signature of the embedding func, so that now it only takes the plain 1D token ids as input. * [Pass] Low batch GEMM using GEMV-like schedule (#1769) When batch size is small, GEMM in MLP of decode stage can be dispatched into a specialized GEMV-like schedule to improve efficiency. GEMM with a dynamic var in spatial axis will now be lowered into ```python if dyn_var <= 8: low_batch_gemv() else: normal_gemm() ``` * Auto updated submodule references * [Serving] Avoid unnecessary worker sync in Model (#1909) Following up #1906, this PR removes the synchronization given it is avoidable. We use another approach to avoid the write-after-write issue. The key to address the issue is to make sure the addresses to be copied to worker 0 is not rewritten before the copy actually happens. So we pre-allocate a large host array to hold all the token ids, and for each sequence, we copy its token ids to the offset given when calling TokenEmbed, so that we can make sure an address will not be written twice before copy happens. * [Serving][Grammar] Enhance GrammarStateMatcher to support general grammar (#1917) * [Android] Improve perf of TIR PagedAttn kernel on Android (#1915) * android perf * Update kv_cache.py * Deprecate old flow (#1928) * Deprecate old flow This PR deprecates the old flow. As of today most of the efforts are centralized around the new flow with SLM compilation. Additionally, we are bringing model definitions through unified kv interface so we can have a single model across all backends, server and local setting. We kept the old flow around for a while, but it is a good time to do the transition. All the documents are updated to point to the new flow. We also created a backup branch https://github.com/mlc-ai/mlc-llm/tree/backup-before-old-flow-deprecation for people who would like to checkout some of the old flow references. * Remove deprecated prebuilts * [Serving] Register the StableLM3B conversation template (#1920) Update conversation_template.py * Remove deprecated build.py * [Fix] KVCache creation with call_pure_packed (#1930) With https://github.com/apache/tvm/pull/16684 merged in, the KV cache creation will fail when compiling models. This PR fixes the problem by using `call_pure_packed`. * [KVCache] Update FlashInfer PackedFunc names (#1931) This PR updates the FlashInfer names given https://github.com/apache/tvm/pull/16692 has been merged. * [REFACTOR] remove tests/legacy-python (#1933) This PR removes the folder tests/legacy-python as a followup cleanup step of the old flow Some of the files like compare lib are useful and we should recover them later at mlc_llm.testing.DebugChat flow * [REFACTOR] rename mlc_chat => mlc_llm (#1932) This PR renames the mlc_chat pckage to the mlc_llm package now that this is the new official flow. We also update the necessary locations that might touch the package. * Auto updated submodule references * [Docs] Deprecating CUDA 11.7/11.8 support (#1939) We have deprecated the wheel support for CUDA 11.7/11.8 due to TVM thrust compatibility with old CUDA versions. * [Fix] Fix KV cache call in mistral (#1938) The latest TVM introduces the wellformedness check of the IR. The mistral model definition breaks the wellformedness due to the purity. This PR fixes this issue. * [ChatModule] Remove eos_token_ids (#1940) This PR removes the eos_token_ids from the ChatModule given it is nowhere used actually. * [SLM] Weight conversion with generator (#1916) This PR enhances weight conversion so that it passes a generator to `tvmjs.dump_ndarray_cache`. This effectively reduces the CPU memory pressure when converting weights, especially when the total converted weight size is close to or larger to the CPU memory size. * [Serve] Introducing GPU sampler for CUDA (#1934) This PR introduces the GPU sampler for CUDA only. The GPU sampler makes use of the GPU sampling ops introduced in apache/tvm#16575. We will follow up to benchmark the performance of the GPU sampler over CPU sampler. * [Serve] Constrain KV cache capacity on Metal (#1943) This PR constrains the KV cache capacity for Metal devices to 32768, in order to avoid large tensors in KV cache. This is because right now Metal runtime has performance issue when running a kernel where when some input buffer is very large, even if little of the large buffer is accesed in the kernel. * [CI] Add windows ci (#1942) This PR adds windows CI. * Auto updated submodule references * [Fix] Fix embedding shape check in ChatModule (#1953) This PR is a fix to address #1952. * [Fix] Fetching the Git-LFS tokenizer files (#1954) Prior to this PR, when running commands like ```shell python3 -m mlc_chat chat HF://mlc-ai/gemma-7b-it-q4f16_2-MLC ``` only the binary weight files are downloaded, among all the Git LFS files. For models like Gemma whose tokenizer is large and also in Git LFS file, the tokenizer files are not effectively downloaded automatically. For example, the cloned Gemma `tokenizer.json` file has content ``` version https://git-lfs.github.com/spec/v1 oid sha256:05e97791a5e007260de1db7e1692e53150e08cea481e2bf25435553380c147ee size 17477929 ``` and this content is never realized to the actual tokenizer. This will lead to the issue of #1913. This PR fixes the issue by pulling all the Git LFS files that are not binary files. * [LogitProcessor] Add max thread awareness to logit processing kernels (#1955) Make the kernels in `AttachLogitProcessFunc` to be aware of maximum threads, fixing https://github.com/mlc-ai/mlc-llm/issues/1951. Most code change is due to indentation, the main change is changing `1024` to `tx`, where `tx` is ``` tx = 1024 # default max_num_threads_per_block = get_max_num_threads_per_block(target) if max_num_threads_per_block < tx: tx = max_num_threads_per_block check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1) ``` * [Model] Use static hidden size in mixtral scatter_output (#1959) * Auto updated submodule references * [CompilerFlag] Detect if FlashInfer is enabled from libinfo (#1941) This PR supports the detection of if FlashInfer is enabled when building TVM, so that FlashInfer won't be enabled when TVM is not built with FlashInfer enabled. * [Serving][Grammar] Add grammar termination as a stop condition (#1964) * Unify schema for conversation template and embed into mlc-chat-config.json (#1965) * [SLM] Small correction on Stablelm and Qwen2. (#1958) * small fix * small fix * Update stablelm_model.py * [Serving][Fix] Fix JSON output check in test_server.py (#1966) `test_server::is_json_or_json_prefix` is used to check the output is JSON or a prefix of JSON. It uses json.loads internally. However, json.loads (i.e. json.decode) is token-based instead of char based. If half a token is left at the end of the string, it cannot be matched. This PR adds another check for the rest "half a token" if it exists. * [Model] Migrate Mistral to use PagedKVCache (#1967) This PR migrates the mistral model to the PagedKVCache interface which supports sliding window attention with paged attention kernel written in TensorIR. We thereby introduce a `support_sliding_window` mode for KV cache, which leaves space for supporting sliding window for any model at runtime. This PR tests the mistral on with both chat and serve. The chat performance of Mistral 7B gets improvement than before, benefitted from the paged attention implementation. * Auto updated submodule references * [REST] Update Rest API docs for the latest serve flow (#1972) * [Docs][Upd] Server launch, examples for endpoints for MLC Serve * remove v1/completions * add api docs to rest --------- Co-authored-by: Shrey Gupta * [Conv] Add bos_token to llama and mistral in ConvTemplateRegistry (#1970) Since we don't have the `add_bos` field in the new Conversation template, we should add the bos token into the system_prefix_token_ids, so that it will be added to the tokenized prompt. * [Model][Serve] Add support for LLaVa model in serving engine (#1974) This PR adds support for LLaVa-v1.5 model on the serving engine. Use the HF weights and config from https://huggingface.co/llava-hf/llava-1.5-7b-hf. Passing image input is supported as url (reference: https://platform.openai.com/docs/guides/vision) Example: ```python data = { "model": "dist/llava-1.5-7b-hf-q4f16_1-MLC/params/", "messages": [ { "role": "user", "content": [ { "type": "image_url", "image_url": "https://llava-vl.github.io/static/images/view.jpg", }, {"type": "text", "text": "What does this image represent?"}, ], } ] } response = requests.post("http://127.0.0.1:8000/v1/chat/completions", json=data) print("Response body:", response.text) ``` * [Serve] Hot fix for the mixtral serving (#1975) [Fix] hotfix for the mixtral serving Co-authored-by: Yong Wu * [REST] REST API Deprecated (#1973) Deleted old Rest API - Removed rest.py - Removed old interface/openai_api.py - Update ChatModule to use new OpenAI Api protocol Co-authored-by: Kartik Khandelwal * [Fix] Fix handling of non-numerical cuda arch (#1976) In the latest gpu, cuda arch may not be integer, e.g `sm_90a`. This fixes a few places that rely on integer parsing. * [Serving][Grammar] Support specifying the main rule in grammar (#1982) finish * [Fix] Fix `MLC_MULTI_ARCH` with arch `sm_90a` (#1984) This PR fixes the missing patch for target with `sm_90a` arch, as follow up pr of #1976. * Fix Llama-2 and Mistral conversation template. Update ConvTemplateRegistry (#1981) The current prompt format for Llama-2 and Mistral is not completely correct. This PR updates the code to strictly follow the official prompt format for the two models. Also adds in missing conv templates to ConvTemplateRegistry. * [SpecDecode] Fix sampler selection. (#1971) This PR temporarily fixes sampler selection logic for speculative decoding. As GPU sampler support for speculative decoding is not ready, speculative decoding will use cpu sampler. * [Serving][Grammar] Utility to convert json schema to EBNF grammar (#1983) This PR adds a generic utility to convert json schema, especially generated from pydantic, to EBNF grammar. This helps the grammar guided generation when we provide a json schema as the restriction. This converter features the support of json standard indent style in the output grammar. API: ``` def json_schema_to_ebnf( json_schema: str, *, indent: Optional[int] = None, separators: Optional[Tuple[str, str]] = None, strict_mode: bool = True, ) -> str: """Convert JSON schema string to EBNF grammar string. Parameters ---------- json_schema : str The JSON schema string. indent : Optional[int] The number of spaces for each indent. If it is None, there will be no indent or newline. The indent and separators parameters follow the same convention as `json.dumps()`. separators : Optional[Tuple[str, str]] The separator between different elements in json. Examples include "," and ", ". strict_mode : bool Whether to use strict mode. In strict mode, the generated grammar will not allow unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. This helps LLM to generate accurate output in the grammar-guided generation with JSON schema. """ pass ``` * Auto updated submodule references * [Fix] Fix serve model to adapt the latest Allocator signature (#1989) PR apache/tvm#16738 updated the Allocator signature. This PR updates the caller side accordingly. * [Model] Use optimized group gemm for Mixtral (#1988) * [Attn] Fix the construction of attn result merge kernel (#1995) This PR fixes the mistake of passing wrong number of heads to the attention result merge kernel. * [iOS][Android] Add validation of library file for iOS and Android build (#1993) This PR adds validation of symbols in iOS and android build. During static library build, we need the right model_lib for us to point to the packaged model executables. Not doing so correctly will results in vm_load_executable not found which is not informative. This PR we validate the compiled model lib by dumping the global symbols and ensure the list of model libs matches with each other. In future we should perhaps lift the validation to mlc_llm package. * Auto updated submodule references * [Serve] add allocator in Storage as the upstream change (#1997) The changes in https://github.com/apache/tvm/pull/16750 modified the signature of the Storage, this pull request updates the caller code in mlc-llm to accommodate the new Storage class signature. Ran into build error w/o the change. * [Compiler] Support IPC memory and customized all-reduce kernels (#1990) This PR introduces the IPC memory and customized all-reduce kernel dispatches for tensor parallelism. We add a new compiler flag `--allreduce-strategy`, which supports `"ring"`, `"one-shot"` and `"two-shot"`. The flag defaults to `"ring"`, which means this PR makes no difference if people do not manually change the all-reduce strategy. As of now the IPC-memory-backed customized all-reduce kernels are only available on CUDA. To enable all-reduce strategies other than "ring", here are some example compile commands: ```python python -m mlc_llm compile model/mlc-chat-config.json --device cuda --opt "allreduce-strategy=one-shot" -o model/lib.so python -m mlc_llm compile model/mlc-chat-config.json --device cuda --opt "allreduce-strategy=two-shot" -o model/lib.so ``` Please be aware that, you probably also need to specify other compiler flags, for example, like `--opt "cublas_gemm=1;allreduce-strategy=one-shot"`. * Auto updated submodule references * [Model] Fix the top-k TIR script for well-formedness (#2002) This PR fixes the malformed MoE TIR scripts. * Fix invalid use of dataflow var in sampler output (#2003) * [Fix] Fix KV cache creation pass after nn.Module changes (#2011) This PR corrects the assertion after latest changes in apache/tvm that updates some nn.Module behavior. * [iOS] Fix typo in prepare_model_lib.py (#2013) Fix typo in prepare_model_lib.py tar_list.append(valid_paths[ls0]) is introduced by mistake in https://github.com/mlc-ai/mlc-llm/pull/1993 * Remove unstable assertion in KV cache creation dispatch (#2017) This particular assertion is unstable recently given the back-and-forth upstream TVM nn.Module exporter behavior. * Auto updated submodule references * [SLM] Qwen2 Multi-GPU support (#1985) * Update qwen2_model.py * fix lint issue * fix lint issue * fix lint issue * more info for preshard (#2027) * When the pre-sharded version of a certain model is not available, the program will default back to the normal workflow without issuing any alert. Now, when someone attempts to convert to a pre-sharded model but cannot, the program will throw a warning message to inform users that it will revert to the standard model conversion process. * format fix. * black reformatted, i did not see any diff. * black reformatted.. * Register stablelm-2 conversation template (#2029) * [Serving][Fix] Fix problems in PopenServer (#2032) This PR fixes several problems in the PopenServer: - Add check for the server is not started and the request returns a fail number, e.g. 502. And changed the retry time to 0.1s. - Add a `__enter__` and `__exit__` method for PopenServer. When the program is interrupted, using with clause (`__enter__` and `__exit__`) can ensure the server always terminates. When using `start()` and `terminate()`, the server may still be staying in the background even though the parent process ends. * [Quantization] Skip MoE gate layer (#2012) This PR skips quantizing the MoE gate layer. * [Serving][Grammar] Integration of JSON schema generation (#2030) Previous PR #1983 introduced a transformation from json schema to BNF grammar. This PR further integrates the grammar from json schema to the generation pipeline, so that the engine now supports json schema output. GrammarStateInitContexts are stored in a cache, so it will not be created again with the same schema. Interface: - Python ``` @dataclass class ResponseFormat: type: Literal["text", "json_object"] = "text" schema: Optional[str] = None ``` - Rest API ``` class RequestResponseFormat(BaseModel): type: Literal["text", "json_object"] = "text" json_schema: Optional[str] = Field(default=None, alias="schema") class CompletionRequest(BaseModel): ... response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) class ChatCompletionRequest(BaseModel): ... response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) ``` Performance: We only tests single-batch performance now to show the overhead in latency. - Model: `Llama-2-7b-chat-hf-q4f16_1` - GPU: `NVIDIA GeForce RTX 3080` - CPU: `AMD Ryzen 9 5900X 12-Core Processor` ``` JSON ON Batch=1 Average prefill tokens: 651.0000 tok/req Average decode tokens: 499.0000 tok/req Single token prefill latency: 0.3140 ms/tok Single token decode latency: 8.6831 ms/tok Prefill token throughput: 3184.8002 tok/s Decode token throughput: 116.6039 tok/s JSON OFF Batch=1 Average prefill tokens: 651.0000 tok/req Average decode tokens: 499.0000 tok/req Single token prefill latency: 0.3098 ms/tok Single token decode latency: 8.6823 ms/tok Prefill token throughput: 3227.8141 tok/s Decode token throughput: 116.9251 tok/s ``` This PR also does these bug fixes / changes: - Changed the structure of the converted grammar from schema to avoid large amount of uncertain tokens, which caused a performance degradation * [Compiler] Support AUTO mode for all-reduce strategy (#2034) This PR supports the auto mode for IPC all-reduce strategy. It renames the strategy from `allreduce-strategy` to `ipc-allreduce-strategy` in the compiler optimization flags. The default RING mode is renamed to NONE mode, which, when specified, uses nccl all-reduce without any IPC memory rewrite. So right now to enable IPC all-reduce, the ideal way is to do `ipc-allreduce-strategy=auto`. * [LLaVa] Follow-up for TODOs in LLaVa model (#2010) Llava: 1. Added base64 image support. 2. Merged as_prompt and as_prompt_list. 3. get_image_from_url uses config * [Pipeline] Defer GPU IPC memory lowering (#2038) This PR moves the position of GPU IPC memory lowering pass in pipeline, so that it applies after the CUDA graph rewrite to enable CUDA graph with the customized all-reduce kernels. * [Model] Add missing broadcast of logit_position for multigpu (#2040) This commit adds the broadcasting of `logit_pos` in batch prefill for all models to avoid the logit position out-of-bound issue. * [Preshard] apply presharding after quantization (#2039) This change the behavior of presharding by apply presharding after quantization. This makes the behavior consistent with or without presharding * [SLM] Baichuan Multi-GPU support (#2037) This PR enables TP function of Baichuan2 model. * Auto updated submodule references * [Model] Skip TVMSynchronize when tracing is not enabled (#2041) This PR removes the synchronization in `Model` when Chrome tracing is not enabled. It can help some logit process kernels launching earlier. * [Serving] Support NVTX for benchmarking (#2043) This PR supports MLC serve with NVTX which helps analyzing benchmarking results. **Note.** To enable NVTX, please add `set(USE_NVTX ON)` to file `build/config.cmake`. * Update huggingface_loader.py * [Serve] Separate callback invocation to another thread in AsyncEngine (#2046) This PR enhances the AsyncThreadEngine by separating the callback invocation to another thread, in order to reduce the CPU time overhead of invoking Python callback. * [LLaVa] Fix random token output after first sentence (#2048) Fix Llava random token after first '.' token Co-authored-by: Animesh Bohara * Auto updated submodule references * [Pass] Fix LiftGlobalBufferAlloc for proper GlobalVar struct info (#2053) This PR fixes the GlobalVar struct info mismatch issue cased by pass LiftGlobalBufferAlloc after a latest TVM commit. * Auto updated submodule references * [Serving] CLI Support for SERVE (#2014) This PR adds CLI support for serve. Usage: `mlc_llm serve [Model]` refer `mlc_llm serve -h` for more options Comments - Supports JIT compilation of Model lib - Added context manager to `ServerContext` class Co-authored-by: Ruihang Lai Co-authored-by: Shrey Gupta * [Pipeline] Insert hints to enable cuda graph symbolic capture (#2050) * [Pipeline] Add pass to insert hints to enable cuda graph symbolic capture * [Loader] Print message when multi-GPU loader is finished (#2051) * [Loader] Print message when multi-GPU loader is finished * Update multi_gpu_loader.cc * fix * [KVCache] Support matching arbitrary element offset for aux data (#2057) This PR enhances the TIR attention-related functions to support matching arbitrary element offests. This makes room for the KV cache to allocate a large array the all the auxiliary data and do slicing on it. This PR should affect nothing for the current codebase, given all the element offsets are zeros as of now. * [Serving] Support copy stream in LogitProcessor and GPUSampler (#2058) This PR introduces copy stream to LogitProcessor and GPUSampler for CUDA, so that auxiliary data can be copied on a separate stream and overlap with the computation time. * [SLM] Stablelm Multi-GPU support (#2052) This PR enables TP function of Stablelm model. * [KVCache] Introducing single page copy func for KV cache fork (#2060) This PR introduces the single page copy TIR function for KV cache. This function is helpful for sequence fork at specified positions. NOTE: this PR is a breaking change, so you will need to re-compile your model and update TVM or the MLC-AI pip package to the latest. Related PR: apache/tvm#16813 Co-authored-by: Yaxing Cai * [Python] Implement testing.DebugChat for end-to-end model debugging (#2056) * [Docs] Fix docs for python server and rest call (#2066) This PR updates the MLC serve documentation for server launching. * [CI] Enable submodule clone for WASM model compilation (#2068) The incoming WASM runtime requires 3rdparty for builds. This PR enables the submodule clone for WASM model compilation in CI. * [Serve] Fork sequence at specified positions (#2067) With PagedKVCache supporting fork at a specified position, this PR updates `Model` interface accordingly. The fork position defaults to -1, which means the last position. * [SLM] Add support for RWKV6 model (#1977) * [SLM]: Support for rwkv tokenizer * [SLM] RWKV6 World Support * [Quantization] Reorganize utils code in group_quantization (#2055) * [Serving] Bugfix for empty stop string (#2070) add check for empty stop string; fix Vanilla LM conversation template * [SLM] Internlm Multi-GPU support (#2072) This PR enables tensor parallelism support for InternLM model. * [WebGPU] Add mlc wasm runtime, support grammar in web (#2061) * [WebGPU] Add mlc wasm runtime, support grammar in web * Make in web for wasm ci * Fix wasm ci * Fix wasm ci * Change export library arg name * Move macro to cc instead of makefile * [Build] Use TVM_HOME environment variable (#2073) Prior to this commit, the `CMakeLists.txt` file checked a cmake `TVM_HOME` variable, but did not check the usual `TVM_HOME` environment variable. If this variable is set, it should be used. * [Serving] Support input chunking (#2069) This PR supports input chunking with regard to customized "prefill chunk size" (field `prefill_chunk_size` in `mlc-chat-config.json`). With this PR, we can now chunk a long input into multiples when there is an upper limit on the prefill chunk size. Only `TokenData` is supported for now. * [Docs] API Code Completion Guide (#2054) * Allow "mlc_llm --host" option to override host triple the model compi… (#2074) Allow "mlc_llm --host" option to override host triple the model compile to * [Web] Move prep emcc deps script to web folder (#2077) * [SLM] Qwen Multi-GPU support (#2075) * Fix mismatch of metadata func and global symbol (#2078) * Fix mismatch of metadata func and global symbol * Update estimate_memory_usage.py * [Disco] Set worker CPU affinity with env variable (#2042) This PR enables setting the CPU affinity of disco workers in MLC, following the support in apache/tvm#16807. The purpose is to try reduce the CPU core switch overhead brought to disco workers which may cause extra bubble times in disco workers before/during tasks. We use a macro `MLC_DISCO_WORKER_CPU_BINDING` to specify the CPU affinities of workers. This is by default not used. To enable it, you can run the command like ```shell MLC_DISCO_WORKER_CPU_BINDING=64,65,66,67 python some_mlc_app.py ``` to specify the four CPU core ids for the four workers. * [Quantization] Introduce PerTensor and F8 quantization (#2079) * [Quantization] Introduce PerTensor and F8 quantization * address comments * [Serving][Refactor] Rename AsyncThreadedEngine to ThreadedEngine (#2081) This PR renames the AsyncThreadedEngine to ThreadedEngine to prepare for follow up refactors of Python interface. Meanwhile, this PR exposes a creation function for AsyncThreadedEngine so that it can be further used by others, such as JSONFFIEngine. * [Serving] Add cuda profiling in benchmark test (#2084) * [Serving] Add cuda profiling in benchmark test * [Grammar] Fix broken grammar tests (#2083) This PR fixes some grammar parser tests that were broken. * [Serving][Fix] Fix chunked prefill condition (#2082) This PR fixes a bug when trying to chunk an input and do prefill. The stats prior ot this PR was wrong. * [Conversation] Fix RedPajama conversation template (#2087) As reported and discussed in #2086, this PR fixes the RedPajama template. * [Serving][Refactor] Python interface refactor (#2085) This PR is an initial major Python interface refactor of MLC Serve. With this PR, `mlc_llm.serve` in Python now exposes two engine classes: `AsyncEngine` and `Engine`. Both classes have two entrypoints, `chat_completion` and `completion` which conform to OpenAI Python API (reference: https://github.com/openai/openai-python). As the name suggested, `AsyncEngine` works asynchronously, and `Engine` works synchronously. It worths noting that the `Engine` since this PR is different from the `Engine` so far. The new `Engine` does not provide interfaces for batch generation. For robustness and correctness, the old `Engine` in Python is moved to `mlc_llm.serve.sync_engine.SyncEngine`. We do not directly expose this SyncEngine, and it now mainly serves testing and debug purposes. It is useful to check the correctness of new features, because of its simplicity. It keeps the low-level interface to directly invoke `step()` function of the engine, and also keeps the low-level batch generation interface. Our REST API entry points defined under `mlc_llm/serve/entrypoints/` are also refactored accordingly to adapt to the latest Python API in MLC Serve. In short, most of the logic in OpenAI API entry points are moved to Python API, which simplifies the implementation of entry points. Please note that this is the first (also the largest) planned refactor. We will follow up with some other refactors, which have smaller scopes compared with this PR. The planned refactors include: * provide submodule interface to align OpenAI Python package in https://github.com/openai/openai-python * refactor the constructor interface of `Engine`/`AsyncEngine` to align the MLC serve CLI interface. * [Serving] Separating ThreadedEngine creation and initialization (#2090) This PR separates the creation and initialization of ThreadedEngine for multi-threading use cases. So we can make sure that the ThreadedEngine instance is created before any other operations (such as initialization, running background loop, etc.). * [Serving] Enhance robustness with small KV capacity (#2091) This PR enhances the robustness, which had issue when the KV capacity is small. * [REST] Update REST API docs (#2092) This updates the rest docs to use `mlc_llm serve` and also adds a quick start section. * [DOCS] Clarify vulkan loader dependency (#2095) This PR clarifies the vulkan loader dependecy. Some system may not have the right vulkan loader and we need to install them via conda. * [SLM] Add support for Chatglm3 architecture (#2096) This pr enable Chatglm3 model. * [Quantization] Add OpenCL device (#2097) This PR adds OpenCL device for weight conversion. * [Serving] Support stream=True for Python API (#2098) The previous refactoring PR formalizes the MLC serve Python API but does not respect the `stream` flag properly: no matter if `stream` is True or False, the functions always work in a streaming style. This PR supports the non-stream case. * [Serving][Refactor] OpenAI API Python interface alignment (#2099) This PR aligns the Python API of chat completions and completions MLC serve with the OpenAI Python package https://github.com/openai/openai-python. Specifically, say we first create an engine or async engine, then we can use entrance `engine.chat.completions.create(...)` for chat completions. We will add more use examples in the codebase after another few refactors. * [DOC] fix small python env install error (#2102) Fixed one slight issue of tvm install: would require specify python=3.11 on the platform otherwise might encounter python not found error. * [JSONFFIEngine] Initial implementation of JSONFFIEngine (#2101) This PR introduces initial support for the JSONFFIEngine. The request is supposed to be a JSON string in the [Chat completion request body format](https://platform.openai.com/docs/api-reference/chat/create). The output (input to the callback function provided) is a list of JSON strings in the [Chat completion chunk object format](https://platform.openai.com/docs/api-reference/chat/streaming). There is still functionality to be added, which will be added in follow-up PRs. 1. Support for other input datatypes (image, etc.) 2. Applying conversation template to input 3. Function calling and tools support 4. Generation config parameters support 5. Independent text streamers for each request 6. logprobs support --- Co-authored-by: Ruihang Lai * [Model] Use tanh approximation of GeLU in Gemma MLP (#2106) This is in line with the implementation in the [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L183) library. Also, the [gemma-1.1](https://huggingface.co/google/gemma-1.1-2b-it/blob/main/config.json#L10) model config. * Auto updated submodule references * [Quantization] Stricter checks for MoE gate (#2109) This PR strenthens the MoE gate checks to include checking number of experts, given the real MoE gate router layer's output feature number is the number of experts and is usually very small. This PR comes from a regression that there is a layer in RWKV6 that ends with name "gate" is not for MoE at all. * Auto updated submodule references * [LLaVa] Fix allowed text model value in config (#2062) * Llava support vicuna and mistral text models * Support f32 quantization * Lint fix * Use preset if transformers not installed * Rebase on main --------- Co-authored-by: Animesh Bohara * Auto updated submodule references * Revert "Allow "mlc_llm --host" option to override host triple the model compi…" (#2115) This reverts commit 12ca8fdbe2a24f43bbc72241a76735dbad8c2026. Co-authored-by: Mengshiun Yu * Revert "Auto updated submodule references" (#2117) This reverts commit c4169d8c8a4afedd06bc9d9b99c3aa65eee4a89e which causes CI broken. * [Metadata] Include picojson rather than forward declaring (#2118) This PR fixes the picojson uses in MLC that conflicts with the latest changes on the picojson side. * Auto updated submodule references * Auto updated submodule references * [Serving][Grammar] Porting the json schema converter from python to C++ (#2112) [Serve][Grammar] Porting the json schema converter from python to C++ This PR ports the json schema converter from python to C++. It defines the interface: ``` std::string JSONSchemaToEBNF( std::string schema, std::optional indent = std::nullopt, std::optional> separators = std::nullopt, bool strict_mode = true); ``` And uses it in BNFGrammar::FromSchema. This helps cases where python cannot be deployed. * [Model] Use R.topk/cumsum for mixtral (#2107) * Enable flashinfer when group_size == 6 (#2124) * [SpecDecode] Support Eagle in speculative decoding (#2080) 1. Add Eagle-Llama-7b-chat model support. 2. Add speculative decoding support with Eagle. * [Pass] Attach non-negative TIR var attributes (#2125) This PR attaches the attributes of `tir.non_negative_var` for memory planning. * [Serving][Refactor] Engine constructor interface refactor (#2126) This PR is a refactor of the engine's contructor interface and the serve CLI interface. This PR introduces the "mode" argument for engine, which has options "local", "interactive" and "server". The choice of mode will affect the automatically inferred value of `max_batch_size`, `max_total_sequence_length` and `prefill_chunk_size` (only effective when arguements are not specified. Once an argument is specified, we will not override it). For detailed specification of the mode, please check out the CLI help messages in `mlc_llm/help.py` or the engine constructor in `mlc_llm/serve/engine.py`. No matter which mode is chosen, we will print out the current mode and the values of these arguments, for peopple to understand the settings of the engine. We also provide hints on how to adjust the mode. For example, ``` [2024-04-12 16:12:26] INFO chat_module.py:379: Using model folder: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q0f16-MLC [2024-04-12 16:12:26] INFO chat_module.py:380: Using mlc chat config: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q0f16-MLC/mlc-chat-config.json [2024-04-12 16:12:26] INFO chat_module.py:529: Using library model: dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so [2024-04-12 16:12:26] INFO chat_module.py:379: Using model folder: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q4f16_1-MLC [2024-04-12 16:12:26] INFO chat_module.py:380: Using mlc chat config: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json [2024-04-12 16:12:26] INFO chat_module.py:529: Using library model: dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so [2024-04-12 16:12:29] INFO engine_base.py:382: Engine mode is "local". Max batch size is set to 4. Max KV cache token capacity is set to 4096. Prefill chunk size is set to 4096. [2024-04-12 16:12:29] INFO engine_base.py:387: Estimated total single GPU memory usage: 21543.74 MB (Parameters: 16467.64 MB. KVCache: 4450.07 MB. Temporary buffer: 626.03 MB). The actual usage might be slightly larger than the estimated number. [2024-04-12 16:12:29] INFO engine_base.py:398: Please switch to mode "server" if you want to use more GPU memory and support more concurrent requests. ``` After the refactor, we bring the speculative decoding to the serve CLI so that people can use multiple models and run speculative decoding with the server launched in CLI (which was not doable before). * [Serving] Revamp engine mode selection logging info (#2128) This PR revamps the logging info for engine mode selection to provide more detailed information and the rationale of different modes. * [SLM] Chatglm3 Multi-GPU support (#2123) This PR enables TP for Chatglm3 model. * [Serving] Fix support of large `n` under low max batch size (#2136) Prior to this PR, due to the improper prefill policy on `n` (parallel generation), the engine will loop forever when the a request has `n` larger than the maximum batch size that the engine can support. This PR fixes this issue by updating the prefill action, and with this PR, even the "interactive" engine mode can well support multiple parallel generation. After this fix, it is possible that a request require 10 parallel generation while the max batch size is 1. Given the shapes of temporary NDArrays in GPU sampler is determined by the max batch size, GPU sampler does not natively support sampling 10 tokens at a time. To approach this issue, this PR introduces chunking to GPU sampler. Therefore, in this particular case, the GPU sampler will have chunk size 1, and the 10 required samples will be processed by the GPU sampler one by one in order. Chunking is the minimum change we can do to support large `n`. * [Docs] Revamp landing page with Engine Python API and server (#2137) This PR revamps the landing documentation page. * The Python API panel is changed from showing ChatModule to showing Engine. * A new panel "REST Server" is added to show a quick start example of launching REST server and send request. * A "what to do next" section is introduced at the bottom of the landing page. Todo items for future PR: * add the page of Python API with Engine. * revamp weight conversion page. * revamp model library compilation page. * [Target] Update Target tags (#2141) The commit updates the target tags, in order to identify the different SoC hardware targets for further target-specific optimizations. Meanwhile, update the vulkan support for int64. * [Util] Support debug debug_compare (#2142) * [Minor][SpecInfer] Fix Optional FC Bias for Mixtral Eagle Model (#2146) * Add optional fc bias for mixtral. * Fix lint. * [Serving] fix hardcoded host and port in popen_server (#2147) * [Docs] Introductory tutorial (#2145) This PR updates the documentation with an introduction turorial. The landing page now directs to the quick start page and the tutorial. * [Serving] Support `DebugCallFuncOnAllAllWorker` and CUDA profiler (#2148) This PR adds a new function `DebugCallFuncOnAllAllWorker` which calls a global function of sigunature `[] -> None` on all distributed workers when tensor parallelism is enabled (or the local session itself if not enabled). As the name suggests, this function is only for the debug purpose, and we will not expose any public interface to invoke this function. This PR also introduces the global functions `"mlc.debug_cuda_profiler_start"` and `"mlc.debug_cuda_profiler_stop"`, which enables CUDA profiling when using PopenServer. * [DOCS] Update introduction (#2151) * [DOCS] Update introduction Some minor tweaks on the introduction doc * Update docs/get_started/introduction.rst Co-authored-by: Ruihang Lai --------- Co-authored-by: Ruihang Lai * [Serving][Python] Rename Engine to LLMEngine (#2152) We rename the public Python serve interface from `Engine` to `LLMEngine` (and from `AsyncEngine` to `AsyncLLMEngine` accordingly) for better class name clarity. This is because in cases people do wildcard import, in which case the name `Engine` itself does not convey enough meaning. * Auto updated submodule references * [Quantization] Add e4m3 mode and enable fp8 storage type (#2154) * [Quantization] Add e4m3 mode and enable fp8 storage type * add quantize linear flag * Revert "[Quantization] Add e4m3 mode and enable fp8 storage type" (#2158) Revert "[Quantization] Add e4m3 mode and enable fp8 storage type (#2154)" This reverts commit e9a4a0bf719a7c4fd42b438cf9e159a1e8d72590. * [Serving] EngineConfig refactor (#2159) This PR refactors EngineConfig for a cleaner interface of internal Engine constructor in MLC serve. This is a preparation step towards the engine reload/unload which will be introduced in follow-up PRs for JSONFFIEngine functionality on mobile and other platforms. * [Llama3] Support Llama 3 (#2163) * Add conv template and model preset * Fix conv template * Trivial * [Fix] Fix llama 3 conv template (#2164) Fix llama 3 conv template * Auto updated submodule references * [Serving][HotFix] No `std::move()` for disco CallPacked (#2166) The disco `CallPacked` function cannot handle `std::move()` very well. A previous engine refactor PR introduced a regression that broke our tensor parallelism support. This commit fixes the issue. * [Docs] Update example for Llama3 (#2169) This PR updates the huggingface repo examples to use Llama3. * [README] Fix broken link to Python API (#2168) * [Docs] Update README (#2170) This PR updates README for Llama3 quick start examples. * [Docs] Documentation of LLMEngine in Python API (#2172) This PR completes the documentation page of LLMEngine and AsyncLLMEngine in our Python API. * [Docs] Update project website (#2175) This PR mainly updates the project website, and also updates some minor points for other docs. * [Docs][Fix] Update index.md for jekyll failure (#2176) This PR fixes the jekyll failure of the project website by removing the citation section (having it in README is sufficient). * [Quantization] Add e4m3 mode and enable fp8 storage type (reland #2154) (#2161) * [Quantization] Add e4m3 mode and enable fp8 storage type * add quantize linear flag * [Docs] Fix API reference not displayed (#2177) This PR fixes the issue of the API reference not displayed in the documentation. * [Docs] Update project website (#2180) This PR updates the project landing website to remove some information. * [Misc] Pass env along when calling `subprocess.run` (#2179) The uses of `subprocess.run` in the codebase did not pass the environment, which may cause some issues in cases. * Change OpenAI protocol default value to None and supply using model config (#2178) * Change OpenAI protocol default value to None and supply using model config * Fix lint * [Serving][Spec] Fix the output inconsistent bug of q0f32 spec decoding (#2184) - According to https://github.com/mlc-ai/mlc-llm/issues/2167, the problem that the output of spec decoding in q0f32 is inconsistent with the single model of q0f32 has been fixed. - Modified the test_engine_generate function located in `tests/python/serve/test_serve_engine_spec.py` to support comparison of the output of a single model and the output of spec decoding - The accuracy comparison with hugging face is left (because the current version of llama-2-7b of q0f32 cannot be consistent with the output of hugging face model) - The output of spec decoding for q0f16 cannot be consistent with the output of a single model of q0f16, but this may be due to floating point errors. Co-authored-by: DearFishi * [Serving] Support ThreadedEngine Reload/Unload/Reset (#2185) This PR brings the support of reload (reload the engine with a new model), unload (unload the current running model) and reset (reset the engine to the initial states without unloading) to ThreadedEngine and JSONFFIEngine. These functions are useful for app bindings for iOS/Android. * [WASM] Support grammar schema in wasm (#2187) * [Serving] Support loading system library (#2189) This PR introduces the support of loading system libraries. Now in engine reload, when the given library path starts with `"system://"`, we recognize this as a system library and will try to load the the library from the path after the `"system://"` prefix. This PR also decouples the InitBackgroundEngine of ThreadedEngine into two parts, where the reload is now called explicitly when initializing the engine. This can be also done for the JSONFFIEngine. However, we need to move the construction of streamers in JSONFFIEngine before doing the same thing for JSONFFIEngine. So this is marked as a TODO item. * [Op] Batch verify for speculative decoding (#2186) This PR adds batch verify for spec decode ---- Co-authored-by: Wuwei Lin * [JIT] Better organize JIT and AOT handling (#2191) * [JIT] Better organize JIT and AOT handling Previously we do JIT when AOT lib lookup failed. The error message can become cryptic when JIT also fails, it will show up as cannot find None-vulkan.dll. This PR changes the behavior to only to lookup when model_lib_path is provided, or only to JIT when it is not. This will leads to cleaner error message overall. * Windows compact * More windows instructions * Fix prefill and context flag names in doc (#2192) * Update compile_models.rst Fix flag names for prefill chunk size and context window size. * Update compile_models.rst * [Docs] Update quick start to mention Llama 3 8B (#2196) This commit updates the quick start to mention Llama 3 8B instead of Llama 2 7B. The code blocks where already updated. * [SERVING] Add Conv Template and Function Calling support to JSON FFI (#2190) This PR adds conv template support to the JSON FFI Engine. Also add function calling and pass stop str to generation config. Co-authored-by: Shrey Gupta * [Serving] Paged Radix Tree for Prefix Caching (#2183) This PR introduces the Paged Radix Tree data structure, as foundation and prerequisite of prefix caching. * [Serving] Remove mandatory model check in server (#2195) This PR removes the mandatory model check in server since as of now we serve one engine at most which means there is always a unique engine being served. As issue #2155 points out, the model check in server can be a bad experience when the model string mismatches. * [Sampler] Enable GPU sampler for draft verification (#2198) * [Eagle] Attach gpu verifier to model * WIP * WIP * fix * Enable GPU verifier * lint * lint * [Eagle] Make eagle disco compatible (#2197) * [Eagle] Make BatchSelectLastHidden able to run on the controller * [Serving][Spec] Fix normal mode verification for extra draft token (#2206) This PR updates the draft verification of the normal mode speculative decoding. Prior to this PR, we did not effectively leverage all the draft tokens, and this PR fixes the issue. * [Sampler] Prob renormalization with top p for spec decoding (#2201) This PR introduces a renormalization interface with regard to top-p values for speculative decoding. This is helpful for simplifying the logic of speculative decoding verification stage, as all probs have been already updated with the top-p values and no top-p needs to be taken into consideration. So for speculative decoding, we always renorm the probability distribution before sampling/verifying. For non speculative decoding mode, we keep using the previous flow, which applies top-p together when sampling. Co-authored-by: Wuwei Lin * [Python] Rename LLMEngine to MLCEngine (#2210) This commit renames the LLMEngine to MLCEngine. * [Fix] CUDA architecture detection bug fix (#2211) This commit returns a list of integers and adds an assert to check that the string of CUDA architecture must contain numbers only. Co-authored-by: msyu * [Android ] Enable OpenCL host pointer usage (#2215) Take advantage of OpenCl host ptr that improves copy performance * [PYTHON][KVCACHE] Enhance the thread limit for opencl (#2216) It improves 2x time for tir based page attention for opencl adreno. * [Serving] Support RWKV for serving (#2111) feat: support serving for rwkv * [Serving] Remove `cli.model_metadata` import from engine base (#2226) This PR removes the imports of functions in `cli.model_metadata` from engine_base.py. The file `cli.model_metadata` is not designed for import directly, and when importing functions from the file, it repetitively reports warnings of ``` RuntimeWarning: 'mlc_llm.cli.model_metadata' found in sys.modules after import of package 'mlc_llm.cli', but prior to execution of 'mlc_llm.cli.model_metadata'; this may result in unpredictable behaviour ``` * [JSONFFIEngine] Support generation config in JSONFFIEngine. Default config values to NOT_GIVEN (#2225) * Change OpenAI protocol default value to None in JSON FFI engine * [JSONFFIEngine] Support generation config in JSONFFIEngine. Default config values to NOT_GIVEN * [Sampler] Fix GPU sampler behavior when batch size is 0 (#2234) This PR adds the early exit for the GPU sampler, which ran into GPU kernels even when the batch size is 0 prior to this commit. The 0 batch size case can happen when parallel generation of a request and engine preemption exists. In this case, the GPU sampler should just synchronization and return, and not run into any GPU kernel. * [Pass] Support two-stage softmax (#2220) This PR introduces the compiler pass that rewrites the normal softmax to a two-stage softmax. This is based on our finding that when vocabulary size is large, the normal softmax cannot have high-enough parallelism on GPU. So we partition the workload into two stages for better parallelism and better performance. * Auto updated submodule references * [Docs] Update deploy/ios#bring-your-own-model-library (#2235) remove model metadata step (#1) * remove model metadata step and make minor fixes * [Op] Top-p cutoff pivot (#2221) This commit introduces the GPU top-p cutoff operator for efficient probability renormalization under top-p. * [Op] Batch Verify: accept proposal when p and q are close enough (#2236) * dev * dev * [Serving] Creating EngineConfig from JSON (#2237) This PR supports creating EngineConfig from a JSON string, which is useful for JSONFFIEngine and its API bindings. This commit also removes the device from the EngineConfig for better clarity. * [Bugfix] layer_norm_eps in GPT2Config should be float (#2240) * [REFACTOR] Migrate JSONFFIEngine to formal namespace (#2241) This PR migrates JSONFFIEngine to a formal namespace. Also list TODOs to further simplify the JSONFFIEngine. * [Serving] Share disco sessions among multiple model function tables (#2242) * [DOC] Improve Install via environment variable (#2245) improve Install via environment variable * [Sampler] FlashInfer sampling func integration (#2224) This PR integrates the sampling function in FlashInfer. We integrate the one without top-p for now. * Model Library Delivery (#2139) * add model lib delivery * fix lint * fixed --------- Co-authored-by: Yixin Dong Co-authored-by: Ruihang Lai Co-authored-by: Git bot Co-authored-by: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Co-authored-by: Eric Lunderberg Co-authored-by: Shushi Hong <820958424@qq.com> Co-authored-by: Egor Churaev Co-authored-by: Siyuan Feng Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: Tianqi Chen Co-authored-by: Kartik Khandelwal Co-authored-by: Shrey Gupta Co-authored-by: Diego Cao <50705298+DiegoCao@users.noreply.github.com> Co-authored-by: David Pissarra <61968959+davidpissarra@users.noreply.github.com> Co-authored-by: Wuwei Lin Co-authored-by: Ricardo Lu <37237570+gesanqiu@users.noreply.github.com> Co-authored-by: Hongyi Jin Co-authored-by: Bohan Hou Co-authored-by: tqchen Co-authored-by: Rick Zhou Co-authored-by: Animesh Bohara Co-authored-by: Yong Wu Co-authored-by: Yong Wu Co-authored-by: Shrey Gupta <51860471+shreygupta2809@users.noreply.github.com> Co-authored-by: Yaxing Cai Co-authored-by: ZCHNO Co-authored-by: Andrew Co-authored-by: na20215 <78482004+na20215@users.noreply.github.com> Co-authored-by: Animesh Bohara Co-authored-by: Yogesh Garg Co-authored-by: Linyu Wu <95223577+Celve@users.noreply.github.com> Co-authored-by: Yu Xuanchi Co-authored-by: Mengshiun Yu Co-authored-by: Jeethu Rao Co-authored-by: Xiyou Zhou Co-authored-by: Simon Willison Co-authored-by: DearFishi <89983913+DearFishi@users.noreply.github.com> Co-authored-by: DearFishi Co-authored-by: Oleh Shliazhko Co-authored-by: Ewout ter Hoeven Co-authored-by: msyu Co-authored-by: Siva Co-authored-by: krishnaraj36 Co-authored-by: Kimura (Yamakado) Nobuhiro <37305503+nobuhiroYamakado@users.noreply.github.com> Co-authored-by: Wei Tao <1136862851@qq.com> --- README.md | 180 +- android/library/prepare_libs.sh | 1 + cpp/json_ffi/config.cc | 357 +++ cpp/json_ffi/config.h | 172 ++ cpp/json_ffi/json_ffi_engine.cc | 80 +- cpp/json_ffi/json_ffi_engine.h | 3 + cpp/json_ffi/openai_api_protocol.cc | 278 +- cpp/json_ffi/openai_api_protocol.h | 47 +- cpp/metadata/json_parser.h | 16 + cpp/serve/config.cc | 102 +- cpp/serve/config.h | 35 +- cpp/serve/engine.cc | 136 +- cpp/serve/engine.h | 6 +- cpp/serve/engine_actions/action_commons.h | 2 +- cpp/serve/engine_actions/batch_decode.cc | 2 +- cpp/serve/engine_actions/batch_draft.cc | 6 +- cpp/serve/engine_actions/batch_verify.cc | 60 +- cpp/serve/engine_actions/eagle_batch_draft.cc | 6 +- .../engine_actions/eagle_batch_verify.cc | 19 +- .../eagle_new_request_prefill.cc | 8 +- .../engine_actions/new_request_prefill.cc | 9 +- cpp/serve/event_trace_recorder.h | 2 +- cpp/serve/function_table.cc | 56 +- cpp/serve/function_table.h | 5 +- cpp/serve/grammar/grammar_serializer.h | 2 +- cpp/serve/grammar/grammar_state_matcher.cc | 5 +- cpp/serve/grammar/json_schema_converter.cc | 8 + cpp/serve/model.cc | 151 +- cpp/serve/model.h | 21 +- cpp/serve/radix_tree.cc | 718 ++++++ cpp/serve/radix_tree.h | 110 + cpp/serve/sampler/cpu_sampler.cc | 281 +- cpp/serve/sampler/gpu_sampler.cc | 343 ++- cpp/serve/sampler/sampler.h | 58 +- cpp/serve/threaded_engine.cc | 121 +- cpp/serve/threaded_engine.h | 17 +- cpp/support/utils.h | 17 + docs/compilation/compile_models.rst | 6 +- docs/deploy/cli.rst | 20 +- docs/deploy/ios.rst | 32 +- docs/deploy/python_engine.rst | 261 +- docs/deploy/rest.rst | 205 +- docs/get_started/introduction.rst | 51 +- docs/get_started/quick_start.rst | 18 +- docs/index.rst | 1 - docs/install/mlc_llm.rst | 11 +- docs/install/tvm.rst | 9 +- docs/prebuilt_models.rst | 2 +- docs/requirements.txt | 4 + examples/python/sample_mlc_engine.py | 6 +- python/mlc_llm/chat_module.py | 8 +- python/mlc_llm/cli/delivery.py | 10 +- python/mlc_llm/cli/lib_delivery.py | 200 ++ python/mlc_llm/cli/model_metadata.py | 4 +- python/mlc_llm/cli/serve.py | 4 + .../mlc_llm/compiler_pass/attach_sampler.py | 109 +- .../compiler_pass/estimate_memory_usage.py | 2 + python/mlc_llm/compiler_pass/pipeline.py | 2 + .../mlc_llm/compiler_pass/rewrite_softmax.py | 190 ++ python/mlc_llm/conversation_template.py | 23 +- python/mlc_llm/help.py | 9 +- python/mlc_llm/interface/compiler_flags.py | 1 - python/mlc_llm/interface/convert_weight.py | 3 +- python/mlc_llm/interface/gen_config.py | 1 + python/mlc_llm/interface/jit.py | 6 +- python/mlc_llm/interface/serve.py | 4 +- python/mlc_llm/json_ffi/__init__.py | 8 + python/mlc_llm/json_ffi/engine.py | 310 +++ python/mlc_llm/model/gpt2/gpt2_model.py | 2 +- python/mlc_llm/model/llama/llama_model.py | 133 +- python/mlc_llm/model/model.py | 2 +- python/mlc_llm/model/model_preset.py | 50 + python/mlc_llm/model/rwkv5/rwkv5_model.py | 70 +- python/mlc_llm/model/rwkv6/rwkv6_model.py | 68 +- python/mlc_llm/nn/kv_cache.py | 2 +- python/mlc_llm/op/__init__.py | 3 + python/mlc_llm/op/batch_spec_verify.py | 177 ++ python/mlc_llm/op/moe_matmul.py | 3 +- python/mlc_llm/op/top_p_pivot.py | 315 +++ python/mlc_llm/protocol/protocol_utils.py | 3 +- python/mlc_llm/serve/__init__.py | 7 +- python/mlc_llm/serve/config.py | 134 +- python/mlc_llm/serve/engine.py | 2280 +++++++++++++---- python/mlc_llm/serve/engine_base.py | 391 ++- .../serve/entrypoints/debug_entrypoints.py | 41 +- .../serve/entrypoints/openai_entrypoints.py | 566 +--- python/mlc_llm/serve/event_trace_recorder.py | 2 +- python/mlc_llm/serve/grammar.py | 2 +- python/mlc_llm/serve/radix_tree.py | 150 ++ python/mlc_llm/serve/server/server_context.py | 13 +- python/mlc_llm/serve/sync_engine.py | 10 +- python/mlc_llm/support/auto_config.py | 2 +- python/mlc_llm/support/auto_device.py | 3 + python/mlc_llm/support/auto_target.py | 10 +- python/mlc_llm/support/download.py | 4 +- python/mlc_llm/support/max_thread_check.py | 2 +- python/mlc_llm/testing/debug_chat.py | 2 +- scripts/build_mlc_for_docs.sh | 8 + scripts/build_site.sh | 1 + scripts/gh_deploy_site.sh | 1 + site/index.md | 59 +- tests/python/json_ffi/test_json_ffi_engine.py | 322 +-- tests/python/op/test_batch_spec_verify.py | 160 ++ tests/python/op/test_top_p_pivot.py | 83 + tests/python/op/test_two_stage_softmax.py | 47 + tests/python/serve/evaluate_engine.py | 4 +- tests/python/serve/server/test_server.py | 63 - tests/python/serve/test_radix_tree.py | 79 + tests/python/serve/test_serve_async_engine.py | 14 +- .../serve/test_serve_async_engine_spec.py | 20 +- tests/python/serve/test_serve_engine.py | 108 +- .../python/serve/test_serve_engine_grammar.py | 12 +- tests/python/serve/test_serve_engine_image.py | 4 +- tests/python/serve/test_serve_engine_spec.py | 71 +- tests/python/serve/test_serve_sync_engine.py | 12 +- web/emcc/mlc_wasm_runtime.cc | 3 + 116 files changed, 8125 insertions(+), 2353 deletions(-) create mode 100644 cpp/json_ffi/config.cc create mode 100644 cpp/json_ffi/config.h create mode 100644 cpp/serve/radix_tree.cc create mode 100644 cpp/serve/radix_tree.h create mode 100644 python/mlc_llm/cli/lib_delivery.py create mode 100644 python/mlc_llm/compiler_pass/rewrite_softmax.py create mode 100644 python/mlc_llm/json_ffi/__init__.py create mode 100644 python/mlc_llm/json_ffi/engine.py create mode 100644 python/mlc_llm/op/batch_spec_verify.py create mode 100644 python/mlc_llm/op/top_p_pivot.py create mode 100644 python/mlc_llm/serve/radix_tree.py create mode 100755 scripts/build_mlc_for_docs.sh create mode 100644 tests/python/op/test_batch_spec_verify.py create mode 100644 tests/python/op/test_top_p_pivot.py create mode 100644 tests/python/op/test_two_stage_softmax.py create mode 100644 tests/python/serve/test_radix_tree.py diff --git a/README.md b/README.md index 9bea5ccc0e..88e3abd07d 100644 --- a/README.md +++ b/README.md @@ -50,98 +50,118 @@ -**Scalable.** MLC LLM scales universally on NVIDIA and AMD GPUs, cloud and gaming GPUs. Below -showcases our single batch decoding performance with prefilling = 1 and decoding = 256. +## Quick Start -Performance of 4-bit CodeLlama-34B and Llama2-70B on two NVIDIA RTX 4090 and two AMD Radeon 7900 XTX: -

- - -

+We introduce the quick start examples of chat CLI, Python API and REST server here to use MLC LLM. +We use 4-bit quantized 8B Llama-3 model for demonstration purpose. +The pre-quantized Llama-3 weights is available at https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC. +You can also try out unquantized Llama-3 model by replacing `q4f16_1` to `q0f16` in the examples below. +Please visit our [documentation](https://llm.mlc.ai/docs/index.html) for detailed quick start and introduction. -Scaling of fp16 and 4-bit CodeLlama-34 and Llama2-70B on A100-80G-PCIe and A10G-24G-PCIe, up to 8 GPUs: -

- -

+### Installation -## News +MLC LLM is available via [pip](https://llm.mlc.ai/docs/install/mlc_llm.html#install-mlc-packages). +It is always recommended to install it in an isolated conda virtual environment. -* [10/18/2023] [[Post]](https://blog.mlc.ai/2023/10/19/Scalable-Language-Model-Inference-on-Multiple-NVDIA-AMD-GPUs) Scalable multi-GPU support for CUDA and ROCm are official. -* [09/02/2023] Prebuilt ROCm 5.7 and CUDA 12.2 package is [available](https://llm.mlc.ai/docs/install/tvm.html#option-1-prebuilt-package). -* [08/25/2023] CodeLlama support is up. -* [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi. -* [08/09/2023] [[Post]](https://blog.mlc.ai/2023/08/09/Making-AMD-GPUs-competitive-for-LLM-inference) ROCm backend is mature to use. -* [08/02/2023] [Dockerfile](https://github.com/mlc-ai/llm-perf-bench/) is released for CUDA performance benchmarking. -* [07/19/2023] Support for Llama2-7B/13B/70B is up. -* [05/22/2023] [[Post]](https://blog.mlc.ai/2023/05/22/bringing-open-large-language-models-to-consumer-devices) RedPajama support is up. -* [05/08/2023] [[Post]](https://blog.mlc.ai/2023/05/08/bringing-hardware-accelerated-language-models-to-android-devices) MLC LLM is now available on Android. -* [05/01/2023] [[Post]](https://blog.mlc.ai/2023/05/01/bringing-accelerated-llm-to-consumer-hardware) MLC LLM is released with Metal, Vulkan and CUDA backends. -* [04/14/2023] [WebLLM](https://github.com/mlc-ai/web-llm) is released prior to MLC LLM with WebGPU and WebAssembly backend. +To verify the installation, activate your virtual environment, run -## Getting Started +```bash +python -c "import mlc_llm; print(mlc_llm.__path__)" +``` -Please visit our [documentation](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions. +You are expected to see the installation path of MLC LLM Python package. -## Model Support +### Chat CLI -MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can -use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list, and [Compile Models via MLC](https://llm.mlc.ai/docs/compilation/compile_models.html) to see how to use models not on this list. +We can try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ArchitecturePrebuilt Model Variants
LlamaLlama-2, Code Llama, Vicuna, WizardLM, WizardMath, OpenOrca Platypus2, FlagAlpha Llama-2 Chinese, georgesung Llama-2 Uncensored
GPT-NeoXRedPajama
GPT-J
RWKVRWKV-raven
MiniGPT
GPTBigCodeWizardCoder
ChatGLM
StableLM
Mistral
Phi
+```bash +mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC +``` + +It may take 1-2 minutes for the first time running this command. +After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. + +``` +You can use the following special commands: +/help print the special commands +/exit quit the cli +/stats print out the latest stats (token/sec) +/reset restart a fresh chat +/set [overrides] override settings in the generation config. For example, + `/set temperature=0.5;max_gen_len=100;stop=end,stop` + Note: Separate stop words in the `stop` option with commas (,). +Multi-line input: Use escape+enter to start a new line. + +user: What's the meaning of life +assistant: +What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + +The concept of the meaning of life has been debated and... +``` + +### Python API + +We can run the Llama-3 model with the chat completion Python API of MLC LLM. +You can save the code below into a Python file and run it. + +```python +from mlc_llm import MLCEngine + +# Create engine +model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" +engine = MLCEngine(model) + +# Run chat completion in OpenAI API. +for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, +): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) +print("\n") + +engine.terminate() +``` + +**The Python API of `mlc_llm.MLCEngine` fully aligns with OpenAI API**. +You can use MLCEngine in the same way of using +[OpenAI's Python package](https://github.com/openai/openai-python?tab=readme-ov-file#usage) +for both synchronous and asynchronous generation. + +If you would like to do concurrent asynchronous generation, you can use `mlc_llm.AsyncMLCEngine` instead. + +### REST Server + +We can launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. +The server has fully OpenAI API completeness. + +```bash +mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC +``` + +The server is hooked at `http://127.0.0.1:8000` by default, and you can use `--host` and `--port` +to set a different host and port. +When the server is ready (showing `INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)`), +we can open a new shell and send a cURL request via the following command: + +```bash +curl -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} + ] + }' \ + http://127.0.0.1:8000/v1/chat/completions +``` ## Universal Deployment APIs MLC LLM provides multiple sets of APIs across platforms and environments. These include -* [Python API](https://llm.mlc.ai/docs/deploy/python.html) +* [Python API](https://llm.mlc.ai/docs/deploy/python_engine.html) * [OpenAI-compatible Rest-API](https://llm.mlc.ai/docs/deploy/rest.html) * [C++ API](https://llm.mlc.ai/docs/deploy/cli.html) * [JavaScript API](https://llm.mlc.ai/docs/deploy/javascript.html) and [Web LLM](https://github.com/mlc-ai/web-llm) @@ -165,7 +185,7 @@ The underlying techniques of MLC LLM include:
References (Click to expand) - + ```bibtex @inproceedings{tensorir, author = {Feng, Siyuan and Hou, Bohan and Jin, Hongyi and Lin, Wuwei and Shao, Junru and Lai, Ruihang and Ye, Zihao and Zheng, Lianmin and Yu, Cody Hao and Yu, Yong and Chen, Tianqi}, diff --git a/android/library/prepare_libs.sh b/android/library/prepare_libs.sh index a06e9f067d..c089927d09 100755 --- a/android/library/prepare_libs.sh +++ b/android/library/prepare_libs.sh @@ -27,6 +27,7 @@ cmake .. \ -DMLC_LLM_INSTALL_STATIC_LIB=ON \ -DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON \ -DUSE_OPENCL=ON \ + -DUSE_OPENCL_ENABLE_HOST_PTR=ON \ -DUSE_CUSTOM_LOGGING=ON \ cmake --build . --target tvm4j_runtime_packed --config release diff --git a/cpp/json_ffi/config.cc b/cpp/json_ffi/config.cc new file mode 100644 index 0000000000..8f5c0e1062 --- /dev/null +++ b/cpp/json_ffi/config.cc @@ -0,0 +1,357 @@ +#include "config.h" + +#include + +#include "../metadata/json_parser.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +using namespace mlc::llm; + +/****************** Model-defined generation config ******************/ + +TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode); + +ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p, + double frequency_penalty, + double presence_penalty) { + ObjectPtr n = make_object(); + n->temperature = temperature; + n->top_p = top_p; + n->frequency_penalty = frequency_penalty; + n->presence_penalty = presence_penalty; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig") + .set_body_typed([](double temperature, double top_p, double frequency_penalty, + double presence_penalty) { + return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty); + }); + +/****************** Conversation template ******************/ + +std::map PLACEHOLDERS = { + {MessagePlaceholders::SYSTEM, "{system_message}"}, + {MessagePlaceholders::USER, "{user_message}"}, + {MessagePlaceholders::ASSISTANT, "{assistant_message}"}, + {MessagePlaceholders::TOOL, "{tool_message}"}, + {MessagePlaceholders::FUNCTION, "{function_string}"}}; + +MessagePlaceholders MessagePlaceholderFromString(const std::string& role) { + static const std::unordered_map enum_map = { + {"system", MessagePlaceholders::SYSTEM}, {"user", MessagePlaceholders::USER}, + {"assistant", MessagePlaceholders::ASSISTANT}, {"tool", MessagePlaceholders::TOOL}, + {"function", MessagePlaceholders::FUNCTION}, + }; + + return enum_map.at(role); +} + +Conversation::Conversation() + : role_templates({{"user", PLACEHOLDERS[MessagePlaceholders::USER]}, + {"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]}, + {"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {} + +std::vector Conversation::CheckMessageSeps(std::vector& seps) { + if (seps.size() == 0 || seps.size() > 2) { + throw std::invalid_argument("seps should have size 1 or 2."); + } + return seps; +} + +std::optional> Conversation::AsPrompt(std::string* err) { + // Get the system message + std::string system_msg = system_template; + size_t pos = system_msg.find(PLACEHOLDERS[MessagePlaceholders::SYSTEM]); + if (pos != std::string::npos) { + system_msg.replace(pos, PLACEHOLDERS[MessagePlaceholders::SYSTEM].length(), + this->system_message); + } + + // Get the message strings + std::vector message_list; + std::vector separators = seps; + if (separators.size() == 1) { + separators.push_back(separators[0]); + } + + if (!system_msg.empty()) { + system_msg += separators[0]; + message_list.push_back(TextData(system_message)); + } + + for (int i = 0; i < messages.size(); i++) { + std::string role = messages[i].role; + std::optional>> content = + messages[i].content; + if (roles.find(role) == roles.end()) { + *err += "\nRole " + role + " is not supported. "; + return std::nullopt; + } + + std::string separator = separators[role == "assistant"]; // check assistant role + + // If content is empty, add the role and separator + // assistant's turn to generate text + if (!content.has_value()) { + message_list.push_back(TextData(roles[role] + role_empty_sep)); + continue; + } + + std::string message = ""; + std::string role_prefix = ""; + // Do not append role prefix if this is the first message and there + // is already a system message + if (add_role_after_system_message || system_msg.empty() || i != 0) { + role_prefix = roles[role] + role_content_sep; + } + + message += role_prefix; + + for (auto& item : content.value()) { + if (item.find("type") == item.end()) { + *err += "Content item should have a type field"; + return std::nullopt; + } + if (item["type"] == "text") { + if (item.find("text") == item.end()) { + *err += "Content item should have a text field"; + return std::nullopt; + } + // replace placeholder[ROLE] with input message from role + std::string role_text = role_templates[role]; + std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)]; + size_t pos = role_text.find(placeholder); + if (pos != std::string::npos) { + role_text.replace(pos, placeholder.length(), item["text"]); + } + if (use_function_calling.has_value() && use_function_calling.value()) { + // replace placeholder[FUNCTION] with function_string + // this assumes function calling is used for a single request scenario only + if (!function_string.has_value()) { + *err += "Function string is required for function calling"; + return std::nullopt; + } + pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]); + if (pos != std::string::npos) { + role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(), + function_string.value()); + } + } + message += role_text; + } else { + *err += "Unsupported content type: " + item["type"]; + return std::nullopt; + } + } + + message += separator; + message_list.push_back(TextData(message)); + } + + return message_list; +} + +std::optional Conversation::FromJSON(const picojson::object& json, std::string* err) { + Conversation conv; + + // name + std::string name; + if (json::ParseJSONField(json, "name", name, err, false)) { + conv.name = name; + } + + std::string system_template; + if (!json::ParseJSONField(json, "system_template", system_template, err, true)) { + return std::nullopt; + } + conv.system_template = system_template; + + std::string system_message; + if (!json::ParseJSONField(json, "system_message", system_message, err, true)) { + return std::nullopt; + } + conv.system_message = system_message; + + picojson::array system_prefix_token_ids_arr; + if (json::ParseJSONField(json, "system_prefix_token_ids", system_prefix_token_ids_arr, err, + false)) { + std::vector system_prefix_token_ids; + for (const auto& token_id : system_prefix_token_ids_arr) { + if (!token_id.is()) { + *err += "system_prefix_token_ids should be an array of integers."; + return std::nullopt; + } + system_prefix_token_ids.push_back(token_id.get()); + } + conv.system_prefix_token_ids = system_prefix_token_ids; + } + + bool add_role_after_system_message; + if (!json::ParseJSONField(json, "add_role_after_system_message", add_role_after_system_message, + err, true)) { + return std::nullopt; + } + conv.add_role_after_system_message = add_role_after_system_message; + + picojson::object roles_object; + if (!json::ParseJSONField(json, "roles", roles_object, err, true)) { + return std::nullopt; + } + std::unordered_map roles; + for (const auto& role : roles_object) { + if (!role.second.is()) { + *err += "roles should be a map of string to string."; + return std::nullopt; + } + roles[role.first] = role.second.get(); + } + conv.roles = roles; + + picojson::object role_templates_object; + if (json::ParseJSONField(json, "role_templates", role_templates_object, err, false)) { + for (const auto& role : role_templates_object) { + if (!role.second.is()) { + *err += "role_templates should be a map of string to string."; + return std::nullopt; + } + conv.role_templates[role.first] = role.second.get(); + } + } + + picojson::array messages_arr; + if (!json::ParseJSONField(json, "messages", messages_arr, err, true)) { + return std::nullopt; + } + std::vector messages; + for (const auto& message : messages_arr) { + if (!message.is()) { + *err += "messages should be an array of objects."; + return std::nullopt; + } + picojson::object message_obj = message.get(); + std::string role; + if (!json::ParseJSONField(message_obj, "role", role, err, true)) { + *err += "role field is required in messages."; + return std::nullopt; + } + picojson::array content_arr; + std::vector> content; + if (json::ParseJSONField(message_obj, "content", content_arr, err, false)) { + for (const auto& item : content_arr) { + if (!item.is()) { + *err += "Content item is not an object"; + return std::nullopt; + } + std::unordered_map item_map; + picojson::object item_obj = item.get(); + for (picojson::value::object::const_iterator i = item_obj.begin(); i != item_obj.end(); + ++i) { + item_map[i->first] = i->second.to_str(); + } + content.push_back(item_map); + } + } + messages.push_back({role, content}); + } + conv.messages = messages; + + picojson::array seps_arr; + if (!json::ParseJSONField(json, "seps", seps_arr, err, true)) { + return std::nullopt; + } + std::vector seps; + for (const auto& sep : seps_arr) { + if (!sep.is()) { + *err += "seps should be an array of strings."; + return std::nullopt; + } + seps.push_back(sep.get()); + } + conv.seps = seps; + + std::string role_content_sep; + if (!json::ParseJSONField(json, "role_content_sep", role_content_sep, err, true)) { + return std::nullopt; + } + conv.role_content_sep = role_content_sep; + + std::string role_empty_sep; + if (!json::ParseJSONField(json, "role_empty_sep", role_empty_sep, err, true)) { + return std::nullopt; + } + conv.role_empty_sep = role_empty_sep; + + picojson::array stop_str_arr; + if (!json::ParseJSONField(json, "stop_str", stop_str_arr, err, true)) { + return std::nullopt; + } + std::vector stop_str; + for (const auto& stop : stop_str_arr) { + if (!stop.is()) { + *err += "stop_str should be an array of strings."; + return std::nullopt; + } + stop_str.push_back(stop.get()); + } + conv.stop_str = stop_str; + + picojson::array stop_token_ids_arr; + if (!json::ParseJSONField(json, "stop_token_ids", stop_token_ids_arr, err, true)) { + return std::nullopt; + } + std::vector stop_token_ids; + for (const auto& stop : stop_token_ids_arr) { + if (!stop.is()) { + *err += "stop_token_ids should be an array of integers."; + return std::nullopt; + } + stop_token_ids.push_back(stop.get()); + } + conv.stop_token_ids = stop_token_ids; + + std::string function_string; + if (!json::ParseJSONField(json, "function_string", function_string, err, false)) { + conv.function_string = function_string; + } + + bool use_function_calling; + if (json::ParseJSONField(json, "use_function_calling", use_function_calling, err, false)) { + conv.use_function_calling = use_function_calling; + } + + return conv; +} + +std::optional Conversation::FromJSON(const std::string& json_str, std::string* err) { + std::optional json_obj = json::LoadJSONFromString(json_str, err); + if (!json_obj.has_value()) { + return std::nullopt; + } + return Conversation::FromJSON(json_obj.value(), err); +} + +/****************** JSON FFI engine config ******************/ + +TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode); + +JSONFFIEngineConfig::JSONFFIEngineConfig( + String conv_template, Map model_generation_cfgs) { + ObjectPtr n = make_object(); + n->conv_template = conv_template; + n->model_generation_cfgs = model_generation_cfgs; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig") + .set_body_typed([](String conv_template, + Map model_generation_cfgs) { + return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs)); + }); + +} // namespace json_ffi +} // namespace llm +} // namespace mlc diff --git a/cpp/json_ffi/config.h b/cpp/json_ffi/config.h new file mode 100644 index 0000000000..fe5e4e42e2 --- /dev/null +++ b/cpp/json_ffi/config.h @@ -0,0 +1,172 @@ +#ifndef MLC_LLM_JSON_FFI_CONFIG_H +#define MLC_LLM_JSON_FFI_CONFIG_H + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../serve/data.h" +#include "picojson.h" + +using namespace mlc::llm::serve; + +namespace mlc { +namespace llm { +namespace json_ffi { + +/****************** Model-defined generation config ******************/ + +class ModelDefinedGenerationConfigNode : public Object { + public: + double temperature; + double top_p; + double frequency_penalty; + double presence_penalty; + + static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object); +}; + +class ModelDefinedGenerationConfig : public ObjectRef { + public: + explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty, + double presence_penalty); + + TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef, + ModelDefinedGenerationConfigNode); +}; + +/****************** Conversation template ******************/ + +enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; + +MessagePlaceholders messagePlaceholderFromString(const std::string& role); + +class Message { + public: + std::string role; + std::optional>> content = std::nullopt; +}; + +/** + * @brief A struct that specifies the convention template of conversation + * and contains the conversation history. + */ +struct Conversation { + // Optional name of the template. + std::optional name = std::nullopt; + + // The system prompt template, it optionally contains the system + // message placeholder, and the placeholder will be replaced with + // the system message below. + std::string system_template; + + // The content of the system prompt (without the template format). + std::string system_message; + + // The system token ids to be prepended at the beginning of tokenized + // generated prompt. + std::optional> system_prefix_token_ids = std::nullopt; + + // Whether or not to append user role and separator after the system message. + // This is mainly for [INST] [/INST] style prompt format + bool add_role_after_system_message = true; + + // The conversation roles + std::unordered_map roles; + + // The roles prompt template, it optionally contains the defaults + // message placeholders and will be replaced by actual content + std::unordered_map role_templates; + + // The conversation history messages. + // Each message is a pair of strings, denoting "(role, content)". + // The content can be None. + std::vector messages; + + // The separators between messages when concatenating into a single prompt. + // List size should be either 1 or 2. + // - When size is 1, the separator will be used between adjacent messages. + // - When size is 2, seps[0] is used after user message, and + // seps[1] is used after assistant message. + std::vector seps; + + // The separator between the role and the content in a message. + std::string role_content_sep; + + // The separator between the role and empty contents. + std::string role_empty_sep; + + // The stop criteria + std::vector stop_str; + std::vector stop_token_ids; + + // Function call fields + // whether using function calling or not, helps check for output message format in API call + std::optional function_string = std::nullopt; + std::optional use_function_calling = false; + + Conversation(); + + /** + * @brief Checks the size of the separators vector. + * This function checks if the size of the separators vector is either 1 or 2. + * If the size is not 1 or 2, it throws an invalid_argument exception. + */ + static std::vector CheckMessageSeps(std::vector& seps); + + /*! + * \brief Create the list of prompts from the messages based on the conversation template. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + std::optional> AsPrompt(std::string* err); + + /*! + * \brief Create a Conversation instance from the given JSON object. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const picojson::object& json, std::string* err); + + /*! + * \brief Parse and create a Conversation instance from the given JSON string. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const std::string& json_str, std::string* err); +}; + +/****************** JSON FFI engine config ******************/ + +class JSONFFIEngineConfigNode : public Object { + public: + String conv_template; + Map model_generation_cfgs; + + static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object); +}; + +class JSONFFIEngineConfig : public ObjectRef { + public: + explicit JSONFFIEngineConfig(String conv_template, + Map model_generation_cfgs); + + TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode); +}; + +} // namespace json_ffi +} // namespace llm +} // namespace mlc + +#endif /* MLC_LLM_JSON_FFI_CONV_TEMPLATE_H */ diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index b02a28ca89..d5fc53b8fa 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -51,33 +51,40 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request // TODO: Check if request_id is present already // inputs - // TODO: Apply conv template - Array inputs; + Conversation conv_template = this->conv_template_; + std::vector messages; for (const auto& message : request.messages) { - if (message.content.has_value()) { - for (const auto& content : message.content.value()) { - if (content.find("type") == content.end()) { - err_ += "Content should have a type field"; - return false; - } - std::string type = content.at("type"); - if (type == "text") { - if (content.find("text") == content.end()) { - err_ += "Content should have a text field"; - return false; - } - std::string text = content.at("text"); - inputs.push_back(TextData(text)); - } else { - err_ += "Content type not supported"; - return false; - } - } + std::string role; + if (message.role == Role::user) { + role = "user"; + } else if (message.role == Role::assistant) { + role = "assistant"; + } else if (message.role == Role::tool) { + role = "tool"; + } else { + role = "system"; } + messages.push_back({role, message.content}); + } + messages.push_back({"assistant", std::nullopt}); + conv_template.messages = messages; + + // check function calling + bool success_check = request.CheckFunctionCalling(conv_template, &err_); + if (!success_check) { + return false; + } + + // get prompt + std::optional> inputs_obj = conv_template.AsPrompt(&err_); + if (!inputs_obj.has_value()) { + return false; } + Array inputs = inputs_obj.value(); // generation_cfg - Optional generation_cfg = GenerationConfig::FromJSON(request_json_str, &err_); + Optional generation_cfg = GenerationConfig::Create( + request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]); if (!generation_cfg.defined()) { return false; } @@ -103,6 +110,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); TVM_MODULE_VTABLE_ENTRY("init_background_engine", &JSONFFIEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("reload", &JSONFFIEngineImpl::Reload); + TVM_MODULE_VTABLE_ENTRY("unload", &JSONFFIEngineImpl::Unload); + TVM_MODULE_VTABLE_ENTRY("reset", &JSONFFIEngineImpl::Reset); TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); @@ -112,9 +122,20 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(EngineConfig engine_config, - Optional request_stream_callback, + void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config, + Device device, Optional request_stream_callback, Optional trace_recorder) { + std::optional conv_template = + Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_); + if (!conv_template.has_value()) { + LOG(FATAL) << "Invalid conversation template JSON: " << err_; + } + this->conv_template_ = conv_template.value(); + this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs; + + // Todo(mlc-team): decouple InitBackgroundEngine into two functions + // by removing `engine_config` from arguments, after properly handling + // streamers. this->streamer_ = TextStreamer(Tokenizer::FromPath(engine_config->model)); CHECK(request_stream_callback.defined()) @@ -129,10 +150,17 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitBackgroundEngine( - std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); + this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback), + std::move(trace_recorder)); + this->engine_->Reload(std::move(engine_config)); } + void Reload(EngineConfig engine_config) { this->engine_->Reload(std::move(engine_config)); } + + void Unload() { this->engine_->Unload(); } + + void Reset() { this->engine_->Reset(); } + void RunBackgroundLoop() { this->engine_->RunBackgroundLoop(); } void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); } diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index 83013b5876..d57384abb5 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -12,6 +12,7 @@ #include "../serve/threaded_engine.h" #include "../streamer.h" +#include "config.h" #include "openai_api_protocol.h" namespace mlc { @@ -47,6 +48,8 @@ class JSONFFIEngine { std::string err_; PackedFunc request_stream_callback_; TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request + Conversation conv_template_; + Map model_generation_cfgs; }; } // namespace json_ffi diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 41378fc3e0..13f4b140ce 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -11,14 +11,166 @@ namespace mlc { namespace llm { namespace json_ffi { -std::optional ChatCompletionMessage::FromJSON(const picojson::value& json, - std::string* err) { - if (!json.is()) { - *err += "Input is not a valid JSON object"; +std::string generate_uuid_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} + +std::optional ChatFunction::FromJSON(const picojson::object& json_obj, + std::string* err) { + ChatFunction chatFunc; + + // description (optional) + std::string description; + if (json::ParseJSONField(json_obj, "description", description, err, false)) { + chatFunc.description = description; + } + + // name + std::string name; + if (!json::ParseJSONField(json_obj, "name", name, err, true)) { + return std::nullopt; + } + chatFunc.name = name; + + // parameters + picojson::object parameters_obj; + if (!json::ParseJSONField(json_obj, "parameters", parameters_obj, err, true)) { + return std::nullopt; + } + std::unordered_map parameters; + for (picojson::value::object::const_iterator i = parameters_obj.begin(); + i != parameters_obj.end(); ++i) { + parameters[i->first] = i->second.to_str(); + } + chatFunc.parameters = parameters; + + return chatFunc; +} + +picojson::object ChatFunction::ToJSON() const { + picojson::object obj; + if (this->description.has_value()) { + obj["description"] = picojson::value(this->description.value()); + } + obj["name"] = picojson::value(this->name); + picojson::object parameters_obj; + for (const auto& pair : this->parameters) { + parameters_obj[pair.first] = picojson::value(pair.second); + } + obj["parameters"] = picojson::value(parameters_obj); + return obj; +} + +std::optional ChatTool::FromJSON(const picojson::object& json_obj, std::string* err) { + ChatTool chatTool; + + // function + picojson::object function_obj; + if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { + return std::nullopt; + } + + std::optional function = ChatFunction::FromJSON(function_obj, err); + if (!function.has_value()) { return std::nullopt; } - picojson::object json_obj = json.get(); + chatTool.function = function.value(); + + return chatTool; +} +picojson::object ChatTool::ToJSON() const { + picojson::object obj; + obj["type"] = picojson::value("function"); + obj["function"] = picojson::value(this->function.ToJSON()); + return obj; +} + +std::optional ChatFunctionCall::FromJSON(const picojson::object& json_obj, + std::string* err) { + ChatFunctionCall chatFuncCall; + + // name + std::string name; + if (!json::ParseJSONField(json_obj, "name", name, err, true)) { + return std::nullopt; + } + chatFuncCall.name = name; + + // arguments + picojson::object arguments_obj; + if (json::ParseJSONField(json_obj, "arguments", arguments_obj, err, false)) { + std::unordered_map arguments; + for (picojson::value::object::const_iterator i = arguments_obj.begin(); + i != arguments_obj.end(); ++i) { + arguments[i->first] = i->second.to_str(); + } + chatFuncCall.arguments = arguments; + } + + return chatFuncCall; +} + +picojson::object ChatFunctionCall::ToJSON() const { + picojson::object obj; + picojson::object arguments_obj; + if (this->arguments.has_value()) { + for (const auto& pair : this->arguments.value()) { + arguments_obj[pair.first] = picojson::value(pair.second); + } + obj["arguments"] = picojson::value(arguments_obj); + } + + obj["name"] = picojson::value(this->name); + return obj; +} + +std::optional ChatToolCall::FromJSON(const picojson::object& json_obj, + std::string* err) { + ChatToolCall chatToolCall; + + // function + picojson::object function_obj; + if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { + return std::nullopt; + } + + std::optional function = ChatFunctionCall::FromJSON(function_obj, err); + if (!function.has_value()) { + return std::nullopt; + }; + chatToolCall.function = function.value(); + + // overwrite default id + std::string id; + if (!json::ParseJSONField(json_obj, "id", id, err, false)) { + return std::nullopt; + } + chatToolCall.id = id; + + return chatToolCall; +} + +picojson::object ChatToolCall::ToJSON() const { + picojson::object obj; + obj["id"] = picojson::value(this->id); + obj["function"] = picojson::value(this->function.ToJSON()); + obj["type"] = picojson::value("function"); + return obj; +} + +std::optional ChatCompletionMessage::FromJSON( + const picojson::object& json_obj, std::string* err) { ChatCompletionMessage message; // content @@ -65,7 +217,30 @@ std::optional ChatCompletionMessage::FromJSON(const picoj message.name = name; } - // TODO: tool_calls and tool_call_id + // tool calls + picojson::array tool_calls_arr; + if (json::ParseJSONField(json_obj, "tool_calls", tool_calls_arr, err, false)) { + std::vector tool_calls; + for (const auto& item : tool_calls_arr) { + if (!item.is()) { + *err += "Chat Tool Call item is not an object"; + return std::nullopt; + } + picojson::object item_obj = item.get(); + std::optional tool_call = ChatToolCall::FromJSON(item_obj, err); + if (!tool_call.has_value()) { + return std::nullopt; + }; + tool_calls.push_back(tool_call.value()); + } + message.tool_calls = tool_calls; + } + + // tool call id + std::string tool_call_id; + if (json::ParseJSONField(json_obj, "tool_call_id", tool_call_id, err, false)) { + message.tool_call_id = tool_call_id; + } return message; } @@ -81,7 +256,8 @@ std::optional ChatCompletionRequest::FromJSON( } std::vector messages; for (const auto& item : messages_arr) { - std::optional message = ChatCompletionMessage::FromJSON(item, err); + picojson::object item_obj = item.get(); + std::optional message = ChatCompletionMessage::FromJSON(item_obj, err); if (!message.has_value()) { return std::nullopt; } @@ -108,6 +284,32 @@ std::optional ChatCompletionRequest::FromJSON( request.presence_penalty = presence_penalty; } + // tool_choice + std::string tool_choice = "auto"; + request.tool_choice = tool_choice; + if (json::ParseJSONField(json_obj, "tool_choice", tool_choice, err, false)) { + request.tool_choice = tool_choice; + } + + // tools + picojson::array tools_arr; + if (json::ParseJSONField(json_obj, "tools", tools_arr, err, false)) { + std::vector tools; + for (const auto& item : tools_arr) { + if (!item.is()) { + *err += "Chat Tool item is not an object"; + return std::nullopt; + } + picojson::object item_obj = item.get(); + std::optional tool = ChatTool::FromJSON(item_obj, err); + if (!tool.has_value()) { + return std::nullopt; + }; + tools.push_back(tool.value()); + } + request.tools = tools; + } + // TODO: Other parameters return request; @@ -122,7 +324,7 @@ std::optional ChatCompletionRequest::FromJSON(const std:: return ChatCompletionRequest::FromJSON(json_obj.value(), err); } -picojson::object ChatCompletionMessage::ToJSON() { +picojson::object ChatCompletionMessage::ToJSON() const { picojson::object obj; picojson::array content_arr; for (const auto& item : this->content.value()) { @@ -142,13 +344,57 @@ picojson::object ChatCompletionMessage::ToJSON() { } else if (this->role == Role::tool) { obj["role"] = picojson::value("tool"); } - if (name.has_value()) { - obj["name"] = picojson::value(name.value()); + if (this->name.has_value()) { + obj["name"] = picojson::value(this->name.value()); + } + if (this->tool_call_id.has_value()) { + obj["tool_call_id"] = picojson::value(this->tool_call_id.value()); + } + if (this->tool_calls.has_value()) { + picojson::array tool_calls_arr; + for (const auto& tool_call : this->tool_calls.value()) { + tool_calls_arr.push_back(picojson::value(tool_call.ToJSON())); + } + obj["tool_calls"] = picojson::value(tool_calls_arr); } return obj; } -picojson::object ChatCompletionResponseChoice::ToJSON() { +bool ChatCompletionRequest::CheckFunctionCalling(Conversation& conv_template, std::string* err) { + if (!tools.has_value() || (tool_choice.has_value() && tool_choice.value() == "none")) { + conv_template.use_function_calling = false; + return true; + } + std::vector tools_ = tools.value(); + std::string tool_choice_ = tool_choice.value(); + + // TODO: support with tool choice as dict + for (const auto& tool : tools_) { + if (tool.function.name == tool_choice_) { + conv_template.use_function_calling = true; + picojson::value function_str(tool.function.ToJSON()); + conv_template.function_string = function_str.serialize(); + return true; + } + } + + if (tool_choice_ != "auto") { + *err += "Invalid tool_choice value: " + tool_choice_; + return false; + } + + picojson::array function_list; + for (const auto& tool : tools_) { + function_list.push_back(picojson::value(tool.function.ToJSON())); + } + + conv_template.use_function_calling = true; + picojson::value function_list_json(function_list); + conv_template.function_string = function_list_json.serialize(); + return true; +}; + +picojson::object ChatCompletionResponseChoice::ToJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -168,7 +414,7 @@ picojson::object ChatCompletionResponseChoice::ToJSON() { return obj; } -picojson::object ChatCompletionStreamResponseChoice::ToJSON() { +picojson::object ChatCompletionStreamResponseChoice::ToJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -189,11 +435,11 @@ picojson::object ChatCompletionStreamResponseChoice::ToJSON() { return obj; } -picojson::object ChatCompletionResponse::ToJSON() { +picojson::object ChatCompletionResponse::ToJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; - for (auto& choice : this->choices) { + for (const auto& choice : this->choices) { choices_arr.push_back(picojson::value(choice.ToJSON())); } obj["choices"] = picojson::value(choices_arr); @@ -204,11 +450,11 @@ picojson::object ChatCompletionResponse::ToJSON() { return obj; } -picojson::object ChatCompletionStreamResponse::ToJSON() { +picojson::object ChatCompletionStreamResponse::ToJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; - for (auto& choice : this->choices) { + for (const auto& choice : this->choices) { choices_arr.push_back(picojson::value(choice.ToJSON())); } obj["choices"] = picojson::value(choices_arr); diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 1579b5f337..429050da3c 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -8,10 +8,12 @@ #include #include +#include #include #include #include +#include "config.h" #include "picojson.h" namespace mlc { @@ -22,7 +24,8 @@ enum class Role { system, user, assistant, tool }; enum class Type { text, json_object, function }; enum class FinishReason { stop, length, tool_calls, error }; -// TODO: Implement the following class +std::string generate_uuid_string(size_t length); + class ChatFunction { public: std::optional description = std::nullopt; @@ -30,32 +33,37 @@ class ChatFunction { std::unordered_map parameters; // Assuming parameters are string key-value pairs - static std::optional FromJSON(const picojson::value& json, std::string* err); + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; -// TODO: Implement the following class class ChatTool { public: Type type = Type::function; ChatFunction function; - static std::optional FromJSON(const picojson::value& json, std::string* err); + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; -// TODO: Implement the following class class ChatFunctionCall { public: std::string name; std::optional> arguments = std::nullopt; // Assuming arguments are string key-value pairs + + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; -// TODO: Implement the following class class ChatToolCall { public: - std::string id; // TODO: python code initializes this to an random string + std::string id = "call_" + generate_uuid_string(8); Type type = Type::function; ChatFunctionCall function; + + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; class ChatCompletionMessage { @@ -64,12 +72,12 @@ class ChatCompletionMessage { std::nullopt; // Assuming content is a list of string key-value pairs Role role; std::optional name = std::nullopt; - std::optional> tool_calls = std::nullopt; // TODO: Implement this - std::optional tool_call_id = std::nullopt; // TODO: Implement this + std::optional> tool_calls = std::nullopt; + std::optional tool_call_id = std::nullopt; - static std::optional FromJSON(const picojson::value& json, + static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class RequestResponseFormat { @@ -82,8 +90,8 @@ class ChatCompletionRequest { public: std::vector messages; std::string model; - double frequency_penalty = 0.0; - double presence_penalty = 0.0; + std::optional frequency_penalty = std::nullopt; + std::optional presence_penalty = std::nullopt; bool logprobs = false; int top_logprobs = 0; std::optional> logit_bias = std::nullopt; @@ -92,8 +100,8 @@ class ChatCompletionRequest { std::optional seed = std::nullopt; std::optional> stop = std::nullopt; bool stream = false; - double temperature = 1.0; - double top_p = 1.0; + std::optional temperature = std::nullopt; + std::optional top_p = std::nullopt; std::optional> tools = std::nullopt; std::optional tool_choice = std::nullopt; std::optional user = std::nullopt; @@ -113,6 +121,7 @@ class ChatCompletionRequest { static std::optional FromJSON(const std::string& json_str, std::string* err); + bool CheckFunctionCalling(Conversation& conv_template, std::string* err); // TODO: check_penalty_range, check_logit_bias, check_logprobs }; @@ -123,7 +132,7 @@ class ChatCompletionResponseChoice { ChatCompletionMessage message; // TODO: logprobs - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class ChatCompletionStreamResponseChoice { @@ -133,7 +142,7 @@ class ChatCompletionStreamResponseChoice { ChatCompletionMessage delta; // TODO: logprobs - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class ChatCompletionResponse { @@ -146,7 +155,7 @@ class ChatCompletionResponse { std::string object = "chat.completion"; // TODO: usage_info - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class ChatCompletionStreamResponse { @@ -158,7 +167,7 @@ class ChatCompletionStreamResponse { std::string system_fingerprint; std::string object = "chat.completion.chunk"; - picojson::object ToJSON(); + picojson::object ToJSON() const; }; } // namespace json_ffi diff --git a/cpp/metadata/json_parser.h b/cpp/metadata/json_parser.h index f6ff10e1ac..99a284fc42 100644 --- a/cpp/metadata/json_parser.h +++ b/cpp/metadata/json_parser.h @@ -149,6 +149,22 @@ inline ValueType Lookup(const picojson::object& json, const std::string& key) { return it->second.get(); } +template +inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key, + const ValueType& default_value) { + auto it = json.find(key); + if (it == json.end()) { + return default_value; + } + + if (it->second.is()) { + return default_value; + } + + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} + template inline ValueType Lookup(const picojson::array& json, int index) { CHECK(index < json.size()) << "IndexError: json::array index out of range"; diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 5d647ec532..3bb809ad67 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -161,19 +161,35 @@ GenerationConfig::GenerationConfig(String config_json_str) { data_ = std::move(n); } -Optional GenerationConfig::FromJSON(const std::string& json_str, - std::string* err) { - std::optional json_obj = json::LoadJSONFromString(json_str, err); - if (!err->empty() || !json_obj.has_value()) { +Optional GenerationConfig::Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config) { + std::optional optional_json_obj = json::LoadJSONFromString(json_str, err); + if (!err->empty() || !optional_json_obj.has_value()) { return NullOpt; } + picojson::object& json_obj = optional_json_obj.value(); ObjectPtr n = make_object(); - // TODO(mlc-team): Pass the parameters from `json_obj` to `n`. + n->temperature = + json::LookupOrDefault(json_obj, "temperature", model_defined_gen_config->temperature); + n->top_p = json::LookupOrDefault(json_obj, "top_p", model_defined_gen_config->top_p); + n->frequency_penalty = json::LookupOrDefault(json_obj, "frequency_penalty", + model_defined_gen_config->frequency_penalty); + n->presence_penalty = json::LookupOrDefault(json_obj, "presence_penalty", + model_defined_gen_config->presence_penalty); + n->logprobs = json::LookupOrDefault(json_obj, "logprobs", false); + n->top_logprobs = static_cast(json::LookupOrDefault(json_obj, "top_logprobs", 0)); + n->ignore_eos = json::LookupOrDefault(json_obj, "ignore_eos", false); - if (!err->empty()) { - return NullOpt; + // Copy stop str from conversation template to generation config + for (auto& stop_str : conv_template.stop_str) { + n->stop_strs.push_back(stop_str); + } + for (auto& stop_token_id : conv_template.stop_token_ids) { + n->stop_token_ids.push_back(stop_token_id); } + GenerationConfig gen_config; gen_config.data_ = std::move(n); return gen_config; @@ -228,37 +244,85 @@ String GenerationConfigNode::AsJSONString() const { TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); EngineConfig::EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, - int max_total_sequence_length, int max_single_sequence_length, - int prefill_chunk_size, SpeculativeMode speculative_mode, - int spec_draft_length) { + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, + SpeculativeMode speculative_mode, int spec_draft_length) { ObjectPtr n = make_object(); n->model = std::move(model); n->model_lib_path = std::move(model_lib_path); n->additional_models = std::move(additional_models); n->additional_model_lib_paths = std::move(additional_model_lib_paths); - n->device = device; n->kv_cache_page_size = kv_cache_page_size; n->max_num_sequence = max_num_sequence; n->max_total_sequence_length = max_total_sequence_length; n->max_single_sequence_length = max_single_sequence_length; n->prefill_chunk_size = prefill_chunk_size; + n->max_history_size = max_history_size; + n->kv_state_kind = kv_state_kind; n->spec_draft_length = spec_draft_length; n->speculative_mode = speculative_mode; data_ = std::move(n); } +EngineConfig EngineConfig::FromJSONString(const std::string& json_str) { + picojson::value config_json; + std::string err = picojson::parse(config_json, json_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + + // Get json fields. + picojson::object config = config_json.get(); + String model = json::Lookup(config, "model"); + String model_lib_path = json::Lookup(config, "model_lib_path"); + std::vector additional_models; + std::vector additional_model_lib_paths; + int kv_cache_page_size = json::Lookup(config, "kv_cache_page_size"); + int max_num_sequence = json::Lookup(config, "max_num_sequence"); + int max_total_sequence_length = json::Lookup(config, "max_total_sequence_length"); + int max_single_sequence_length = json::Lookup(config, "max_single_sequence_length"); + int prefill_chunk_size = json::Lookup(config, "prefill_chunk_size"); + int max_history_size = json::Lookup(config, "max_history_size"); + KVStateKind kv_state_kind = + static_cast(json::Lookup(config, "kv_state_kind")); + SpeculativeMode speculative_mode = + static_cast(json::Lookup(config, "speculative_mode")); + int spec_draft_length = json::Lookup(config, "spec_draft_length"); + + picojson::array additional_models_arr = + json::Lookup(config, "additional_models"); + picojson::array additional_model_lib_paths_arr = + json::Lookup(config, "additional_model_lib_paths"); + CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size()) + << "The number of additional model lib paths does not match the number of additional models"; + int num_additional_models = additional_models_arr.size(); + additional_models.reserve(num_additional_models); + additional_model_lib_paths.reserve(num_additional_models); + for (int i = 0; i < num_additional_models; ++i) { + additional_models.push_back(json::Lookup(additional_models_arr, i)); + additional_model_lib_paths.push_back( + json::Lookup(additional_model_lib_paths_arr, i)); + } + + return EngineConfig(std::move(model), std::move(model_lib_path), additional_models, + additional_model_lib_paths, kv_cache_page_size, max_num_sequence, + max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, + max_history_size, kv_state_kind, speculative_mode, spec_draft_length); +} + TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") .set_body_typed([](String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, int speculative_mode, - int spec_draft_length) { + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, int max_history_size, + int kv_state_kind, int speculative_mode, int spec_draft_length) { return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), - std::move(additional_model_lib_paths), device, kv_cache_page_size, + std::move(additional_model_lib_paths), kv_cache_page_size, max_num_sequence, max_total_sequence_length, max_single_sequence_length, - prefill_chunk_size, SpeculativeMode(speculative_mode), spec_draft_length); + prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind), + SpeculativeMode(speculative_mode), spec_draft_length); }); } // namespace serve diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 404566fe2c..fd76dd49f0 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -11,12 +11,15 @@ #include +#include "../json_ffi/config.h" + namespace mlc { namespace llm { namespace serve { using namespace tvm; using namespace tvm::runtime; +using namespace mlc::llm::json_ffi; /****************** GenerationConfig ******************/ @@ -60,10 +63,13 @@ class GenerationConfig : public ObjectRef { explicit GenerationConfig(String config_json_str); /*! - * \brief Parse the generation config from the given JSON string. - * When parsing fails, errors are dumped to the input error string, and NullOpt is returned. + * \brief Create a generation config from a ChatCompletionRequest. + * If the request does not contain a generation config, the model-defined + * generation config will be used. */ - static Optional FromJSON(const std::string& json_str, std::string* err); + static Optional Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config); TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; @@ -80,6 +86,12 @@ enum class SpeculativeMode : int { kEagle = 2, }; +/*! \brief The kind of cache. */ +enum KVStateKind { + kAttention = 0, + kRNNState = 1, +}; + /*! \brief The configuration of engine execution config. */ class EngineConfigNode : public Object { public: @@ -94,11 +106,6 @@ class EngineConfigNode : public Object { /*! \brief The path to the additional models' libraries. */ Array additional_model_lib_paths; - /*************** Device ***************/ - - /*! \brief The device where the models run. */ - DLDevice device; - /*************** KV cache config and engine capacities ***************/ /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ @@ -117,6 +124,10 @@ class EngineConfigNode : public Object { int max_single_sequence_length; /*! \brief The maximum total sequence length in a prefill. */ int prefill_chunk_size; + /*! \brief The maximum history size for RNN state. KV cache does not need this. */ + int max_history_size; + /*! \brief The kind of cache. Whether it's KV cache or RNN state. */ + KVStateKind kv_state_kind; /*************** Speculative decoding ***************/ @@ -136,11 +147,15 @@ class EngineConfigNode : public Object { class EngineConfig : public ObjectRef { public: explicit EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, SpeculativeMode speculative_mode, int spec_draft_length); + /*! \brief Create EngineConfig from JSON string. */ + static EngineConfig FromJSONString(const std::string& json_str); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 85d1c66c2d..d82c886355 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -44,7 +45,8 @@ class EngineImpl : public Engine { public: /********************** Engine Management **********************/ - explicit EngineImpl(EngineConfig engine_config, Optional request_stream_callback, + explicit EngineImpl(EngineConfig engine_config, DLDevice device, + Optional request_stream_callback, Optional trace_recorder) { // Step 1. Initialize metadata and singleton states inside the engine this->estate_->Reset(); @@ -62,14 +64,24 @@ class EngineImpl : public Engine { this->models_.clear(); this->model_workspaces_.clear(); - auto f_create_model = [this, &engine_config, &trace_recorder](const String& model_path, - const String& model_lib_path) { - Model model = Model::Create(model_lib_path, std::move(model_path), engine_config->device, - engine_config->max_num_sequence, + std::vector model_configs; + model_configs.push_back(Model::LoadModelConfig(engine_config->model)); + for (const auto& model_path : engine_config->additional_models) { + model_configs.push_back(Model::LoadModelConfig(model_path)); + } + + Optional session = CreateDiscoSession(model_configs, device); + + auto f_create_model = [this, &engine_config, &device, &trace_recorder, &model_configs, + &session](const String& model_path, const String& model_lib_path, + int model_index) { + Model model = Model::Create(model_lib_path, std::move(model_path), model_configs[model_index], + device, engine_config->max_num_sequence, session, /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, - engine_config->prefill_chunk_size); + engine_config->prefill_chunk_size, engine_config->max_history_size, + engine_config->kv_state_kind); CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length) << "The window size of the model, " << model->GetMaxWindowSize() << ", is smaller than the pre-defined max single sequence length, " @@ -79,18 +91,18 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); }; - f_create_model(engine_config->model, engine_config->model_lib_path); + f_create_model(engine_config->model, engine_config->model_lib_path, /*model_index=*/0); CHECK_EQ(engine_config->additional_models.size(), engine_config->additional_model_lib_paths.size()) << "The additional model and lib path list has mismatched size."; for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { f_create_model(engine_config->additional_models[i], - engine_config->additional_model_lib_paths[i]); + engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); } int max_num_tokens = engine_config->max_num_sequence; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { - max_num_tokens *= engine_config->spec_draft_length; + max_num_tokens *= engine_config->spec_draft_length + 1; } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); @@ -102,18 +114,18 @@ class EngineImpl : public Engine { ICHECK_GT(this->models_.size(), 1U); switch (engine_config->speculative_mode) { case SpeculativeMode::kEagle: - this->actions_ = { - EngineAction::EagleNewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, - this->model_workspaces_, this->trace_recorder_), - EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, engine_config, - this->trace_recorder_)}; + this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::EagleBatchDraft( + this->models_, logit_processor, sampler, this->model_workspaces_, + this->trace_recorder_, engine_config->spec_draft_length), + EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, engine_config, + this->trace_recorder_)}; break; default: this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // @@ -143,6 +155,7 @@ class EngineImpl : public Engine { } void Reset() final { + AbortAllRequests(); estate_->Reset(); for (Model model : models_) { model->Reset(); @@ -167,7 +180,8 @@ class EngineImpl : public Engine { request = Request::FromUntokenized(request, tokenizer_); ICHECK_NE(request->input_total_length, -1); - if (request->input_total_length >= engine_config_->max_single_sequence_length) { + if (request->input_total_length >= engine_config_->max_single_sequence_length && + request_stream_callback_.defined()) { // If the request input length exceeds the maximum allowed single sequence length, // invoke callback and do not process the request. Array output{RequestStreamOutput( @@ -240,6 +254,28 @@ class EngineImpl : public Engine { // The request to abort is in waiting queue estate_->waiting_queue.erase(it_waiting); } + + // Send a callback to notice the abortion. + if (request_stream_callback_.defined()) { + Array output{RequestStreamOutput( + request_id, std::vector(request->generation_cfg->n), + Optional>>(), + std::vector>(request->generation_cfg->n, String("abort")))}; + request_stream_callback_.value()(std::move(output)); + } + } + + void AbortAllRequests() final { + // - Collect all the request ids. + std::vector request_ids; + request_ids.reserve(estate_->request_states.size()); + for (const auto& kv : estate_->request_states) { + request_ids.push_back(kv.first); + } + // - Abort all the requests. + for (const String& request_id : request_ids) { + AbortRequest(request_id); + } } /*********************** Engine Action ***********************/ @@ -261,6 +297,51 @@ class EngineImpl : public Engine { "action (e.g. prefill, decode, etc.) but it does not."; } + /************** Utility Functions **************/ + Optional CreateDiscoSession(std::vector model_configs, Device device) { + const auto& base_model_config = model_configs[0]; + + auto f_get_num_shards = [](const picojson::object& model_config) -> int { + constexpr auto kNumShardsKey = "tensor_parallel_shards"; + if (model_config.count(kNumShardsKey)) { + const auto& val = model_config.at(kNumShardsKey); + CHECK(val.is()); + return static_cast(val.get()); + } else { + LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; + } + throw; + }; + + int num_shards = std::transform_reduce( + model_configs.begin(), model_configs.end(), 1, [](int a, int b) { return std::max(a, b); }, + f_get_num_shards); + Optional session = NullOpt; + if (num_shards > 1) { + constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; + if (Registry::Get(f_create_process_pool) == nullptr) { + LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " + << "Multi-GPU inference depends on MLC LLM Python API to launch process."; + } + std::string ccl; + if (device.device_type == kDLCUDA) { + ccl = "nccl"; + } else if (device.device_type == kDLROCM) { + ccl = "rccl"; + } else { + LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) + << " is not supported. Currently, only NCCL and RCCL are integrated."; + } + std::vector device_ids(num_shards); + for (int i = 0; i < num_shards; ++i) { + device_ids[i] = i; + } + session = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); + session.value()->InitCCL(ccl, ShapeTuple(device_ids)); + } + return session; + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { @@ -314,10 +395,11 @@ class EngineImpl : public Engine { Optional trace_recorder_; }; -std::unique_ptr Engine::Create(EngineConfig engine_config, +std::unique_ptr Engine::Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - return std::make_unique(std::move(engine_config), std::move(request_stream_callback), + return std::make_unique(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } @@ -343,10 +425,10 @@ class EngineModule : public ModuleNode { TVM_MODULE_VTABLE_END(); /*! \brief Initialize the engine with config and other fields. */ - void Init(EngineConfig engine_config, Optional request_stream_callback, + void Init(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - this->engine_ = Engine::Create(std::move(engine_config), std::move(request_stream_callback), - std::move(trace_recorder)); + this->engine_ = Engine::Create(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } /*! \brief Construct an EngineModule. */ static tvm::runtime::Module Create() { return Module(make_object()); } diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index fc5e4205ae..2fc0a4d730 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -51,11 +51,12 @@ class Engine { /*! * \brief Create an engine in unique pointer. * \param engine_config The engine config. + * \param device The device where the run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. * \return The created Engine in pointer. */ - static std::unique_ptr Create(EngineConfig engine_config, + static std::unique_ptr Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder); @@ -82,6 +83,9 @@ class Engine { /*! \brief Abort the input request (specified by id string) from engine. */ virtual void AbortRequest(const String& request_id) = 0; + /*! \brief Abort all requests from the engine. */ + virtual void AbortAllRequests() = 0; + /*********************** Engine Action ***********************/ /*! diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index aea455a1be..78e3937d0b 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -47,7 +47,7 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector sample_results = sampler_->BatchSampleTokens( + std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( probs_on_device, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index b56f7fa9b6..c1ddeb6e4e 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -116,8 +116,10 @@ class BatchDraftActionObj : public EngineActionObj { std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 6f38292ba3..42c9bbe018 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -7,6 +7,7 @@ #include #include +#include #include "../../random.h" #include "../config.h" @@ -42,8 +43,8 @@ class BatchVerifyActionObj : public EngineActionObj { return {}; } - const auto& [rsentries, draft_lengths, total_draft_length] = GetDraftsToVerify(estate); - ICHECK_EQ(rsentries.size(), draft_lengths.size()); + const auto& [rsentries, verify_lengths, total_verify_length] = GetDraftsToVerify(estate); + ICHECK_EQ(rsentries.size(), verify_lengths.size()); if (rsentries.empty()) { return {}; } @@ -62,7 +63,7 @@ class BatchVerifyActionObj : public EngineActionObj { std::vector> draft_output_tokens; std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_rsentries); - all_tokens_to_verify.reserve(total_draft_length); + all_tokens_to_verify.reserve(total_verify_length); verify_request_mstates.reserve(num_rsentries); rngs.reserve(num_rsentries); generation_cfg.reserve(num_rsentries); @@ -73,12 +74,12 @@ class BatchVerifyActionObj : public EngineActionObj { RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; RequestModelState draft_mstate = rsentries[i]->mstates[draft_model_id_]; request_internal_ids.push_back(verify_mstate->internal_id); - ICHECK(!draft_lengths.empty()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_prob_dist.size()); - // the last committed token + all the draft tokens but the last one. + ICHECK(!verify_lengths.empty()); + ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_tokens.size() + 1); + ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_prob_dist.size() + 1); + // the last committed token + all the draft tokens. all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); - for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()) - 1; ++j) { + for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); } verify_request_mstates.push_back(verify_mstate); @@ -95,19 +96,19 @@ class BatchVerifyActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "start verify"); NDArray logits = - models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids, draft_lengths); + models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids, verify_lengths); RECORD_EVENT(trace_recorder_, request_ids, "finish verify"); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], total_draft_length); + ICHECK_EQ(logits->shape[1], total_verify_length); // - Update logits. std::vector cum_verify_lengths = {0}; cum_verify_lengths.reserve(num_rsentries + 1); for (int i = 0; i < num_rsentries; ++i) { - cum_verify_lengths.push_back(cum_verify_lengths.back() + draft_lengths[i]); + cum_verify_lengths.push_back(cum_verify_lengths.back() + verify_lengths[i]); } - logits = logits.CreateView({total_draft_length, logits->shape[2]}, logits->dtype); + logits = logits.CreateView({total_verify_length, logits->shape[2]}, logits->dtype); logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates, request_ids, &cum_verify_lengths, &draft_output_tokens); @@ -115,9 +116,14 @@ class BatchVerifyActionObj : public EngineActionObj { NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits( logits, generation_cfg, request_ids, &cum_verify_lengths); - std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( - probs_on_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, - draft_output_prob_dist); + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector> sample_results_arr = + sampler_->BatchVerifyDraftTokensWithProbAfterTopP( + renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, + draft_output_tokens, draft_output_prob_dist); ICHECK_EQ(sample_results_arr.size(), num_rsentries); for (int i = 0; i < num_rsentries; ++i) { @@ -128,10 +134,8 @@ class BatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); } estate->stats.total_accepted_length += accept_length; - // - Minus one because the last draft token has no kv cache entry - // - Take max with 0 in case of all accepted. int rollback_length = - std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length - 1, 0); + std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0); // rollback kv cache // NOTE: when number of small models is more than 1 (in the future), // it is possible to re-compute prefill for the small models. @@ -158,10 +162,10 @@ class BatchVerifyActionObj : public EngineActionObj { struct DraftRequestStateEntries { /*! \brief The request state entries to verify. */ Array draft_rsentries; - /*! \brief The draft length of each request state. */ - std::vector draft_lengths; + /*! \brief The length to verify for each request state. */ + std::vector verify_lengths; /*! \brief The total draft length. */ - int total_draft_length; + int total_verify_length; }; /*! @@ -171,8 +175,8 @@ class BatchVerifyActionObj : public EngineActionObj { * state and input length. */ DraftRequestStateEntries GetDraftsToVerify(EngineState estate) { - std::vector draft_lengths; - int total_draft_length = 0; + std::vector verify_lengths; + int total_verify_length = 0; int total_required_pages = 0; int num_available_pages = models_[verify_model_id_]->GetNumAvailablePages(); @@ -184,24 +188,24 @@ class BatchVerifyActionObj : public EngineActionObj { int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size(); int num_require_pages = (draft_length + engine_config_->kv_cache_page_size - 1) / engine_config_->kv_cache_page_size; - draft_lengths.push_back(draft_length); + verify_lengths.push_back(draft_length + 1); num_page_requirement.push_back(num_require_pages); - total_draft_length += draft_length; + total_verify_length += draft_length + 1; total_required_pages += num_require_pages; } while (!CanVerify(total_required_pages)) { RequestStateEntry preempted = PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { - total_draft_length -= draft_lengths.back(); + total_verify_length -= verify_lengths.back(); total_required_pages -= num_page_requirement.back(); - draft_lengths.pop_back(); + verify_lengths.pop_back(); num_page_requirement.pop_back(); running_rsentries.pop_back(); } } - return {running_rsentries, draft_lengths, total_draft_length}; + return {running_rsentries, verify_lengths, total_verify_length}; } bool CanVerify(int num_required_pages) { diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc index 50393c38a2..fde314a5c5 100644 --- a/cpp/serve/engine_actions/eagle_batch_draft.cc +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -145,8 +145,10 @@ class EagleBatchDraftActionObj : public EngineActionObj { std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 043f68b9c2..b259417050 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -88,7 +88,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { generation_cfg.push_back(rsentries[i]->request->generation_cfg); rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - CHECK(draft_mstate->draft_output_prob_dist[0]->device.device_type == kDLCPU); draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } @@ -129,10 +128,14 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Compute probability distributions. NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits( logits, generation_cfg, request_ids, &cum_verify_lengths); - - std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( - probs_on_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, - draft_output_prob_dist); + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector> sample_results_arr = + sampler_->BatchVerifyDraftTokensWithProbAfterTopP( + renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, + draft_output_tokens, draft_output_prob_dist); ICHECK_EQ(sample_results_arr.size(), num_rsentries); std::vector last_hidden_states; @@ -230,8 +233,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj { std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 133c23e8a1..a687e7eb7f 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -277,8 +277,10 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } } std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); // - Update the committed tokens of states. @@ -459,7 +461,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // No exceeding of the maximum allowed requests that can // run simultaneously. int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? engine_config_->spec_draft_length + ? (engine_config_->spec_draft_length + 1) : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index c3f7491960..b4192a04f1 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -229,7 +229,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { rsentry_activated.push_back(true); } } - std::vector sample_results = sampler_->BatchSampleTokens( + std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( probs_on_device, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); @@ -396,10 +396,15 @@ class NewRequestPrefillActionObj : public EngineActionObj { int num_running_rsentries) { ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); + // For RNN State, it can prefill as long as it can be instantiated. + if (engine_config_->kv_state_kind == KVStateKind::kRNNState) { + return true; + } + // No exceeding of the maximum allowed requests that can // run simultaneously. int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? engine_config_->spec_draft_length + ? (engine_config_->spec_draft_length + 1) : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { diff --git a/cpp/serve/event_trace_recorder.h b/cpp/serve/event_trace_recorder.h index fd98cc844a..76e87ca710 100644 --- a/cpp/serve/event_trace_recorder.h +++ b/cpp/serve/event_trace_recorder.h @@ -22,7 +22,7 @@ using namespace tvm::runtime; class EventTraceRecorderObj : public Object { public: /*! - * \brief Record a event for the the input request in the trace recorder. + * \brief Record a event for the input request in the trace recorder. * \param request_id The subject request of the event. * \param event The event in a string name. * It can have one of the following patterns: diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index fa24828399..3267f1dd38 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -69,7 +69,8 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, }); } -void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config) { +void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session) { local_gpu_device = device; Device null_device{DLDeviceType(0), 0}; int num_shards; @@ -85,29 +86,10 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->cached_buffers = Map(); if (num_shards > 1) { - constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; - if (Registry::Get(f_create_process_pool) == nullptr) { - LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " - << "Multi-GPU inference depends on MLC LLM Python API to launch process."; - } - std::string ccl; - if (device.device_type == kDLCUDA) { - ccl = "nccl"; - } else if (device.device_type == kDLROCM) { - ccl = "rccl"; - } else { - LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) - << " is not supported. Currently, only NCCL and RCCL are integrated."; - } - std::vector device_ids(num_shards); - for (int i = 0; i < num_shards; ++i) { - device_ids[i] = i; - } + this->sess = session.value(); this->use_disco = true; - this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); - this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), - std::move(reload_lib_path), null_device); + reload_lib_path, null_device); this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { @@ -130,14 +112,23 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->_InitFunctions(); } else { Module executable{nullptr}; - if (false) { - // Todo(mlc-team): system lib reload // reload_lib_path starts with "system://" + PackedFunc fload_exec{nullptr}; + if (StartsWith(reload_lib_path, "system://")) { + const PackedFunc* f_load_system_lib = Registry::Get("runtime.SystemLib"); + ICHECK_NOTNULL(f_load_system_lib); + std::string system_lib_prefix = std::string(reload_lib_path).substr(9); + std::replace(system_lib_prefix.begin(), system_lib_prefix.end(), /*old=*/'-', /*new=*/'_'); + executable = (*f_load_system_lib)(system_lib_prefix + "_"); + fload_exec = executable->GetFunction("vm_load_executable"); + ICHECK(fload_exec.defined()) + << "Cannot find system lib with " << system_lib_prefix + << ", please make sure you set model_lib field consistently with the compilation "; } else { executable = tvm::runtime::Module::LoadFromFile(reload_lib_path); + fload_exec = executable->GetFunction("vm_load_executable"); + ICHECK(fload_exec.defined()) << "TVM runtime cannot find vm_load_executable"; } this->use_disco = false; - auto fload_exec = executable->GetFunction("vm_load_executable"); - ICHECK(fload_exec.defined()) << "TVM runtime cannot find vm_load_executable"; this->local_vm = fload_exec(); this->local_vm->GetFunction("vm_initialization")( static_cast(device.device_type), device.device_id, @@ -225,8 +216,8 @@ void FunctionTable::_InitFunctions() { this->verify_to_last_hidden_func_ = mod_get_func("batch_verify_to_last_hidden_states"); this->fuse_embed_hidden_func_ = mod_get_func("fuse_embed_hidden_states"); Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; - this->get_logits_func_ = mod->GetFunction("get_logits", true); - this->batch_get_logits_func_ = mod->GetFunction("batch_get_logits", true); + this->get_logits_func_ = mod_get_func("get_logits"); + this->batch_get_logits_func_ = mod_get_func("batch_get_logits"); this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true); this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); @@ -235,7 +226,12 @@ void FunctionTable::_InitFunctions() { this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor"); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (!this->create_kv_cache_func_.defined()) { - this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state"); + if (f_create_rnn_state.defined()) { + this->create_kv_cache_func_ = f_create_rnn_state; + } else { + this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + } ICHECK(this->create_kv_cache_func_.defined()); } this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); @@ -256,6 +252,8 @@ void FunctionTable::_InitFunctions() { gpu_argsort_probs_func_ = mod->GetFunction("argsort_probs", true); gpu_sample_with_top_p_func_ = mod->GetFunction("sample_with_top_p", true); gpu_sampler_take_probs_func_ = mod->GetFunction("sampler_take_probs", true); + gpu_verify_draft_tokens_func_ = mod->GetFunction("sampler_verify_draft_tokens", true); + gpu_renormalize_by_top_p_func_ = mod->GetFunction("renormalize_by_top_p", true); } this->nd_view_func_ = get_global_func("vm.builtin.reshape"); this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index f6a156b8a3..bc2b4f21c8 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -41,7 +41,8 @@ using namespace tvm::runtime; struct FunctionTable { static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name); - void Init(String reload_lib_path, Device device, picojson::object model_config); + void Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session); ObjectRef LoadParams(const std::string& model_path, Device device); @@ -104,6 +105,8 @@ struct FunctionTable { PackedFunc gpu_argsort_probs_func_; PackedFunc gpu_sample_with_top_p_func_; PackedFunc gpu_sampler_take_probs_func_; + PackedFunc gpu_verify_draft_tokens_func_; + PackedFunc gpu_renormalize_by_top_p_func_; PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; diff --git a/cpp/serve/grammar/grammar_serializer.h b/cpp/serve/grammar/grammar_serializer.h index 8746b1f6ae..4ad5c2103b 100644 --- a/cpp/serve/grammar/grammar_serializer.h +++ b/cpp/serve/grammar/grammar_serializer.h @@ -77,7 +77,7 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { }; /*! - * \brief Serialize the the raw representation of the BNF AST to a string with JSON format. + * \brief Serialize the raw representation of the BNF AST to a string with JSON format. * \sa BNFJSONParser::Parse for parsing the JSON string. * \details JSON format: * { diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index d9954f1e28..5c4ef98efe 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -469,9 +469,10 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenizer") TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenTable") .set_body([](TVMArgs args, TVMRetValue* rv) { BNFGrammar grammar = args[0]; + Array token_table_arr = args[1]; std::vector token_table; - for (int i = 1; i < args.size() - 1; ++i) { - token_table.push_back(args[i]); + for (int i = 0; i < token_table_arr.size(); ++i) { + token_table.push_back(token_table_arr[i]); } int max_rollback_steps = args[args.size() - 1]; auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table); diff --git a/cpp/serve/grammar/json_schema_converter.cc b/cpp/serve/grammar/json_schema_converter.cc index 93d693f3c6..83be710cf5 100644 --- a/cpp/serve/grammar/json_schema_converter.cc +++ b/cpp/serve/grammar/json_schema_converter.cc @@ -23,6 +23,14 @@ namespace serve { using namespace tvm::runtime; +// EMCC somehow cannot pickup operator overload from picojson.h, so we copy here. +#ifdef COMPILE_MLC_WASM_RUNTIME +inline std::ostream& operator<<(std::ostream& os, const picojson::value& x) { + x.serialize(std::ostream_iterator(os)); + return os; +} +#endif + /*! * \brief Manage the indent and separator for the generation of EBNF grammar. * \param indent The number of spaces for each indent. If it is std::nullopt, there will be no diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 17121d8e28..6f34220219 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -13,6 +13,7 @@ #include +#include "config.h" #include "logit_processor.h" namespace mlc { @@ -25,10 +26,27 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); -Model Model::Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) { - return Model( - make_object(reload_lib_path, model_path, device, max_num_sequence, trace_enabled)); +Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) { + return Model(make_object(reload_lib_path, model_path, model_config, device, + max_num_sequence, session, trace_enabled)); +} + +picojson::object Model::LoadModelConfig(const String& model_path) { + picojson::object model_config; + std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); + std::ostringstream config_ostream; + ICHECK(config_istream); + config_ostream << config_istream.rdbuf(); + std::string config_str = config_ostream.str(); + picojson::value config_json; + std::string err = picojson::parse(config_json, config_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + picojson::object config = config_json.get(); + return config; } class ModelImpl : public ModelObj { @@ -37,23 +55,16 @@ class ModelImpl : public ModelObj { * \brief Constructor of ModelImpl. * \sa Model::Create */ - explicit ModelImpl(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) + explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) : device_(device) { // Step 1. Process model config json string. - picojson::object model_config; - { - std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); - std::ostringstream config_ostream; - ICHECK(config_istream); - config_ostream << config_istream.rdbuf(); - std::string config_str = config_ostream.str(); - model_config = LoadModelConfigJSON(config_str); - } + LoadModelConfigJSON(model_config); // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. - this->ft_.Init(reload_lib_path, device_, model_config); + this->ft_.Init(reload_lib_path, device_, model_config, session); // Step 3. Load params in nd-array cache. this->params_ = ft_.LoadParams(model_path, device_); // Step 4. Set max_num_sequence @@ -68,6 +79,12 @@ class ModelImpl : public ModelObj { token_ids_storage_ = memory::Storage( allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); + // Step 7. Set model type + if (model_config["model_type"].get().find("rwkv") != std::string::npos) { + this->kind = KVStateKind::kRNNState; + } else { + this->kind = KVStateKind::kAttention; + } } /*********************** Model Computation ***********************/ @@ -136,16 +153,23 @@ class ModelImpl : public ModelObj { ICHECK_EQ(hidden_states->device.device_type, device_.device_type); ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - hidden_states_dref_or_nd = + hidden_states = hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); + // This copy can be avoided by not copying the hidden states to engine. + hidden_states_dref_or_nd = ft_.CopyToWorker0( + hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - NDArray logits; - logits = Downcast(ret); + NDArray logits{nullptr}; + if (ret->IsInstance()) { + logits = Downcast(ret)->DebugGetFromRemote(0); + } else { + logits = Downcast(ret); + } CHECK(logits.defined()); // logits: (b * s, v) ICHECK_EQ(logits->ndim, 2); @@ -185,8 +209,11 @@ class ModelImpl : public ModelObj { ICHECK_EQ(hidden_states->device.device_type, device_.device_type); ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - hidden_states_dref_or_nd = - hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + hidden_states = hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + + // This copy can be avoided by not copying the hidden states to engine. + hidden_states_dref_or_nd = ft_.CopyToWorker0( + hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); ObjectRef ret = ft_.batch_get_logits_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); @@ -218,8 +245,15 @@ class ModelImpl : public ModelObj { p_logit_pos[i] = total_length - 1; } NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); + + // This step runs on the engine thread. + // By temporarily turning off the disco flag, this copies the logit_pos_nd to the cached device + // tensor without actually copying to the worker. + bool use_disco = ft_.use_disco; + ft_.use_disco = false; ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); + ft_.use_disco = use_disco; CHECK(ft_.batch_select_last_hidden_func_.defined()) << "`batch_select_last_hidden_states` function is not found in the model."; @@ -240,7 +274,7 @@ class ModelImpl : public ModelObj { hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); ObjectRef ret = - ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); + ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } @@ -265,10 +299,17 @@ class ModelImpl : public ModelObj { // No ICHECK_EQ(hidden->shape[0], hidden_size_) here to allow different hidden_sizes. hidden = hidden.CreateView({1, hidden_size_}, hidden->dtype); // Reuse the copy embedding function - ft_.nd_copy_embedding_to_offset_func_(hidden, *dst, cum_length); + ObjectRef hidden_dref_or_nd = + ft_.CopyToWorker0(hidden, "hidden_for_concat", {1, hidden_size_}); + ft_.nd_copy_embedding_to_offset_func_(hidden_dref_or_nd, *dst, cum_length); cum_length += 1; } - NDArray ret = Downcast(*dst); + NDArray ret{nullptr}; + if ((*dst)->IsInstance()) { + ret = Downcast(*dst)->DebugGetFromRemote(0); + } else { + ret = Downcast(*dst); + } ret = ret.CreateView({cum_length, hidden_size_}, hidden_states[0]->dtype); return ret; } @@ -295,7 +336,7 @@ class ModelImpl : public ModelObj { return embeddings_nd.CreateView({batch_size, seq_len, hidden_size_}, embeddings_nd->dtype); } } else { - ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; + ShapeTuple embedding_shape{batch_size * seq_len, hidden_size_}; embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { @@ -715,16 +756,26 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) final { - IntTuple max_num_sequence_tuple{max_num_sequence}; - IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; - IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; - IntTuple page_size_tuple{page_size}; - IntTuple support_sliding_window{sliding_window_size_ != -1}; - kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, - prefill_chunk_size_tuple, page_size_tuple, - support_sliding_window); - local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) final { + if (kv_state_kind == KVStateKind::kAttention) { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; + IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; + IntTuple page_size_tuple{page_size}; + IntTuple support_sliding_window{sliding_window_size_ != -1}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, + prefill_chunk_size_tuple, page_size_tuple, + support_sliding_window); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } else { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_history_size_tuple = {std::max(max_history_size, 1)}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_history_size_tuple); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } } void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } @@ -751,11 +802,21 @@ class ModelImpl : public ModelObj { /************** Raw Info Query **************/ int GetNumAvailablePages() const final { - return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not introduce new page at runtime + return std::numeric_limits::max(); + } else { + return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + } } int GetCurrentTotalSequenceLength() const final { - return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not have a total sequence length limit + return 0; + } else { + return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + } } /*********************** Utilities ***********************/ @@ -768,9 +829,7 @@ class ModelImpl : public ModelObj { Sampler CreateSampler(int max_num_sample, int num_models, Optional trace_recorder) { - if (num_models > 1) { // speculative decoding uses cpu sampler - return Sampler::CreateCPUSampler(std::move(trace_recorder)); - } else if (Sampler::SupportGPUSampler(device_)) { + if (Sampler::SupportGPUSampler(device_)) { return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, std::move(trace_recorder)); } else { @@ -842,15 +901,7 @@ class ModelImpl : public ModelObj { private: /*! \brief Load model configuration from JSON. */ - picojson::object LoadModelConfigJSON(const std::string& config_str) { - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); + picojson::object LoadModelConfigJSON(picojson::object config) { if (config.count("context_window_size")) { CHECK(config["context_window_size"].is()); this->max_window_size_ = config["context_window_size"].get(); @@ -924,6 +975,8 @@ class ModelImpl : public ModelObj { NDArray logit_pos_arr_{nullptr}; // A boolean indicating if tracing is enabled. bool trace_enabled_; + // An enum indicating whether it's RNN-based. + KVStateKind kind; }; TVM_REGISTER_GLOBAL("mlc.copy_embedding_to_offset") diff --git a/cpp/serve/model.h b/cpp/serve/model.h index da532f83e8..bc63840a74 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -234,9 +234,13 @@ class ModelObj : public Object { * in the engine. * \param prefill_chunk_size The maximum total number of tokens whose KV data * are allowed to exist in the KV cache at any time. + * \param max_history_size The maximum history size for RNN state to roll back. + * The KV cache does not need this. + * \param kv_state_kind The kind of cache. It can be KV cache or RNN state. */ virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) = 0; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; @@ -315,13 +319,24 @@ class Model : public ObjectRef { * \brief Create the runtime module for LLM functions. * \param reload_lib_path The model library path. * \param model_path The path to the model weight parameters. + * \param model_config The model config json object. * \param device The device to run the model on. * \param max_num_sequence The maximum number of sequences to be processed + * \param session The session to run the model on. * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ - TVM_DLL static Model Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled); + TVM_DLL static Model Create(String reload_lib_path, String model_path, + const picojson::object& model_config, DLDevice device, + int max_num_sequence, const Optional& session, + bool trace_enabled); + + /*! + * Load the model config from the given model path. + * \param model_path The path to the model weight parameters. + * \return The model config json object. + */ + static picojson::object LoadModelConfig(const String& model_path); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); }; diff --git a/cpp/serve/radix_tree.cc b/cpp/serve/radix_tree.cc new file mode 100644 index 0000000000..5d5c311593 --- /dev/null +++ b/cpp/serve/radix_tree.cc @@ -0,0 +1,718 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/radix_tree.cc + */ +#include "radix_tree.h" + +#include + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief The sequence ID linked list structure in paged radix tree node. + */ +struct SequenceIDNode { + /*! \brief The stored sequence ID. */ + int64_t id = 0; + /*! \brief The pointer to the next sequence ID. */ + SequenceIDNode* next = nullptr; +}; + +/*! + * \brief The sequence Id node pool. + * + * The sequence Id node pool allocates all sequence ID nodes when construction and frees when + * destruction, to avoid frequent memory operation. + */ +class SequenceIDNodePool { + public: + /*! \brief The constructor of sequence Id node pool, allocating memory for each node. */ + SequenceIDNodePool(size_t num_nodes) : num_nodes_(num_nodes) { + nodes_.reserve(num_nodes); + free_node_indicess_.reserve(num_nodes); + used_nodes_.clear(); + raw_pool_ = new SequenceIDNode[num_nodes_]; + for (size_t i = 0; i < num_nodes; ++i) { + nodes_.push_back(&raw_pool_[i]); + free_node_indicess_.push_back(i); + } + } + + /*! + * \brief Get a radix page from pool, and assign the fields. + * \param seq_id The assigned sequence ID of allocated sequence ID node. + * \param node The next sequence ID node pointer of allocated sequence ID node. + * \return The allocated radix page. + * \throw Error if no free radix page available in pool. + */ + SequenceIDNode* Allocate(int64_t seq_id, SequenceIDNode* next) { + CHECK(!free_node_indicess_.empty()) << "Sequence ID node pool has no free sequence ID nodes."; + size_t id = free_node_indicess_.back(); + free_node_indicess_.pop_back(); + SequenceIDNode* node = nodes_[id]; + used_nodes_[node] = id; + node->id = seq_id; + node->next = next; + return node; + } + + /*! + * \brief Free a sequence ID node to pool. + * \param node The sequence ID node to free. + */ + void Free(SequenceIDNode* node) { + CHECK(used_nodes_.find(node) != used_nodes_.end()); + free_node_indicess_.push_back(used_nodes_[node]); + used_nodes_.erase(node); + } + + /*! \brief The destructor of sequence Id node pool, freeing memory for each node. */ + ~SequenceIDNodePool() { delete[] raw_pool_; } + + private: + /*! \brief The number of nodes in sequence ID node pool. */ + size_t num_nodes_; + /*! \brief The raw sequence ID node pool. */ + SequenceIDNode* raw_pool_; + /*! \brief The sequence ID node pool. */ + std::vector nodes_; + /*! \brief The indices of free sequence ID node in node pool. */ + std::vector free_node_indicess_; + /*! \brief The map from used paged sequence ID node to its index in node pool. */ + std::unordered_map used_nodes_; +}; + +/*! + * \brief The paged radix tree node data structure. + * + * The paged radix tree node is similar to original radix tree node, but with the limited length for + * prefix in page, so that the memory usage in each page is the same and is fixed once allocated. + * Since the page only consists of pointers and int tokens, the page memory layout is int array + * indeed. The lower offset is the pointers and page information, while the higher offset is the + * stored prefix tokens. + * + * And since the vocabulary size may be very large, the paged Radix tree is represented + * as left-child, right-sibling binary tree. + * + * Also, due to possible pop/push front/back tokens in page, the page is designed as circular + * buffer, to make full use of each page. + * + * Each page records the sequence excatly ends with the prefix tokens stored in page. In other word, + * all sequences locate in the boundary of each page, or the end of each page. + */ +struct RedixPage { + /*! \brief The parent page. */ + RedixPage* parent; + /*! \brief The first child page. */ + RedixPage* first_child; + /*! \brief The sibling page shareing the same parent page. */ + RedixPage* next_sibiling; + /*! \brief The head of sequence ID linked list. */ + SequenceIDNode* seq_ids; + /*! \brief The capacity of maximum stored prefix tokens. */ + size_t capacity; + /*! \brief The start offset of stored prefix tokens. The legal value is of [0, capacity). */ + size_t offset; + /*! \brief The length of stored prefix tokens. The legal value is of [0, capacity). */ + size_t length; + /*! \brief The offset of first prefix token in memory layout. */ + static constexpr int DATA_OFFSET = (sizeof(RedixPage*) * 3 + sizeof(SequenceIDNode*) + + sizeof(size_t) * 3 + sizeof(int32_t) - 1) / + sizeof(int32_t); + + /*! + * \brief Overload opeartor [] to get the prefix tokens by index as simple int array. + * \param i The prefix token index. + * \return The value of i-th prefix token. + */ + int32_t& operator[](size_t i) { + return reinterpret_cast(this)[DATA_OFFSET + (i + offset) % capacity]; + } + + /*! + * \brief Extend or push back a suffix tokens in page. + * \param suffix The suffix tokens array. + * \param suffix_length The suffix length to extend. + * \throw Error if suffix length is larger than current vacant space. + */ + void Extend(const int64_t* suffix, size_t suffix_length) { + CHECK_LE(suffix_length + length, capacity); + for (int i = 0; i < suffix_length; ++i) { + (*this)[i + length] = (int32_t)suffix[i]; + } + length += suffix_length; + } + + /*! + * \brief Add a sequence ID in page. + * \param pool The sequence ID node pool to allocate new node. + * \param id The sequence ID to add. + */ + void AddSequence(SequenceIDNodePool* pool, int64_t id) { seq_ids = pool->Allocate(id, seq_ids); } + + /*! + * \brief Pop a sequence ID in page. + * \param pool The sequence ID node pool to free popped node. + * \param id The sequence ID to pop. + * \throw Error if no such sequence ID in page. + */ + void PopSequence(SequenceIDNodePool* pool, int64_t id) { + if (seq_ids->id == id) { + // If the popped sequencs ID is the first node in linked list, + // directly skip from head and free it. + SequenceIDNode* next = seq_ids->next; + pool->Free(seq_ids); + seq_ids = next; + } else { + // If the popped sequencs ID is not the first node in linked list, + // skip it from previous node and free it. + SequenceIDNode* last = seq_ids; + SequenceIDNode* cur = seq_ids->next; + while (cur) { + if (cur->id == id) { + last->next = cur->next; + pool->Free(cur); + return; + } + } + LOG(FATAL) << "Sequence ID = " << id << " not found."; + } + } + + /*! + * \brief Get all sequence ID in page. + * \return The std::vector of sequence ID in page. + */ + std::vector GetLocalSequence() { + std::vector output; + for (SequenceIDNode* node = seq_ids; node; node = node->next) { + output.push_back(node->id); + } + return output; + } + + /*! + * \brief Get any sequence ID in current page or child pages. + * Since there is always a sequence in leaf pages, it only check first child if no sequence ID in + * current page. + * \return The any sequence ID in current page or child pages. + */ + int32_t FindAnyChildSequence() { + if (seq_ids) return seq_ids->id; + return first_child->FindAnyChildSequence(); + } + + /*! + * \brief Get all sequence ID in current page and child pages, using Iterate method with lambda + * expression as callback to avoid frequently memory allocation of std::vector. + * \return The std::vector of all sequence ID in current page and child pages. + */ + std::vector FindAllChildSequence() { + std::vector output = GetLocalSequence(); + if (first_child) { + first_child->Iterate([&output](const RedixPage* page) { + for (SequenceIDNode* node = page->seq_ids; node; node = node->next) { + output.push_back(node->id); + } + }); + } + return output; + } + + /*! + * \brief The iteration method for tree or sub-tree traverse. + * \param f The callback function to invoke at each radix page visited. + */ + template + void Iterate(CallbackFunc f) { + f(this); + if (next_sibiling) next_sibiling->Iterate(f); + if (first_child) first_child->Iterate(f); + } + + /*! + * \brief Get the last sibling of current page. + * \return The page whose next_sibling is current page, or nullptr if current is the fisrt_child + * of its parent page. + */ + RedixPage* GetLastSibling() { + if (parent == nullptr) return nullptr; + if (parent->first_child == this) return nullptr; + for (RedixPage* child = parent->first_child; child; child = child->next_sibiling) { + if (child->next_sibiling == this) return child; + } + return nullptr; + } + + /*! + * \brief Find the child indexed by first token. + * \return The child page started with first token, or nullptr if no such child page. + */ + RedixPage* FindChild(int64_t first_token) { + int32_t casted = first_token; + // Iterate all child radix pages, as the child radix pages are stored unorderly. + for (RedixPage* child = first_child; child; child = child->next_sibiling) { + if ((*child)[0] == casted) return child; + } + return nullptr; + } + + /*! \brief Insert a new child page. */ + void InsertChild(RedixPage* child) { + child->parent = this; + child->next_sibiling = first_child; + first_child = child; + } + + /*! + * \brief Remove a child page. + * \throw Error if page to be removed is not child page. + */ + void RemoveChild(RedixPage* child) { + CHECK(child->parent == this); + if (first_child == child) { + first_child = child->next_sibiling; + } else { + child->GetLastSibling()->next_sibiling = child->next_sibiling; + } + } + + /*! + * \brief Check current page is mergable with its child page. + * The page is mergable if and only if + * 1. No sequence ID in current page, as sequence ID is not allowed to exist within page. + * 2. The current page has child page. + * 3. The current page has only one child page. + * 4. The current page perfix and the child page prefix can be concatenated into one page. + * \return True if current page is mergable, or false. + */ + bool Mergeable() { + if (seq_ids) return false; + if (!first_child) return false; + if (first_child->next_sibiling) return false; + if (length + first_child->length > capacity) return false; + return true; + } + + /*! + * \brief Match the given prefix within page. + * \param prefix The prefix token array. + * \param prefix_length The length of prefix token array. + * \return The matched prefix offset within page, or the first mismatched token position. The + * possible return value is [0, page->length], where page->length means the page is completely the + * prefix of given prefix. + */ + size_t MatchPrefix(const int64_t* prefix, size_t prefix_length) { + size_t n = std::min(length, prefix_length); + for (int i = 0; i < n; ++i) { + if ((*this)[i] != prefix[i]) return i; + } + return n; + } +}; + +/*! + * \brief The paged radix tree page pool. + * + * The paged radix tree page pool allocates all radix tree pages when construction and frees when + * destruction, to avoid frequent memory operation. + */ +class RadixPagePool { + public: + /*! \brief The constructor of paged radix tree page pool, allocating memory for each page. */ + RadixPagePool(size_t page_size, size_t num_pages) : page_size_(page_size), num_pages_(num_pages) { + pages_.reserve(num_pages); + free_page_indices_.reserve(num_pages); + raw_pool_ = new int32_t[num_pages * page_size / sizeof(int32_t)]; + int32_t num_int = page_size / sizeof(int32_t); + for (size_t i = 0; i < num_pages; ++i) { + pages_.push_back(reinterpret_cast(raw_pool_ + i * num_int)); + free_page_indices_.push_back(i); + } + } + + /*! + * \brief Get a radix page from pool. + * \return The allocated radix page. + * \throw Error if no free radix page available in pool. + */ + RedixPage* Allocate() { + CHECK(!free_page_indices_.empty()) << "Radix page pool has no free radix tree pages."; + int id = free_page_indices_.back(); + free_page_indices_.pop_back(); + RedixPage* page = pages_[id]; + used_pages_[page] = id; + page->parent = page->first_child = page->next_sibiling = nullptr; + page->capacity = page_size_ / sizeof(int32_t) - RedixPage::DATA_OFFSET; + page->offset = page->length = 0; + page->seq_ids = nullptr; + return page; + } + + /*! + * \brief Free a radix page to pool. + * \param page The radix page to free. + */ + void Free(RedixPage* page) { + CHECK_EQ(page->seq_ids, nullptr); + CHECK(used_pages_.find(page) != used_pages_.end()); + free_page_indices_.push_back(used_pages_[page]); + CHECK(used_pages_.erase(page)); + } + + /*! + * \brief Get the token capacity of free pages. + * \return The the token capacity of free pages. + */ + size_t FreeCapacity() { + return free_page_indices_.size() * (page_size_ / sizeof(int32_t) - RedixPage::DATA_OFFSET); + } + + /*! \brief The destructor of paged radix tree page pool, freeing memory for each page. */ + ~RadixPagePool() { delete[] raw_pool_; } + + private: + /*! \brief The page size of each paged radix tree page. */ + size_t page_size_; + /*! \brief The number of pages in paged radix tree page pool. */ + size_t num_pages_; + /*! \brief The raw paged radix tree page pool. */ + int32_t* raw_pool_; + /*! \brief The paged radix tree page pool. */ + std::vector pages_; + /*! \brief The indices of free paged radix page in page pool. */ + std::vector free_page_indices_; + /*! \brief The map from used paged radix tree page to its index in page pool. */ + std::unordered_map used_pages_; +}; + +// PagedRadixTree + +/*! + * \brief The paged radix tree data structure. + */ +class PagedRadixTreeImpl : public PagedRadixTreeObj { + public: + /*! \brief The page size of each paged radix tree node. */ + size_t page_size; + /*! \brief The number of pages in paged radix tree page pool. */ + size_t num_pages; + /*! \brief The maximum number of sequence ID in paged radix tree page pool. */ + size_t num_seqs; + /*! \brief The map from sequence to paged radix tree node it is stored. */ + std::unordered_map seq2page; + /*! \brief The sequence ID node pool. */ + SequenceIDNodePool* seq_id_node_pool = nullptr; + /*! \brief The radix page pool. */ + RadixPagePool* radix_page_pool = nullptr; + /*! \brief The root page of paged radix tree. */ + RedixPage* root = nullptr; + + explicit PagedRadixTreeImpl(size_t num_pages, size_t page_size, size_t num_seqs) { + num_pages = num_pages; + page_size = page_size; + num_seqs = num_seqs; + + seq_id_node_pool = new SequenceIDNodePool(num_seqs); + radix_page_pool = new RadixPagePool(page_size, num_pages); + + root = reinterpret_cast(new int32_t[RedixPage::DATA_OFFSET]); + root->parent = root->first_child = root->next_sibiling = nullptr; + root->offset = root->length = root->capacity = 0; + root->seq_ids = nullptr; + } + + /*! + * \brief Get a sequence's all tokens. + * \param seq_id The sequence ID for index. + * \return The sequence tokens. + * \throw Error if sequence ID is not valid. + */ + IntTuple GetSequence(int64_t seq_id) { + CHECK(seq2page.find(seq_id) != seq2page.end()); + size_t length = GetSequenceLength(seq_id); + std::vector output(length); + size_t offset = length; + for (RedixPage* page = seq2page[seq_id]; page; page = page->parent) { + offset -= page->length; + for (int i = 0; i < page->length; ++i) { + output[offset + i] = (*page)[i]; + } + } + return IntTuple(output); + } + + /*! + * \brief Get all sequences with longest common prefix with give prefix tokens. + * \param tokens The prefix tokens for reference. + * \return The pair of matched prefix length and the array of matched sequences indices. + */ + std::pair> MatchPrefix(IntTuple tokens) { + const int64_t* prefix = tokens.data(); + size_t length = tokens.size(); + auto [page, offset, in_page_offset] = MatchSequence(root, prefix, length); + if (!offset) return std::make_pair(0, std::vector()); + return std::make_pair(offset, page->FindAllChildSequence()); + } + + /*! + * \brief Get a sequence's length. + * \param seq_id The sequence ID for index. + * \return The sequence length. + * \throw Error if sequence ID is not valid. + */ + size_t GetSequenceLength(int64_t seq_id) { + CHECK(seq2page.find(seq_id) != seq2page.end()); + size_t length = 0; + for (RedixPage* page = seq2page[seq_id]; page; page = page->parent) { + length += page->length; + } + return length; + } + + /*! + * \brief Fork a sequence from parent sequence at given position. + * \param seq_id The new sequence ID. + * \param parent_seq_id The parent sequence ID to fork from. + * \param forked_offset The position of parent sequence to fork at. + * The valid value is [1, length of forked sequence]. If the position equals the length of forked + * sequence, the new sequence will copy the entire forked sequence. + * \throw Error if sequence ID or + * forked postion is not valid. + */ + void ForkSequence(int64_t seq_id, int64_t parent_seq_id, size_t forked_offset) { + CHECK(seq2page.find(seq_id) == seq2page.end()); + CHECK(seq2page.find(parent_seq_id) != seq2page.end()); + CHECK_GT(forked_offset, 0); + size_t length = GetSequenceLength(parent_seq_id); + CHECK_LE(forked_offset, length); + for (RedixPage* page = seq2page[parent_seq_id]; page; page = page->parent) { + if (forked_offset >= length - page->length) { + if (forked_offset < length) { + // Split radix page if forked position is within page + page = SplitPage(page, forked_offset + page->length - length); + } + page->AddSequence(seq_id_node_pool, seq_id); + seq2page[seq_id] = page; + return; + } + length -= page->length; + } + } + + /*! + * \brief Add an empty sequence at root. + * \param seq_id The new sequence ID. + * \throw Error if sequence ID is not valid. + */ + void AddSequence(int64_t seq_id) { + CHECK(seq2page.find(seq_id) == seq2page.end()); + root->AddSequence(seq_id_node_pool, seq_id); + seq2page[seq_id] = root; + } + + /*! + * \brief Extend a sequence with given tokens. + * \param seq_id The sequence ID for index. + * \param tokens The given tokens to extend. + * \throw Error if sequence ID is not valid. + */ + void ExtendSequence(int64_t seq_id, IntTuple tokens) { + CHECK(seq2page.find(seq_id) != seq2page.end()); + const int64_t* suffix = tokens.data(); + size_t length = tokens.size(); + RedixPage* original_page = seq2page[seq_id]; + original_page->PopSequence(seq_id_node_pool, seq_id); + auto [page, offset, in_page_offset] = MatchSequence(original_page, suffix, length); + if (in_page_offset < page->length) { + // Split page if extended sequence mismatches within page + page = SplitPage(page, in_page_offset); + } + if (offset < length && !page->seq_ids && !page->first_child && page->capacity > page->length) { + // Extend in the existing leaf page first if possible. + size_t suffix_length = std::min(page->capacity - page->length, length - offset); + page->Extend(suffix + offset, suffix_length); + offset += suffix_length; + } + while (offset < length) { + // Allocate new radix page and extend tokens + RedixPage* new_page = radix_page_pool->Allocate(); + page->InsertChild(new_page); + page = new_page; + size_t suffix_length = std::min(page->capacity - page->length, length - offset); + page->Extend(suffix + offset, suffix_length); + offset += suffix_length; + } + page->AddSequence(seq_id_node_pool, seq_id); + seq2page[seq_id] = page; + if (original_page->Mergeable()) { + // The original page may be mergeable, as the sequence ID changes + MergePage(original_page); + } + } + + /*! + * \brief Remove a sequence. + * \param seq_id The sequence ID to remove. + * \throw Error if sequence ID is not valid. + */ + void RemoveSequence(int64_t seq_id) { + RedixPage* page = seq2page[seq_id]; + page->PopSequence(seq_id_node_pool, seq_id); + seq2page.erase(seq_id); + while (page->parent && !page->seq_ids && !page->first_child) { + RedixPage* parent = page->parent; + parent->RemoveChild(page); + radix_page_pool->Free(page); + page = parent; + } + if (page && page->Mergeable()) { + // The remaining page may be mergeable, as the sequence ID changes + MergePage(page); + } + } + + /*! + * \brief Get the remaining token capacity of the paged radix tree. + * \return The the remaining token capacity of the paged radix tree. + */ + size_t FreeCapacity() { return radix_page_pool->FreeCapacity(); } + + /*! \brief The destructor to free root page. */ + ~PagedRadixTreeImpl() { + delete[] reinterpret_cast(root); + delete seq_id_node_pool; + delete radix_page_pool; + } + + private: + /*! + * \brief Merge a radix tree page with its child radix tree page, to save radix tree page. + * e.g. MergePage([1, 2, _, _, _] -> [3, 4, 5, _, _]) = [1, 2, 3, 4, 5]. + * And the page to be merged should be page->Mergeable(). + * \param page The parent radix tree page. + */ + void MergePage(RedixPage* page) { + CHECK(page->Mergeable()); + RedixPage* child = page->first_child; + for (int i = 0; i < child->length; ++i) { + (*page)[i + page->length] = (*child)[i]; + } + page->length += child->length; + page->first_child = child->first_child; + for (RedixPage* p = child->first_child; p; p = p->next_sibiling) { + p->parent = page; + } + page->seq_ids = child->seq_ids; + std::vector seq_ids = page->GetLocalSequence(); + for (int64_t id : seq_ids) seq2page[id] = page; + child->seq_ids = nullptr; + radix_page_pool->Free(child); + } + + /*! + * \brief Split a radix tree page at given postition, to accept new sequence. + * e.g. SplitPage([1, 2, 3, 4, 5], 2) = [1, 2, _, _, _] -> [3, 4, 5, _, _]. + * \param page The radix tree page to split. + * \param offset The position to split the radix tree page. + * \return The splitted radix tree page. It can be different from the input radix tree page, as + * there may be implicit radix tree page merge. + */ + RedixPage* SplitPage(RedixPage* page, size_t offset) { + CHECK_LT(offset, page->length); + RedixPage* child = radix_page_pool->Allocate(); + child->parent = page; + child->first_child = page->first_child; + for (RedixPage* p = page->first_child; p; p = p->next_sibiling) { + p->parent = child; + } + page->first_child = child; + for (int i = offset; i < page->length; ++i) { + (*child)[i - offset] = (*page)[i]; + } + child->length = page->length - offset; + page->length = offset; + if (child->Mergeable()) { + // The child page may be mergeable + MergePage(child); + } + if (page->parent && page->parent->Mergeable()) { + // The parent page may be mergeable + page = page->parent; + MergePage(page); + } + return page; + } + + /*! + * \brief Match with given token from a radix tree page, stopping at first mismatch. + * \param page The radix tree page to start matching. + * \param tokens The given tokens to match. + * \param length The length of given tokens. + */ + std::tuple MatchSequence(RedixPage* page, const int64_t* tokens, + size_t length) { + size_t offset = 0; + while (offset < length) { + if (RedixPage* child = page->FindChild(tokens[offset])) { + // If child page starts with offset-th token, common prefix at least ends with child page + size_t matched_offset = child->MatchPrefix(tokens + offset, length - offset); + offset += matched_offset; + if (matched_offset < child->length) { + // Common prefix ends within child page + return std::make_tuple(child, offset, matched_offset); + } + page = child; + } else { + // No child page starts with offset-th token, common prefix ends with current page + return std::make_tuple(page, offset, page->length); + } + } + return std::make_tuple(page, length, page->length); + } +}; + +TVM_REGISTER_OBJECT_TYPE(PagedRadixTreeImpl); + +PagedRadixTree::PagedRadixTree(size_t num_pages, size_t page_size, size_t num_seqs) { + data_ = std::move(make_object(num_pages, page_size, num_pages)); +} + +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTree") + .set_body_typed([](uint64_t num_pages, uint64_t page_size, uint64_t num_seqs) { + return PagedRadixTree(num_pages, page_size, num_seqs); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeMatchPrefix") + .set_body_typed([](PagedRadixTree paged_radix_tree, IntTuple tokens) { + auto [offset, seq_ids] = paged_radix_tree->MatchPrefix(tokens); + seq_ids.insert(seq_ids.begin(), offset); + return IntTuple(seq_ids); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeExtendSequence") + .set_body_method(&PagedRadixTreeObj::ExtendSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeForkSequence") + .set_body_typed([](PagedRadixTree paged_radix_tree, int64_t seq_id, int64_t parent_seq_id, + uint64_t forked_offset) { + paged_radix_tree->ForkSequence(seq_id, parent_seq_id, forked_offset); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeAddSequence") + .set_body_method(&PagedRadixTreeObj::AddSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeRemoveSequence") + .set_body_method(&PagedRadixTreeObj::RemoveSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeGetSequence") + .set_body_method(&PagedRadixTreeObj::GetSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeGetSequenceLength") + .set_body_typed([](PagedRadixTree paged_radix_tree, int64_t seq_id) { + return (int64_t)paged_radix_tree->GetSequenceLength(seq_id); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeFreeCapacity") + .set_body_typed([](PagedRadixTree paged_radix_tree) { + return (int64_t)paged_radix_tree->FreeCapacity(); + }); +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/radix_tree.h b/cpp/serve/radix_tree.h new file mode 100644 index 0000000000..ed831c17b1 --- /dev/null +++ b/cpp/serve/radix_tree.h @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/radix_tree.h + */ +#ifndef MLC_LLM_SERVE_RADIX_TREE_H_ +#define MLC_LLM_SERVE_RADIX_TREE_H_ +#include +#include + +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief The paged radix tree data structure. + */ +class PagedRadixTreeObj : public Object { + public: + /*! + * \brief Get a sequence's all tokens. + * \param seq_id The sequence ID for index. + * \return The sequence tokens. + * \throw Error if sequence ID is not valid. + */ + virtual IntTuple GetSequence(int64_t seq_id) = 0; + + /*! + * \brief Get all sequences with longest common prefix with give prefix tokens. + * \param tokens The prefix tokens for reference. + * \return The pair of matched prefix length and the array of matched sequences indices. + */ + virtual std::pair> MatchPrefix(IntTuple tokens) = 0; + + /*! + * \brief Get a sequence's length. + * \param seq_id The sequence ID for index. + * \return The sequence length. + * \throw Error if sequence ID is not valid. + */ + virtual size_t GetSequenceLength(int64_t seq_id) = 0; + + /*! + * \brief Fork a sequence from parent sequence at given position. + * \param seq_id The new sequence ID. + * \param parent_seq_id The parent sequence ID to fork from. + * \param forked_offset The position of parent sequence to fork at. + * The valid value is [1, length of forked sequence]. If the position equals the length of forked + * sequence, the new sequence will copy the entire forked sequence. + * \throw Error if sequence ID or + * forked postion is not valid. + */ + virtual void ForkSequence(int64_t seq_id, int64_t parent_seq_id, size_t forked_offset) = 0; + + /*! + * \brief Add an empty sequence at root. + * \param seq_id The new sequence ID. + * \throw Error if sequence ID is not valid. + */ + virtual void AddSequence(int64_t seq_id) = 0; + + /*! + * \brief Extend a sequence with given tokens. + * \param seq_id The sequence ID for index. + * \param tokens The given tokens to extend. + * \throw Error if sequence ID is not valid. + */ + virtual void ExtendSequence(int64_t seq_id, IntTuple tokens) = 0; + + /*! + * \brief Remove a sequence. + * \param seq_id The sequence ID to remove. + * \throw Error if sequence ID is not valid. + */ + virtual void RemoveSequence(int64_t seq_id) = 0; + + /*! + * \brief Get the remaining token capacity of the paged radix tree. + * \return The the remaining token capacity of the paged radix tree. + */ + virtual size_t FreeCapacity() = 0; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "mlc.serve.PagedRadixTree"; + TVM_DECLARE_BASE_OBJECT_INFO(PagedRadixTreeObj, Object) +}; + +TVM_REGISTER_OBJECT_TYPE(PagedRadixTreeObj); + +class PagedRadixTree : public ObjectRef { + public: + /*! + * \brief Constructor of paged radix tree. + * \param num_pages The number of radix tree pages. + * \param page_size The page size of each radix tree page. + * \param num_seqs The maximum number of sequence ID. + */ + PagedRadixTree(size_t num_pages, size_t page_size, size_t num_seqs); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PagedRadixTree, ObjectRef, PagedRadixTreeObj); +}; +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_RADIX_TREE_H_ diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index 02b7e2a81d..98080c979d 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include "../../random.h" @@ -43,12 +44,7 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o ICHECK(prob.IsContiguous()); ICHECK(prob.DataType() == DataType::Float(32)); - - if (prob->device.device_type != kDLCPU) { - prob = prob.CopyTo(DLDevice{kDLCPU, 0}); - } - - ICHECK(prob->device.device_type == kDLCPU); + ICHECK_EQ(prob->device.device_type, DLDeviceType::kDLCPU); int64_t ndata = prob->shape[prob->ndim - 1]; const float* __restrict p_prob = @@ -186,6 +182,98 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o return {sampled_index.second, sampled_index.first}; } +/*! + * \brief Renormalize the probability distribution by the top p value. + * \param prob The input batch of probability distributions. + * \param unit_offset The offset specifying which distribution to output + * \param top_p The top p value for renormalization. + * \param eps A small epsilon value for comparison stability. + */ +void RenormalizeProbByTopP(NDArray prob, int unit_offset, double top_p, double eps) { + // prob: (*, v) + // The prob array may have arbitrary ndim and shape. + // The last dimension corresponds to the prob distribution size. + // We use the `unit_offset` parameter to determine which slice + // of the prob array we will renormalize. + ICHECK(prob.IsContiguous()); + ICHECK(prob.DataType() == DataType::Float(32)); + ICHECK_EQ(prob->device.device_type, DLDeviceType::kDLCPU); + + int vocab_size = prob->shape[prob->ndim - 1]; + float* __restrict p_prob = + static_cast(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * vocab_size); + + // We manually choice the cutoff values of "top_p / 256" and "top_p / 8192". + // In most of the cases, only one round is needed. + std::vector cutoff_values{top_p / 256, top_p / 8192, 0.0f}; + + // Create the upper partition vector and the lower partition rolling vectors. + std::vector upper_partition; + std::vector lower_partitions[2]; + upper_partition.reserve(vocab_size); + lower_partitions[0].reserve(vocab_size); + lower_partitions[1].reserve(vocab_size); + float upper_partition_sum = 0.0; + for (int round = 0; round < static_cast(cutoff_values.size()); ++round) { + const float* lower_partition_begin; + const float* lower_partition_end; + if (round == 0) { + lower_partition_begin = p_prob; + lower_partition_end = p_prob + vocab_size; + } else { + int idx = (round - 1) & 1; + lower_partition_begin = lower_partitions[idx].data(); + lower_partition_end = lower_partitions[idx].data() + lower_partitions[idx].size(); + } + + // - Partition the last round lower partition into upper and lower + // based on the new cutoff value. + std::vector& lower_partition = lower_partitions[round & 1]; + lower_partition.clear(); + for (const float* ptr = lower_partition_begin; ptr != lower_partition_end; ++ptr) { + if (*ptr >= cutoff_values[round]) { + upper_partition.push_back(*ptr); + upper_partition_sum += *ptr; + } else { + lower_partition.push_back(*ptr); + } + } + // - If the upper partition sum is at least top p, exit the loop. + if (upper_partition_sum >= top_p - eps) { + break; + } + } + + // - Sort the upper partition in descending order. + std::sort(upper_partition.begin(), upper_partition.end(), std::greater<>()); + // - Find the top p boundary prob value. + float boundary_value = -1.0; + upper_partition_sum = 0.0; + for (float upper_value : upper_partition) { + upper_partition_sum += upper_value; + if (upper_partition_sum >= top_p - eps) { + boundary_value = upper_value; + break; + } + } + // - Mask all values smaller than the boundary to 0. + float renormalize_sum = 0.0; + std::vector upper_partition_indices; + upper_partition_indices.reserve(vocab_size); + for (int i = 0; i < vocab_size; ++i) { + if (p_prob[i] >= boundary_value) { + upper_partition_indices.push_back(i); + renormalize_sum += p_prob[i]; + } else { + p_prob[i] = 0.0; + } + } + // - Renormalize. + for (int idx : upper_partition_indices) { + p_prob[idx] /= renormalize_sum; + } +} + namespace detail { /*! \brief Implementation of getting top probs on CPU. */ @@ -266,68 +354,87 @@ class CPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray probs_on_device, // - const std::vector& sample_indices, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // - std::vector* output_prob_dist) final { + NDArray BatchRenormalizeProbsByTopP(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg) final { // probs_on_device: (n, v) - RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); CHECK_EQ(probs_on_device->ndim, 2); // - Copy probs to CPU RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - NDArray probs_host = CopyProbsToCPU(probs_on_device); + NDArray probs_on_host = CopyProbsToCPU(probs_on_device); RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); - - // - Sample tokens from probabilities. - int n = request_ids.size(); - ICHECK_EQ(generation_cfg.size(), n); - ICHECK_EQ(rngs.size(), n); - - std::vector sample_results; - sample_results.resize(n); - if (output_prob_dist) { - output_prob_dist->resize(n); + int num_samples = sample_indices.size(); + int num_probs = probs_on_device->shape[0]; + int vocab_size = probs_on_device->shape[1]; + ICHECK_EQ(request_ids.size(), num_samples); + ICHECK_EQ(generation_cfg.size(), num_samples); + + std::vector top_p_indices; + std::vector top_p_values; + for (int i = 0; i < num_samples; ++i) { + if (top_p_indices.empty() || top_p_indices.back() != sample_indices[i]) { + top_p_indices.push_back(sample_indices[i]); + top_p_values.push_back(generation_cfg[i]->top_p); + } else { + CHECK(fabs(top_p_values.back() - generation_cfg[i]->top_p) < eps_) + << "Sampler requires the top_p values for each prob distribution are the same."; + } + } + if (top_p_indices.empty()) { + // Return if no top p needs to apply. + return probs_on_host; } tvm::runtime::parallel_for_with_threading_backend( - [this, &sample_results, &probs_host, &generation_cfg, &rngs, &request_ids, sample_indices, - output_prob_dist](int i) { - RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); - // Sample top p from probability. - sample_results[i].sampled_token_id = SampleTopPFromProb( - probs_host, i, sample_indices[i], - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber(), output_prob_dist); - if (output_prob_dist == nullptr) { - // When `output_prob_dist` is not nullptr, it means right now - // we are sampling for a small model in speculation, in which - // case we do not need to get the top probs. - sample_results[i].top_prob_tokens = - ComputeTopProbs(probs_host, i, generation_cfg[i]->top_logprobs); - } - RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); + [this, &probs_on_host, &request_ids, &top_p_indices, &top_p_values](int i) { + RECORD_EVENT(this->trace_recorder_, request_ids[i], "start renormalize by top p"); + RenormalizeProbByTopP(probs_on_host, top_p_indices[i], top_p_values[i], eps_); + RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish renormalize by top p"); }, - 0, n); - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); - return sample_results; + 0, static_cast(top_p_indices.size())); + + return probs_on_host; } - std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, - const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { + std::vector BatchSampleTokensWithProbBeforeTopP( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs) final { // probs_on_device: (n, v) - RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); CHECK_EQ(probs_on_device->ndim, 2); // - Copy probs to CPU RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - NDArray probs_host = CopyProbsToCPU(probs_on_device); + NDArray probs_on_host = CopyProbsToCPU(probs_on_device); RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); + return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs, + /*top_p_applied=*/false); + } + + std::vector BatchSampleTokensWithProbAfterTopP( + NDArray probs_on_host, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + std::vector* output_prob_dist) final { + return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs, + /*top_p_applied=*/true, output_prob_dist); + } + + std::vector> BatchVerifyDraftTokensWithProbAfterTopP( + NDArray probs_on_host, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, + const std::vector>& draft_output_prob_dist) final { + // probs_on_host: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); + CHECK_EQ(probs_on_host->ndim, 2); + int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); @@ -337,8 +444,8 @@ class CPUSampler : public SamplerObj { sample_results.resize(num_sequence); float* __restrict global_p_probs = - static_cast(__builtin_assume_aligned(probs_host->data, 4)); - int vocab_size = probs_host->shape[1]; + static_cast(__builtin_assume_aligned(probs_on_host->data, 4)); + int vocab_size = probs_on_host->shape[1]; tvm::runtime::parallel_for_with_threading_backend( [&](int i) { @@ -355,7 +462,7 @@ class CPUSampler : public SamplerObj { if (p_value >= q_value) { sample_results[i].push_back( SampleResult{{cur_token, p_value}, - ComputeTopProbs(probs_host, verify_start + cur_token_idx, + ComputeTopProbs(probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs)}); continue; } @@ -363,7 +470,7 @@ class CPUSampler : public SamplerObj { if (r < p_value / (q_value + eps_)) { sample_results[i].push_back( SampleResult{{cur_token, p_value}, - ComputeTopProbs(probs_host, verify_start + cur_token_idx, + ComputeTopProbs(probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs)}); continue; } @@ -388,11 +495,10 @@ class CPUSampler : public SamplerObj { // sample a new token from the new distribution SampleResult sample_result; sample_result.sampled_token_id = SampleTopPFromProb( - probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber()); + probs_on_host, verify_start + cur_token_idx, verify_start + cur_token_idx, + /*top_p=*/1.0f, rngs[i]->GetRandomNumber()); sample_result.top_prob_tokens = ComputeTopProbs( - probs_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); + probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); sample_results[i].push_back(sample_result); break; } @@ -403,11 +509,10 @@ class CPUSampler : public SamplerObj { SampleResult sample_result; // sample a new token from the original distribution sample_result.sampled_token_id = SampleTopPFromProb( - probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber()); + probs_on_host, verify_start + cur_token_idx, verify_start + cur_token_idx, + /*top_p=*/1.0f, rngs[i]->GetRandomNumber()); sample_result.top_prob_tokens = ComputeTopProbs( - probs_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); + probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); sample_results[i].push_back(sample_result); } }, @@ -417,6 +522,56 @@ class CPUSampler : public SamplerObj { } private: + std::vector BatchSampleTokensImpl( + NDArray probs_on_host, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + bool top_p_applied, // + std::vector* output_prob_dist = nullptr) { + // probs_on_host: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + ICHECK_EQ(probs_on_host->ndim, 2); + ICHECK_EQ(probs_on_host->device.device_type, DLDeviceType::kDLCPU); + + // - Sample tokens from probabilities. + int n = request_ids.size(); + ICHECK_EQ(generation_cfg.size(), n); + ICHECK_EQ(rngs.size(), n); + + std::vector sample_results; + sample_results.resize(n); + if (output_prob_dist) { + output_prob_dist->resize(n); + } + + tvm::runtime::parallel_for_with_threading_backend( + [this, &sample_results, &probs_on_host, &generation_cfg, &rngs, &request_ids, top_p_applied, + sample_indices, output_prob_dist](int i) { + RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); + // Sample top p from probability. + double top_p = + top_p_applied + ? 1.0f + : (generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p); + sample_results[i].sampled_token_id = + SampleTopPFromProb(probs_on_host, i, sample_indices[i], top_p, + rngs[i]->GetRandomNumber(), output_prob_dist); + if (output_prob_dist == nullptr) { + // When `output_prob_dist` is not nullptr, it means right now + // we are sampling for a small model in speculation, in which + // case we do not need to get the top probs. + sample_results[i].top_prob_tokens = + ComputeTopProbs(probs_on_host, i, generation_cfg[i]->top_logprobs); + } + RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); + }, + 0, n); + RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + return sample_results; + } + /*! \brief Copy prob distributions from device to CPU. */ NDArray CopyProbsToCPU(NDArray probs_on_device) { // probs_on_device: (n, v) diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index b376523dac..58a27c24f7 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -43,12 +43,17 @@ class GPUSampler : public SamplerObj { gpu_argsort_probs_func_(ft->gpu_argsort_probs_func_), gpu_sample_with_top_p_func_(ft->gpu_sample_with_top_p_func_), gpu_sampler_take_probs_func_(ft->gpu_sampler_take_probs_func_), + gpu_verify_draft_tokens_func_(ft->gpu_verify_draft_tokens_func_), + gpu_renormalize_by_top_p_func_(ft->gpu_renormalize_by_top_p_func_), trace_recorder_(std::move(trace_recorder)) { ICHECK(gpu_multinomial_from_uniform_func_.defined()); ICHECK(gpu_argsort_probs_func_.defined()); ICHECK(gpu_sample_with_top_p_func_.defined()); ICHECK(gpu_sampler_take_probs_func_.defined()); + flashinfer_multinomial_sample_func_ = + Registry::Get("flashinfer.sampling.parallel_sampling_from_prob"); + DLDevice device_cpu{DLDeviceType::kDLCPU, /*device_id=*/0}; // We support at most 5 top prob results for each sequence. // Initialize auxiliary arrays on CPU. @@ -56,6 +61,10 @@ class GPUSampler : public SamplerObj { sample_indices_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); top_p_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); top_prob_offsets_host_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device_cpu); + draft_tokens_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + token_tree_first_child_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + token_tree_next_sibling_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + token_tree_parent_ptr_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); sampled_token_ids_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); sampled_probs_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); top_prob_probs_host_ = NDArray::Empty({max_num_sample * 5}, dtype_f32_, device_cpu); @@ -65,6 +74,12 @@ class GPUSampler : public SamplerObj { sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); + draft_probs_device_ = NDArray::Empty({max_num_sample, vocab_size}, dtype_f32_, device); + draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + token_tree_next_sibling_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + token_tree_parent_ptr_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + sampled_token_ids_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); // If the device is CUDA/ROCm, we create a standalone copy stream, in // purpose to hide the latency of auxiliary stream copy. @@ -83,20 +98,237 @@ class GPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray probs_on_device, // - const std::vector& sample_indices, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // - std::vector* output_prob_dist) final { - NVTXScopedRange nvtx_scope("BatchSampleTokens"); + NDArray BatchRenormalizeProbsByTopP(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg) final { + NVTXScopedRange nvtx_scope("BatchRenormalizeProbsByTopP"); + // probs_on_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start renormalization by top p"); + CHECK_EQ(probs_on_device->ndim, 2); + int num_samples = sample_indices.size(); + int num_probs = probs_on_device->shape[0]; + int vocab_size = probs_on_device->shape[1]; + ICHECK_LE(num_probs, max_num_sample_); + ICHECK_EQ(request_ids.size(), num_samples); + ICHECK_EQ(generation_cfg.size(), num_samples); + + // - Check if there is need for applying top p. + bool need_top_p = CheckTopP(generation_cfg, sample_indices, num_probs, num_samples, vocab_size); + if (!need_top_p) { + return probs_on_device; + } + + // - Argsort the probability. + Array argsort_results = gpu_argsort_probs_func_(probs_on_device); + ICHECK_EQ(argsort_results.size(), 2); + NDArray sorted_probs_on_device = argsort_results[0]; + NDArray sorted_indices_on_device = argsort_results[1]; + + // - Copy auxiliary array for top-p. + NDArray top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_); + NDArray top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_); + CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_); + SyncCopyStream(device_, compute_stream_, copy_stream_); + + // - Renormalize the prob with top p. + NDArray renormed_probs_on_device = + gpu_renormalize_by_top_p_func_(probs_on_device, sorted_probs_on_device, top_p_device); + + RECORD_EVENT(trace_recorder_, request_ids, "finish renormalization by top p"); + return renormed_probs_on_device; + } + + std::vector BatchSampleTokensWithProbBeforeTopP( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs) final { + NVTXScopedRange nvtx_scope("BatchSampleTokensWithProbBeforeTopP"); + return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids, + generation_cfg, rngs, /*top_p_applied=*/false); + } + + std::vector BatchSampleTokensWithProbAfterTopP( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + std::vector* output_prob_dist = nullptr) final { + NVTXScopedRange nvtx_scope("BatchSampleTokensWithProbAfterTopP"); + return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids, + generation_cfg, rngs, /*top_p_applied=*/true, output_prob_dist); + } + + std::vector> BatchVerifyDraftTokensWithProbAfterTopP( + NDArray probs_on_device, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, + const std::vector>& draft_output_prob_dist) final { + NVTXScopedRange nvtx_scope("BatchVerifyDraftTokensWithProbAfterTopP"); + std::vector> sample_results; + // probs_on_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); + CHECK_EQ(probs_on_device->ndim, 2); + + int num_sequence = static_cast(cum_verify_lengths.size()) - 1; + CHECK_EQ(rngs.size(), num_sequence); + CHECK_EQ(draft_output_tokens.size(), num_sequence); + CHECK_EQ(draft_output_prob_dist.size(), num_sequence); + sample_results.resize(num_sequence); + + int num_nodes = cum_verify_lengths.back(); + NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_nodes}, dtype_f32_); + NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_nodes}, dtype_f32_); + NDArray draft_probs_device = + draft_probs_device_.CreateView({num_nodes, vocab_size_}, dtype_f32_); + NDArray draft_tokens_host = draft_tokens_host_.CreateView({num_nodes}, dtype_i32_); + NDArray draft_tokens_device = draft_tokens_device_.CreateView({num_nodes}, dtype_i32_); + + // Concat draft prob distributions to a ragged tensor (num_nodes, vocab_size) + for (int i = 0; i < num_sequence; i++) { + const std::vector& draft_output_tokens_i = draft_output_tokens[i]; + const std::vector& draft_output_prob_dist_i = draft_output_prob_dist[i]; + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + // start/end is the range of the sequence i in probs_on_device, which includes the prob dist + // of the draft tokens and the last committed token + ICHECK_EQ(draft_output_tokens_i.size() + 1, end - start); + ICHECK_EQ(draft_output_prob_dist_i.size() + 1, end - start); + for (int j = 0; j < end - start - 1; j++) { + // Copy prob dist + ICHECK_EQ(draft_probs_device->dtype.bits, 32); + float* p_draft_probs = + static_cast(draft_probs_device->data) + + (j + start + 1) * + vocab_size_; // shift by one, q of the last committed token is undefined + // Copy sampled token id + draft_output_prob_dist_i[j].CopyToBytes(p_draft_probs, vocab_size_ * sizeof(float)); + *(static_cast(draft_tokens_host->data) + j + start + 1) = + draft_output_tokens_i[j].sampled_token_id.first; + } + } + CopyArray(draft_tokens_host, draft_tokens_device, copy_stream_); + + float* p_uniform_samples = static_cast(uniform_samples_host->data); + for (int i = 0; i < num_sequence; ++i) { + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + for (int j = start; j < end; j++) { + p_uniform_samples[j] = rngs[i]->GetRandomNumber(); + } + } + CopyArray(uniform_samples_host, uniform_samples_device, copy_stream_); + + NDArray token_tree_first_child_host = + token_tree_first_child_host_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_first_child_device = + token_tree_first_child_device_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_next_sibling_host = + token_tree_next_sibling_host_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_next_sibling_device = + token_tree_next_sibling_device_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_parent_ptr_host = + token_tree_parent_ptr_host_.CreateView({num_sequence}, dtype_i32_); + NDArray token_tree_parent_ptr_device = + token_tree_parent_ptr_device_.CreateView({num_sequence}, dtype_i32_); + std::vector token_tree_child_to_parent(/*n=*/num_nodes); + + // Build the tree structure on CPU + for (int i = 0; i < num_sequence; i++) { + // Assuming no tree structure for now + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + ICHECK_GE(end - start, 2); + token_tree_child_to_parent[start] = -1; // root has no parent + for (int j = 0; j < end - start; j++) { + int cur_node = j + start; + int child_node = j + 1 >= end - start ? -1 : cur_node + 1; + static_cast(token_tree_first_child_host->data)[cur_node] = child_node; + if (child_node != -1) { + token_tree_child_to_parent[child_node] = cur_node; + } + static_cast(token_tree_next_sibling_host->data)[cur_node] = -1; + } + static_cast(token_tree_parent_ptr_host->data)[i] = start; // point to the root + } + // Copy token tree structure to GPU + CopyArray(token_tree_first_child_host, token_tree_first_child_device, copy_stream_); + CopyArray(token_tree_next_sibling_host, token_tree_next_sibling_device, copy_stream_); + CopyArray(token_tree_parent_ptr_host, token_tree_parent_ptr_device, copy_stream_); + + SyncCopyStream(device_, compute_stream_, copy_stream_); + + gpu_verify_draft_tokens_func_(draft_probs_device, draft_tokens_device, probs_on_device, + token_tree_first_child_device, token_tree_next_sibling_device, + uniform_samples_device, token_tree_parent_ptr_device); + + CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, compute_stream_); + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); + + std::vector sample_indices; + + for (int i = 0; i < num_sequence; i++) { + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + int last_accepted = static_cast(token_tree_parent_ptr_host->data)[i]; + int num_accepted = 0; + for (int cur_node = last_accepted; cur_node != start; + cur_node = token_tree_child_to_parent[cur_node]) { + sample_results[i].push_back(draft_output_tokens[i][cur_node - start - 1]); + num_accepted++; + } + std::reverse(sample_results[i].rbegin(), sample_results[i].rbegin() + num_accepted); + sample_indices.push_back(last_accepted); + } + std::vector additional_sample_result; + additional_sample_result = this->BatchSampleTokensWithProbAfterTopP( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs); + ICHECK_EQ(additional_sample_result.size(), num_sequence); + for (int i = 0; i < num_sequence; i++) { + sample_results[i].push_back(additional_sample_result[i]); + } + + RECORD_EVENT(trace_recorder_, request_ids, "finish draft verification"); + return sample_results; + } + + private: + std::vector BatchSampleTokensImpl( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + bool top_p_applied, // + std::vector* output_prob_dist = nullptr) { // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); - CHECK(output_prob_dist == nullptr) << "GPU sampler does not support collecting output probs."; CHECK_EQ(probs_on_device->ndim, 2); + CHECK_EQ(probs_on_device->device.device_id, device_.device_id); + CHECK_EQ(probs_on_device->device.device_type, device_.device_type); int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; int vocab_size = probs_on_device->shape[1]; + if (output_prob_dist != nullptr) { + ICHECK(output_prob_dist->empty()); + output_prob_dist->reserve(num_samples); + for (int i = 0; i < num_samples; ++i) { + NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_); + float* p_prob = static_cast(probs_on_device->data) + sample_indices[i] * vocab_size; + prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float)); + output_prob_dist->push_back(std::move(prob_dist)); + } + } + if (num_samples == 0) { + // This synchronization is necessary for making sure that this round + // of model forward is finished. + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); + return {}; + } ICHECK_EQ(request_ids.size(), num_samples); ICHECK_EQ(generation_cfg.size(), num_samples); ICHECK_EQ(rngs.size(), num_samples); @@ -105,7 +337,8 @@ class GPUSampler : public SamplerObj { // we apply chunking to support large `num_samples`. std::vector sample_results; if (num_samples <= max_num_sample_) { - sample_results = ChunkSampleTokensImpl(probs_on_device, sample_indices, generation_cfg, rngs); + sample_results = ChunkSampleTokensImpl(probs_on_device, sample_indices, generation_cfg, rngs, + top_p_applied); } else { for (int chunk_start = 0; chunk_start < num_samples; chunk_start += max_num_sample_) { int chunk_end = std::min(chunk_start + max_num_sample_, num_samples); @@ -116,7 +349,7 @@ class GPUSampler : public SamplerObj { std::vector rngs_chunk(rngs.begin() + chunk_start, rngs.begin() + chunk_end); std::vector sample_results_chunk = ChunkSampleTokensImpl( - probs_on_device, sample_indices_chunk, generation_cfg_chunk, rngs_chunk); + probs_on_device, sample_indices_chunk, generation_cfg_chunk, rngs_chunk, top_p_applied); sample_results.insert(sample_results.end(), sample_results_chunk.begin(), sample_results_chunk.end()); } @@ -126,20 +359,11 @@ class GPUSampler : public SamplerObj { return sample_results; } - std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, - const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { - LOG(FATAL) << "GPU sampler does not support batch verification for now."; - } - - private: std::vector ChunkSampleTokensImpl(NDArray probs_on_device, // const std::vector& sample_indices, // const Array& generation_cfg, // - const std::vector& rngs) { + const std::vector& rngs, // + bool top_p_applied) { // probs_on_device: (n, v) int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; @@ -153,11 +377,13 @@ class GPUSampler : public SamplerObj { // - Check if there is need for applying top p or prob values, // so that argsort is needed. bool need_top_p = false; - bool need_prob_values = false; + if (!top_p_applied) { + need_top_p = CheckTopP(generation_cfg, sample_indices, num_probs, num_samples, vocab_size); + } // The indptr array of the number of top probs for each sample. std::vector top_prob_offset_indptr; - CheckTopPAndProbValues(generation_cfg, sample_indices, num_probs, num_samples, vocab_size, - &need_top_p, &need_prob_values, &top_prob_offset_indptr); + bool need_prob_values = CheckProbValues(generation_cfg, sample_indices, num_probs, num_samples, + vocab_size, &top_prob_offset_indptr); // - Sample tokens on GPU, and take out the probability values if needed. std::vector device_arrays = @@ -217,30 +443,39 @@ class GPUSampler : public SamplerObj { return {uniform_samples_device, sample_indices_device}; } - /*! \brief Check if top p and prob values are needed, and collect info when necessary. */ - void CheckTopPAndProbValues(const Array& generation_cfg, - const std::vector& sample_indices, int num_probs, - int num_samples, int vocab_size, bool* need_top_p, - bool* need_prob_values, std::vector* top_prob_offset_indptr) { - top_prob_offset_indptr->reserve(num_samples + 1); - top_prob_offset_indptr->push_back(0); + /*! \brief Check if top p is needed. Update host top p array in place. */ + bool CheckTopP(const Array& generation_cfg, + const std::vector& sample_indices, int num_probs, int num_samples, + int vocab_size) { // Initialize top p values with -1. float* p_top_p = static_cast(top_p_host_->data); for (int i = 0; i < num_probs; ++i) { p_top_p[i] = -1.0; } - int* p_top_prob_offsets = static_cast(top_prob_offsets_host_->data); - int num_top_probs = 0; + bool need_top_p = false; for (int i = 0; i < num_samples; ++i) { if (p_top_p[sample_indices[i]] == -1.0) { p_top_p[sample_indices[i]] = generation_cfg[i]->top_p; - *need_top_p |= generation_cfg[i]->top_p != 1.0; + need_top_p |= generation_cfg[i]->top_p != 1.0; } else { CHECK(fabs(p_top_p[sample_indices[i]] - generation_cfg[i]->top_p) < eps_) << "GPU sampler requires the top_p values for each prob distribution are the same."; } + } + return need_top_p; + } - *need_prob_values |= generation_cfg[i]->logprobs; + /*! \brief Check whether prob values are needed, and collect info when necessary. */ + bool CheckProbValues(const Array& generation_cfg, + const std::vector& sample_indices, int num_probs, int num_samples, + int vocab_size, std::vector* top_prob_offset_indptr) { + top_prob_offset_indptr->reserve(num_samples + 1); + top_prob_offset_indptr->push_back(0); + int* p_top_prob_offsets = static_cast(top_prob_offsets_host_->data); + int num_top_probs = 0; + bool need_prob_values = false; + for (int i = 0; i < num_samples; ++i) { + need_prob_values |= generation_cfg[i]->logprobs; for (int j = 0; j < generation_cfg[i]->top_logprobs; ++j) { p_top_prob_offsets[num_top_probs++] = sample_indices[i] * vocab_size + j; } @@ -248,6 +483,7 @@ class GPUSampler : public SamplerObj { generation_cfg[i]->top_logprobs); } ICHECK_EQ(num_top_probs, top_prob_offset_indptr->back()); + return need_prob_values; } /*! \brief Sample tokens on GPU. Take out the probability values when needed. */ @@ -263,8 +499,15 @@ class GPUSampler : public SamplerObj { if (!need_top_p && !need_prob_values) { // - Short path: If top_p and prob values are not needed, we directly sample from multinomial. SyncCopyStream(device_, compute_stream_, copy_stream_); - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device, top_prob_indices_device}; } @@ -299,8 +542,15 @@ class GPUSampler : public SamplerObj { uniform_samples_device, sample_indices_device, top_p_device); } else { // - Sample without top_p. - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } } if (need_prob_values) { @@ -354,7 +604,7 @@ class GPUSampler : public SamplerObj { } // Synchronize for CPU to get the correct array results. - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); return {sampled_token_ids_host, sampled_probs_host, top_prob_probs_host, top_prob_indices_host}; } @@ -370,11 +620,18 @@ class GPUSampler : public SamplerObj { PackedFunc gpu_argsort_probs_func_; PackedFunc gpu_sample_with_top_p_func_; PackedFunc gpu_sampler_take_probs_func_; + PackedFunc gpu_verify_draft_tokens_func_; + PackedFunc gpu_renormalize_by_top_p_func_; + const PackedFunc* flashinfer_multinomial_sample_func_; // Auxiliary NDArrays on CPU NDArray uniform_samples_host_; NDArray sample_indices_host_; NDArray top_p_host_; NDArray top_prob_offsets_host_; + NDArray draft_tokens_host_; + NDArray token_tree_first_child_host_; + NDArray token_tree_next_sibling_host_; + NDArray token_tree_parent_ptr_host_; NDArray sampled_token_ids_host_; NDArray sampled_probs_host_; NDArray top_prob_probs_host_; @@ -384,6 +641,12 @@ class GPUSampler : public SamplerObj { NDArray sample_indices_device_; NDArray top_p_device_; NDArray top_prob_offsets_device_; + NDArray draft_probs_device_; + NDArray draft_tokens_device_; + NDArray token_tree_first_child_device_; + NDArray token_tree_next_sibling_device_; + NDArray token_tree_parent_ptr_device_; + NDArray sampled_token_ids_device_; // The event trace recorder for requests. */ Optional trace_recorder_; // The device stream for the default computation operations. diff --git a/cpp/serve/sampler/sampler.h b/cpp/serve/sampler/sampler.h index 03d031bdb7..7943231e55 100644 --- a/cpp/serve/sampler/sampler.h +++ b/cpp/serve/sampler/sampler.h @@ -26,14 +26,33 @@ using namespace tvm::runtime; /*! * \brief The base class of runtime sampler. - * Its main function is `BatchSampleTokens`, which takes a batch of + * Its main function is `BatchSampleTokensWithProbBeforeTopP`, which takes a batch of * logits and corresponding configuration, and sample one token * for each instance of the batch. */ class SamplerObj : public Object { public: + /*! + * \brief Renormalize the input batch of probability distributions with top p values. + * \param probs_on_device The batch of prob distributions before normalization. + * \param sample_indices Specifying which request we will sample for + * in i-th output for the sampling later on. + * The output result of the sampling will be as follow: + * result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i])); + * For renormalization, the sample indices are used for determine the top-p grouping. + * \param request_ids The id of each request. + * \param generation_cfg The generation config of each request in the input batch. + * \return The renormalized probability distributions, residing on device + * if the sampler is GPU sampler, or on host if the sampler is CPU sampler. + */ + virtual NDArray BatchRenormalizeProbsByTopP(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg) = 0; + /*! * \brief Sample tokens from the input batch of prob distribution on device. + * The input prob distributions are not yet applied with top-p. * \param probs_on_device The prob distributions on GPU to sample tokens from. * \param sample_indices Specifying which request we should sample for * in i-th output. The output result is sample as follow: @@ -42,22 +61,46 @@ class SamplerObj : public Object { * \param generation_cfg The generation config of each request * in the input batch. * \param rngs The random number generator of each sequence. - * \param output_prob_dist The output probability distribution * \return The batch of sampling results, which contain the sampled token id * and other probability info. */ - virtual std::vector BatchSampleTokens( + virtual std::vector BatchSampleTokensWithProbBeforeTopP( NDArray probs_on_device, // const std::vector& sample_indices, // const Array& request_ids, // const Array& generation_cfg, // + const std::vector& rngs) = 0; + + /*! + * \brief Sample tokens from the input batch of prob distribution on device. + * The input prob distributions are already applied with top-p. + * \param probs The prob distributions. + * It resides on GPU if the sampler is GPU sampler, or on host if hte sampler is CPU sampler. + * \param sample_indices Specifying which request we should sample for + * in i-th output. The output result is sample as follow: + * result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i])); + * \param request_ids The id of each request. + * \param generation_cfg The generation config of each request + * in the input batch. + * \param rngs The random number generator of each sequence. + * \param output_prob_dist The output probability distribution + * \return The batch of sampling results, which contain the sampled token id + * and other probability info. + */ + virtual std::vector BatchSampleTokensWithProbAfterTopP( + NDArray probs, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // const std::vector& rngs, // std::vector* output_prob_dist = nullptr) = 0; /*! * \brief Verify draft tokens generated by small models in the large model * in speculative decoding. The input corresponds to a batch of sequences. - * \param probs_on_device The prob distributions on GPU to sample tokens from. + * The input prob distributions are already applied with top-p. + * \param probs The prob distributions on GPU to sample tokens from. + * It resides on GPU if the sampler is GPU sampler, or on host if hte sampler is CPU sampler. * \param request_ids The id of each request. * \param cum_verify_lengths The cumulative draft lengths to verify of all sequences. * \param generation_cfg The generation config of each request @@ -69,10 +112,9 @@ class SamplerObj : public Object { * small model for each sequence. * \return The list of accepted tokens for each request. */ - virtual std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, + virtual std::vector> BatchVerifyDraftTokensWithProbAfterTopP( + NDArray probs, const Array& request_ids, const std::vector& cum_verify_lengths, + const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_prob_dist) = 0; diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index 458d2ae5d7..2f6f77a3a0 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -29,37 +29,59 @@ enum class InstructionKind : int { kAbortRequest = 1, kUnloadEngine = 2, kReloadEngine = 3, - kDebugCallFuncOnAllAllWorker = 4, + kResetEngine = 4, + kDebugCallFuncOnAllAllWorker = 5, }; /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(EngineConfig engine_config, - Optional request_stream_callback, + void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) final { + device_ = device; CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; request_stream_callback_ = request_stream_callback.value(); + trace_recorder_ = trace_recorder; + } - auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 1); - Array delta_outputs = args[0]; - bool need_notify = false; - { - std::lock_guard lock(request_stream_callback_mutex_); - request_stream_callback_inputs_.push_back(std::move(delta_outputs)); - ++pending_request_stream_callback_cnt_; - need_notify = stream_callback_waiting_; - } - if (need_notify) { - request_stream_callback_cv_.notify_one(); - } - }; + void Reload(EngineConfig engine_config) final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kReloadEngine, std::move(engine_config)); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } + + void Unload() final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kUnloadEngine, ObjectRef(nullptr)); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } - request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - background_engine_ = Engine::Create( - std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); + void Reset() final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kResetEngine, ObjectRef(nullptr)); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } } void AddRequest(Request request) final { @@ -97,7 +119,8 @@ class ThreadedEngineImpl : public ThreadedEngine { std::unique_lock lock(background_loop_mutex_); engine_waiting_ = true; background_loop_cv_.wait(lock, [this] { - return !background_engine_->Empty() || pending_request_operation_cnt_.load() > 0 || + return (background_engine_ != nullptr && !background_engine_->Empty()) || + pending_request_operation_cnt_.load() > 0 || exit_now_.load(std::memory_order_relaxed); }); engine_waiting_ = false; @@ -108,22 +131,30 @@ class ThreadedEngineImpl : public ThreadedEngine { } for (const auto& [kind, arg] : local_instruction_queue) { if (kind == InstructionKind::kAddRequest) { + CHECK(background_engine_ != nullptr) << "Background engine is not loaded."; background_engine_->AddRequest(Downcast(arg)); } else if (kind == InstructionKind::kAbortRequest) { + CHECK(background_engine_ != nullptr) << "Background engine is not loaded."; background_engine_->AbortRequest(Downcast(arg)); } else if (kind == InstructionKind::kUnloadEngine) { - // Todo(mlc-team): implement engine unload - LOG(FATAL) << "Not implemented yet."; + EngineUnloadImpl(); } else if (kind == InstructionKind::kReloadEngine) { - // Todo(mlc-team): implement engine reload - LOG(FATAL) << "Not implemented yet."; + EngineUnloadImpl(); + EngineReloadImpl(Downcast(arg)); + } else if (kind == InstructionKind::kResetEngine) { + if (background_engine_ != nullptr) { + background_engine_->Reset(); + } } else if (kind == InstructionKind::kDebugCallFuncOnAllAllWorker) { + CHECK(background_engine_ != nullptr) << "Background engine is not loaded."; background_engine_->DebugCallFuncOnAllAllWorker(Downcast(arg)); } else { LOG(FATAL) << "Cannot reach here"; } } - background_engine_->Step(); + if (background_engine_ != nullptr) { + background_engine_->Step(); + } } } @@ -184,10 +215,47 @@ class ThreadedEngineImpl : public ThreadedEngine { } private: + void EngineReloadImpl(EngineConfig engine_config) { + auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + bool need_notify = false; + { + std::lock_guard lock(request_stream_callback_mutex_); + request_stream_callback_inputs_.push_back(std::move(delta_outputs)); + ++pending_request_stream_callback_cnt_; + need_notify = stream_callback_waiting_; + } + if (need_notify) { + request_stream_callback_cv_.notify_one(); + } + }; + + Optional request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); + background_engine_ = Engine::Create(std::move(engine_config), device_, + std::move(request_stream_callback), trace_recorder_); + } + + void EngineUnloadImpl() { + if (background_engine_ != nullptr) { + background_engine_->AbortAllRequests(); + background_engine_ = nullptr; + // Clear the allocated memory in cached memory pool. + const PackedFunc* fclear_memory_manager = + tvm::runtime::Registry::Get("vm.builtin.memory_manager.clear"); + ICHECK(fclear_memory_manager) << "Cannot find env function vm.builtin.memory_manager.clear"; + (*fclear_memory_manager)(); + } + } + + /*! \brief The device to run models on. */ + Device device_; /*! \brief The background normal engine for request processing. */ std::unique_ptr background_engine_; /*! \brief The request stream callback. */ PackedFunc request_stream_callback_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; /*! \brief The mutex ensuring only one thread can access critical regions. */ std::mutex background_loop_mutex_; @@ -237,6 +305,7 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.serve.async_threaded_engine"); TVM_MODULE_VTABLE_ENTRY("init_background_engine", &ThreadedEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("reload", &ThreadedEngineImpl::Reload); TVM_MODULE_VTABLE_ENTRY("add_request", &ThreadedEngineImpl::AddRequest); TVM_MODULE_VTABLE_ENTRY("abort_request", &ThreadedEngineImpl::AbortRequest); TVM_MODULE_VTABLE_ENTRY("run_background_loop", &ThreadedEngineImpl::RunBackgroundLoop); diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index 3d11ba36f1..49ba8f2175 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -35,14 +35,25 @@ class ThreadedEngine { /*! * \brief Initialize the threaded engine from packed arguments in TVMArgs. - * \param engine_config The engine config. + * \param device The device where to run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. */ - virtual void InitBackgroundEngine(EngineConfig engine_config, - Optional request_stream_callback, + virtual void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) = 0; + /*! + * \brief Reload the engine with the new engine config. + * \param engine_config The engine config. + */ + virtual void Reload(EngineConfig engine_config) = 0; + + /*! \brief Unload the background engine. */ + virtual void Unload() = 0; + + /*! \brief Reset the engine to the initial state. */ + virtual void Reset() = 0; + /*! \brief Starts the background request processing loop. */ virtual void RunBackgroundLoop() = 0; diff --git a/cpp/support/utils.h b/cpp/support/utils.h index 5360f0496c..6c53e35715 100644 --- a/cpp/support/utils.h +++ b/cpp/support/utils.h @@ -10,6 +10,7 @@ namespace mlc { namespace llm { +/*! \brief Split the input string by the given delimiter character. */ inline std::vector Split(const std::string& str, char delim) { std::string item; std::istringstream is(str); @@ -20,5 +21,21 @@ inline std::vector Split(const std::string& str, char delim) { return ret; } +/*! + * \brief Check whether the string starts with a given prefix. + * \param str The given string. + * \param prefix The given prefix. + * \return Whether the prefix matched. + */ +inline bool StartsWith(const std::string& str, const char* prefix) { + size_t n = str.length(); + for (size_t i = 0; i < n; i++) { + if (prefix[i] == '\0') return true; + if (str.data()[i] != prefix[i]) return false; + } + // return true if the str is equal to the prefix + return prefix[n] == '\0'; +} + } // namespace llm } // namespace mlc diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 00beb5cc4d..4706e09811 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -235,7 +235,7 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con RuntimeError: Cannot find libraries: wasm_runtime.bc .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. Otherwise, you may run into issues like: .. code:: text @@ -664,7 +664,7 @@ generalized to any model variant, as long as mlc-llm supports the architecture. RuntimeError: Cannot find libraries: wasm_runtime.bc .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. Otherwise, you may run into issues like: .. code:: text @@ -793,7 +793,7 @@ generalized to any model variant, as long as mlc-llm supports the architecture. RuntimeError: Cannot find libraries: wasm_runtime.bc .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. Otherwise, you may run into issues like: .. code:: text diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index f341e31e71..a7ebe28d6d 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -54,15 +54,15 @@ To run a model with MLC LLM in any platform, you can either: **Option 1: Use model prebuilts** To run ``mlc_llm``, you can specify the Huggingface MLC prebuilt model repo path with the prefix ``HF://``. -For example, to run the MLC Llama 2 7B Q4F16_1 model (`Repo link `_), -simply use ``HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC``. The model weights and library will be downloaded +For example, to run the MLC Llama 3 8B Q4F16_1 model (`Repo link `_), +simply use ``HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC``. The model weights and library will be downloaded automatically from Huggingface. .. code:: shell - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 -.. code:: shell +.. code:: You can use the following special commands: /help print the special commands @@ -74,13 +74,11 @@ automatically from Huggingface. Note: Separate stop words in the `stop` option with commas (,). Multi-line input: Use escape+enter to start a new line. - [INST]: What's the meaning of life - [/INST]: - Ah, a question that has puzzled philosophers and theologians for centuries! The meaning - of life is a deeply personal and subjective topic, and there are many different - perspectives on what it might be. However, here are some possible answers that have been - proposed by various thinkers and cultures: - ... + user: What's the meaning of life + assistant: + What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + + The concept of the meaning of life has been debated and... **Option 2: Use locally compiled model weights and libraries** diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index c0217db9e9..75a5cdbdc7 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -341,10 +341,24 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con mlc_llm gen_config ./dist/models/phi-2/ \ --quantization q4f16_1 --conv-template phi-2 \ -o dist/phi-2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json + # 2. mkdir: create a directory to store the compiled model library + mkdir -p dist/libs + # 3. compile: compile model library with specification in mlc-chat-config.json mlc_llm compile ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json \ --device iphone -o dist/libs/phi-2-q4f16_1-iphone.tar +Given the compiled library, it is possible to calculate an upper bound for the VRAM +usage during runtime. This useful to better understand if a model is able to fit particular +hardware. +That information will be displayed at the end of the console log when the ``compile`` is executed. +It might look something like this: + +.. code:: shell + + [2024-04-25 03:19:56] INFO model_metadata.py:96: Total memory usage: 1625.73 MB (Parameters: 1492.45 MB. KVCache: 0.00 MB. Temporary buffer: 133.28 MB) + [2024-04-25 03:19:56] INFO model_metadata.py:105: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` + [2024-04-25 03:19:56] INFO compile.py:198: Generated: dist/libs/phi-2-q4f16_1-iphone.tar + .. note:: When compiling larger models like ``Llama-2-7B``, you may want to add a lower chunk size while prefilling prompts ``--prefill_chunk_size 128`` or even lower ``context_window_size``\ @@ -388,21 +402,7 @@ This would result in something like `phi-2-q4f16_1-MLC `_. -**Step 4. Calculate estimated VRAM usage** - -Given the compiled library, it is possible to calculate an upper bound for the VRAM -usage during runtime. This useful to better understand if a model is able to fit particular -hardware. We can calculate this estimate using the following command: - -.. code:: shell - - ~/mlc-llm > python -m mlc_llm.cli.model_metadata ./dist/libs/phi-2-q4f16_1-iphone.tar \ - > --memory-only --mlc-chat-config ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json - INFO model_metadata.py:90: Total memory usage: 3042.96 MB (Parameters: 1492.45 MB. KVCache: 640.00 MB. Temporary buffer: 910.51 MB) - INFO model_metadata.py:99: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` - - -**Step 5. Register as a ModelRecord** +**Step 4. Register as a ModelRecord** Finally, we update the code snippet for `app-config.json `__ diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst index c5d9a072a7..89c60ac422 100644 --- a/docs/deploy/python_engine.rst +++ b/docs/deploy/python_engine.rst @@ -4,12 +4,261 @@ Python API ========== .. note:: - This page introduces the Python API with LLMEngine in MLC LLM. - If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, - please go to :ref:`deploy-python-chat-module` + This page introduces the Python API with MLCEngine in MLC LLM. + If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, + please go to :ref:`deploy-python-chat-module` .. contents:: Table of Contents - :local: - :depth: 2 + :local: + :depth: 2 -🚧 Under construction... + +MLC LLM provides Python API through classes :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine` +which **support full OpenAI API completeness** for easy integration into other Python projects. + +This page introduces how to use the engines in MLC LLM. +The Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via +the :ref:`installation page `. + + +Verify Installation +------------------- + +.. code:: bash + + python -c "from mlc_llm import MLCEngine; print(MLCEngine)" + +You are expected to see the output of ````. + +If the command above results in error, follow :ref:`install-mlc-packages` to install prebuilt pip +packages or build MLC LLM from source. + + +Run MLCEngine +------------- + +:class:`mlc_llm.MLCEngine` provides the interface of OpenAI chat completion synchronously. +:class:`mlc_llm.MLCEngine` does not batch concurrent request due to the synchronous design, +and please use :ref:`AsyncMLCEngine ` for request batching process. + +**Stream Response.** In :ref:`quick-start` and :ref:`introduction-to-mlc-llm`, +we introduced the basic use of :class:`mlc_llm.MLCEngine`. + +.. code:: python + + from mlc_llm import MLCEngine + + # Create engine + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + engine = MLCEngine(model) + + # Run chat completion in OpenAI API. + for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + print("\n") + + engine.terminate() + +This code example first creates an :class:`mlc_llm.MLCEngine` instance with the 8B Llama-3 model. +**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.MLCEngine` in the same way of using +`OpenAI's Python package `_ +for both synchronous and asynchronous generation. + +**Non-stream Response.** The code example above uses the synchronous chat completion +interface and iterate over all the stream responses. +If you want to run without streaming, you can run + +.. code:: python + + response = engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=False, + ) + print(response) + +Please refer to `OpenAI's Python package `_ +and `OpenAI chat completion API `_ +for the complete chat completion interface. + + +.. _python-engine-async-llm-engine: + +Run AsyncMLCEngine +------------------ + +:class:`mlc_llm.AsyncMLCEngine` provides the interface of OpenAI chat completion with +asynchronous features. +**We recommend using** :class:`mlc_llm.AsyncMLCEngine` **to batch concurrent request for better throughput.** + +**Stream Response.** The core use of :class:`mlc_llm.AsyncMLCEngine` for stream responses is as follows. + +.. code:: python + + async for response in await engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + +.. collapse:: The collapsed is a complete runnable example of AsyncMLCEngine in Python. + + .. code:: python + + import asyncio + from typing import Dict + + from mlc_llm.serve import AsyncMLCEngine + + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + prompts = [ + "Write a three-day travel plan to Pittsburgh.", + "What is the meaning of life?", + ] + + + async def test_completion(): + # Create engine + async_engine = AsyncMLCEngine(model=model) + + num_requests = len(prompts) + output_texts: Dict[str, str] = {} + + async def generate_task(prompt: str): + async for response in await async_engine.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + model=model, + stream=True, + ): + if response.id not in output_texts: + output_texts[response.id] = "" + output_texts[response.id] += response.choices[0].delta.content + + tasks = [asyncio.create_task(generate_task(prompts[i])) for i in range(num_requests)] + await asyncio.gather(*tasks) + + # Print output. + for request_id, output in output_texts.items(): + print(f"Output of request {request_id}:\n{output}\n") + + async_engine.terminate() + + + asyncio.run(test_completion()) + +| + +**Non-stream Response.** Similarly, :class:`mlc_llm.AsyncEngine` provides the non-stream response +interface. + +.. code:: python + + response = await engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=False, + ) + print(response) + +Please refer to `OpenAI's Python package `_ +and `OpenAI chat completion API `_ +for the complete chat completion interface. + + +Engine Mode +----------- + +To ease the engine configuration, the constructors of :class:`mlc_llm.MLCEngine` and +:class:`mlc_llm.AsyncMLCEngine` have an optional argument ``mode``, +which falls into one of the three options ``"local"``, ``"interactive"`` or ``"server"``. +The default mode is ``"local"``. + +Each mode denotes a pre-defined configuration of the engine to satisfy different use cases. +The choice of the mode controls the request concurrency of the engine, +as well as engine's KV cache token capacity (or in other words, the maximum +number of tokens that the engine's KV cache can hold), +and further affects the GPU memory usage of the engine. + +In short, + +- mode ``"local"`` uses low request concurrency and low KV cache capacity, which is suitable for cases where **concurrent requests are not too many, and the user wants to save GPU memory usage**. +- mode ``"interactive"`` uses 1 as the request concurrency and low KV cache capacity, which is designed for **interactive use cases** such as chats and conversations. +- mode ``"server"`` uses as much request concurrency and KV cache capacity as possible. This mode aims to **fully utilize the GPU memory for large server scenarios** where concurrent requests may be many. + +**For system benchmark, please select mode** ``"server"``. +Please refer to :ref:`python-engine-api-reference` for detailed documentation of the engine mode. + + +Deploy Your Own Model with Python API +------------------------------------- + +The :ref:`introduction page ` introduces how we can deploy our +own models with MLC LLM. +This section introduces how you can use the model weights you convert and the model library you build +in :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine`. + +We use the `Phi-2 `_ as the example model. + +**Specify Model Weight Path.** Assume you have converted the model weights for your own model, +you can construct a :class:`mlc_llm.MLCEngine` as follows: + +.. code:: python + + from mlc_llm import MLCEngine + + model = "models/phi-2" # Assuming the converted phi-2 model weights are under "models/phi-2" + engine = MLCEngine(model) + + +**Specify Model Library Path.** Further, if you build the model library on your own, +you can use it in :class:`mlc_llm.MLCEngine` by passing the library path through argument ``model_lib_path``. + +.. code:: python + + from mlc_llm import MLCEngine + + model = "models/phi-2" + model_lib_path = "models/phi-2/lib.so" # Assuming the phi-2 model library is built at "models/phi-2/lib.so" + engine = MLCEngine(model, model_lib_path=model_lib_path) + + +The same applies to :class:`mlc_llm.AsyncMLCEngine`. + + +.. _python-engine-api-reference: + +API Reference +------------- + +The :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine` classes provide the following constructors. + +The MLCEngine and AsyncMLCEngine have full OpenAI API completeness. +Please refer to `OpenAI's Python package `_ +and `OpenAI chat completion API `_ +for the complete chat completion interface. + +.. currentmodule:: mlc_llm + +.. autoclass:: MLCEngine + :members: + :exclude-members: evaluate + :undoc-members: + :show-inheritance: + + .. automethod:: __init__ + +.. autoclass:: AsyncMLCEngine + :members: + :exclude-members: evaluate + :undoc-members: + :show-inheritance: + + .. automethod:: __init__ diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index e59abc1257..07d39dbfad 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -1,10 +1,6 @@ .. _deploy-rest-api: -<<<<<<< HEAD -Rest API -======= REST API ->>>>>>> upstream/main ======== .. contents:: Table of Contents @@ -17,18 +13,6 @@ for a user to interact with MLC-LLM in their own programs. Install MLC-LLM Package ------------------------ -<<<<<<< HEAD -SERVE is a part of the MLC-Chat package, installation instruction for which we be found here :doc:`<../install/mlc_llm>`. - -Verify Installation -^^^^^^^^^^^^^^^^^^^ - -.. code:: bash - - python -m mlc_llm.serve.server --help - -You are expected to see the help information of the MLC SERVE. -======= SERVE is a part of the MLC-LLM package, installation instruction for which can be found :ref:`here `. Once you have install the MLC-LLM package, you can run the following command to check if the installation was successful: .. code:: bash @@ -36,13 +20,10 @@ SERVE is a part of the MLC-LLM package, installation instruction for which can b mlc_llm serve --help You should see serve help message if the installation was successful. ->>>>>>> upstream/main Quick start ------------ -<<<<<<< HEAD -======= This section provides a quick start guide to work with MLC-LLM REST API. To launch a server, run the following command: .. code:: bash @@ -77,24 +58,15 @@ Once you have launched the Server, you can use the API in your own program to se .. _rest_launch_server: ->>>>>>> upstream/main Launch the Server ----------------- -<<<<<<< HEAD -To launch the MLC Server for MLC-Chat, run the following command in your terminal. - -.. code:: bash - - python -m mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] -======= To launch the MLC Server for MLC-LLM, run the following command in your terminal. .. code:: bash mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] ->>>>>>> upstream/main MODEL The model folder after compiling with MLC-LLM build process. The parameter can either be the model name with its quantization scheme @@ -131,11 +103,7 @@ The REST API provides the following endpoints: ------------------------------------------------ -<<<<<<< HEAD - Get a list of models available for MLC-Chat. -======= Get a list of models available for MLC-LLM. ->>>>>>> upstream/main **Example** @@ -153,118 +121,8 @@ The REST API provides the following endpoints: print(response.json()) else: print("Error:", response.status_code) -<<<<<<< HEAD -.. http:post:: /v1/chat/completions -======= ->>>>>>> upstream/main - - -<<<<<<< HEAD - Get a response from MLC-Chat using a prompt, either with or without streaming. - -**Chat Completion Request Object** - -- **messages** (*List[ChatCompletionMessage]*, required): A sequence of messages that have been exchanged in the conversation so far. Each message in the conversation is represented by a `ChatCompletionMessage` object, which includes the following fields: - - **content** (*Optional[Union[str, List[Dict[str, str]]]]*): The text content of the message or structured data in case of tool-generated messages. - - **role** (*Literal["system", "user", "assistant", "tool"]*): The role of the message sender, indicating whether the message is from the system, user, assistant, or a tool. - - **name** (*Optional[str]*): An optional name for the sender of the message. - - **tool_calls** (*Optional[List[ChatToolCall]]*): A list of calls to external tools or functions made within this message, applicable when the role is `tool`. - - **tool_call_id** (*Optional[str]*): A unique identifier for the tool call, relevant when integrating external tools or services. - -- **model** (*str*, required): The model to be used for generating responses. - -- **frequency_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model’s likelihood to repeat tokens. - -- **presence_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens if they are already present in the text so far, decreasing the model’s likelihood to repeat tokens. - -- **logprobs** (*bool*, optional, default=False): Indicates whether to include log probabilities for each token in the response. - -- **top_logprobs** (*int*, optional, default=0): An integer ranging from 0 to 5. It determines the number of tokens, most likely to appear at each position, to be returned. Each token is accompanied by a log probability. If this parameter is used, 'logprobs' must be set to true. - -- **logit_bias** (*Optional[Dict[int, float]]*): Allows specifying biases for or against specific tokens during generation. - -- **max_tokens** (*Optional[int]*): The maximum number of tokens to generate in the response(s). - -- **n** (*int*, optional, default=1): Number of responses to generate for the given prompt. - -- **seed** (*Optional[int]*): A seed for deterministic generation. Using the same seed and inputs will produce the same output. - -- **stop** (*Optional[Union[str, List[str]]]*): One or more strings that, if encountered, will cause generation to stop. - -- **stream** (*bool*, optional, default=False): If `True`, responses are streamed back as they are generated. - -- **temperature** (*float*, optional, default=1.0): Controls the randomness of the generation. Lower values lead to less random completions. - -- **top_p** (*float*, optional, default=1.0): Nucleus sampling parameter that controls the diversity of the generated responses. - -- **tools** (*Optional[List[ChatTool]]*): Specifies external tools or functions that can be called as part of the chat. - -- **tool_choice** (*Optional[Union[Literal["none", "auto"], Dict]]*): Controls how tools are selected for use in responses. - -- **user** (*Optional[str]*): An optional identifier for the user initiating the request. - -- **ignore_eos** (*bool*, optional, default=False): If `True`, the model will ignore the end-of-sequence token for generating responses. - -- **response_format** (*RequestResponseFormat*, optional): Specifies the format of the response. Can be either "text" or "json_object", with optional schema definition for JSON responses. - -**Returns** - -- If `stream` is `False`, a `ChatCompletionResponse` object containing the generated response(s). -- If `stream` is `True`, a stream of `ChatCompletionStreamResponse` objects, providing a real-time feed of generated responses. - - -**ChatCompletionResponseChoice** - -- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls", "error"]]*, optional): The reason the completion process was terminated. It can be due to reaching a stop condition, the maximum length, output of tool calls, or an error. - -- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. - -- **message** (*ChatCompletionMessage*, required): The message part of the chat completion, containing the content of the chat response. - -- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token - -**ChatCompletionStreamResponseChoice** - -- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls"]]*, optional): Specifies why the streaming completion process ended. Valid reasons are "stop", "length", and "tool_calls". - -- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. - -- **delta** (*ChatCompletionMessage*, required): Represents the incremental update or addition to the chat completion message in the stream. - -- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token - -**ChatCompletionResponse** - -- **id** (*str*, required): A unique identifier for the chat completion session. - -- **choices** (*List[ChatCompletionResponseChoice]*, required): A collection of `ChatCompletionResponseChoice` objects, representing the potential responses generated by the model. - -- **created** (*int*, required, default=current time): The UNIX timestamp representing when the response was generated. - -- **model** (*str*, required): The name of the model used to generate the chat completions. - -- **system_fingerprint** (*str*, required): A system-generated fingerprint that uniquely identifies the computational environment. - -- **object** (*Literal["chat.completion"]*, required, default="chat.completion"): A string literal indicating the type of object, here always "chat.completion". - -- **usage** (*UsageInfo*, required, default=empty `UsageInfo` object): Contains information about the API usage for this specific request. - -**ChatCompletionStreamResponse** - -- **id** (*str*, required): A unique identifier for the streaming chat completion session. - -- **choices** (*List[ChatCompletionStreamResponseChoice]*, required): A list of `ChatCompletionStreamResponseChoice` objects, each representing a part of the streaming chat response. - -- **created** (*int*, required, default=current time): The creation time of the streaming response, represented as a UNIX timestamp. - -- **model** (*str*, required): Specifies the model that was used for generating the streaming chat completions. - -- **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions. - -- **object** (*Literal["chat.completion.chunk"]*, required, default="chat.completion.chunk"): A literal indicating that this object represents a chunk of a streaming chat completion. -======= .. http:post:: /v1/chat/completions ------------------------------------------------ @@ -343,7 +201,6 @@ The REST API provides the following endpoints: - **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token **ChatCompletionResponse** ->>>>>>> upstream/main - **id** (*str*, required): A unique identifier for the chat completion session. @@ -359,10 +216,8 @@ The REST API provides the following endpoints: - **usage** (*UsageInfo*, required, default=empty `UsageInfo` object): Contains information about the API usage for this specific request. +**ChatCompletionStreamResponse** -<<<<<<< HEAD -**Example** -======= - **id** (*str*, required): A unique identifier for the streaming chat completion session. - **choices** (*List[ChatCompletionStreamResponseChoice]*, required): A list of `ChatCompletionStreamResponseChoice` objects, each representing a part of the streaming chat response. @@ -374,69 +229,14 @@ The REST API provides the following endpoints: - **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions. - **object** (*Literal["chat.completion.chunk"]*, required, default="chat.completion.chunk"): A literal indicating that this object represents a chunk of a streaming chat completion. ->>>>>>> upstream/main - -Once you have launched the Server, you can use the API in your own program. Below is an example of using the API to interact with MLC-Chat in Python without Streaming (suppose the server is running on ``http://127.0.0.1:8080/``): -.. code:: bash +------------------------------------------------ -<<<<<<< HEAD - import requests - - # Get a response using a prompt without streaming - payload = { - "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", - "messages": [ - {"role": "user", "content": "Hello! Our project is MLC LLM."}, - { - "role": "assistant", - "content": "Hello! It's great to hear about your project, MLC LLM.", - }, - {"role": "user", "content": "What is the name of our project?"}, - ], - "stream": False, - # "n": 1, - "max_tokens": 300, - } - r = requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload) - choices = r.json()["choices"] - for choice in choices: - print(f"{choice['message']['content']}\n") -======= **Example** ->>>>>>> upstream/main Below is an example of using the API to interact with MLC-LLM in Python with Streaming. -<<<<<<< HEAD -Below is an example of using the API to interact with MLC-Chat in Python with Streaming. - -.. code:: bash - - import requests - import json - - # Get a response using a prompt with streaming - payload = { - "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", - "messages": [{"role": "user", "content": "Write a haiku"}], - "stream": True, - } - with requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload, stream=True) as r: - for chunk in r.iter_content(chunk_size=None): - chunk = chunk.decode("utf-8") - if "[DONE]" in chunk[6:]: - break - response = json.loads(chunk[6:]) - content = response["choices"][0]["delta"].get("content", "") - print(content, end="", flush=True) - print("\n") - ------------------------------------------------- - - -======= .. code:: bash import requests @@ -460,7 +260,6 @@ Below is an example of using the API to interact with MLC-Chat in Python with St ------------------------------------------------ ->>>>>>> upstream/main There is also support for function calling similar to OpenAI (https://platform.openai.com/docs/guides/function-calling). Below is an example on how to use function calling in Python. .. code:: bash diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index 282b4764c2..29060d5a60 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -32,12 +32,12 @@ You are expected to see the installation path of MLC LLM Python package. Chat CLI -------- -As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 7B Llama-2 model. +As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. You can run MLC chat through a one-liner command: .. code:: bash - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC It may take 1-2 minutes for the first time running this command. After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. @@ -54,17 +54,19 @@ After waiting, this command launch a chat interface where you can enter your pro Note: Separate stop words in the `stop` option with commas (,). Multi-line input: Use escape+enter to start a new line. - [INST]: What's the meaning of life? - [/INST]: - Ah, a question that has puzzled philosophers and theologians for centuries! ... + user: What's the meaning of life + assistant: + What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + + The concept of the meaning of life has been debated and... The figure below shows what run under the hood of this chat CLI command. For the first time running the command, there are three major phases. -- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-2 model from `Hugging Face `_ and saves it to your local cache directory. -- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-2 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM `_ compiler, and generate the binary model library that enables the execution language models on your local GPU. -- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-2 model. +- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-3 model from `Hugging Face `_ and saves it to your local cache directory. +- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-3 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM `_ compiler, and generate the binary model library that enables the execution language models on your local GPU. +- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-3 model. We cache the pre-quantized model weights and compiled model library locally. Therefore, phase 1 and 2 will only execute **once** over multiple runs. @@ -83,16 +85,16 @@ Therefore, phase 1 and 2 will only execute **once** over multiple runs. Python API ---------- -In the second example, we run the Llama-2 model with the chat completion Python API of MLC LLM. +In the second example, we run the Llama-3 model with the chat completion Python API of MLC LLM. You can save the code below into a Python file and run it. .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # Create engine - model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" - engine = LLMEngine(model) + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( @@ -112,9 +114,9 @@ You can save the code below into a Python file and run it. MLC LLM Python API -This code example first creates an :class:`mlc_llm.LLMEngine` instance with the the 4-bit quantized Llama-2 model. -**We design the Python API** :class:`mlc_llm.LLMEngine` **to align with OpenAI API**, -which means you can use :class:`mlc_llm.LLMEngine` in the same way of using +This code example first creates an :class:`mlc_llm.MLCEngine` instance with the 4-bit quantized Llama-3 model. +**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.MLCEngine` in the same way of using `OpenAI's Python package `_ for both synchronous and asynchronous generation. @@ -132,17 +134,17 @@ If you want to run without streaming, you can run print(response) You can also try different arguments supported in `OpenAI chat completion API `_. -If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncLLMEngine` instead. +If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncMLCEngine` instead. REST Server ----------- -For the third example, we launch a REST server to serve the 4-bit quantized Llama-2 model +For the third example, we launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. The server can be launched in command line with .. code:: bash - mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC The server is hooked at ``http://127.0.0.1:8000`` by default, and you can use ``--host`` and ``--port`` to set a different host and port. @@ -154,7 +156,7 @@ we can open a new shell and send a cURL request via the following command: curl -X POST \ -H "Content-Type: application/json" \ -d '{ - "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", "messages": [ {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} ] @@ -165,6 +167,7 @@ The server will process this request and send back the response. Similar to :ref:`introduction-to-mlc-llm-python-api`, you can pass argument ``"stream": true`` to request for stream responses. +.. _introduction-deploy-your-own-model: Deploy Your Own Model --------------------- @@ -226,7 +229,7 @@ You can also use this model in Python API, MLC serve and other use scenarios. (Optional) Compile Model Library ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In previous sections, model libraries are compiled when the :class:`mlc_llm.LLMEngine` launches, +In previous sections, model libraries are compiled when the :class:`mlc_llm.MLCEngine` launches, which is what we call "JIT (Just-in-Time) model compilation". In some cases, it is beneficial to explicitly compile the model libraries. We can deploy LLMs with reduced dependencies by shipping the library for deployment without going through compilation. @@ -254,12 +257,12 @@ At runtime, we need to specify this model library path to use it. For example, .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # For Python API model = "models/phi-2" model_lib_path = "models/phi-2/lib.so" - engine = LLMEngine(model, model_lib_path=model_lib_path) + engine = MLCEngine(model, model_lib_path=model_lib_path) :ref:`compile-model-libraries` introduces the model compilation command in detail, where you can find instructions and example commands to compile model to different @@ -280,7 +283,7 @@ environments (e.g. SteamDeck). .. code:: bash - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device vulkan + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device vulkan The same core LLM runtime engine powers all the backends, enabling the same model to be deployed across backends as long as they fit within the memory and computing budget of the corresponding hardware backend. @@ -298,7 +301,7 @@ To briefly summarize this page, - We went through three examples (chat CLI, Python API, and REST server) of MLC LLM, - we introduced how to convert model weights for your own models to run with MLC LLM, and (optionally) how to compile your models. -- We also discussed the the universal deployment capability of MLC LLM. +- We also discussed the universal deployment capability of MLC LLM. Next, please feel free to check out the pages below for quick start examples and more detailed information on specific platforms diff --git a/docs/get_started/quick_start.rst b/docs/get_started/quick_start.rst index bd3b41218e..8349197eda 100644 --- a/docs/get_started/quick_start.rst +++ b/docs/get_started/quick_start.rst @@ -6,7 +6,7 @@ Quick Start Examples -------- -To begin with, try out MLC LLM support for int4-quantized Llama2 7B. +To begin with, try out MLC LLM support for int4-quantized Llama3 8B. It is recommended to have at least 6GB free VRAM to run it. .. tabs:: @@ -20,11 +20,11 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # Create engine - model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" - engine = LLMEngine(model) + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( @@ -57,7 +57,7 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: shell - mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC **Send requests to server.** When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), open a new shell and send a request via the following command: @@ -67,7 +67,7 @@ It is recommended to have at least 6GB free VRAM to run it. curl -X POST \ -H "Content-Type: application/json" \ -d '{ - "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", "messages": [ {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} ] @@ -94,7 +94,7 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: bash - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC If you are using windows/linux/steamdeck and would like to use vulkan, @@ -133,7 +133,7 @@ It is recommended to have at least 6GB free VRAM to run it. | - **Requirement**. Llama2-7B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + **Requirement**. Llama3-8B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. **Tutorial and source code**. The source code of the iOS app is fully `open source `__, and a :ref:`tutorial ` is included in documentation. @@ -154,7 +154,7 @@ It is recommended to have at least 6GB free VRAM to run it. | - **Requirement**. Llama2-7B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + **Requirement**. Llama3-8B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. The demo is tested on - Samsung S23 with Snapdragon 8 Gen 2 chip diff --git a/docs/index.rst b/docs/index.rst index e9835e152d..2d5597d18e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,7 +46,6 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a compilation/convert_weights.rst compilation/compile_models.rst compilation/define_new_models.rst - compilation/configure_quantization.rst .. toctree:: :maxdepth: 1 diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index c6602559ae..ce15616957 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -118,6 +118,13 @@ Select your operating system/compute platform and run the command in your termin python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly .. note:: + Make sure you also install vulkan loader and clang to avoid vulkan + not found error or clang not found(needed for jit compile) + + .. code-block:: bash + + conda install -c conda-forge clang libvulkan-loader + If encountering the error below: .. code-block:: bash @@ -207,7 +214,9 @@ There are two ways to do so: .. code-tab :: bash Install via environment variable - export PYTHONPATH=/path-to-mlc-llm/python:$PYTHONPATH + export MLC_LLM_HOME=/path-to-mlc-llm + export PYTHONPATH=$MLC_LLM_HOME/python:$PYTHONPATH + alias mlc_llm="python -m mlc_llm" .. code-tab :: bash Install via pip local project diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index 849152cce6..ed4977e5e3 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -112,6 +112,13 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. note:: + Make sure you also install vulkan loader and clang to avoid vulkan + not found error or clang not found(needed for jit compile) + + .. code-block:: bash + + conda install -c conda-forge clang libvulkan-loader + If encountering the error below: .. code-block:: bash @@ -213,7 +220,7 @@ While it is generally recommended to always use the prebuilt TVM Unity, if you r If you are using CUDA and your compute capability is above 80, then it is require to build with ``set(USE_FLASHINFER ON)``. Otherwise, you may run into ``Cannot find PackedFunc`` issue during runtime. - + To check your CUDA compute capability, you can use ``nvidia-smi --query-gpu=compute_cap --format=csv``. Once ``config.cmake`` is edited accordingly, kick off build with the commands below: diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst index f97909a515..2f772a5d7e 100644 --- a/docs/prebuilt_models.rst +++ b/docs/prebuilt_models.rst @@ -68,7 +68,7 @@ For more, please see :ref:`the CLI page `, and the :ref:`the Python .. code:: shell - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC To run the model with Python API, see :ref:`the Python page ` (all other downloading steps are the same as CLI). diff --git a/docs/requirements.txt b/docs/requirements.txt index bc020bc662..0156a180b0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,5 +6,9 @@ tlcpack-sphinx-addon==0.2.2 sphinxcontrib_httpdomain==1.8.1 sphinxcontrib-napoleon==0.7 sphinx-reredirects==0.1.2 +shortuuid +pydantic +uvicorn +fastapi --find-links https://mlc.ai/wheels mlc-ai-nightly diff --git a/examples/python/sample_mlc_engine.py b/examples/python/sample_mlc_engine.py index e26e17f1e2..e4f869930f 100644 --- a/examples/python/sample_mlc_engine.py +++ b/examples/python/sample_mlc_engine.py @@ -1,8 +1,8 @@ -from mlc_llm import LLMEngine +from mlc_llm import MLCEngine # Create engine -model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" -engine = LLMEngine(model) +model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" +engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 943f98c7e2..24ad8faecf 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -664,7 +664,7 @@ def _inspect_model_lib_metadata_memory_usage(model_lib_path, config_file_path): "--mlc-chat-config", config_file_path, ] - subprocess.run(cmd, check=False) + subprocess.run(cmd, check=False, env=os.environ) class ChatModule: # pylint: disable=too-many-instance-attributes @@ -768,7 +768,7 @@ def __init__( # pylint: disable=too-many-arguments self.chat_config = _get_chat_config(self.config_file_path, chat_config) # 4. Look up model library - try: + if model_lib_path is not None: self.model_lib_path = _get_lib_module_path( model, self.model_path, @@ -777,8 +777,8 @@ def __init__( # pylint: disable=too-many-arguments self.device.MASK2STR[self.device.device_type], self.config_file_path, ) - except FileNotFoundError: - logger.info("Model lib not found. Now compiling model lib on device...") + else: + logger.info("Now compiling model lib on device...") from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel self.model_lib_path = str( diff --git a/python/mlc_llm/cli/delivery.py b/python/mlc_llm/cli/delivery.py index 50b9c7e170..a7dd6408b0 100644 --- a/python/mlc_llm/cli/delivery.py +++ b/python/mlc_llm/cli/delivery.py @@ -1,7 +1,9 @@ """Continuous model delivery for MLC LLM models.""" + import argparse import dataclasses import json +import os import shutil import subprocess import sys @@ -131,7 +133,9 @@ def _run_quantization( cmd += ["--" + optional_arg.replace("_", "-"), str(optional_arg_val)] print(" ".join(cmd), file=log_file, flush=True) - subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + subprocess.run( + cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ + ) cmd = [ sys.executable, "-m", @@ -146,7 +150,9 @@ def _run_quantization( output_dir, ] print(" ".join(cmd), file=log_file, flush=True) - subprocess.run(cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT) + subprocess.run( + cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ + ) logger.info("[MLC] Complete!") if not (Path(output_dir) / "ndarray-cache.json").exists(): logger.error( diff --git a/python/mlc_llm/cli/lib_delivery.py b/python/mlc_llm/cli/lib_delivery.py new file mode 100644 index 0000000000..a5d678fbe2 --- /dev/null +++ b/python/mlc_llm/cli/lib_delivery.py @@ -0,0 +1,200 @@ +"""Continuous model delivery for MLC LLM models.""" + +import argparse +import dataclasses +import json +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, Callable, Dict, List + +from mlc_llm.support import logging +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.constants import MLC_TEMP_DIR +from mlc_llm.support.style import bold, green, red + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ModelInfo: # pylint: disable=too-many-instance-attributes + """Necessary information for the model delivery""" + + model_id: str + model: Path + quantization: str + device: str + # overrides the `context_window_size`, `prefill_chunk_size`, + # `sliding_window_size`, `attention_sink_size`, `max_batch_size` + # and `tensor_parallel_shards in mlc-chat-config.json + overrides: Dict[str, int] + + +class DeferredScope: + """A context manager that defers execution of functions until exiting the scope.""" + + def __init__(self): + self.deferred_functions = [] + + def add(self, func: Callable[[], None]): + """Add a function to be executed when exiting the scope.""" + self.deferred_functions.append(func) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + for func in reversed(self.deferred_functions): + func() + return False + + def create_temp_dir(self) -> Path: + """Create a temporary directory that will be deleted when exiting the scope.""" + temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR) + self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) + return Path(temp_dir) + + +def _run_compilation(model_info: ModelInfo, repo_dir: Path) -> bool: + """Run the compilation of the model library.""" + + def get_lib_ext(device: str) -> str: + if device in ["cuda", "vulkan", "metal"]: + return ".so" + if device in ["android", "ios"]: + return ".tar" + if device in ["webgpu"]: + return ".wasm" + + return "" + + succeeded = True + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as temp_dir: + log_path = Path(temp_dir) / "logs.txt" + model_lib_name = f"{model_info.model_id}-{model_info.quantization}-{model_info.device}" + lib_ext = get_lib_ext(model_info.device) + if lib_ext == "": + raise ValueError(f"Unsupported device: {model_info.device}") + model_lib_name += lib_ext + with log_path.open("a", encoding="utf-8") as log_file: + overrides = ";".join(f"{key}={value}" for key, value in model_info.overrides.items()) + cmd = [ + sys.executable, + "-m", + "mlc_llm", + "compile", + str(model_info.model), + "--device", + model_info.device, + "--quantization", + model_info.quantization, + "--overrides", + overrides, + "--output", + os.path.join(temp_dir, model_lib_name), + ] + print(" ".join(cmd), file=log_file, flush=True) + subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + logger.info("[MLC] Compilation Complete!") + if not (Path(temp_dir) / model_lib_name).exists(): + logger.error( + "[%s] Model %s. Device %s. No compiled library found.", + red("FAILED"), + model_info.model_id, + model_info.device, + ) + succeeded = False + return succeeded + + # overwrite git repo file with the compiled library + repo_filepath = repo_dir / model_info.model_id / model_lib_name + if not repo_filepath.parent.exists(): + repo_filepath.parent.mkdir(parents=True, exist_ok=True) + # copy lib from Path(temp_dir) / model_lib_name to repo_filepath + shutil.copy(Path(temp_dir) / model_lib_name, repo_filepath) + logger.info("Saved library %s at %s", model_lib_name, repo_filepath) + return succeeded + + +def _main( # pylint: disable=too-many-locals + spec: Dict[str, Any], +): + """Compile the model libs in the spec and save them to the binary_libs_dir.""" + failed_cases: List[Any] = [] + for task_index, task in enumerate(spec["tasks"], 1): + logger.info( + bold("[{task_index}/{total_tasks}] Processing model: ").format( + task_index=task_index, + total_tasks=len(spec["tasks"]), + ) + + green(task["model_id"]) + ) + model_info = { + "model_id": task["model_id"], + "model": task["model"], + } + for compile_opt in spec["default_compile_options"] + task.get("compile_options", []): + for quantization in spec["default_quantization"] + task.get("quantization", []): + model_info["quantization"] = quantization + model_info["device"] = compile_opt["device"] + model_info["overrides"] = compile_opt.get("overrides", {}) + logger.info( + "[Config] " + + bold("model_id: ") + + model_info["model_id"] + + bold(", quantization: ") + + model_info["quantization"] + + bold(", device: ") + + model_info["device"] + + bold(", overrides: ") + + json.dumps(model_info["overrides"]) + ) + + result = _run_compilation( + ModelInfo(**model_info), + repo_dir=Path(spec["binary_libs_dir"]), + ) + if not result: + failed_cases.append(model_info) + + if failed_cases: + logger.info("Total %s %s:", len(failed_cases), red("failures")) + for case in failed_cases: + logger.info( + "model_id %s, quantization %s, device %s, overrides %s", + case["model_id"], + case["quantization"], + case["device"], + json.dumps(case["overrides"]), + ) + + +def main(): + """Entry point.""" + + def _load_spec(path_spec: str) -> Dict[str, Any]: + path = Path(path_spec) + if not path.exists(): + raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}") + with path.open("r", encoding="utf-8") as i_f: + return json.load(i_f) + + parser = ArgumentParser("MLC LLM continuous library delivery") + parser.add_argument( + "--spec", + type=_load_spec, + required=True, + help="Path to the spec file", + ) + parsed = parser.parse_args() + _main( + spec=parsed.spec, + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_llm/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py index 9b45561665..81473b1ec7 100644 --- a/python/mlc_llm/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any, Dict, List, Union -import numpy as np +from tvm.runtime import DataType from mlc_llm.support import logging from mlc_llm.support.argparse import ArgumentParser @@ -81,7 +81,7 @@ def _compute_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBa else: # Contains dynamic shape; use config to look up concrete values param_shape = _read_dynamic_shape(param["shape"], config) - params_bytes += math.prod(param_shape) * np.dtype(param["dtype"]).itemsize + params_bytes += math.prod(param_shape) * DataType(param["dtype"]).itemsize() temp_func_bytes = 0.0 for _func_name, func_bytes in metadata["memory_usage"].items(): temp_func_bytes = max(temp_func_bytes, func_bytes) diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 9f7c1c3580..6663a0c230 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -44,6 +44,9 @@ def main(argv): "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] ) parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) + parser.add_argument( + "--max-history-size", type=int, default=1, help=HELP["max_history_size_serve"] + ) parser.add_argument( "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] ) @@ -100,6 +103,7 @@ def main(argv): max_batch_size=parsed.max_batch_size, max_total_sequence_length=parsed.max_total_seq_length, prefill_chunk_size=parsed.prefill_chunk_size, + max_history_size=parsed.max_history_size, gpu_memory_utilization=parsed.gpu_memory_utilization, speculative_mode=SpeculativeMode[parsed.speculative_mode], spec_draft_length=parsed.spec_draft_length, diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 1b7b0328a9..46dc40c106 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -7,6 +7,8 @@ from tvm.relax.frontend import nn from tvm.script import tir as T +from ..op.batch_spec_verify import batch_spec_verify + @tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc") class AttachGPUSamplingFunc: # pylint: disable=too-few-public-methods @@ -46,6 +48,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR _attach_argsort_func(bb, vocab_size), _attach_sample_with_top_p(bb, vocab_size), _attach_take_probs_func(bb, vocab_size), + _attach_batch_verifier(bb, vocab_size), + _attach_renormalize_by_top_p(bb, vocab_size), ] ] @@ -126,6 +130,17 @@ def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): return gv +@T.prim_func +def full(var_result: T.handle, value: T.int32): + """The filling function for top k.""" + batch_size = T.int32(is_size_var=True) + result = T.match_buffer(var_result, (batch_size, 1), "int32") + for i in T.serial(batch_size): + with T.block("block"): + vi = T.axis.spatial(batch_size, i) + result[vi, 0] = value + + def _attach_sample_with_top_p( # pylint: disable=too-many-locals bb: relax.BlockBuilder, vocab_size: tir.PrimExpr ): @@ -143,15 +158,6 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals sample_indices = relax.Var("sample_indices", relax.TensorStructInfo((num_samples,), "int32")) top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) - @T.prim_func - def full(var_result: T.handle, value: T.int32): - batch_size = T.int32(is_size_var=True) - result = T.match_buffer(var_result, (batch_size, 1), "int32") - for i in T.serial(batch_size): - with T.block("block"): - vi = T.axis.spatial(batch_size, i) - result[vi, 0] = value - with bb.function( "sample_with_top_p", [sorted_probs, sorted_indices, uniform_samples, sample_indices, top_p], @@ -221,6 +227,44 @@ def full(var_result: T.handle, value: T.int32): return gv +def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): + batch_size = tir.Var("batch_size", "int64") + probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) + sorted_probs = relax.Var( + "sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") + ) + top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) + with bb.function("renormalize_by_top_p", [probs, sorted_probs, top_p]): + with bb.dataflow(): + probs_tensor = nn.wrap_nested(probs, name="probs") + sorted_probs_tensor = nn.wrap_nested(sorted_probs, name="sorted_probs") + top_p_shape = relax.ShapeExpr([batch_size, 1]) + top_p_tensor = nn.wrap_nested( + relax.call_pure_packed( + "vm.builtin.reshape", + top_p, + top_p_shape, + sinfo_args=relax.TensorStructInfo(top_p_shape, "float32"), + ), + name="sample_indices", + ) + top_k_tensor = nn.tensor_ir_op( + full, + name_hint="full", + args=[vocab_size], + out=nn.Tensor.placeholder( + [batch_size, 1], + "int32", + ), + ) + renormalized_probs = nn.renormalize_top_p_top_k_prob( + probs_tensor, sorted_probs_tensor, top_p_tensor, top_k_tensor + ) + bb.emit_output(renormalized_probs._expr) # pylint: disable=protected-access + gv = bb.emit_func_output(renormalized_probs._expr) # pylint: disable=protected-access + return gv + + def _attach_take_probs_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): batch_size = tir.Var("batch_size", "int64") num_samples = tir.Var("num_samples", "int64") @@ -289,3 +333,50 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument bb.emit_output(taken_probs_indices) gv = bb.emit_func_output(taken_probs_indices) return gv + + +def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): + num_nodes = tir.Var("num_nodes", "int64") + nbatch = tir.Var("nbatch", "int64") + draft_probs = relax.Var( + "draft_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32") + ) + draft_tokens = relax.Var("draft_tokens", relax.TensorStructInfo((num_nodes,), "int32")) + model_probs = relax.Var( + "model_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32") + ) + token_tree_first_child = relax.Var( + "token_tree_first_child", relax.TensorStructInfo((num_nodes,), "int32") + ) + token_tree_next_sibling = relax.Var( + "token_tree_next_sibling", relax.TensorStructInfo((num_nodes,), "int32") + ) + uniform_samples = relax.Var("uniform_samples", relax.TensorStructInfo((num_nodes,), "float32")) + token_tree_parent_ptr = relax.Var( + "token_tree_parent_ptr", relax.TensorStructInfo((nbatch,), "int32") + ) + args = [ + draft_probs, + draft_tokens, + model_probs, + token_tree_first_child, + token_tree_next_sibling, + uniform_samples, + token_tree_parent_ptr, + ] + with bb.function("sampler_verify_draft_tokens", args): + with bb.dataflow(): + res = bb.emit( + relax.call_tir_inplace( + bb.add_func(batch_spec_verify(vocab_size), "batch_verify_on_gpu_single_kernel"), + args, + inplace_indices=[args.index(model_probs), args.index(token_tree_parent_ptr)], + out_sinfo=[ + model_probs.struct_info, # pylint: disable=no-member + token_tree_parent_ptr.struct_info, # pylint: disable=no-member + ], + ) + ) + bb.emit_output(res) + gv = bb.emit_func_output(res) + return gv diff --git a/python/mlc_llm/compiler_pass/estimate_memory_usage.py b/python/mlc_llm/compiler_pass/estimate_memory_usage.py index d69d99109d..83007fde66 100644 --- a/python/mlc_llm/compiler_pass/estimate_memory_usage.py +++ b/python/mlc_llm/compiler_pass/estimate_memory_usage.py @@ -25,6 +25,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR func_name = "_metadata" + func_name = "_metadata" + def _emit_metadata(metadata): bb = relax.BlockBuilder() # pylint: disable=invalid-name with bb.function(func_name, params=[]): diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index b85a6a2cf6..57b68f742d 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -33,6 +33,7 @@ from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc from .low_batch_specialization import LowBatchGemvSpecialize +from .rewrite_softmax import RewriteTwoStageSoftmax from .scatter_tuple_get_item import ScatterTupleGetItem logger = logging.getLogger(__name__) @@ -117,6 +118,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline _LogProgress("Lowering to TVM TIR kernels"), tvm.relax.backend.DispatchSortScan(), + RewriteTwoStageSoftmax(target=target), tvm.relax.transform.LegalizeOps(), tvm.relax.transform.AnnotateTIROpPattern(), tvm.relax.transform.FoldConstant(), diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py new file mode 100644 index 0000000000..1a6e41eafc --- /dev/null +++ b/python/mlc_llm/compiler_pass/rewrite_softmax.py @@ -0,0 +1,190 @@ +"""A compiler pass that rewrites one-shot softmax into two-stage softmax.""" + +import math + +import tvm +from tvm import relax +from tvm.ir.module import IRModule +from tvm.relax.expr import Expr +from tvm.relax.expr_functor import PyExprMutator, mutator +from tvm.script import tir as T + +from ..support.max_thread_check import get_max_num_threads_per_block + + +@tvm.transform.module_pass(opt_level=0, name="RewriteTwoStageSoftmax") +class RewriteTwoStageSoftmax: # pylint: disable=too-few-public-methods + """Rewrites one-shot softmax into two-stage softmax.""" + + def __init__(self, target: tvm.target.Target) -> None: + self.target = target + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _Rewriter(mod, self.target).transform() + + +@mutator +class _Rewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: + super().__init__(mod) + self.mod = mod + self.target = target + self.chunk_size = 4096 + + def transform(self) -> IRModule: + """Entry point""" + gv = self.mod.get_global_var("softmax_with_temperature") + updated_func = self.visit_expr(self.mod[gv]) + self.builder_.update_func(gv, updated_func) + return self.builder_.get() + + def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-renamed + if call.op != tvm.ir.Op.get("relax.nn.softmax"): + return call + x = call.args[0] + if call.attrs.axis not in [-1, x.struct_info.ndim - 1]: + return call + # Currently the softmax input is 3-dim, and dtype is float32. + assert x.struct_info.ndim == 3 + assert x.struct_info.dtype == "float32" + x_shape = x.struct_info.shape + new_shape = relax.ShapeExpr([x_shape[0] * x_shape[1], x_shape[2]]) + x_reshaped = relax.call_pure_packed( + "vm.builtin.reshape", + x, + new_shape, + sinfo_args=relax.TensorStructInfo(new_shape, x.struct_info.dtype), + ) + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(self.target, self.chunk_size) + chunked_lse = relax.call_tir( + self.builder_.add_func(f_chunk_lse, "chunk_lse"), + args=[x_reshaped], + out_sinfo=relax.TensorStructInfo( + (new_shape[0], (new_shape[1] + self.chunk_size - 1) // self.chunk_size), + x.struct_info.dtype, + ), + ) + softmax = relax.call_tir( + self.builder_.add_func(f_softmax_with_lse, "softmax_with_chunked_lse"), + args=[x_reshaped, chunked_lse], + out_sinfo=relax.TensorStructInfo(new_shape, x.struct_info.dtype), + ) + return relax.call_pure_packed( + "vm.builtin.reshape", softmax, x_shape, sinfo_args=x.struct_info + ) + + +def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements + target: tvm.target.Target, chunk_size: int +): + log2e = math.log2(math.exp(1)) + + # pylint: disable=invalid-name + @T.prim_func + def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") + A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(chunk_size)), dtype="float32") + temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + A_pad[v0, v1, v2] = T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + A[v0, v1 * T.int64(chunk_size) + v2], + T.min_value("float32"), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("max"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_max[v0, v1] = T.min_value("float32") + temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("sum_exp"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_sum[v0, v1] = T.float32(0) + temp_sum[v0, v1] += T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + T.exp2((A_pad[v0, v1, v2] - temp_max[v0, v1]) * log2e), + T.float32(0), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): + with T.block("log"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + chunked_lse[v0, v1] = T.log2(temp_sum[v0, v1]) + temp_max[v0, v1] * log2e + + @T.prim_func + def softmax_with_chunked_lse(var_A: T.handle, var_chunked_lse: T.handle, var_softmax: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") + softmax = T.match_buffer(var_softmax, (batch_size, vocab_size), dtype="float32") + temp_max = T.alloc_buffer((batch_size,), dtype="float32") + temp_sum = T.alloc_buffer((batch_size,), dtype="float32") + lse = T.alloc_buffer((batch_size,), dtype="float32") + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("max"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_max[v0] = T.min_value("float32") + temp_max[v0] = T.max(temp_max[v0], chunked_lse[v0, v1]) + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("sum_exp"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_sum[v0] = T.float32(0) + temp_sum[v0] += T.exp2(chunked_lse[v0, v1] - temp_max[v0]) + for l0 in T.serial(0, batch_size): + with T.block("log"): + v0 = T.axis.remap("S", [l0]) + lse[v0] = T.log2(temp_sum[v0]) + temp_max[v0] + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + if v1 * T.int64(chunk_size) + v2 < vocab_size: + softmax[v0, v1 * T.int64(chunk_size) + v2] = T.exp2( + A[v0, v1 * T.int64(chunk_size) + v2] * log2e - lse[v0] + ) + + sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_lse": softmax_with_chunked_lse})) + max_threads = get_max_num_threads_per_block(target) + TX = 32 + TY = max_threads // TX + unroll_depth = 64 + # pylint: enable=invalid-name + + sch.work_on("softmax_with_chunked_lse") + sch.compute_inline("log") + l0, l1, l2 = sch.get_loops("pad") + bx = sch.fuse(l0, l1) + sch.bind(bx, "blockIdx.x") + unroll, ty, tx = sch.split(l2, [None, TY, TX]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1) + + for block_name in ["sum_exp", "max"]: + block = sch.get_block(block_name) + sch.set_scope(block, buffer_index=0, storage_scope="shared") + sch.compute_at(block, bx) + r_loop = sch.get_loops(block)[-1] + r_loop, tx = sch.split(r_loop, [None, TX]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + return chunk_lse, sch.mod["softmax_with_chunked_lse"] diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 1b2a06feab..1c599fa875 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -36,6 +36,27 @@ def get_conv_template(name: str) -> Optional[Conversation]: ############## Preset Conversation Templates ############## +# Llama3 +# See https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models +# and https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py +ConvTemplateRegistry.register_conv_template( + Conversation( + name="llama-3", + system_template=( + f"<|start_header_id|>system<|end_header_id|>\n\n{MessagePlaceholders.SYSTEM.value}" + ), + system_message="You are a helpful, respectful and honest assistant.", + roles={"user": "user", "assistant": "assistant"}, + seps=["<|eot_id|><|start_header_id|>"], + role_content_sep="<|end_header_id|>\n\n", + role_empty_sep="<|end_header_id|>\n\n", + stop_str=["<|end_of_text|>", "<|eot_id|>"], + stop_token_ids=[128001, 128009], # "<|end_of_text|>", "<|eot_id|>" + system_prefix_token_ids=[128000], # "<|begin_of_text|>" + add_role_after_system_message=True, + ) +) + # Llama2 ConvTemplateRegistry.register_conv_template( Conversation( @@ -344,7 +365,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: # RWKV World ConvTemplateRegistry.register_conv_template( Conversation( - name="rwkv-world", + name="rwkv_world", system_template=f"User: hi\n\nAssistant: {MessagePlaceholders.SYSTEM.value}", system_message=( "Hi. I am your assistant and I will provide expert full response " diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index b4321ebdec..86930fa5ea 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -152,6 +152,11 @@ The maximum number of tokens the model passes for prefill each time. It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config. +""".strip(), + "max_history_size_serve": """ +The maximum history length for rolling back the RNN state. +If unspecified, the default value is 1. +KV cache does not need this. """.strip(), "enable_tracing_serve": """ Enable Chrome Tracing for the server. @@ -188,7 +193,7 @@ "gpu_memory_utilization_serve": """ A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. -When it is unspecified, it defaults to 0.90. +When it is unspecified, it defaults to 0.85. Under mode "local" or "interactive", the actual memory usage may be significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. """, @@ -203,7 +208,7 @@ The number of draft tokens to generate in speculative proposal. The default values is 4. """, "engine_config_serve": """ -The LLMEngine execution configuration. +The MLCEngine execution configuration. Currently speculative decoding mode is specified via engine config. For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to specify the eagle-style speculative decoding. diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 2d0d668672..77b611d139 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -2,7 +2,6 @@ import dataclasses import enum -import re from io import StringIO from typing import Optional diff --git a/python/mlc_llm/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py index b54318ef4c..179c872e50 100644 --- a/python/mlc_llm/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -7,10 +7,9 @@ from pathlib import Path from typing import Any, Dict, Iterator, Tuple -import numpy as np from tvm import tir from tvm.contrib import tvmjs -from tvm.runtime import Device, NDArray +from tvm.runtime import DataType, Device, NDArray from tvm.runtime import cpu as cpu_device from tvm.target import Target diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index d22aa7d231..8e617fc3d2 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -274,6 +274,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b # FIXME: Copy RWKV tokenizer file # pylint: disable=fixme CONV_TEMPLATES = { + "llama-3", "chatml", "open_hermes_mistral", "neural_hermes_mistral", diff --git a/python/mlc_llm/interface/jit.py b/python/mlc_llm/interface/jit.py index 25548e0e4a..e999a36468 100644 --- a/python/mlc_llm/interface/jit.py +++ b/python/mlc_llm/interface/jit.py @@ -93,7 +93,11 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): ] logger.info("Compiling using commands below:") logger.info("%s", blue(shlex.join(cmd))) - subprocess.run(cmd, check=True) + subprocess.run(cmd, check=False, env=os.environ) + # note on windows: compilation can succeed but return code is still nonzero + # check whether file exists instead + if not os.path.isfile(dso_path): + raise RuntimeError("Cannot find compilation output, compilation failed") shutil.move(dso_path, dst) logger.info("Using compiled model lib: %s", bold(dst)) diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index c5696ef473..40fa9fdda8 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -22,6 +22,7 @@ def serve( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -35,7 +36,7 @@ def serve( ): # pylint: disable=too-many-arguments, too-many-locals """Serve the model with the specified configuration.""" # Create engine and start the background loop - async_engine = engine.AsyncLLMEngine( + async_engine = engine.AsyncMLCEngine( model=model, device=device, model_lib_path=model_lib_path, @@ -44,6 +45,7 @@ def serve( max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, diff --git a/python/mlc_llm/json_ffi/__init__.py b/python/mlc_llm/json_ffi/__init__.py new file mode 100644 index 0000000000..8a7059153d --- /dev/null +++ b/python/mlc_llm/json_ffi/__init__.py @@ -0,0 +1,8 @@ +"""JSON FFI is a pure string based interface of MLC LLM Engine. + +We build interfacing with JSON FFI for both testing purposes +and internal use. For most python API usage, please use MLCEngine +and MLCAsyncEngine +""" + +from .engine import JSONFFIEngine diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py new file mode 100644 index 0000000000..0c604a2ef3 --- /dev/null +++ b/python/mlc_llm/json_ffi/engine.py @@ -0,0 +1,310 @@ +# pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import json +import queue +import threading +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union + +import tvm + +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import engine_utils +from mlc_llm.serve.engine_base import ( + EngineConfig, + SpeculativeMode, + _infer_kv_cache_config, + _parse_models, + _process_model_args, + detect_device, +) +from mlc_llm.tokenizer import Tokenizer + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# construction to not depend on any config and directly pass in JSON +# model defined generation config should be read from the JSONFFIEngine via Reload +def create_model_defined_generation_config( + temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( + temperature, + top_p, + frequency_penalty, + presence_penalty, + ) + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# Engine config should be passed as json str +# and backend should have good default +# only model and model_lib should be mandatory +def create_json_ffi_engine_config( + conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( + conv_template, model_generation_cfgs + ) + + +class EngineState: + sync_queue: queue.Queue + + def get_request_stream_callback(self) -> Callable[[List[str]], None]: + # ChatCompletionStreamResponse + + def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + self._sync_request_stream_callback(chat_completion_stream_responses_json_str) + + return _callback + + def _sync_request_stream_callback( + self, chat_completion_stream_responses_json_str: List[str] + ) -> None: + # Put the delta outputs to the queue in the unblocking way. + self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) + + +class JSONFFIEngine: + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + max_history_size: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, + gpu_memory_utilization: Optional[float] = None, + ) -> None: + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) + ( + model_args, + model_config_paths, + self.conv_template, + ) = _process_model_args(models, device) + + # TODO(mlc-team) Remove the model config parsing, estimation below + # in favor of a simple direct passing of parameters into backend. + # JSONFFIEngine do not have to support automatic mode + # + # Instead, its config should default to interactive mode always + # and allow overrides of parameters through json config via reload + # + # This is to simplify the logic of users of JSONFFI + # since we won't have similar logics in android/iOS + # + # - Load the raw model config into dict + self.model_config_dicts = [] + for i, model_info in enumerate(models): + model_info.model_lib_path = model_args[i][1] + with open(model_config_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + # - Decide the KV cache config based on mode and user input. + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + + # - Initialize engine state and engine. + self.state = EngineState() + module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "init_background_engine", + "reload", + "unload", + "reset", + "chat_completion", + "abort", + "get_last_error", + "run_background_loop", + "run_background_stream_back_loop", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(model_args[0][0]) + + self.engine_config = EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ) + + self.json_ffi_engine_config = create_json_ffi_engine_config( + conv_template=self.conv_template.model_dump_json(), + model_generation_cfgs={ + model.model: create_model_defined_generation_config( + temperature=model_config["temperature"], + top_p=model_config["top_p"], + frequency_penalty=model_config["frequency_penalty"], + presence_penalty=model_config["presence_penalty"], + ) + for model, model_config in zip(models, self.model_config_dicts) + }, + ) + + self._ffi["init_background_engine"]( + self.json_ffi_engine_config, + self.engine_config, + device, + self.state.get_request_stream_callback(), + None, + ) + + def _background_loop(): + self._ffi["run_background_loop"]() + + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + + # Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) + self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() + self._terminated = False + + def terminate(self): + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() + + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ).model_dump_json(), + n=n, + request_id=request_id, + ) + for response in chatcmpl_generator: + yield response + + def _handle_chat_completion( + self, request_json_str: str, n: int, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + self.state.sync_queue = queue.Queue() + num_unfinished_requests = n + + success = bool(self._ffi["chat_completion"](request_json_str, request_id)) + + try: + while num_unfinished_requests > 0: + chat_completion_stream_responses_json_str = self.state.sync_queue.get() + for chat_completion_response_json_str in chat_completion_stream_responses_json_str: + chat_completion_response = ( + openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( + chat_completion_response_json_str + ) + ) + for choice in chat_completion_response.choices: + if choice.finish_reason is not None: + num_unfinished_requests -= 1 + yield chat_completion_response + except Exception as exception: # pylint: disable=broad-exception-caught + self._ffi["abort"](request_id) + raise exception + + def _test_reload(self): + self._ffi["reload"](self.engine_config) + + def _test_reset(self): + self._ffi["reset"]() + + def _test_unload(self): + self._ffi["unload"]() diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index 28c34353e2..ede9dc350f 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -28,7 +28,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes n_embd: int n_layer: int n_head: int - layer_norm_epsilon: int + layer_norm_epsilon: float n_inner: int = -1 context_window_size: int = 0 prefill_chunk_size: int = 0 diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 2ae5500c6d..18238f688e 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -224,15 +224,41 @@ def batch_forward( hidden_states = self.model(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) + return self.get_logits(hidden_states) + + def batch_forward_to_last_hidden_states( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + return hidden_states + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.model.embed_tokens(input_ids) + + def get_logits(self, hidden_states: Tensor): + op_ext.configure() logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") return logits - def embed(self, input_ids: Tensor): + def batch_get_logits(self, hidden_states: Tensor, logit_positions: Tensor): + op_ext.configure() if self.tensor_parallel_shards > 1: - input_ids = op.ccl_broadcast_from_worker0(input_ids) - return self.model.embed_tokens(input_ids) + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + hidden_states = op.take(hidden_states, logit_positions, axis=0) + return self.get_logits(hidden_states) + + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): + op_ext.configure() + hidden_states = op.take(hidden_states, logit_positions, axis=0) + return hidden_states def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() @@ -243,20 +269,28 @@ def _index(x: te.Tensor): # x[:-1,:] hidden_states = self.model(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) - logits = self.lm_head(hidden_states) - if logits.dtype != "float32": - logits = logits.astype("float32") + logits = self.get_logits(hidden_states) return logits, paged_kv_cache def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) - logits = self.lm_head(hidden_states) - if logits.dtype != "float32": - logits = logits.astype("float32") + logits = self.get_logits(hidden_states) return logits, paged_kv_cache + def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + return hidden_states, paged_kv_cache + + def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + return hidden_states, paged_kv_cache + def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): @@ -273,6 +307,24 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache + def batch_prefill_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_decode_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_verify_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) @@ -309,6 +361,29 @@ def get_default_spec(self): "effect_mode": "none", }, }, + "get_logits": { + "hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_get_logits": { + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_select_last_hidden_states": { + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), @@ -325,6 +400,22 @@ def get_default_spec(self): "effect_mode": "none", }, }, + "prefill_to_last_hidden_states": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode_to_last_hidden_states": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), @@ -350,6 +441,30 @@ def get_default_spec(self): "effect_mode": "none", }, }, + "batch_prefill_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "softmax_with_temperature": { "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), "temperature": nn.spec.Tensor(["batch_size"], "float32"), diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 1c513e15d3..272cffdc80 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -85,7 +85,7 @@ class Model: "group-quant": llama_quantization.group_quant, "ft-quant": llama_quantization.ft_quant, "awq": llama_quantization.awq_quant, - "smoothquant": llama_quantization.smooth_quant, + "smoothquant": llama_quantization.smooth_quant }, ), "mistral": Model( diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index 3bfe1cb891..41abf0292c 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -660,4 +660,54 @@ "eos_token_id": 2, "pad_token_id": 0, }, + "llama3_8b": { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "use_cache": True, + "vocab_size": 128256, + }, + "llama3_70b": { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "use_cache": True, + "vocab_size": 128256, + }, } diff --git a/python/mlc_llm/model/rwkv5/rwkv5_model.py b/python/mlc_llm/model/rwkv5/rwkv5_model.py index 49386720da..81c9e9aa7f 100644 --- a/python/mlc_llm/model/rwkv5/rwkv5_model.py +++ b/python/mlc_llm/model/rwkv5/rwkv5_model.py @@ -40,6 +40,7 @@ class RWKV5Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -129,23 +130,18 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - # x.shape = (batch, seq_len, hidden_size) - # state.shape = (batch, hidden_size) - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): # x.shape = (batch, seq_len, hidden_size) batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -350,10 +346,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -367,11 +367,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -386,7 +402,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -396,9 +411,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -406,7 +419,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -414,8 +452,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/rwkv6/rwkv6_model.py b/python/mlc_llm/model/rwkv6/rwkv6_model.py index 0e1887310d..a8faf48a6b 100644 --- a/python/mlc_llm/model/rwkv6/rwkv6_model.py +++ b/python/mlc_llm/model/rwkv6/rwkv6_model.py @@ -40,6 +40,7 @@ class RWKV6Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -126,20 +127,17 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -390,10 +388,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -407,11 +409,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -426,7 +444,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -436,9 +453,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -446,7 +461,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -454,8 +494,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 4a058c6e03..e4cbf1c047 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -887,7 +887,7 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 if target.kind.name == "opencl" and "android" in str(target.host): - THREAD_LIMIT = 64 + THREAD_LIMIT = 256 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) diff --git a/python/mlc_llm/op/__init__.py b/python/mlc_llm/op/__init__.py index 342568639d..850312a8a7 100644 --- a/python/mlc_llm/op/__init__.py +++ b/python/mlc_llm/op/__init__.py @@ -1,6 +1,9 @@ """Extern module for compiler.""" + from . import moe_matmul, moe_misc from .attention import attention +from .batch_spec_verify import batch_spec_verify from .extern import configure, enable, get_store from .ft_gemm import faster_transformer_dequantize_gemm from .position_embedding import llama_rope +from .top_p_pivot import top_p_pivot, top_p_renorm diff --git a/python/mlc_llm/op/batch_spec_verify.py b/python/mlc_llm/op/batch_spec_verify.py new file mode 100644 index 0000000000..d1a57fc71c --- /dev/null +++ b/python/mlc_llm/op/batch_spec_verify.py @@ -0,0 +1,177 @@ +"""Operators for batch verify in speculative decoding.""" + +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type,name-defined" +# pylint: disable=too-many-locals,invalid-name,too-many-arguments, +# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches + + +def batch_spec_verify(vocab_size): + """Batch draft verify function. This function verifies the token tree. + + Before calling the function + + - token_tree_parent_ptr[b] should store the root of the tree + + - draft_probs[node_id, :] stores the prob that samples the correspond tree node + - model_probs[node_id, :] stores the prob that should be used to sample its children + - Please note that the storage convention difference between model_probs and draft_probs + draft_probs was stored on the token node, while model_probs stores on the parent. + This is an intentional design since we can sample different child token with different + proposal draft probabilities, but the ground truth model_prob is unique per parent. + + After calling the function + - token_tree_parent_ptr[b] points to the last token accepted + - There should be a followup sample step that samples from model_probs[token_tree_parent_ptr[b], :] + This token will be appended to the token generated. + + This function will inplace update model_probs if a token was rejected and renormalization is needed. + + Parameters + ---------- + draft_probs: + The draft probability attached to each tree node + + draft_tokens: + The draft token in each node + + model_probs: + The model proability attached to each parent + + token_tree_first_child: + The first child of each tree node, if there is no child, it should be -1 + + token_tree_next_sibling + The next sibling of each tree node, if there is no next sibling, it should be -1 + + uniform_samples + Per node uniform sample used to check rejection + + token_tree_parent_ptr: + Current parent ptr state + """ + TX = 1024 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + # fmt: off + @T.prim_func(private=True) + def _func( + var_draft_probs: T.handle, + var_draft_tokens: T.handle, + var_model_probs: T.handle, + var_token_tree_first_child: T.handle, + var_token_tree_next_sibling: T.handle, + var_uniform_samples: T.handle, + var_token_tree_parent_ptr: T.handle, + ): + """ + [ + blockIdx.x on batch, + threadIdx.x on vocab_size, + for loop over excessive amounts + ] + """ + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + num_nodes = T.int32(is_size_var=True) + nbatch = T.int32(is_size_var=True) + + draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size), "float32") + draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32") + model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size), "float32") + token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32") + token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32") + uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,), "float32") + token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32") + + with T.block("kernel"): + child_ptr = _var() + parent_ptr = _var() + child_token = _var() + done = _var("bool") + psum = _var("float32") + t0 = _var("float32") + model_prob_local = _var("float32") + draft_prob_local = _var("float32") + p_child = _var("float32") + q_child = _var("float32") + uniform_sample = _var("float32") + + pred_shared = T.alloc_buffer((1,), "bool", scope="shared") + pred_local = T.alloc_buffer((1,), "bool", scope="local") + + for _bx in T.thread_binding(0, nbatch, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + # batch size + b = T.axis.S(nbatch, _bx) + tx = T.axis.S(TX, _tx) + + parent_ptr[0] = token_tree_parent_ptr[b] + child_ptr[0] = token_tree_first_child[parent_ptr[0]] + done[0] = False + + while T.Not(done[0]): + T.tvm_storage_sync("shared") # ensure all effects last round are visible + if child_ptr[0] == -1: + done[0] = True + T.tvm_storage_sync("shared") # sync before exit + else: + # decide to validate current ptr + if tx == 0: + child_token[0] = draft_tokens[child_ptr[0]] + p_child[0] = model_probs[parent_ptr[0], child_token[0]] + q_child[0] = draft_probs[child_ptr[0], child_token[0]] + uniform_sample[0] = uniform_samples[child_ptr[0]] + pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0] # use multiplication to avoid division by zero + T.tvm_storage_sync("shared") # make sure all read of model_probs are done + pred_local[0] = pred_shared[0] + + # accept the proposal, we move to child + if pred_local[0]: + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + psum[0] = 0.0 + # renormalize probability, predicated by stopped_expansion[b]: + for i in T.serial(T.ceildiv(vocab_size, TX)): + k = T.meta_var(i * TX + tx) + if k < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], k] + draft_prob_local[0] = draft_probs[child_ptr[0], k] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) + psum[0] += model_prob_local[0] + + with T.block("block_cross_thread"): + T.reads(psum[0]) + T.writes(t0[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), psum[0], True, t0[0], tx, dtype="handle") + + if t0[0] < 1e-7: + # accept the proposal, we move to child + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + # renormalize + for i in T.serial(T.ceildiv(vocab_size, TX)): + k = T.meta_var(i * TX + tx) + if k < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], k] + draft_prob_local[0] = draft_probs[child_ptr[0], k] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) + model_probs[parent_ptr[0], k] = model_prob_local[0] / t0[0] + + child_ptr[0] = token_tree_next_sibling[child_ptr[0]] + + if tx == 0: + token_tree_parent_ptr[b] = parent_ptr[0] + # fmt: on + + return _func diff --git a/python/mlc_llm/op/moe_matmul.py b/python/mlc_llm/op/moe_matmul.py index 6978d8ba0e..b4ebb5b630 100644 --- a/python/mlc_llm/op/moe_matmul.py +++ b/python/mlc_llm/op/moe_matmul.py @@ -2,7 +2,7 @@ from typing import Literal, Optional -from tvm import DataType, tir +from tvm import DataType, DataTypeCode, tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T @@ -335,6 +335,7 @@ def _dequantize(w, s, e, i, j): if num_elem_per_storage == 1: w = tir.reinterpret(quantize_dtype, w[e, i, j]) else: + assert DataType(storage_dtype).type_code == DataTypeCode.UINT tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) w = w[e, i, j // num_elem_per_storage] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) diff --git a/python/mlc_llm/op/top_p_pivot.py b/python/mlc_llm/op/top_p_pivot.py new file mode 100644 index 0000000000..9c97959bff --- /dev/null +++ b/python/mlc_llm/op/top_p_pivot.py @@ -0,0 +1,315 @@ +"""Operators for choosing the pivot to cut-off top-p percentile """ + +import tvm +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type,name-defined" +# pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda +# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches + + +def top_p_pivot(pN): + """Top-p pivot function. This function finds the pivot to cut-off top-p percentile. + + A valide pivot should satisfy the following conditions: + - lsum >= top_p + - top_p > lsum - cmin * lmin + where lsum is the sum of elements that are larger or equal to the pivot, + lmin is the minimum elements that is larger or equal to the pivot, + cmin is the count of elements that are equal to lmin, + + Parameters + ---------- + prob: + The probability vector + + top_p_global: + The top-p threshold + + init_pivots: + The initial pivot candidates + + final_pivot: + The final pivot to cut-off top-p percentile + """ + TX = 1024 + K = 32 + eps_LR = 1e-7 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + def valid(lsum, lmin, cmin, top_p): + return tvm.tir.all(lsum >= top_p, top_p > lsum - cmin * lmin) + + # fmt: off + @T.prim_func(private=True) + def _func( + var_prob: T.handle, + top_p_global: T.buffer([1], dtype="float32"), + var_init_pivots: T.handle, + var_final_pivot: T.handle, + var_final_lsum: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32() + N = T.int32() + prob = T.match_buffer(var_prob, (B, N,), "float32") + init_pivots = T.match_buffer(var_init_pivots, (pN,), "float32") + final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") + final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") + + with T.block("kernel"): + pivot = T.alloc_buffer((pN,), "float32", scope="local") + top_p = _var("float32") + + L = T.alloc_buffer((1,), "float32", scope="shared") + R = T.alloc_buffer((1,), "float32", scope="shared") + L_local = _var("float32") + R_local = _var("float32") + + q = _var("float32") + lsum = T.alloc_buffer((pN,), "float32", scope="local") + lmin_broadcast = T.alloc_buffer((1), "float32", scope="shared") + lmin_broadcast_local = _var("float32") + lmin = T.alloc_buffer((pN,), "float32", scope="local") + cmin = T.alloc_buffer((pN,), "int32", scope="local") + total_sum = _var("float32") + + it = _var("int32") + es_local = _var("bool") + es = T.alloc_buffer((1,), "bool", scope="shared") + find_pivot_local = _var("bool") + find_pivot = T.alloc_buffer((1,), "bool", scope="shared") + + total_sum_reduce = _var("float32") + lsum_reduce = _var("float32") + lmin_reduce = _var("float32") + cmin_reduce = _var("int32") + + for _bx in T.thread_binding(0, B, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + b, tx = T.axis.remap("SS", [_bx, _tx]) + + top_p[0] = top_p_global[0] + + if tx == 0: + # leader thread initializes L, R + L[0] = 1.0 - top_p[0] + R[0] = eps_LR + find_pivot[0] = False + T.tvm_storage_sync("shared") + + L_local[0] = L[0] + R_local[0] = R[0] + for i in T.unroll(0, pN): + # pivots are in descending order + pivot[i] = init_pivots[i] + find_pivot_local[0] = False + + while T.tvm_thread_invariant( + L_local[0] - R_local[0] > eps_LR + and T.Not(find_pivot_local[0]) + ): + # sync before each iteration + T.tvm_storage_sync("shared") + + ### get lsum, lmin, total_sum + for pidx in T.unroll(0, pN): + lsum[pidx] = 0.0 + lmin[pidx] = 1.0 + cmin[pidx] = 0 + total_sum[0] = 0.0 + it[0] = 0 + es_local[0] = False + while it[0] < T.ceildiv(N, TX) and T.Not(es_local[0]): + idx = T.meta_var(it[0] * TX + tx) + q[0] = T.if_then_else(idx < N, prob[b, idx], 0.0) + total_sum[0] += q[0] + for pidx in T.unroll(0, pN): + if q[0] >= pivot[pidx]: + lsum[pidx] += q[0] + if lmin[pidx] > q[0]: + lmin[pidx] = q[0] + cmin[pidx] = 1 + elif lmin[pidx] == q[0]: + cmin[pidx] += 1 + it[0] += 1 + + # early stop every K iterations + if it[0] % K == 0: + # reduce total_sum over tx + # T.tvm_storage_sync("shared") + with T.block("block_cross_thread"): + T.reads(total_sum[0]) + T.writes(total_sum_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), total_sum[0], True, total_sum_reduce[0], tx, dtype="handle") + # T.tvm_storage_sync("shared") + + if tx == 0: + # leader thread checks if we can stop early + es[0] = 1 - total_sum_reduce[0] < pivot[pN - 1] + T.tvm_storage_sync("shared") + es_local[0] = es[0] + + T.tvm_storage_sync("shared") + + # reduce lsum, lmin, cmin, over tx + for pidx in T.serial(0, pN): + # reduce lsum over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(lsum[pidx]) + T.writes(lsum_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], True, lsum_reduce[0], tx, dtype="handle") + + # reduce lmin over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(lmin[pidx]) + T.writes(lmin_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], True, lmin_reduce[0], tx, dtype="handle") + + if tx == 0: + # broadcast lmin to all threads + lmin_broadcast[0] = lmin_reduce[0] + T.tvm_storage_sync("shared") + lmin_broadcast_local[0] = lmin_broadcast[0] + if lmin[pidx] > lmin_broadcast_local[0]: + cmin[pidx] = 0 + if tx == 0: + # only the leader thread updates lsum, lmin + lsum[pidx] = lsum_reduce[0] + lmin[pidx] = lmin_reduce[0] + + # reduce cmin over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(cmin[pidx]) + T.writes(cmin_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.int32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], True, cmin_reduce[0], tx, dtype="handle") + + if tx == 0: + # only the leader thread updates cmin + cmin[pidx] = cmin_reduce[0] + + T.tvm_storage_sync("shared") + + if tx == 0: + # leader thread checks if we have found the pivot, or updates L, R + it[0] = 0 + while it[0] < pN and T.Not(find_pivot_local[0]): + pidx = T.meta_var(it[0]) + if valid(lsum[pidx], lmin[pidx], cmin[pidx], top_p[0]): + find_pivot[0] = True + find_pivot_local[0] = True + # write back the pivot and lsum + final_pivot[b] = pivot[pidx] + final_lsum[b] = lsum[pidx] + elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]: + R[0] = pivot[pidx] + elif lsum[pidx] < top_p[0]: + L[0] = pivot[pidx] + it[0] += 1 + + T.tvm_storage_sync("shared") + + L_local[0] = L[0] + R_local[0] = R[0] + find_pivot_local[0] = find_pivot[0] + # new pivots for next iteration + # uniform spacing between L and R + for pidx in T.unroll(0, pN): + pivot[pidx] = L[0] - (pidx + 1) * (L_local[0] - R_local[0]) / (pN + 1) + + if tx == 0: + # leader thread writes back the pivot + if T.Not(find_pivot_local[0]): + final_pivot[b] = -1e5 + # fmt: on + + return _func + + +def top_p_renorm(): + """Top-p renormalization function. This function renormalizes the probability vector. + + Given the pivot, the probability vector is renormalized as follows: + - if prob >= pivot, renorm_prob = prob / lsum + - otherwise, renorm_prob = 0 + + Parameters + ---------- + prob: + The probability vector + + final_pivot: + The final pivot to cut-off top-p percentile + + final_lsum: + The sum of elements that are larger or equal to the pivot + + renorm_prob: + The renormalized probability vector + """ + TX = 1024 + CTA_COUNT = 512 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + # fmt: off + @T.prim_func(private=True) + def _func( + var_prob: T.handle, + var_final_pivot: T.handle, + var_final_lsum: T.handle, + var_renorm_prob: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32() + N = T.int32() + prob = T.match_buffer(var_prob, (B, N,), "float32") + final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") + final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") + renorm_prob = T.match_buffer(var_renorm_prob, (B, N,), "float32") + + with T.block("kernel"): + pivot = _var("float32") + lsum = _var("float32") + BX = T.meta_var(T.ceildiv(CTA_COUNT, B)) + + for _by in T.thread_binding(0, B, thread="blockIdx.y"): + for _bx in T.thread_binding(0, BX, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx]) + + pivot[0] = final_pivot[by] + lsum[0] = final_lsum[by] + + for i in T.serial(T.ceildiv(N, BX * TX)): + idx = T.meta_var(i * BX * TX + bx * TX + tx) + if idx < N: + renorm_prob[by, idx] = T.if_then_else(prob[by, idx] >= pivot[0], prob[by, idx] / lsum[0], 0.0) + # fmt: on + + return _func diff --git a/python/mlc_llm/protocol/protocol_utils.py b/python/mlc_llm/protocol/protocol_utils.py index f4273d0302..3005909bbd 100644 --- a/python/mlc_llm/protocol/protocol_utils.py +++ b/python/mlc_llm/protocol/protocol_utils.py @@ -23,13 +23,14 @@ def get_unsupported_fields(request: RequestProtocol) -> List[str]: def get_generation_config( request: RequestProtocol, + model_config: Dict[str, Any], extra_stop_token_ids: Optional[List[int]] = None, extra_stop_str: Optional[List[str]] = None, ) -> GenerationConfig: """Create the generation config in MLC LLM out from the input request protocol.""" kwargs: Dict[str, Any] if isinstance(request, (OpenAICompletionRequest, OpenAIChatCompletionRequest)): - kwargs = openai_api_get_generation_config(request) + kwargs = openai_api_get_generation_config(request, model_config) else: raise RuntimeError("Cannot reach here") diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 8e06de7b54..59358c1646 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,11 +2,10 @@ # Load MLC LLM library by importing base from .. import base -from .async_engine import AsyncThreadedEngine -from .config import EngineMode, GenerationConfig, KVCacheConfig +from .config import EngineConfig, GenerationConfig, SpeculativeMode from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData -from .engine import Engine +from .engine import AsyncMLCEngine, MLCEngine from .grammar import BNFGrammar, GrammarStateMatcher -from .json_schema_converter import json_schema_to_ebnf +from .radix_tree import PagedRadixTree from .request import Request from .server import PopenServer diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index e539ec7e56..6b808ac37b 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -1,9 +1,14 @@ """Configuration dataclasses used in MLC LLM serving""" +import enum import json from dataclasses import asdict, dataclass, field from typing import Dict, List, Literal, Optional +import tvm + +from . import _ffi_api + @dataclass class ResponseFormat: @@ -123,66 +128,101 @@ def from_json(json_str: str) -> "GenerationConfig": return GenerationConfig(**json.loads(json_str)) -@dataclass -class KVCacheConfig: - """The KV cache initialization configuration. +class KVStateKind(enum.IntEnum): # pylint: disable=too-few-public-methods + """Possible kinds of KV state.""" + + ATTENTION = 0 + RNNSTATE = 1 + + +class SpeculativeMode(enum.IntEnum): + """The speculative mode.""" + + # Disable speculative decoding. + DISABLE = 0 + # The normal speculative decoding (small draft) mode. + SMALL_DRAFT = 1 + # The eagle-style speculative decoding. + EAGLE = 2 + + +@tvm._ffi.register_object("mlc.serve.EngineConfig") # pylint: disable=protected-access +class EngineConfig(tvm.runtime.Object): + """The class of MLCEngine execution configuration. Parameters ---------- - page_size : int - The number of consecutive tokens handled in each page in paged KV cache. + model : str + The path to the model directory. - max_num_sequence : int - The maximum number of sequences that are allowed to processed by the KV - cache at any time. + model_lib_path : str + The path to the model library. - max_total_sequence_length : Optional[int] - The maximum total number of tokens whose KV data are allowed to exist - in the KV cache at any time. - Set it to None to enable automatic computation of the max total - sequence length. + additional_models : List[str] + The path to the additional models' directories. - prefill_chunk_size : Optional[int] - The maximum total sequence length in a prefill. - If not specified, it will be automatically inferred from model config. - """ + additional_model_lib_paths : List[str] + The path to the additional models' libraries. - page_size: int = 16 - max_num_sequence: int = 32 - max_total_sequence_length: Optional[int] = None - prefill_chunk_size: Optional[int] = None + kv_cache_page_size : int + The number of consecutive tokens handled in each page in paged KV cache. - def asjson(self) -> str: - """Return the config in string of JSON format.""" - return json.dumps(asdict(self)) + max_num_sequence : int + The maximum number of sequences that are allowed to be + processed by the KV cache at any time. - @staticmethod - def from_json(json_str: str) -> "KVCacheConfig": - """Construct a config from JSON string.""" - return KVCacheConfig(**json.loads(json_str)) + max_total_sequence_length : int + The maximum length allowed for a single sequence in the engine. + max_single_sequence_length : int + The maximum total number of tokens whose KV data are allowed + to exist in the KV cache at any time. -@dataclass -class EngineMode: - """The Engine execution mode. + prefill_chunk_size : int + The maximum total sequence length in a prefill. - Parameters - ---------- - enable_speculative : bool - Whether the speculative decoding mode is enabled, default False. + max_history_size: int + The maximum history size for RNN state to rool back. - spec_draft_length : int - The number of tokens to generate in speculative proposal (draft), default 4. - """ + kv_state_kind: KVStateKind + The kind of cache. - enable_speculative: bool = False - spec_draft_length: int = 4 + speculative_mode : SpeculativeMode + The speculative mode. - def asjson(self) -> str: - """Return the config in string of JSON format.""" - return json.dumps(asdict(self)) + spec_draft_length : int + The number of tokens to generate in speculative proposal (draft). + """ - @staticmethod - def from_json(json_str: str) -> "EngineMode": - """Construct a config from JSON string.""" - return EngineMode(**json.loads(json_str)) + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + model_lib_path: str, + additional_models: List[str], + additional_model_lib_paths: List[str], + kv_cache_page_size: int, + max_num_sequence: int, + max_total_sequence_length: int, + max_single_sequence_length: int, + prefill_chunk_size: int, + max_history_size: int, + kv_state_kind: KVStateKind, + speculative_mode: SpeculativeMode, + spec_draft_length: int, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.EngineConfig, # type: ignore # pylint: disable=no-member + model, + model_lib_path, + additional_models, + additional_model_lib_paths, + kv_cache_page_size, + max_num_sequence, + max_total_sequence_length, + max_single_sequence_length, + prefill_chunk_size, + max_history_size, + kv_state_kind, + speculative_mode, + spec_draft_length, + ) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 607f970a1e..413c856db1 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1,306 +1,838 @@ """The MLC LLM Serving Engine.""" -import json -import os -import subprocess +# pylint: disable=too-many-lines + +import asyncio +import queue import sys -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import weakref +from typing import ( + Any, + AsyncGenerator, + Dict, + Iterator, + List, + Literal, + Optional, + Union, + overload, +) -import tvm from tvm.runtime import Device -from mlc_llm.protocol.conversation_protocol import Conversation -from mlc_llm.serve import data +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import data, engine_utils +from mlc_llm.serve.config import GenerationConfig, SpeculativeMode +from mlc_llm.serve.request import Request +from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging -from mlc_llm.support.auto_device import detect_device -from mlc_llm.support.style import green -from ..chat_module import _get_chat_config, _get_lib_module_path, _get_model_path -from ..streamer import TextStreamer -from ..tokenizer import Tokenizer -from . import data -from .config import EngineMode, GenerationConfig, KVCacheConfig -from .event_trace_recorder import EventTraceRecorder -from .request import Request +from . import engine_base logging.enable_logging() logger = logging.getLogger(__name__) -@dataclass -class ModelInfo: - """The model info dataclass. +class Chat: # pylint: disable=too-few-public-methods + """The proxy class to direct to chat completions.""" - Parameters - ---------- - model : str - The identifier of the input model. - It may be a compiled model's id (e.g., "Llama-2-7b-chat-hf-q4f16_1"), - or a full path to a model directory - (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") - - device : str - The device where to run the model. - It can be "auto", "device_name" (e.g., "cuda") or - "device_name:device_id" (e.g., "cuda:1"). - - model_lib_path : str - The path to the compiled library of the model. - E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - """ + def __init__(self, engine: weakref.ReferenceType) -> None: + assert isinstance(engine(), (AsyncMLCEngine, MLCEngine)) + self.completions = ( + AsyncChatCompletion(engine) # type: ignore + if isinstance(engine(), AsyncMLCEngine) + else ChatCompletion(engine) # type: ignore + ) - model: str - model_lib_path: str - device: Device = "auto" # type: ignore - - def __post_init__(self): - if isinstance(self.device, str): - self.device = detect_device(self.device) - assert isinstance(self.device, Device) - - -def _create_tvm_module( - creator: str, ffi_funcs: Sequence[str], creator_args: Optional[List[Any]] = None -) -> Dict[str, Callable]: - """Internal method to create a module.""" - if creator_args is None: - creator_args = [] - module = tvm.get_global_func(creator, allow_missing=False)(*creator_args) - return {key: module[key] for key in ffi_funcs} - - -def _process_model_args( - models: List[ModelInfo], -) -> Tuple[List[Any], List[str], str, int, int, Optional[str]]: - """Process the input ModelInfo to get the engine initialization arguments.""" - max_single_sequence_length = int(1e9) - prefill_chunk_size = int(1e9) - tokenizer_path: Optional[str] = None - conv_template_name: Optional[str] = None - config_file_paths: List[str] = [] - - def _convert_model_info(model: ModelInfo) -> List[Any]: - nonlocal max_single_sequence_length, prefill_chunk_size, tokenizer_path, conv_template_name - - device = model.device - model_path, config_file_path = _get_model_path(model.model) - config_file_paths.append(config_file_path) - chat_config = _get_chat_config(config_file_path, user_chat_config=None) - if chat_config.context_window_size and chat_config.context_window_size != -1: - max_single_sequence_length = min( - max_single_sequence_length, - chat_config.context_window_size, - ) - if chat_config.prefill_chunk_size: - prefill_chunk_size = min(prefill_chunk_size, chat_config.prefill_chunk_size) - if tokenizer_path is None: - tokenizer_path = model_path - if conv_template_name is None: - assert isinstance(chat_config.conv_template, Conversation) - conv_template_name = chat_config.conv_template.name - # Try look up model library, and do JIT compile if model library not found. - try: - model_lib_path = _get_lib_module_path( - model=model.model, - model_path=model_path, - chat_config=chat_config, - model_lib_path=model.model_lib_path, - device_name=device.MASK2STR[device.device_type], - config_file_path=config_file_path, - ) - except FileNotFoundError: - from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - - model_lib_path = str( - jit.jit( - model_path=Path(model_path), - chat_config=asdict(chat_config), - device=device, - ) - ) - return [model_lib_path, model_path, device.device_type, device.device_id] - - model_args: List[Any] = sum( - (_convert_model_info(model) for model in models), - start=[], - ) - - assert prefill_chunk_size != int(1e9) - return ( - model_args, - config_file_paths, - tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, - conv_template_name, - ) - - -def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals - models: List[ModelInfo], config_file_paths: List[str], max_num_sequence: int -) -> int: - """Estimate the max total sequence length (capacity) of the KV cache.""" - assert len(models) != 0 - - kv_bytes_per_token = 0 - kv_aux_workspace_bytes = 0 - model_workspace_bytes = 0 - logit_processor_workspace_bytes = 0 - params_bytes = 0 - temp_func_bytes = 0 - - for model, config_file_path in zip(models, config_file_paths): - # Read metadata for the parameter size and the temporary memory size. - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-memory-usage-in-json", - "--mlc-chat-config", - config_file_path, - ] - usage_str = subprocess.check_output(cmd, universal_newlines=True) - usage_json = json.loads(usage_str) - params_bytes += usage_json["params_bytes"] - temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) - - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-kv-cache-metadata-in-json", - ] - kv_cache_metadata_str = subprocess.check_output(cmd, universal_newlines=True) - kv_cache_metadata = json.loads(kv_cache_metadata_str) - - # Read model config and compute the kv size per token. - with open(config_file_path, mode="rt", encoding="utf-8") as file: - json_object = json.load(file) - model_config = json_object["model_config"] - vocab_size = model_config["vocab_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - num_layers = kv_cache_metadata["num_hidden_layers"] - head_dim = kv_cache_metadata["head_dim"] - num_qo_heads = kv_cache_metadata["num_attention_heads"] - num_kv_heads = kv_cache_metadata["num_key_value_heads"] - hidden_size = head_dim * num_qo_heads - kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 - kv_aux_workspace_bytes += ( - (max_num_sequence + 1) * 88 - + prefill_chunk_size * (num_qo_heads + 1) * 8 - + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 - + 48 * 1024 * 1024 + +class AsyncChatCompletion: # pylint: disable=too-few-public-methods + """The proxy class to direct to async chat completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["AsyncMLCEngine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + stream: Literal[True], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]: + """Asynchronous streaming chat completion interface with OpenAI API compatibility. + The method is a coroutine that streams ChatCompletionStreamResponse + that conforms to OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.ChatCompletionResponse: + """Asynchronous non-streaming chat completion interface with OpenAI API compatibility. + The method is a coroutine that streams ChatCompletionStreamResponse + that conforms to OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : ChatCompletionResponse + The chat completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any], + openai_api_protocol.ChatCompletionResponse, + ]: + """Asynchronous chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return await self.engine()._chat_completion( # pylint: disable=protected-access + messages=messages, + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, ) - model_workspace_bytes += ( - prefill_chunk_size * 4 - + max_num_sequence * 4 - + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + + +class ChatCompletion: # pylint: disable=too-few-public-methods + """The proxy class to direct to chat completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["MLCEngine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + stream: Literal[True], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + """Synchronous streaming chat completion interface with OpenAI API compatibility. + The method streams back ChatCompletionStreamResponse that conforms to + OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.ChatCompletionResponse: + """Synchronous non-streaming chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : ChatCompletionResponse + The chat completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + Iterator[openai_api_protocol.ChatCompletionStreamResponse], + openai_api_protocol.ChatCompletionResponse, + ]: + """Synchronous chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return self.engine()._chat_completion( # pylint: disable=protected-access + messages=messages, + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, ) - logit_processor_workspace_bytes += ( - max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + + +class AsyncCompletion: # pylint: disable=too-few-public-methods + """The proxy class to direct to async completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["AsyncMLCEngine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + stream: Literal[True], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + """Asynchronous streaming completion interface with OpenAI API compatibility. + The method is a coroutine that streams CompletionResponse + that conforms to OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.CompletionResponse: + """Asynchronous non-streaming completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : CompletionResponse + The completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.CompletionResponse, Any], + openai_api_protocol.CompletionResponse, + ]: + """Asynchronous completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return await self.engine()._completion( # pylint: disable=protected-access + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, ) - # Get single-card GPU size. - gpu_size_bytes = os.environ.get("MLC_GPU_SIZE_BYTES", default=None) - if gpu_size_bytes is None: - gpu_size_bytes = models[0].device.total_global_memory - if gpu_size_bytes is None: - raise ValueError( - "Cannot read total GPU global memory from device. " - 'Please the GPU memory size in bytes through "MLC_GPU_SIZE_BYTES" env variable.' - ) - max_total_sequence_length = int( - ( - int(gpu_size_bytes) * 0.90 - - params_bytes - - temp_func_bytes - - kv_aux_workspace_bytes - - model_workspace_bytes - - logit_processor_workspace_bytes +class Completion: # pylint: disable=too-few-public-methods + """The proxy class to direct to completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["MLCEngine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + stream: Literal[True], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.CompletionResponse: + """Synchronous streaming completion interface with OpenAI API compatibility. + The method streams back CompletionResponse that conforms to + OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous non-streaming completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : CompletionResponse + The completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return self.engine()._completion( # pylint: disable=protected-access + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, ) - / kv_bytes_per_token - ) - assert max_total_sequence_length > 0, ( - "Cannot estimate KV cache capacity. " - f"The model weight size {params_bytes} may be larger than GPU memory size {gpu_size_bytes}" - ) - - if models[0].device.device_type == Device.kDLMetal: - # NOTE: Metal runtime has severe performance issues with large buffers. - # To work around the issue, we limit the KV cache capacity to 32768. - max_total_sequence_length = min(max_total_sequence_length, 32768) - - total_size = ( - params_bytes - + temp_func_bytes - + kv_aux_workspace_bytes - + model_workspace_bytes - + logit_processor_workspace_bytes - + kv_bytes_per_token * max_total_sequence_length - ) - logger.info( - "%s: %d.", - green('Estimated KVCacheConfig "max_total_sequence_length"'), - max_total_sequence_length, - ) - logger.info( - "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB)", - green("Estimated total single GPU memory usage"), - total_size / 1024 / 1024, - params_bytes / 1024 / 1024, - (kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes) / 1024 / 1024, - (model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes) / 1024 / 1024, - ) - return int(max_total_sequence_length) - - -class Engine: - """The Python interface of request serving engine for MLC LLM. - - The engine can run one or multiple LLM models internally for - text generation. Usually, when there are multiple models, - speculative inference will be activated, where the first model - (index 0) is the main "large model" that has better generation - quality, and all other models are "small" models that used for - speculation. - - The engine receives requests from the "add_request" method. For - an given request, the engine will keep generating new tokens for - the request until finish (under certain criterion). After finish, - the engine will return the generation result through the callback - function provided by the request. + + +class AsyncMLCEngine(engine_base.MLCEngineBase): + """The AsyncMLCEngine in MLC LLM that provides the asynchronous + interfaces with regard to OpenAI API. Parameters ---------- - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. - - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. - - request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]] - The provided callback function to handle the generation - output. It has the signature of `(str, data.TokenData, bool) -> None`, - where - - the first string is the request id, - - the TokenData contains the generated **delta** token ids since - the last invocation of the callback on the specific request, - - the optional string value denotes the finish reason if the - generation of the request is finished, or None if it has not finished. - - The callback function is optional at construction, but it needs to - be set before the engine executing requests. This can be done via - the `set_request_stream_callback` method. Otherwise, the engine will raise - exception. - - engine_mode : Optional[EngineMode] - The Engine execution mode. + models : str + A path to ``mlc-chat-config.json``, or an MLC model directory that contains + `mlc-chat-config.json`. + It can also be a link to a HF repository pointing to an MLC compiled model. + + device: Union[str, Device] + The device used to deploy the model such as "cuda" or "cuda:0". + Will default to "auto" and detect from local available GPUs if not specified. + + model_lib_path : Optional[str] + The full path to the model library file to use (e.g. a ``.so`` file). + If unspecified, we will use the provided ``model`` to search over possible paths. + It the model lib path is not found, it will be compiled in a JIT manner. + + mode : Literal["local", "interactive", "server"] + The engine mode in MLC LLM. + We provide three preset modes: "local", "interactive" and "server". + The default mode is "local". + The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + and "prefill_chunk_size" when they are not explicitly specified. + 1. Mode "local" refers to the local server deployment which has low + request concurrency. So the max batch size will be set to 4, and max + total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 2. Mode "interactive" refers to the interactive use of server, which + has at most 1 concurrent request. So the max batch size will be set to 1, + and max total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 3. Mode "server" refers to the large server use case which may handle + many concurrent request and want to use GPU memory as much as possible. + In this mode, we will automatically infer the largest possible max batch + size and max total sequence length. + + You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + "prefill_chunk_size" to override the automatic inferred values. + + additional_models : Optional[List[str]] + The model paths and (optional) model library paths of additional models + (other than the main model). + When engine is enabled with speculative decoding, additional models are needed. + Each string in the list is either in form "model_path" or "model_path:model_lib_path". + When the model lib path of a model is not given, JIT model compilation will + be activated to compile the model automatically. + + max_batch_size : Optional[int] + The maximum allowed batch size set for the KV cache to concurrently support. + + max_total_sequence_length : Optional[int] + The KV cache total token capacity, i.e., the maximum total number of tokens that + the KV cache support. This decides the GPU memory size that the KV cache consumes. + If not specified, system will automatically estimate the maximum capacity based + on the vRAM size on GPU. + + prefill_chunk_size : Optional[int] + The maximum number of tokens the model passes for prefill each time. + It should not exceed the prefill chunk size in model config. + If not specified, this defaults to the prefill chunk size in model config. + + max_history_size : Optional[int] + The maximum history for RNN state. + + gpu_memory_utilization : Optional[float] + A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + It is used to infer to maximum possible KV cache capacity. + When it is unspecified, it defaults to 0.85. + Under mode "local" or "interactive", the actual memory usage may be + significantly smaller than this number. Under mode "server", the actual + memory usage may be slightly larger than this number. + + engine_config : Optional[EngineConfig] + The MLCEngine execution configuration. + Currently speculative decoding mode is specified via engine config. + For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" + to specify the eagle-style speculative decoding. + Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. enable_tracing : bool A boolean indicating if to enable event logging for requests. @@ -308,245 +840,1021 @@ class Engine: def __init__( # pylint: disable=too-many-arguments self, - models: Union[ModelInfo, List[ModelInfo]], - kv_cache_config: KVCacheConfig, - engine_mode: Optional[EngineMode] = None, - request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, + model: str, + device: Union[str, Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, enable_tracing: bool = False, - ): - if isinstance(models, ModelInfo): - models = [models] + ) -> None: + super().__init__( + "async", + model=model, + device=device, + model_lib_path=model_lib_path, + mode=mode, + additional_models=additional_models, + max_batch_size=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + gpu_memory_utilization=gpu_memory_utilization, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + enable_tracing=enable_tracing, + ) + self.chat = Chat(weakref.ref(self)) + self.completions = AsyncCompletion(weakref.ref(self)) + + async def abort(self, request_id: str) -> None: + """Generation abortion interface. + + Parameter + --------- + request_id : str + The id of the request to abort. + """ + self._abort(request_id) + + async def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any], + openai_api_protocol.ChatCompletionResponse, + ]: + """Asynchronous chat completion internal interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id=request_id, + ) + if stream: + # Stream response. + return chatcmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) + async for response in chatcmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) + + async def _completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.CompletionResponse, Any], + openai_api_protocol.CompletionResponse, + ]: + """Asynchronous completion internal interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + if request_id is None: + request_id = f"cmpl-{engine_utils.random_uuid()}" + cmpl_generator = self._handle_completion( + openai_api_protocol.CompletionRequest( + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id, + ) + if stream: + # Stream response. + return cmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) + + async for response in cmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + return engine_base.wrap_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) + + async def _handle_chat_completion( + self, request: openai_api_protocol.ChatCompletionRequest, request_id: str + ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]: + """The implementation fo asynchronous ChatCompletionRequest handling. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ ( - model_args, - config_file_paths, - tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, - self.conv_template_name, - ) = _process_model_args(models) - self._ffi = _create_tvm_module( - "mlc.serve.create_engine", - ffi_funcs=[ - "init", - "add_request", - "abort_request", - "step", - "stats", - "reset", - "get_request_stream_callback", - "set_request_stream_callback", - ], + prompts, + generation_cfg, + use_function_calling, + prompt_length, + ) = engine_base.process_chat_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer.encode, + self.max_input_sequence_length, + self.conv_template.model_copy(deep=True), ) - self.trace_recorder = EventTraceRecorder() if enable_tracing else None - self.max_input_sequence_length = max_single_sequence_length - if kv_cache_config.max_total_sequence_length is None: - kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( - models, config_file_paths, kv_cache_config.max_num_sequence + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + num_completion_tokens = 0 + self.state.record_event(request_id, event="invoke generate") + async for delta_outputs in self._generate( + prompts, generation_cfg, request_id # type: ignore + ): + response, num_completion_tokens = engine_base.process_chat_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + use_function_calling, + prompt_length, + finish_reasons, + num_completion_tokens, ) - if kv_cache_config.prefill_chunk_size is None: - kv_cache_config.prefill_chunk_size = prefill_chunk_size - elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: - raise ValueError( - f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " - f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " - "models. Please specify a smaller prefill chunk size." + if response is not None: + yield response + self.state.record_event(request_id, event="finish") + + async def _handle_completion( + self, request: openai_api_protocol.CompletionRequest, request_id: str + ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + """The implementation fo asynchronous CompletionRequest handling. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + ( + prompt, + generation_cfg, + prompt_length, + echo_response, + ) = engine_base.process_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer, + self.max_input_sequence_length, + ) + if echo_response is not None: + yield echo_response + + num_completion_tokens = 0 + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + self.state.record_event(request_id, event="invoke generate") + async for delta_outputs in self._generate( + prompt, generation_cfg, request_id # type: ignore + ): + response, num_completion_tokens = engine_base.process_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + prompt_length, + finish_reasons, + num_completion_tokens, ) + if response is not None: + yield response - if engine_mode is None: - # The default engine mode: non-speculative - engine_mode = EngineMode() - - self._ffi["init"]( - max_single_sequence_length, - tokenizer_path, - kv_cache_config.asjson(), - engine_mode.asjson(), - request_stream_callback, - self.trace_recorder, - *model_args, + suffix_response = engine_base.create_completion_suffix_response( + request, request_id, prompt_length, finish_reasons, num_completion_tokens ) - self.tokenizer = Tokenizer(tokenizer_path) + if suffix_response is not None: + yield suffix_response + self.state.record_event(request_id, event="finish") - def generate( # pylint: disable=too-many-locals + async def _generate( self, - prompts: Union[str, List[str], List[int], List[List[int]], List[List[data.Data]]], - generation_config: Union[GenerationConfig, List[GenerationConfig]], - ) -> Tuple[List[List[str]], List[Optional[List[List[str]]]]]: - """Generate texts for a list of input prompts. - Each prompt can be a string or a list of token ids. - The generation for each prompt is independent. - Return the generation results, one for each prompt. + prompt: Union[str, List[int], List[Union[str, List[int], data.Data]]], + generation_config: GenerationConfig, + request_id: str, + ) -> AsyncGenerator[List[engine_base.CallbackStreamOutput], Any]: + """Internal asynchronous text generation interface of AsyncMLCEngine. + The method is a coroutine that streams a list of CallbackStreamOutput + at a time via yield. The returned list length is the number of + parallel generations specified by `generation_config.n`. Parameters ---------- - prompts : Union[str, List[str], List[int], List[List[int]]] - One or a list of input prompts for text generation. - Each prompt can be a string or a list of token ids. + prompt : Union[str, List[int], List[Union[str, List[int], data.Data]]] + The input prompt in forms of text strings, lists of token ids or data. - generation_config : Union[GenerationConfig, List[GenerationConfig]] - The generation config for each requests. - If the it is a single GenerationConfig instance, - this config will be shared by all the prompts. - Otherwise, one generation config is required for every - prompt. + generation_config : GenerationConfig + The generation config of the request. - Returns - ------- - output_text : List[List[str]] - The text generation results, one list of strings for each input prompt. - The length of each list is the parallel generation `n` in - generation config. - - output_logprobs_str : List[Optional[List[List[str]]]] - The logprob strings of each token for each input prompt, or None - if an input prompt does not require logprobs. + request_id : str + The unique identifier (in string) or this generation request. + + Yields + ------ + request_output : List[engine_base.CallbackStreamOutput] + The delta generated outputs in a list. + The number of list elements equals to `generation_config.n`, + and each element corresponds to the delta output of a parallel + generation. """ - if isinstance(prompts, str): - # `prompts` is a single string. - prompts = [prompts] - else: - assert isinstance(prompts, list), ( - "Input `prompts` is expected to be a string, a list of " - "str, a list of token ids or multiple lists of token ids. " - ) - if len(prompts) == 0: - return [], [] - if isinstance(prompts[0], int): - # `prompts` is a list of token ids - prompts = [prompts] # type: ignore - - num_requests = len(prompts) - if not isinstance(generation_config, list): - generation_config = [generation_config] * num_requests - - assert ( - len(generation_config) == num_requests - ), "Number of generation config and number of prompts mismatch" - - num_finished_generations = 0 - output_texts: List[List[str]] = [] - output_logprobs_str: List[Optional[List[List[str]]]] = [] - text_streamers: List[List[TextStreamer]] = [] - for i in range(num_requests): - output_texts.append([]) - output_logprobs_str.append([] if generation_config[i].logprobs else None) - text_streamers.append([]) - for _ in range(generation_config[i].n): - output_texts[i].append("") - text_streamers[i].append(TextStreamer(self.tokenizer)) - if output_logprobs_str[i] is not None: - output_logprobs_str[i].append([]) - - num_total_generations = sum(cfg.n for cfg in generation_config) - - # Save a copy of the original function callback since `generate` - # overrides the callback function. - # The original callback will be set back later on. - original_callback = self._ffi["get_request_stream_callback"]() - - # Define the callback function for request generation results - def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): - nonlocal num_finished_generations - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - rid = int(request_id) - - assert len(stream_outputs) == generation_config[rid].n - for i, (stream_output, text_streamer) in enumerate( - zip(stream_outputs, text_streamers[rid]) - ): - if output_logprobs_str[rid] is not None: - assert stream_output.delta_logprob_json_strs is not None - output_logprobs_str[rid][i] += stream_output.delta_logprob_json_strs - - delta_text = ( - text_streamer.put(stream_output.delta_token_ids) - if len(stream_output.delta_token_ids) > 0 - else "" - ) - if stream_output.finish_reason is not None: - delta_text += text_streamer.finish() - - output_texts[rid][i] += delta_text - if stream_output.finish_reason is not None: - num_finished_generations += 1 - - # Override the callback function in engine. - self._ffi["set_request_stream_callback"](request_stream_callback) - - def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data.Data]: - if isinstance(prompt, str): - return [data.TextData(prompt)] - if isinstance(prompt[0], int): - return [data.TokenData(prompt)] # type: ignore - return prompt # type: ignore - - # Add requests to engine. - for req_id, (prompt, generation_cfg) in enumerate(zip(prompts, generation_config)): - input_data = convert_to_data(prompt) # type: ignore - self.add_request( - Request( - request_id=str(req_id), - inputs=input_data, - generation_config=generation_cfg, + if self._terminated: + raise ValueError("The AsyncThreadedEngine has terminated.") + self.state.async_lazy_init_event_loop() + + # Create the request with the given id, input data, generation + # config and the created callback. + input_data = engine_utils.convert_prompts_to_data(prompt) + request = Request(request_id, input_data, generation_config) + + # Create the unique async request stream of the request. + stream = engine_base.AsyncRequestStream() + if request_id in self.state.async_streamers: + # Report error in the stream if the request id already exists. + stream.push( + RuntimeError( + f'The request id "{request_id} already exists. ' + 'Please make sure the request id is unique."' ) ) + else: + # Record the stream in the tracker + self.state.async_streamers[request_id] = ( + stream, + [TextStreamer(self.tokenizer) for _ in range(generation_config.n)], + ) + self.state.async_num_unfinished_generations[request_id] = generation_config.n + self._ffi["add_request"](request) + + # Iterate the stream asynchronously and yield the output. + try: + async for request_output in stream: + yield request_output + except ( + Exception, + asyncio.CancelledError, + ) as exception: # pylint: disable=broad-exception-caught + await self.abort(request_id) + raise exception + + def _abort(self, request_id: str): + """Internal implementation of request abortion.""" + self.state.async_streamers.pop(request_id, None) + self.state.async_num_unfinished_generations.pop(request_id, None) + self._ffi["abort_request"](request_id) - while num_finished_generations != num_total_generations: - self.step() - # Restore the callback function in engine. - self._ffi["set_request_stream_callback"](original_callback) - return output_texts, output_logprobs_str +class MLCEngine(engine_base.MLCEngineBase): + """The MLCEngine in MLC LLM that provides the synchronous + interfaces with regard to OpenAI API. - def add_request(self, request: Request) -> None: - """Add a new request to the engine. + Parameters + ---------- + models : str + A path to ``mlc-chat-config.json``, or an MLC model directory that contains + `mlc-chat-config.json`. + It can also be a link to a HF repository pointing to an MLC compiled model. + + device: Union[str, Device] + The device used to deploy the model such as "cuda" or "cuda:0". + Will default to "auto" and detect from local available GPUs if not specified. + + model_lib_path : Optional[str] + The full path to the model library file to use (e.g. a ``.so`` file). + If unspecified, we will use the provided ``model`` to search over possible paths. + It the model lib path is not found, it will be compiled in a JIT manner. + + mode : Literal["local", "interactive", "server"] + The engine mode in MLC LLM. + We provide three preset modes: "local", "interactive" and "server". + The default mode is "local". + The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + and "prefill_chunk_size" when they are not explicitly specified. + 1. Mode "local" refers to the local server deployment which has low + request concurrency. So the max batch size will be set to 4, and max + total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 2. Mode "interactive" refers to the interactive use of server, which + has at most 1 concurrent request. So the max batch size will be set to 1, + and max total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 3. Mode "server" refers to the large server use case which may handle + many concurrent request and want to use GPU memory as much as possible. + In this mode, we will automatically infer the largest possible max batch + size and max total sequence length. + + You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + "prefill_chunk_size" to override the automatic inferred values. + + additional_models : Optional[List[str]] + The model paths and (optional) model library paths of additional models + (other than the main model). + When engine is enabled with speculative decoding, additional models are needed. + Each string in the list is either in form "model_path" or "model_path:model_lib_path". + When the model lib path of a model is not given, JIT model compilation will + be activated to compile the model automatically. + + max_batch_size : Optional[int] + The maximum allowed batch size set for the KV cache to concurrently support. + + max_total_sequence_length : Optional[int] + The KV cache total token capacity, i.e., the maximum total number of tokens that + the KV cache support. This decides the GPU memory size that the KV cache consumes. + If not specified, system will automatically estimate the maximum capacity based + on the vRAM size on GPU. + + prefill_chunk_size : Optional[int] + The maximum number of tokens the model passes for prefill each time. + It should not exceed the prefill chunk size in model config. + If not specified, this defaults to the prefill chunk size in model config. + + gpu_memory_utilization : Optional[float] + A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + It is used to infer to maximum possible KV cache capacity. + When it is unspecified, it defaults to 0.85. + Under mode "local" or "interactive", the actual memory usage may be + significantly smaller than this number. Under mode "server", the actual + memory usage may be slightly larger than this number. + + engine_config : Optional[EngineConfig] + The MLCEngine execution configuration. + Currently speculative decoding mode is specified via engine config. + For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" + to specify the eagle-style speculative decoding. + Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. + + enable_tracing : bool + A boolean indicating if to enable event logging for requests. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + device: Union[str, Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, + enable_tracing: bool = False, + ) -> None: + super().__init__( + "sync", + model=model, + device=device, + model_lib_path=model_lib_path, + mode=mode, + additional_models=additional_models, + max_batch_size=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + gpu_memory_utilization=gpu_memory_utilization, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + enable_tracing=enable_tracing, + ) + self.chat = Chat(weakref.ref(self)) + self.completions = Completion(weakref.ref(self)) + + def abort(self, request_id: str) -> None: + """Generation abortion interface. + + Parameter + --------- + request_id : str + The id of the request to abort. + """ + self._ffi["abort_request"](request_id) + + def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + Iterator[openai_api_protocol.ChatCompletionStreamResponse], + openai_api_protocol.ChatCompletionResponse, + ]: + """Synchronous chat completion internal interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. Parameters ---------- - request : Request - The request to add. + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. """ - self._ffi["add_request"](request) + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id=request_id, + ) + if stream: + # Stream response. + return chatcmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) + for response in chatcmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) - def abort_request(self, request_id: str) -> None: - """Abort the generation of the request corresponding to the input request id. + def _completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous completion internal interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. Parameters ---------- - request_id : str - The unique id of the request to abort. + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. """ - self._ffi["abort_request"](request_id) + if request_id is None: + request_id = f"cmpl-{engine_utils.random_uuid()}" + cmpl_generator = self._handle_completion( + openai_api_protocol.CompletionRequest( + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id, + ) + if stream: + # Stream response. + return cmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) - def step(self) -> None: - """The main function that the engine takes a step of action. + for response in cmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + return engine_base.wrap_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) - At each step, the engine may decide to - - run prefill for one (or more) requests, - - run one-step decode for the all existing requests - ... + def _handle_chat_completion( + self, request: openai_api_protocol.ChatCompletionRequest, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + """The implementation fo synchronous ChatCompletionRequest handling. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + ( + prompts, + generation_cfg, + use_function_calling, + prompt_length, + ) = engine_base.process_chat_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer.encode, + self.max_input_sequence_length, + self.conv_template.model_copy(deep=True), + ) - In the end of certain actions (e.g., decode), the engine will - check if any request has finished, and will return the - generation results for those finished requests. + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + num_completion_tokens = 0 + self.state.record_event(request_id, event="invoke generate") + for delta_outputs in self._generate(prompts, generation_cfg, request_id): # type: ignore + response, num_completion_tokens = engine_base.process_chat_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + use_function_calling, + prompt_length, + finish_reasons, + num_completion_tokens, + ) + if response is not None: + yield response + self.state.record_event(request_id, event="finish") + + def _handle_completion( + self, request: openai_api_protocol.CompletionRequest, request_id: str + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """The implementation fo synchronous CompletionRequest handling. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. """ - self._ffi["step"]() - - def reset(self) -> None: - """Reset the engine, clean up all running data and statistics.""" - self._ffi["reset"]() - - def stats(self) -> Dict[str, float]: - """The engine runtime statistics. - We collect the following entries: - - single token prefill latency (s/tok): avg latency of processing one token in prefill - - single token decode latency (s/tok): avg latency of processing one token in decode - - engine time for prefill (sec) - - engine time for decode (sec) - - total number of processed tokens in prefill. - - total number of processed tokens in decode. + ( + prompt, + generation_cfg, + prompt_length, + echo_response, + ) = engine_base.process_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer, + self.max_input_sequence_length, + ) + if echo_response is not None: + yield echo_response + + num_completion_tokens = 0 + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + self.state.record_event(request_id, event="invoke generate") + for delta_outputs in self._generate(prompt, generation_cfg, request_id): # type: ignore + response, num_completion_tokens = engine_base.process_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + prompt_length, + finish_reasons, + num_completion_tokens, + ) + if response is not None: + yield response + + suffix_response = engine_base.create_completion_suffix_response( + request, request_id, prompt_length, finish_reasons, num_completion_tokens + ) + if suffix_response is not None: + yield suffix_response + self.state.record_event(request_id, event="finish") + + def _generate( # pylint: disable=too-many-locals + self, + prompt: Union[str, List[int], List[Union[str, List[int], data.Data]]], + generation_config: GenerationConfig, + request_id: str, + ) -> Iterator[List[engine_base.CallbackStreamOutput]]: + """Internal synchronous text generation interface of AsyncMLCEngine. + The method is a coroutine that streams a list of CallbackStreamOutput + at a time via yield. The returned list length is the number of + parallel generations specified by `generation_config.n`. + + Parameters + ---------- + prompt : Union[str, List[int], List[Union[str, List[int], data.Data]]] + The input prompt in forms of text strings, lists of token ids or data. + + generation_config : GenerationConfig + The generation config of the request. + + request_id : str + The unique identifier (in string) or this generation request. + + Yields + ------ + request_output : List[engine_base.CallbackStreamOutput] + The delta generated outputs in a list. + The number of list elements equals to `generation_config.n`, + and each element corresponds to the delta output of a parallel + generation. """ - stats_json_str = self._ffi["stats"]() - return json.loads(stats_json_str) + if self._terminated: + raise ValueError("The engine has terminated.") + + # Create the request with the given id, input data, generation + # config and the created callback. + input_data = engine_utils.convert_prompts_to_data(prompt) + request = Request(request_id, input_data, generation_config) + + # Record the stream in the tracker + self.state.sync_output_queue = queue.Queue() + self.state.sync_text_streamers = [ + TextStreamer(self.tokenizer) for _ in range(generation_config.n) + ] + self.state.sync_num_unfinished_generations = generation_config.n + self._ffi["add_request"](request) + + # Iterate the stream asynchronously and yield the token. + try: + while self.state.sync_num_unfinished_generations > 0: + delta_outputs = self.state.sync_output_queue.get() + request_outputs = self._request_stream_callback_impl(delta_outputs) + for request_output in request_outputs: + yield request_output + except Exception as exception: # pylint: disable=broad-exception-caught + self.abort(request_id) + raise exception + + def _request_stream_callback_impl( + self, delta_outputs: List[data.RequestStreamOutput] + ) -> List[List[engine_base.CallbackStreamOutput]]: + """The underlying implementation of request stream callback of MLCEngine.""" + batch_outputs: List[List[engine_base.CallbackStreamOutput]] = [] + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + self.state.record_event(request_id, event="start callback") + outputs: List[engine_base.CallbackStreamOutput] = [] + for stream_output, text_streamer in zip(stream_outputs, self.state.sync_text_streamers): + self.state.record_event(request_id, event="start detokenization") + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + self.state.record_event(request_id, event="finish detokenization") + + outputs.append( + engine_base.CallbackStreamOutput( + delta_text=delta_text, + num_delta_tokens=len(stream_output.delta_token_ids), + delta_logprob_json_strs=stream_output.delta_logprob_json_strs, + finish_reason=stream_output.finish_reason, + ) + ) + if stream_output.finish_reason is not None: + self.state.sync_num_unfinished_generations -= 1 + batch_outputs.append(outputs) + self.state.record_event(request_id, event="finish callback") + return batch_outputs diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 4c95f6e612..65b41a66ac 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -20,7 +20,12 @@ from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import ( + EngineConfig, + GenerationConfig, + KVStateKind, + SpeculativeMode, +) from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -89,8 +94,10 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: if conversation is None: assert isinstance(chat_config.conv_template, Conversation) conversation = chat_config.conv_template - # Try look up model library, and do JIT compile if model library not found. - try: + + if model.model_lib_path is not None: + # do model lib search if the model lib path is provided + # error out if file not found model_lib_path = _get_lib_module_path( model=model.model, model_path=model_path, @@ -99,7 +106,9 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: device_name=device.MASK2STR[device.device_type], config_file_path=config_file_path, ) - except FileNotFoundError: + else: + # TODO(mlc-team) add logging information + # Run jit if model_lib_path is not provided from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel model_lib_path = str( @@ -117,7 +126,7 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: return model_args, config_file_paths, conversation -def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-many-locals,too-many-arguments +def _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( # pylint: disable=too-many-locals,too-many-arguments models: List[ModelInfo], device: tvm.runtime.Device, model_config_paths: List[str], @@ -195,7 +204,7 @@ def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-ma if gpu_size_bytes is None: raise ValueError("Cannot read total GPU global memory from device.") if gpu_memory_utilization is None: - gpu_memory_utilization = 0.90 + gpu_memory_utilization = 0.85 model_max_total_sequence_length = int( ( @@ -236,6 +245,90 @@ def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-ma ) +def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=too-many-arguments, too-many-locals, unused-argument + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_paths: List[str], + model_config_dicts: List[Dict[str, Any]], + max_num_sequence: int, + gpu_memory_utilization: Optional[float], +) -> Tuple[float, float, float, int]: + # Get single-card GPU size. + gpu_size_bytes = device.total_global_memory + if gpu_size_bytes is None: + raise ValueError("Cannot read total GPU global memory from device.") + if gpu_memory_utilization is None: + gpu_memory_utilization = 0.90 + + rnn_state_base_bytes = 0.0 # the memory usage for rnn state when history = 1 + param_bytes = 0.0 + temp_func_bytes = 0.0 + model_workspace_bytes = 0.0 + logit_processor_workspace_bytes = 0.0 + for model, model_config_path, model_config_dict in zip( + models, model_config_paths, model_config_dicts + ): + # Read metadata for the parameter size and the temporary memory size. + cmd = [ + sys.executable, + "-m", + "mlc_llm.cli.model_metadata", + model.model_lib_path, + "--print-memory-usage-in-json", + "--mlc-chat-config", + model_config_path, + ] + usage_str = subprocess.check_output(cmd, universal_newlines=True) + usage_json = json.loads(usage_str) + param_bytes += usage_json["params_bytes"] + temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) + + model_config = model_config_dict["model_config"] + vocab_size = model_config_dict["vocab_size"] + head_size = model_config["head_size"] + num_heads = model_config["num_heads"] + num_layers = model_config["num_hidden_layers"] + hidden_size = model_config["hidden_size"] + prefill_chunk_size = model_config["prefill_chunk_size"] + logit_processor_workspace_bytes += ( + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + ) + + model_workspace_bytes += ( + prefill_chunk_size * 4 + + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + ) + + rnn_state_base_bytes += ( + max_num_sequence * hidden_size * num_layers * 2 * 2 + + max_num_sequence * num_heads * head_size * head_size * num_layers * 2 + ) + + max_history_size = int( + ( + gpu_size_bytes * gpu_memory_utilization + - logit_processor_workspace_bytes + - model_workspace_bytes + - param_bytes + - temp_func_bytes + ) + / rnn_state_base_bytes + ) + if max_history_size < 1: + raise ValueError( + f"Memory required by models may be larger than available GPU memory " + f"size {gpu_size_bytes * gpu_memory_utilization} bytes." + ) + + return ( + param_bytes, + model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, + rnn_state_base_bytes, + max_history_size, + ) + + def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: """Read the model config dictionaries, and return the maximum single sequence length the models can support, the maximum prefill chunk @@ -290,7 +383,7 @@ def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[i return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size -def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements +def _infer_kv_cache_config_for_kv_cache( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements mode: Literal["local", "interactive", "server"], max_batch_size: Optional[int], max_total_sequence_length: Optional[int], @@ -300,12 +393,13 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local device: tvm.runtime.Device, model_config_dicts: List[Dict[str, Any]], model_config_paths: List[str], -) -> Tuple[int, int, int, int]: +) -> Tuple[int, int, int, KVStateKind, int]: """Initialize the KV cache config with user input and GPU memory usage estimation. The returned four integers are: - max_batch_size - max_total_sequence_length - prefill_chunk_size + - kv_state_kind - model_max_single_sequence_length """ ( @@ -319,7 +413,7 @@ def infer_args_under_mode( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], - ) -> Tuple[Tuple[int, int, int], List[float]]: + ) -> Tuple[Tuple[int, int, int, KVStateKind], List[float]]: logging_msg = "" # - max_batch_size if max_batch_size is None: @@ -339,7 +433,7 @@ def infer_args_under_mode( kv_aux_workspace_bytes, temp_workspace_bytes, model_max_total_sequence_length, - ) = _estimate_mem_usage_and_max_total_sequence_length( + ) = _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( models, device, model_config_paths, @@ -396,7 +490,12 @@ def infer_args_under_mode( # - Construct the KV cache config # - Estimate total GPU memory usage on single GPU. - return (max_batch_size, max_total_sequence_length, prefill_chunk_size), [ + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + KVStateKind.ATTENTION, + ), [ total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, model_params_bytes, kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, @@ -458,9 +557,192 @@ def infer_args_under_mode( return *kv_cache_config, model_max_single_sequence_length +def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, KVStateKind, int]: + """Initialize the RNN state config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - kv_state_kind + - max_history_size + """ + logging_msg = "" + prefill_chunk_size = 0 + + if prefill_chunk_size is None: + prefill_chunk_size = min( + config["prefill_chunk_size"] if "prefill_chunk_size" in config else 4096 + for config in model_config_dicts + ) + logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " + else: + logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " + if max_batch_size is None: + max_batch_size = 1 if mode == "interactive" else 4 + logging_msg += f"max batch size is set to {max_batch_size}, " + else: + logging_msg += f"max batch size {max_batch_size} is specified by user, " + + if mode == "local": + logging_msg += ( + "We choose small max batch size and RNN state capacity to use less GPU memory." + ) + elif mode == "interactive": + logging_msg += "We fix max batch size to 1 for interactive single sequence use." + else: + logging_msg += ( + "We use as much GPU memory as possible (within the" " limit of gpu_memory_utilization)." + ) + logger.info('Under mode "%s", %s', mode, logging_msg) + + ( + model_param_bytes, + model_temp_bytes, + model_rnn_state_base_bytes, + model_max_history_size, + ) = _estimate_mem_usage_and_max_history_size_for_rnn_state( + models, + device, + model_config_paths, + model_config_dicts, + max_batch_size, + gpu_memory_utilization, + ) + if max_history_size is None: + max_history_size = model_max_history_size + else: + max_history_size = min(max_history_size, model_max_history_size) + max_total_sequence_length = 32768 + prefill_chunk_size = 0 + kind = KVStateKind.RNNSTATE + + logger.info( + "%s: %.2f MB (Parameters: %.2f MB. RNNState: %.2f MB. Temporary buffer: %.2f MB). " + "The actual usage might be slightly larger than the estimated number.", + green("Estimated total single GPU memory usage"), + (model_param_bytes + model_temp_bytes + model_rnn_state_base_bytes) / 1024 / 1024, + model_param_bytes / 1024 / 1024, + max_history_size * model_rnn_state_base_bytes / 1024 / 1024, + model_temp_bytes / 1024 / 1024, + ) + + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kind, + max_history_size, + ) + + +def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, int, int, KVStateKind]: + """Initialize the cache config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - max_single_sequence_length + - max_history_size + - kv_state_kind + """ + if all("rwkv" not in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_single_sequence_length, + ) = _infer_kv_cache_config_for_kv_cache( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_history_size = 0 # KV cache doesn't need this + elif all("rwkv" in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_history_size, + ) = _infer_kv_cache_config_for_rnn_state( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_single_sequence_length = max_total_sequence_length # RNN state doesn't need this + else: + raise ValueError("The models should be either all KV cache models or all RNN state models.") + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) + + +def _infer_generation_config( + model_config_dicts: List[Dict[str, Any]] +) -> List[Tuple[float, float, float, float]]: + """Infer the generation config from the model config dictionaries. + The returned four floats are: + - temperature + - top_p + - frequency_penalty + - presence_penalty + """ + generation_configs = [] + + for model_config in model_config_dicts: + temperature = model_config.get("temperature", 1.0) + top_p = model_config.get("top_p", 1.0) + frequency_penalty = model_config.get("frequency_penalty", 0.0) + presence_penalty = model_config.get("presence_penalty", 0.0) + generation_configs.append((temperature, top_p, frequency_penalty, presence_penalty)) + + return generation_configs + + @dataclass class CallbackStreamOutput: - """The output of LLMEngine._generate and AsyncLLMEngine._generate + """The output of MLCEngine._generate and AsyncMLCEngine._generate Attributes ---------- @@ -485,7 +767,7 @@ class CallbackStreamOutput: class AsyncRequestStream: - """The asynchronous stream for requests in AsyncLLMEngine. + """The asynchronous stream for requests in AsyncMLCEngine. Each request has its own unique stream. The stream exposes the method `push` for engine to push new generated @@ -544,29 +826,29 @@ async def __anext__(self) -> List[CallbackStreamOutput]: class EngineState: """The engine states that the request stream callback function may use. - This class is used for both AsyncLLMEngine and LLMEngine. - AsyncLLMEngine uses the fields and methods starting with "async", - and LLMEngine uses the ones starting with "sync". + This class is used for both AsyncMLCEngine and MLCEngine. + AsyncMLCEngine uses the fields and methods starting with "async", + and MLCEngine uses the ones starting with "sync". - - For AsyncLLMEngine, the state contains an asynchronous event loop, + - For AsyncMLCEngine, the state contains an asynchronous event loop, the streamers and the number of unfinished generations for each request being processed. - - For LLMEngine, the state contains a callback output blocking queue, + - For MLCEngine, the state contains a callback output blocking queue, the text streamers and the number of unfinished requests. We use this state class to avoid the callback function from capturing - the AsyncLLMEngine. + the AsyncMLCEngine. The state also optionally maintains an event trace recorder, which can provide Chrome tracing when enabled. """ trace_recorder = None - # States used for AsyncLLMEngine + # States used for AsyncMLCEngine async_event_loop: Optional[asyncio.AbstractEventLoop] = None async_streamers: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} async_num_unfinished_generations: Dict[str, int] = {} - # States used for LLMEngine + # States used for MLCEngine sync_output_queue: queue.Queue = queue.Queue() sync_text_streamers: List[TextStreamer] = [] sync_num_unfinished_generations: int = 0 @@ -577,7 +859,7 @@ def __init__(self, enable_tracing: bool) -> None: self.trace_recorder = EventTraceRecorder() def record_event(self, request_id: str, event: str) -> None: - """Record a event for the the input request in the trace + """Record a event for the input request in the trace recorder when the recorder exists. Parameters @@ -628,7 +910,7 @@ def async_lazy_init_event_loop(self) -> None: self.async_event_loop = asyncio.get_event_loop() def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for AsyncLLMEngine to stream back + """The request stream callback function for AsyncMLCEngine to stream back the request generation results. Note @@ -648,7 +930,7 @@ def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamO def _async_request_stream_callback_impl( self, delta_outputs: List[data.RequestStreamOutput] ) -> None: - """The underlying implementation of request stream callback for AsyncLLMEngine.""" + """The underlying implementation of request stream callback for AsyncMLCEngine.""" for delta_output in delta_outputs: request_id, stream_outputs = delta_output.unpack() streamers = self.async_streamers.get(request_id, None) @@ -689,28 +971,28 @@ def _async_request_stream_callback_impl( self.record_event(request_id, event="finish callback") def _sync_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for LLMEngine to stream back + """The request stream callback function for MLCEngine to stream back the request generation results. """ # Put the delta outputs to the queue in the unblocking way. self.sync_output_queue.put_nowait(delta_outputs) -class LLMEngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods +class MLCEngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods """The base engine class, which implements common functions that - are shared by LLMEngine and AsyncLLMEngine. + are shared by MLCEngine and AsyncMLCEngine. This class wraps a threaded engine that runs on a standalone thread inside and streams back the delta generated results via callback functions. The internal threaded engine keeps running an loop that drives the engine. - LLMEngine and AsyncLLMEngine inherits this LLMEngineBase class, and implements + MLCEngine and AsyncMLCEngine inherits this MLCEngineBase class, and implements their own methods to process the delta generated results received from callback functions and yield the processed delta results in the forms of standard API protocols. - Checkout subclasses AsyncLLMEngine/LLMEngine for the docstring of constructor parameters. + Checkout subclasses AsyncMLCEngine/MLCEngine for the docstring of constructor parameters. """ def __init__( # pylint: disable=too-many-arguments,too-many-locals @@ -724,6 +1006,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -753,11 +1036,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -776,32 +1062,37 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "abort_request", "run_background_loop", "run_background_stream_back_loop", + "reload", "init_background_engine", "exit_background_loop", "debug_call_func_on_all_worker", ] } self.tokenizer = Tokenizer(model_args[0][0]) + self._ffi["init_background_engine"]( + device, + self.state.get_request_stream_callback(kind), + self.state.trace_recorder, + ) + self._ffi["reload"]( + EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ) + ) def _background_loop(): - self._ffi["init_background_engine"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ), - self.state.get_request_stream_callback(kind), - self.state.trace_recorder, - ) self._ffi["run_background_loop"]() def _background_stream_back_loop(): @@ -919,6 +1210,7 @@ def process_chat_completion_request( # pylint: disable=too-many-arguments # Process generation config. Create request id. generation_cfg = protocol_utils.get_generation_config( request, + model_config, extra_stop_token_ids=conv_template.stop_token_ids, extra_stop_str=conv_template.stop_str, ) @@ -1039,10 +1331,11 @@ def process_chat_completion_stream_output( # pylint: disable=too-many-arguments return response, num_completion_tokens -def process_completion_request( +def process_completion_request( # pylint: disable=too-many-arguments request: openai_api_protocol.CompletionRequest, request_id: str, engine_state: EngineState, + model_config: Dict[str, Any], tokenizer: Tokenizer, max_input_sequence_length: int, ) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]: @@ -1094,7 +1387,7 @@ def process_completion_request( assert isinstance(prompt, list) # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config(request) + generation_cfg = protocol_utils.get_generation_config(request, model_config) # - Echo back the prompt. echo_response = None diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index b95fd4faae..af1613c027 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -5,8 +5,8 @@ import fastapi -from ..server import ServerContext -from . import entrypoint_utils +from mlc_llm.protocol import error_protocol +from mlc_llm.serve.server import ServerContext app = fastapi.APIRouter() @@ -26,11 +26,11 @@ async def debug_dump_event_trace(request: fastapi.Request): # Parse the JSON string request_dict = json.loads(request_json_str) except json.JSONDecodeError: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" ) if "model" not in request_dict: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" ) @@ -41,12 +41,41 @@ async def debug_dump_event_trace(request: fastapi.Request): async_engine = server_context.get_engine(model) if async_engine is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" is not served.' ) if async_engine.state.trace_recorder is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" does not enable tracing' ) return json.loads(async_engine.state.trace_recorder.dump_json()) + + +################ /debug/cuda_profiler_start/end ################ + + +@app.post("/debug/cuda_profiler_start") +async def debug_cuda_profiler_start(_request: fastapi.Request): + """Start the cuda profiler for the engine. Only for debug purpose.""" + server_context: ServerContext = ServerContext.current() + # Since the CUDA profiler is process-wise, call the function for one model is sufficient. + for model in server_context.get_model_list(): + async_engine = server_context.get_engine(model) + async_engine._debug_call_func_on_all_worker( # pylint: disable=protected-access + "mlc.debug_cuda_profiler_start" + ) + break + + +@app.post("/debug/cuda_profiler_stop") +async def debug_cuda_profiler_stop(_request: fastapi.Request): + """Stop the cuda profiler for the engine. Only for debug purpose.""" + server_context: ServerContext = ServerContext.current() + # Since the CUDA profiler is process-wise, call the function for one model is sufficient. + for model in server_context.get_model_list(): + async_engine = server_context.get_engine(model) + async_engine._debug_call_func_on_all_worker( # pylint: disable=protected-access + "mlc.debug_cuda_profiler_stop" + ) + break diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index ac8503d5df..23a279021f 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -1,37 +1,21 @@ """OpenAI API-compatible server entrypoints in MLC LLM""" # pylint: disable=too-many-locals,too-many-return-statements,too-many-statements -import ast -import json from http import HTTPStatus -from typing import AsyncGenerator, Dict, List, Optional, Sequence, Union +from typing import AsyncGenerator, List, Optional import fastapi -from mlc_llm.serve import data - -from ...protocol import protocol_utils -from ...protocol.conversation_protocol import Conversation -from ...protocol.openai_api_protocol import ( - ChatCompletionMessage, +from mlc_llm.protocol import error_protocol +from mlc_llm.protocol.openai_api_protocol import ( ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionStreamResponse, - ChatCompletionStreamResponseChoice, - ChatFunctionCall, - ChatToolCall, CompletionRequest, - CompletionResponse, - CompletionResponseChoice, ListResponse, - LogProbs, LogProbsContent, ModelResponse, - UsageInfo, ) -from ..server import ServerContext -from . import entrypoint_utils +from mlc_llm.serve import engine_base, engine_utils +from mlc_llm.serve.server import ServerContext app = fastapi.APIRouter() @@ -59,130 +43,30 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re server_context: ServerContext = ServerContext.current() async_engine = server_context.get_engine(request.model) if async_engine is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' ) - request_id = f"cmpl-{entrypoint_utils.random_uuid()}" - async_engine.state.record_event(request_id, event="receive request") - - # - Check if unsupported arguments are specified. - error = entrypoint_utils.check_unsupported_fields(request) - if error is not None: - return error - - # - Process prompt and check validity. - async_engine.state.record_event(request_id, event="start tokenization") - prompts = entrypoint_utils.process_prompts(request.prompt, async_engine.tokenizer.encode) - async_engine.state.record_event(request_id, event="finish tokenization") - if isinstance(prompts, fastapi.responses.JSONResponse): - # Errored when processing the prompts - return prompts - if len(prompts) > 1: - return entrypoint_utils.create_error_response( - HTTPStatus.BAD_REQUEST, - message="Entrypoint /v1/completions only accept single prompt. " - f"However, {len(prompts)} prompts {prompts} are received.", - ) - error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) - if error is not None: - return error - prompt = prompts[0] - - # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config(request) + request_id = f"cmpl-{engine_utils.random_uuid()}" # Streaming response. if request.stream: + # We manually get the first response from generator to + # capture potential exceptions in this scope, rather then + # the StreamingResponse scope. + stream_generator = async_engine._handle_completion( # pylint: disable=protected-access + request, request_id + ) + first_response = await anext( # type: ignore # pylint: disable=undefined-variable + stream_generator + ) async def completion_stream_generator() -> AsyncGenerator[str, None]: - # - Echo back the prompt. - if request.echo: - text = async_engine.tokenizer.decode(prompt) - response = CompletionResponse( - id=request_id, - choices=[ - CompletionResponseChoice(index=i, text=text) - for i in range(generation_cfg.n) - ], - model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=0, - ), - ) - yield f"data: {response.model_dump_json()}\n\n" - - # - Generate new tokens. - num_completion_tokens = 0 - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - async_engine.state.record_event(request_id, event="invoke generate") - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): - assert len(delta_outputs) == generation_cfg.n - choices = [] - for i, delta_output in enumerate(delta_outputs): - finish_reason_updated = False - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = delta_output.finish_reason - finish_reason_updated = True - num_completion_tokens += delta_output.num_delta_tokens - if not finish_reason_updated and delta_output.delta_text == "": - # Ignore empty delta text when finish reason is not updated. - continue - - choices.append( - CompletionResponseChoice( - index=i, - finish_reason=finish_reasons[i], - text=delta_output.delta_text, - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in delta_output.delta_logprob_json_strs - ] - ) - if delta_output.delta_logprob_json_strs is not None - else None - ), - ) - ) - - if len(choices) == 0: - # Skip yield when there is no delta output. - continue - response = CompletionResponse( - id=request_id, - choices=choices, - model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=num_completion_tokens, - ), - ) - yield f"data: {response.model_dump_json()}\n\n" - async_engine.state.record_event(request_id, event="finish") - - # - Echo the suffix. - if request.suffix is not None: - assert all(finish_reason is not None for finish_reason in finish_reasons) - response = CompletionResponse( - id=request_id, - choices=[ - CompletionResponseChoice( - index=i, - finish_reason=finish_reason, - text=request.suffix, - ) - for i, finish_reason in enumerate(finish_reasons) - ], - model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=num_completion_tokens, - ), - ) + if isinstance(first_response, StopAsyncIteration): + yield "data: [DONE]\n\n" + return + yield f"data: {first_response.model_dump_json()}\n\n" + async for response in stream_generator: yield f"data: {response.model_dump_json()}\n\n" - yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( @@ -190,165 +74,51 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - init_output_text = "" if not request.echo else async_engine.tokenizer.decode(prompt) - output_texts = [init_output_text for _ in range(generation_cfg.n)] + num_prompt_tokens = 0 num_completion_tokens = 0 - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - logprob_json_strs_list: Optional[List[List[str]]] = ( - [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None + output_texts = ["" for _ in range(request.n)] + finish_reasons: List[Optional[str]] = [None for _ in range(request.n)] + logprob_results: Optional[List[List[LogProbsContent]]] = ( + [[] for _ in range(request.n)] if request.logprobs else None ) - async_engine.state.record_event(request_id, event="invoke generate") - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): + + async for response in async_engine._handle_completion( # pylint: disable=protected-access + request, request_id + ): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, # and abort the request from engine if so. await async_engine.abort(request_id) - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[choice.index] += choice.logprobs.content - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = delta_output.finish_reason - output_texts[i] += delta_output.delta_text - num_completion_tokens += delta_output.num_delta_tokens - if logprob_json_strs_list is not None: - assert delta_output.delta_logprob_json_strs is not None - logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs assert all(finish_reason is not None for finish_reason in finish_reasons) - suffix = request.suffix if request.suffix is not None else "" - async_engine.state.record_event(request_id, event="finish") - response = CompletionResponse( - id=request_id, - choices=[ - CompletionResponseChoice( - index=i, - finish_reason=finish_reason, - text=output_text + suffix, - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in logprob_json_strs_list[ # pylint: disable=unsubscriptable-object - i - ] - ] - ) - if logprob_json_strs_list is not None - else None - ), - ) - for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons)) - ], + return engine_base.wrap_completion_response( + request_id=request_id, model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=num_completion_tokens, - ), + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, ) - return response ################ v1/chat/completions ################ -def chat_completion_check_message_validity( - messages: List[ChatCompletionMessage], -) -> Optional[str]: - """Check if the given chat messages are valid. Return error message if invalid.""" - for i, message in enumerate(messages): - if message.role == "system" and i != 0: - return f"System prompt at position {i} in the message list is invalid." - if message.role == "tool": - return "Tool as the message author is not supported yet." - if message.tool_call_id is not None: - if message.role != "tool": - return "Non-tool message having `tool_call_id` is invalid." - if isinstance(message.content, list): - if message.role != "user": - return "Non-user message having a list of content is invalid." - if message.tool_calls is not None: - if message.role != "assistant": - return "Non-assistant message having `tool_calls` is invalid." - return "Assistant message having `tool_calls` is not supported yet." - return None - - -def check_function_call_usage( - request: ChatCompletionRequest, conv_template: Conversation -) -> Optional[str]: - """Check if function calling is used and update the conversation template. - Return error message if invalid request format for function calling. - """ - - # return if no tools are provided or tool_choice is set to none - if request.tools is None or ( - isinstance(request.tool_choice, str) and request.tool_choice == "none" - ): - conv_template.use_function_calling = False - return None - - # select the tool based on the tool_choice if specified - if isinstance(request.tool_choice, dict): - if request.tool_choice["type"] != "function": - return "Only 'function' tool choice is supported" - - if len(request.tool_choice["function"]) > 1: - return "Only one tool is supported when tool_choice is specified" - - for tool in request.tools: - if tool.function.name == request.tool_choice["function"]["name"]: - conv_template.use_function_calling = True - conv_template.function_string = tool.function.model_dump_json() - return None - - return ( - f"The tool_choice function {request.tool_choice['function']['name']}" - " is not found in the tools list" - ) - - if isinstance(request.tool_choice, str) and request.tool_choice != "auto": - return f"Invalid tool_choice value: {request.tool_choice}" - - function_list = [] - for tool in request.tools: - if tool.type != "function": - return "Only 'function' tool type is supported" - function_list.append(tool.function.model_dump()) - - conv_template.use_function_calling = True - conv_template.function_string = json.dumps(function_list) - return None - - -def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: - """Convert a (possibly list) of function call string to a list of json objects. - Return None for invalid function call string.""" - - def parse_function_call(call_str: str): - node = ast.parse(call_str, mode="eval") - call_node = node.body - if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): - name = call_node.func.id - arguments = {} - for keyword in call_node.keywords: - arguments[keyword.arg] = ast.literal_eval(keyword.value) - return {"name": name, "arguments": arguments} - return None - - if ( - stringified_calls[0] == "[" and stringified_calls[-1] == "]" - ): # hacky way to check if string list - calls = ast.literal_eval(stringified_calls) - else: - calls = [stringified_calls] - function_calls_json = [parse_function_call(call_str) for call_str in calls] - return function_calls_json - - @app.post("/v1/chat/completions") async def request_chat_completion( request: ChatCompletionRequest, raw_request: fastapi.Request @@ -360,132 +130,30 @@ async def request_chat_completion( server_context: ServerContext = ServerContext.current() async_engine = server_context.get_engine(request.model) if async_engine is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' ) - request_id = f"chatcmpl-{entrypoint_utils.random_uuid()}" - async_engine.state.record_event(request_id, event="receive request") - - # - Check if the model supports chat conversation. - conv_template = server_context.get_conv_template(request.model) - if conv_template is None: - return entrypoint_utils.create_error_response( - HTTPStatus.BAD_REQUEST, - message=f'The requested model "{request.model}" does not support chat.', - ) - - # - Check if unsupported arguments are specified. - error = entrypoint_utils.check_unsupported_fields(request) - if error is not None: - return error - - # - Process messages and update the conversation template in three steps: - # i. Check the message validity. - # ii. Add the input messages to the conversation template. - # iii. Add the additional message for the assistant. - error_msg = chat_completion_check_message_validity(request.messages) - if error_msg is not None: - return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - - # Check for function calling usage and update the conversation template - error_msg = check_function_call_usage(request, conv_template) - if error_msg is not None: - return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - - for message in request.messages: - role = message.role - content = message.content - if role == "system": - assert isinstance(content, str) - conv_template.system_message = content if content is not None else "" - continue - - assert role != "tool", "Internal error: tool role." - conv_template.messages.append((role, content)) - conv_template.messages.append(("assistant", None)) - - # - Get the prompt from template, and encode to token ids. - # - Check prompt length - async_engine.state.record_event(request_id, event="start tokenization") - - model_config = server_context.get_model_config(request.model) - prompts = entrypoint_utils.process_prompts( - conv_template.as_prompt(model_config), - async_engine.tokenizer.encode, - ) - - async_engine.state.record_event(request_id, event="finish tokenization") - - if conv_template.system_prefix_token_ids is not None: - prompts[0] = conv_template.system_prefix_token_ids + prompts[0] - error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) - if error is not None: - return error - - prompt: Sequence[Union[List[int], data.ImageData]] = prompts - - # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config( - request, - extra_stop_token_ids=conv_template.stop_token_ids, - extra_stop_str=conv_template.stop_str, - ) + request_id = f"chatcmpl-{engine_utils.random_uuid()}" # Streaming response. if request.stream: + # We manually get the first response from generator to + # capture potential exceptions in this scope, rather then + # the StreamingResponse scope. + stream_generator = async_engine._handle_chat_completion( # pylint: disable=protected-access + request, request_id + ) + first_response = await anext( # type: ignore # pylint: disable=undefined-variable + stream_generator + ) async def completion_stream_generator() -> AsyncGenerator[str, None]: - async_engine.state.record_event(request_id, event="invoke generate") - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): - assert len(delta_outputs) == generation_cfg.n - choices = [] - for i, delta_output in enumerate(delta_outputs): - finish_reason_updated = False - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = ( - delta_output.finish_reason - if not conv_template.use_function_calling - else "tool_calls" - ) - finish_reason_updated = True - if not finish_reason_updated and delta_output.delta_text == "": - # Ignore empty delta text when finish reason is not updated. - async_engine.state.record_event(request_id, event="skip empty delta text") - continue - - choices.append( - ChatCompletionStreamResponseChoice( - index=i, - finish_reason=finish_reasons[i], - delta=ChatCompletionMessage( - content=delta_output.delta_text, role="assistant" - ), - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in delta_output.delta_logprob_json_strs - ] - ) - if delta_output.delta_logprob_json_strs is not None - else None - ), - ) - ) - - if len(choices) == 0: - # Skip yield when there is no delta output. - continue - response = ChatCompletionStreamResponse( - id=request_id, - choices=choices, - model=request.model, - system_fingerprint="", - ) - async_engine.state.record_event(request_id, event="yield delta output") + if isinstance(first_response, StopAsyncIteration): + yield "data: [DONE]\n\n" + return + yield f"data: {first_response.model_dump_json()}\n\n" + async for response in stream_generator: yield f"data: {response.model_dump_json()}\n\n" - async_engine.state.record_event(request_id, event="finish") yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( @@ -493,93 +161,49 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - output_texts = ["" for _ in range(generation_cfg.n)] + num_prompt_tokens = 0 num_completion_tokens = 0 - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - logprob_json_strs_list: Optional[List[List[str]]] = ( - [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None + output_texts = ["" for _ in range(request.n)] + finish_reasons: List[Optional[str]] = [None for _ in range(request.n)] + logprob_results: Optional[List[List[LogProbsContent]]] = ( + [[] for _ in range(request.n)] if request.logprobs else None ) - async_engine.state.record_event(request_id, event="invoke generate") - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): + + async for response in async_engine._handle_chat_completion( # pylint: disable=protected-access + request, request_id + ): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, # and abort the request from engine if so. await async_engine.abort(request_id) - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[choice.index] += choice.logprobs.content - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = delta_output.finish_reason - output_texts[i] += delta_output.delta_text - num_completion_tokens += delta_output.num_delta_tokens - if logprob_json_strs_list is not None: - assert delta_output.delta_logprob_json_strs is not None - logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs assert all(finish_reason is not None for finish_reason in finish_reasons) - - async_engine.state.record_event(request_id, event="finish") - - tool_calls_list: List[List[ChatToolCall]] = [[] for _ in range(generation_cfg.n)] - if conv_template.use_function_calling: - for i, output_text in enumerate(output_texts): - try: - fn_json_list = convert_function_str_to_json(output_text) - except (SyntaxError, ValueError): - output_text = "Got an invalid function call output from model" - finish_reasons[i] = "error" - else: - tool_calls_list[i] = [ - ChatToolCall( - type="function", - function=ChatFunctionCall( - name=fn_json_obj["name"], arguments=fn_json_obj["arguments"] - ), - ) - for fn_json_obj in fn_json_list - if fn_json_obj is not None - ] - if len(tool_calls_list[i]) == 0: - output_texts[i] = "Got an invalid function call output from model" - finish_reasons[i] = "error" - else: - finish_reasons[i] = "tool_calls" - - return ChatCompletionResponse( - id=request_id, - choices=[ - ChatCompletionResponseChoice( - index=i, - finish_reason=finish_reasons[i], - message=( - ChatCompletionMessage(role="assistant", content=output_text) - if (not conv_template.use_function_calling or finish_reason == "error") - else ChatCompletionMessage(role="assistant", tool_calls=tool_calls) - ), - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in logprob_json_strs_list[ # pylint: disable=unsubscriptable-object - i - ] - ] - ) - if logprob_json_strs_list is not None - else None - ), - ) - for i, (output_text, finish_reason, tool_calls) in enumerate( - zip(output_texts, finish_reasons, tool_calls_list) - ) - ], + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, model=request.model, - system_fingerprint="", - usage=UsageInfo( - prompt_tokens=sum(len(item) for item in prompt), completion_tokens=num_completion_tokens - ), + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, ) diff --git a/python/mlc_llm/serve/event_trace_recorder.py b/python/mlc_llm/serve/event_trace_recorder.py index 7a8a8177fe..457918d598 100644 --- a/python/mlc_llm/serve/event_trace_recorder.py +++ b/python/mlc_llm/serve/event_trace_recorder.py @@ -17,7 +17,7 @@ def __init__(self) -> None: ) def add_event(self, request_id: str, event: str) -> None: - """Record a event for the the input request in the trace recorder. + """Record a event for the input request in the trace recorder. Parameters ---------- diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index d5ad862a42..cf491884c2 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -247,7 +247,7 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.GrammarStateMatcherFromTokenTable, # type: ignore # pylint: disable=no-member grammar, - *tokenizer, + tokenizer, max_rollback_steps, ) else: diff --git a/python/mlc_llm/serve/radix_tree.py b/python/mlc_llm/serve/radix_tree.py new file mode 100644 index 0000000000..102cdac675 --- /dev/null +++ b/python/mlc_llm/serve/radix_tree.py @@ -0,0 +1,150 @@ +"""The Paged Radix Tree class.""" + +from typing import List, Tuple, Union + +import tvm +import tvm._ffi +from tvm.runtime import Object, ShapeTuple + +from . import _ffi_api + + +@tvm._ffi.register_object("mlc.serve.PagedRadixTree") # pylint: disable=protected-access +class PagedRadixTree(Object): + """The paged radix tree to manage prefix and sequence.""" + + def __init__(self, num_pages: int, page_size: int, num_seqs: int): + """ + Constructor of paged radix tree. + + Parameters + ---------- + num_pages : int + The number of radix tree pages. + page_size : int + The page size of each radix tree page. + num_seqs : int + The maximum number of sequence ID. + """ + self.__init_handle_by_constructor__(_ffi_api.PagedRadixTree, num_pages, page_size, num_seqs) # type: ignore # pylint: disable=no-member + + def match(self, tokens: Union[ShapeTuple, List, Tuple]) -> Tuple[int, ShapeTuple]: + """ + Get all sequences with longest common prefix with given prefix tokens. + + Parameters + ---------- + tokens : Union[ShapeTuple, List, Tuple] + The prefix tokens for reference. + + Returns + ------ + matched_offset : int + The matched prefix length. + seq_ids : ShapeTuple + The array of matched sequence indice. + """ + if isinstance(tokens, (list, tuple)): + tokens = ShapeTuple(tokens) + output = _ffi_api.PagedRadixTreeMatchPrefix(self, tokens) # type: ignore # pylint: disable=no-member + if len(output) == 1: + return output[0], [] + return output[0], output[1:] + + def add(self, seq_id: int) -> None: + """ + Get all sequences with longest common prefix with give prefix tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + """ + _ffi_api.PagedRadixTreeAddSequence(self, seq_id) # type: ignore # pylint: disable=no-member + + def remove(self, seq_id: int) -> None: + """ + Remove a sequence. + + Parameters + ---------- + seq_id : int + The sequence ID to remove. + """ + _ffi_api.PagedRadixTreeRemoveSequence(self, seq_id) # type: ignore # pylint: disable=no-member + + def extend(self, seq_id: int, tokens: Union[ShapeTuple, List, Tuple]) -> None: + """ + Get all sequences with longest common prefix with give prefix tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + tokens : Union[ShapeTuple, List, Tuple] + The given tokens to extend. + """ + if isinstance(tokens, (list, tuple)): + tokens = ShapeTuple(tokens) + _ffi_api.PagedRadixTreeExtendSequence(self, seq_id, tokens) # type: ignore # pylint: disable=no-member + + def fork(self, seq_id: int, parent_seq_id: int, forked_offset: int) -> None: + """ + Fork a sequence from parent sequence at given position. + + Parameters + ---------- + seq_id : int + The new sequence ID. + parent_seq_id : int + The parent sequence ID to fork from. + forked_offset : int + The position of parent sequence to fork at. + The valid value is [1, length of forked sequence]. + If the position equals the length of forked sequence, + the new sequence will copy the entire forked sequence. + """ + _ffi_api.PagedRadixTreeForkSequence(self, seq_id, parent_seq_id, forked_offset) # type: ignore # pylint: disable=no-member + + def get(self, seq_id: int) -> ShapeTuple: + """ + Get a sequence's all tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + + Returns + ------ + tokens : ShapeTuple + The sequence tokens. + """ + return _ffi_api.PagedRadixTreeGetSequence(self, seq_id) # type: ignore # pylint: disable=no-member + + def get_length(self, seq_id: int) -> int: + """ + Get a sequence's length. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + + Returns + ------ + length : int + The sequence length. + """ + return _ffi_api.PagedRadixTreeGetSequenceLength(self, seq_id) # type: ignore # pylint: disable=no-member + + def free_capacity(self) -> int: + """ + Get the remaining token capacity of the paged radix tree. + + Returns + ------ + capacity : int + The remaining token capacity of the paged radix tree. + """ + return _ffi_api.PagedRadixTreeFreeCapacity(self) # type: ignore # pylint: disable=no-member diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index 0a9a1b0b1f..d6acd4a2be 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional -from ..engine import AsyncLLMEngine +from ..engine import AsyncMLCEngine class ServerContext: @@ -13,7 +13,7 @@ class ServerContext: server_context: Optional["ServerContext"] = None def __init__(self): - self._models: Dict[str, AsyncLLMEngine] = {} + self._models: Dict[str, AsyncMLCEngine] = {} def __enter__(self): if ServerContext.server_context is not None: @@ -31,14 +31,17 @@ def current(): """Returns the current ServerContext.""" return ServerContext.server_context - def add_model(self, hosted_model: str, engine: AsyncLLMEngine) -> None: + def add_model(self, hosted_model: str, engine: AsyncMLCEngine) -> None: """Add a new model to the server context together with the engine.""" if hosted_model in self._models: raise RuntimeError(f"Model {hosted_model} already running.") self._models[hosted_model] = engine - def get_engine(self, model: str) -> Optional[AsyncLLMEngine]: - """Get the async engine of the requested model.""" + def get_engine(self, model: Optional[str]) -> Optional[AsyncMLCEngine]: + """Get the async engine of the requested model, or the unique async engine + if only one engine is served.""" + if len(self._models) == 1: + return next(iter(self._models.values())) return self._models.get(model, None) def get_model_list(self) -> List[str]: diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 23b151d5c7..1be841cb08 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -41,7 +41,7 @@ def _create_tvm_module( return {key: module[key] for key in ffi_funcs} -class SyncLLMEngine: +class SyncMLCEngine: """The Python interface of synchronize request serving engine for MLC LLM. The engine receives requests from the "add_request" method. For @@ -98,6 +98,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, enable_tracing: bool = False, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, @@ -128,11 +129,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -162,15 +166,17 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_lib_path=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ), + device, request_stream_callback, self.trace_recorder, ) diff --git a/python/mlc_llm/support/auto_config.py b/python/mlc_llm/support/auto_config.py index f0247a6ef9..be0ee8af98 100644 --- a/python/mlc_llm/support/auto_config.py +++ b/python/mlc_llm/support/auto_config.py @@ -62,7 +62,7 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path: # search mlc-chat-config.json under path mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json" if not mlc_chat_config_json_path.exists(): - raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.") + raise ValueError(f"Fail to find mlc-chat-config.json under {mlc_chat_config_path}.") else: mlc_chat_config_json_path = mlc_chat_config_path diff --git a/python/mlc_llm/support/auto_device.py b/python/mlc_llm/support/auto_device.py index cf6d09495a..bddb9954c6 100644 --- a/python/mlc_llm/support/auto_device.py +++ b/python/mlc_llm/support/auto_device.py @@ -1,4 +1,6 @@ """Automatic detection of the device available on the local machine.""" + +import os import subprocess import sys from typing import Dict, Optional @@ -65,6 +67,7 @@ def _device_exists(device: Device) -> bool: capture_output=True, text=True, check=False, + env=os.environ, ) .stdout.strip() .splitlines() diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 5c61af6f07..4c32feb6ff 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -295,12 +295,18 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): def detect_cuda_arch_list(target: Target) -> List[str]: """Detect the CUDA architecture list from the target.""" + + def convert_to_num(arch_str): + arch_num_str = "".join(filter(str.isdigit, arch_str)) + assert arch_num_str, f"'{arch_str}' does not contain any digits" + return int(arch_num_str) + assert target.kind.name == "cuda", f"Expect target to be CUDA, but got {target}" if MLC_MULTI_ARCH is not None: - multi_arch = [x.strip() for x in MLC_MULTI_ARCH.split(",")] + multi_arch = [convert_to_num(x) for x in MLC_MULTI_ARCH.split(",")] else: assert target.arch.startswith("sm_") - multi_arch = [target.arch[3:]] + multi_arch = [convert_to_num(target.arch[3:])] multi_arch = list(set(multi_arch)) return multi_arch diff --git a/python/mlc_llm/support/download.py b/python/mlc_llm/support/download.py index a109c967bc..770833e9af 100644 --- a/python/mlc_llm/support/download.py +++ b/python/mlc_llm/support/download.py @@ -36,11 +36,13 @@ def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None: command = ["git", "clone", url, repo_name] _ensure_directory_not_exist(destination, force_redo=False) try: + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: logger.info("[Git] Cloning %s to %s", bold(url), destination) subprocess.run( command, - env={"GIT_LFS_SKIP_SMUDGE": "1"}, + env=env, cwd=tmp_dir, check=True, stdout=subprocess.DEVNULL, diff --git a/python/mlc_llm/support/max_thread_check.py b/python/mlc_llm/support/max_thread_check.py index 6c078c3bbf..6711fb5c55 100644 --- a/python/mlc_llm/support/max_thread_check.py +++ b/python/mlc_llm/support/max_thread_check.py @@ -3,7 +3,7 @@ from tvm.target import Target -def get_max_num_threads_per_block(target: Target): +def get_max_num_threads_per_block(target: Target) -> int: """ max(max_num_threads, max_threads_per_block); if latter does not exist, return max_num_threads. We add this method since some targets have both fields and `max_threads_per_block` is larger. diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 2a70154bba..4f1cfe103d 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -118,7 +118,7 @@ def __call__(self, func, name, before_run, ret_val, *args): print(f"{red(f'{func_name} has INF')}: {num_infs}") self.first_inf_occurred = True - # Save the the arguments to npz + # Save the arguments to npz arg_dict = {} for i, arg in enumerate(args): if isinstance(arg, tvm.nd.NDArray): diff --git a/scripts/build_mlc_for_docs.sh b/scripts/build_mlc_for_docs.sh new file mode 100755 index 0000000000..50eee3231a --- /dev/null +++ b/scripts/build_mlc_for_docs.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -euxo pipefail + +mkdir -p build +cd build +cmake .. +make -j$(nproc) +cd - diff --git a/scripts/build_site.sh b/scripts/build_site.sh index 6340ee838e..062f8094de 100755 --- a/scripts/build_site.sh +++ b/scripts/build_site.sh @@ -1,6 +1,7 @@ #!/bin/bash set -euxo pipefail +export PYTHONPATH=$PWD/python cd docs && make html && cd .. cd site && jekyll b && cd .. diff --git a/scripts/gh_deploy_site.sh b/scripts/gh_deploy_site.sh index 1b21c52d16..326c280484 100755 --- a/scripts/gh_deploy_site.sh +++ b/scripts/gh_deploy_site.sh @@ -4,6 +4,7 @@ set -euxo pipefail +scripts/build_mlc_for_docs.sh scripts/build_site.sh git fetch diff --git a/site/index.md b/site/index.md index 44befd4abc..ac0367cdb2 100644 --- a/site/index.md +++ b/site/index.md @@ -6,62 +6,41 @@ notitle: true # MLC LLM -MLC LLM is a universal solution that allows any language model to be deployed natively on a diverse set of hardware backends and native applications. - -Please visit [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) for detailed instructions. - -## Demos - -- [iOS](#ios) -- [Android](#android) -- [Windows Linux Mac](#windows-linux-mac) -- [Web browser](#web-browser) - -### iOS - -Our iOS app, MLCChat, is available on [App Store](https://apps.apple.com/us/app/mlc-chat/id6448482937) for iPhone and iPad. -You can try out the [Testflight app](https://testflight.apple.com/join/57zd7oxa) that sometimes contains beta release of latest models. -This app is tested on iPhone 15 Pro Max, iPhone 14 Pro Max, iPhone 14 Pro and iPhone 12 Pro. -Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, -[documentation](https://llm.mlc.ai/docs/deploy/ios.html) is available for building iOS apps with MLC LLM. +Documentation: [https://llm.mlc.ai/docs](https://llm.mlc.ai/docs) +**M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques.

- +

-Note: Llama-7B takes 4GB of RAM and RedPajama-3B takes 2.2GB to run. We recommend a latest device with 6GB RAM for Llama-7B, or 4GB RAM for RedPajama-3B, to run the app. The text generation speed could vary from time to time, for example, slow in the beginning but recover to a normal speed then. +## Installation -### Android +MLC LLM is available via [pip](https://llm.mlc.ai/docs/install/mlc_llm.html#install-mlc-packages). +It is always recommended to install it in an isolated conda virtual environment. -The demo APK is available to [download](https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk). The demo is tested on Samsung S23 with Snapdragon 8 Gen 2 chip, Redmi Note 12 Pro with Snapdragon 685 and Google Pixel phones. -Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, -[documentation](https://llm.mlc.ai/docs/deploy/android.html) is available for building android apps with MLC LLM. +To verify the installation, activate your virtual environment, run -

- -

+```bash +python -c "import mlc_llm; print(mlc_llm.__path__)" +``` -### Windows Linux Mac +You are expected to see the installation path of MLC LLM Python package. -Our cpp interface runs on AMD, Intel, Apple and NVIDIA GPUs. -Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, -[documentation](https://llm.mlc.ai/docs/deploy/cli.html) is available for building C++ apps with MLC LLM. +## Quick Start -

- -

+Please check out our documentation for the [quick start](https://llm.mlc.ai/docs/get_started/quick_start.html). -### Web Browser +## Introduction -[WebLLM](https://webllm.mlc.ai/) is our companion project that deploys MLC LLM natively to browsers using WebGPU and WebAssembly. Still everything runs inside the browser without server resources, and accelerated by local GPUs (e.g. AMD, Intel, Apple or NVIDIA). +Please check out our documentation for the [introduction](https://llm.mlc.ai/docs/get_started/introduction.html). ## Links -* Our official [GitHub repo](https://github.com/mlc-ai/mlc-llm); -* Our companion project [WebLLM](https://webllm.mlc.ai/) that enables running LLMs purely in browser. -* [Web Stable Diffusion](https://websd.mlc.ai/) is another MLC-series that runs the diffusion models purely in the browser. -* [Machine Learning Compilation course](https://mlc.ai) is available for a systematic walkthrough of our approach to universal deployment. +- You might want to check out our online public [Machine Learning Compilation course](https://mlc.ai) for a systematic +walkthrough of our approaches. +- [WebLLM](https://webllm.mlc.ai/) is a companion project using MLC LLM's WebGPU and WebAssembly backend. +- [WebStableDiffusion](https://websd.mlc.ai/) is a companion project for diffusion models with the WebGPU backend. ## Disclaimer diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index b86fd423a9..c52571b522 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -1,25 +1,8 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -import json -import queue -import threading from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union -import tvm +from mlc_llm.json_ffi import JSONFFIEngine -from mlc_llm.protocol import openai_api_protocol -from mlc_llm.serve import engine_utils -from mlc_llm.serve.engine_base import ( - EngineConfig, - SpeculativeMode, - _infer_kv_cache_config, - _parse_models, - _process_model_args, - detect_device, -) -from mlc_llm.tokenizer import Tokenizer - -prompts = [ +chat_completion_prompts = [ "What is the meaning of life?", "Introduce the history of Pittsburgh to me. Please elaborate in detail.", "Write a three-day Seattle travel plan. Please elaborate in detail.", @@ -32,227 +15,40 @@ "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", ] +function_calling_prompts = [ + "What is the temperature in Pittsburgh, PA?", + "What is the temperature in Tokyo, JP?", + "What is the temperature in Pittsburgh, PA and Tokyo, JP?", +] -class EngineState: - sync_queue: queue.Queue - - def get_request_stream_callback(self) -> Callable[[List[str]], None]: - # ChatCompletionStreamResponse - - def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: - self._sync_request_stream_callback(chat_completion_stream_responses_json_str) - - return _callback - - def _sync_request_stream_callback( - self, chat_completion_stream_responses_json_str: List[str] - ) -> None: - # Put the delta outputs to the queue in the unblocking way. - self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) - - -class JSONFFIEngine: - def __init__( # pylint: disable=too-many-arguments,too-many-locals - self, - model: str, - device: Union[str, tvm.runtime.Device] = "auto", - *, - model_lib_path: Optional[str] = None, - mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, - gpu_memory_utilization: Optional[float] = None, - ) -> None: - # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) - if isinstance(device, str): - device = detect_device(device) - assert isinstance(device, tvm.runtime.Device) - ( - model_args, - model_config_paths, - self.conv_template, - ) = _process_model_args(models, device) - - # - Load the raw model config into dict - self.model_config_dicts = [] - for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] - with open(model_config_paths[i], "r", encoding="utf-8") as file: - self.model_config_dicts.append(json.load(file)) - - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) - - # - Initialize engine state and engine. - self.state = EngineState() - module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() - self._ffi = { - key: module[key] - for key in [ - "init_background_engine", - "chat_completion", - "abort", - "get_last_error", - "run_background_loop", - "run_background_stream_back_loop", - "exit_background_loop", - ] - } - self.tokenizer = Tokenizer(model_args[0][0]) - - def _background_loop(): - self._ffi["init_background_engine"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ), - self.state.get_request_stream_callback(), - None, - ) - self._ffi["run_background_loop"]() - - def _background_stream_back_loop(): - self._ffi["run_background_stream_back_loop"]() - - # Create the background engine-driving thread and start the loop. - self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) - self._background_stream_back_loop_thread: threading.Thread = threading.Thread( - target=_background_stream_back_loop - ) - self._background_loop_thread.start() - self._background_stream_back_loop_thread.start() - self._terminated = False - - def terminate(self): - self._terminated = True - self._ffi["exit_background_loop"]() - self._background_loop_thread.join() - self._background_stream_back_loop_thread.join() - - def chat_completion( # pylint: disable=too-many-arguments,too-many-locals - self, - *, - messages: List[Dict[str, Any]], - model: str, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - logprobs: bool = False, - top_logprobs: int = 0, - logit_bias: Optional[Dict[int, float]] = None, - max_tokens: Optional[int] = None, - n: int = 1, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, - user: Optional[str] = None, - ignore_eos: bool = False, - response_format: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - if request_id is None: - request_id = f"chatcmpl-{engine_utils.random_uuid()}" - - chatcmpl_generator = self._handle_chat_completion( - openai_api_protocol.ChatCompletionRequest( - messages=[ - openai_api_protocol.ChatCompletionMessage.model_validate(message) - for message in messages - ], - model=model, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - logprobs=logprobs, - top_logprobs=top_logprobs, - logit_bias=logit_bias, - max_tokens=max_tokens, - n=n, - seed=seed, - stop=stop, - stream=stream, - temperature=temperature, - top_p=top_p, - tools=( - [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] - if tools is not None - else None - ), - tool_choice=tool_choice, - user=user, - ignore_eos=ignore_eos, - response_format=( - openai_api_protocol.RequestResponseFormat.model_validate(response_format) - if response_format is not None - else None - ), - ).model_dump_json(), - n=n, - request_id=request_id, - ) - for response in chatcmpl_generator: - yield response - - def _handle_chat_completion( - self, request_json_str: str, n: int, request_id: str - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - self.state.sync_queue = queue.Queue() - num_unfinished_requests = n - - success = bool(self._ffi["chat_completion"](request_json_str, request_id)) - - try: - while num_unfinished_requests > 0: - chat_completion_stream_responses_json_str = self.state.sync_queue.get() - for chat_completion_response_json_str in chat_completion_stream_responses_json_str: - chat_completion_response = ( - openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( - chat_completion_response_json_str - ) - ) - for choice in chat_completion_response.choices: - if choice.finish_reason is not None: - num_unfinished_requests -= 1 - yield chat_completion_response - except Exception as exception: # pylint: disable=broad-exception-caught - self._ffi["abort"](request_id) - raise exception +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } +] -def test_chat_completion(engine: JSONFFIEngine): +def run_chat_completion( + engine: JSONFFIEngine, + model: str, + prompts: List[str] = chat_completion_prompts, + tools: Optional[List[Dict]] = None, +): num_requests = 2 max_tokens = 64 n = 1 @@ -266,6 +62,7 @@ def test_chat_completion(engine: JSONFFIEngine): max_tokens=max_tokens, n=n, request_id=str(rid), + tools=tools, ): for choice in response.choices: assert choice.delta.role == "assistant" @@ -284,24 +81,61 @@ def test_chat_completion(engine: JSONFFIEngine): print(f"Output {req_id}({i}):{output}\n") -def test_malformed_request(engine: JSONFFIEngine): +def test_chat_completion(): + # Create engine. + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + engine = JSONFFIEngine( + model, + max_total_sequence_length=1024, + ) + + run_chat_completion(engine, model) + + # Test malformed requests. for response in engine._handle_chat_completion("malformed_string", n=1, request_id="123"): assert len(response.choices) == 1 assert response.choices[0].finish_reason == "error" + engine.terminate() -if __name__ == "__main__": + +def test_reload_reset_unload(): # Create engine. - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + engine = JSONFFIEngine( + model, + max_total_sequence_length=1024, + ) + + # Run chat completion before and after reload/reset. + run_chat_completion(engine, model) + engine._test_reload() + run_chat_completion(engine, model) + engine._test_reset() + run_chat_completion(engine, model) + engine._test_unload() + + engine.terminate() + + +def test_function_calling(): + model = "dist/gorilla-openfunctions-v1-q4f16_1-MLC" + model_lib_path = ( + "dist/gorilla-openfunctions-v1-q4f16_1-MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so" + ) engine = JSONFFIEngine( model, model_lib_path=model_lib_path, max_total_sequence_length=1024, ) - test_chat_completion(engine) - test_malformed_request(engine) + # run function calling + run_chat_completion(engine, model, function_calling_prompts, tools) engine.terminate() - del engine + + +if __name__ == "__main__": + test_chat_completion() + test_reload_reset_unload() + test_function_calling() diff --git a/tests/python/op/test_batch_spec_verify.py b/tests/python/op/test_batch_spec_verify.py new file mode 100644 index 0000000000..f35a39d71e --- /dev/null +++ b/tests/python/op/test_batch_spec_verify.py @@ -0,0 +1,160 @@ +import numpy as np +import pytest +import tvm +import tvm.testing + +from mlc_llm.op.batch_spec_verify import batch_spec_verify + + +@pytest.mark.parametrize("nbatch", [32, 64]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001, 128000]) +@pytest.mark.parametrize("plist", [[0.5, 0.5], [1, 0], [0, 1]]) +def test_batch_spec_verify(nbatch, vocab, plist): + def numpy_reference( + draft_probs, + draft_tokens, + model_probs, + token_tree_first_child, + token_tree_next_sibling, + uniform_samples, + token_tree_parent_ptr, + ): + nbatch = token_tree_parent_ptr.shape[0] + for b in range(nbatch): + parent_ptr = token_tree_parent_ptr[b] + child_ptr = token_tree_first_child[parent_ptr] + while child_ptr != -1: + child_token = draft_tokens[child_ptr] + p_child = model_probs[parent_ptr, child_token] + q_child = draft_probs[child_ptr, child_token] + uniform_sample = uniform_samples[child_ptr] + if p_child / q_child >= uniform_sample: + parent_ptr = child_ptr + child_ptr = token_tree_first_child[child_ptr] + else: + model_probs[parent_ptr, :] = np.maximum( + model_probs[parent_ptr, :] - draft_probs[child_ptr, :], 0.0 + ) + psum = np.sum(model_probs[parent_ptr, :]) + model_probs[parent_ptr, :] /= psum + child_ptr = token_tree_next_sibling[child_ptr] + token_tree_parent_ptr[b] = parent_ptr + + np.random.seed(0) + + def gen_chain(num_nodes, base): + token_tree_first_child = list() + token_tree_next_sibling = list() + for i in range(num_nodes): + token_tree_first_child.append(base + i + 1 if i + 1 < num_nodes else -1) + token_tree_next_sibling.append(-1) + return token_tree_first_child, token_tree_next_sibling, base, base + 1 + + def gen_full_binary_tree(height, base): + token_tree_first_child = list() + token_tree_next_sibling = list() + num_nodes = 2**height - 1 + for i in range(num_nodes): + token_tree_first_child.append(base + i * 2 + 1 if i * 2 + 1 < num_nodes else -1) + token_tree_next_sibling.append(base + i * 2 + 2 if i * 2 + 2 < num_nodes else -1) + return token_tree_first_child, token_tree_next_sibling, base, base + 1 + + ### Inputs + num_nodes = 0 + token_tree_first_child = list() + token_tree_next_sibling = list() + token_tree_parent_ptr = list() + + for _ in range(nbatch): + choice = np.random.choice(2, 1, p=plist) + if choice == 0: + nodes_batch = np.random.randint(3, 32) + res = gen_chain(nodes_batch, num_nodes) + num_nodes += nodes_batch + else: + height = np.random.randint(3, 5) + res = gen_full_binary_tree(height, num_nodes) + num_nodes += 2**height - 1 + token_tree_first_child.extend(res[0]) + token_tree_next_sibling.extend(res[1]) + token_tree_parent_ptr.append(res[2]) + + token_tree_first_child = np.array(token_tree_first_child).astype("int32") + token_tree_next_sibling = np.array(token_tree_next_sibling).astype("int32") + token_tree_parent_ptr = np.array(token_tree_parent_ptr).astype("int32") + + draft_probs = np.random.rand(num_nodes, vocab).astype("float32") + draft_probs /= np.sum(draft_probs, axis=1, keepdims=True) + draft_tokens = np.random.randint(0, vocab, num_nodes).astype("int32") + model_probs = np.random.rand(num_nodes, vocab).astype("float32") + model_probs /= np.sum(model_probs, axis=1, keepdims=True) + uniform_samples = np.random.rand(num_nodes).astype("float32") + + ### TVM Inputs + dev = tvm.cuda(0) + draft_probs_tvm = tvm.nd.array(draft_probs, dev) + draft_tokens_tvm = tvm.nd.array(draft_tokens, dev) + model_probs_tvm = tvm.nd.array(model_probs, dev) + token_tree_first_child_tvm = tvm.nd.array(token_tree_first_child, dev) + token_tree_next_sibling_tvm = tvm.nd.array(token_tree_next_sibling, dev) + uniform_samples_tvm = tvm.nd.array(uniform_samples, dev) + token_tree_parent_ptr_tvm = tvm.nd.array(token_tree_parent_ptr, dev) + + # print("draft_probs", draft_probs) + # print("draft_tokens", draft_tokens) + # print("model_probs", model_probs) + # print("token_tree_first_child", token_tree_first_child) + # print("token_tree_next_sibling", token_tree_next_sibling) + # print("uniform_samples", uniform_samples) + # print("token_tree_parent_ptr", token_tree_parent_ptr) + + ### Numpy reference + numpy_reference( + draft_probs, + draft_tokens, + model_probs, + token_tree_first_child, + token_tree_next_sibling, + uniform_samples, + token_tree_parent_ptr, + ) + # print("model_probs", model_probs) + # print("token_tree_parent_ptr", token_tree_parent_ptr) + + ### TVM + kernel = batch_spec_verify(vocab) + mod = tvm.build(kernel, target="cuda") + mod( + draft_probs_tvm, + draft_tokens_tvm, + model_probs_tvm, + token_tree_first_child_tvm, + token_tree_next_sibling_tvm, + uniform_samples_tvm, + token_tree_parent_ptr_tvm, + ) + # print("model_probs", model_probs_tvm.asnumpy()) + # print("token_tree_parent_ptr", token_tree_parent_ptr_tvm.asnumpy()) + + tvm.testing.assert_allclose(model_probs, model_probs_tvm.asnumpy()) + tvm.testing.assert_allclose( + token_tree_parent_ptr, token_tree_parent_ptr_tvm.asnumpy(), rtol=0, atol=0 + ) + + time_evaluator = mod.time_evaluator(mod.entry_name, dev, number=10, repeat=3) + print(f"batch_size: {nbatch}, vocab_size: {vocab}, tree_structure: {plist}") + print( + time_evaluator( + draft_probs_tvm, + draft_tokens_tvm, + model_probs_tvm, + token_tree_first_child_tvm, + token_tree_next_sibling_tvm, + uniform_samples_tvm, + token_tree_parent_ptr_tvm, + ) + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/op/test_top_p_pivot.py b/tests/python/op/test_top_p_pivot.py new file mode 100644 index 0000000000..7cfeb60e9c --- /dev/null +++ b/tests/python/op/test_top_p_pivot.py @@ -0,0 +1,83 @@ +import numpy as np +import pytest +import tvm +import tvm.testing + +from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm + +# mypy: disable-error-code="var-annotated" + + +@pytest.mark.parametrize("batch_size", [32, 64]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 128]) +def test_top_p_renorm(batch_size, vocab): + top_p = 0.95 + init_pivots_np = np.array([1 - top_p, 0.02, 0.01]).astype(np.float32) + top_p_np = np.array([top_p]).astype(np.float32) + + p_np = np.random.exponential(3, size=(batch_size, vocab)).astype(np.float32) + p_np /= np.sum(p_np, axis=-1, keepdims=True) + final_pivot_np = np.zeros(batch_size).astype(np.float32) + final_lsum_np = np.zeros(batch_size).astype(np.float32) + + dev = tvm.cuda(0) + var_prob = tvm.nd.array(p_np, dev) + var_init_pivots = tvm.nd.array(init_pivots_np, dev) + top_p_global = tvm.nd.array(top_p_np, dev) + var_final_pivot = tvm.nd.array(final_pivot_np, dev) + var_final_lsum = tvm.nd.array(final_lsum_np, dev) + + kernel = top_p_pivot(init_pivots_np.shape[0]) + mod = tvm.build(kernel, target="cuda") + mod(var_prob, top_p_global, var_init_pivots, var_final_pivot, var_final_lsum) + + final_pivot = var_final_pivot.asnumpy() + final_lsum = var_final_lsum.asnumpy() + + renorm_np = p_np.copy() + var_renorm = tvm.nd.array(renorm_np, dev) + + kernel_renorm = top_p_renorm() + mod_renorm = tvm.build(kernel_renorm, target="cuda") + mod_renorm(var_prob, var_final_pivot, var_final_lsum, var_renorm) + + renorm = var_renorm.asnumpy() + + def verify_pivot(probs: np.ndarray, pivot: float, lsum: float, renorm: np.ndarray): + sorted_probs = np.sort(probs, axis=-1)[::-1] + num_larger_than_pivot = np.sum(sorted_probs >= pivot) + filtered_sorted_probs = sorted_probs[:num_larger_than_pivot] + min_larger_than_pivot = min(filtered_sorted_probs) + + sum_larger_than_pivot = np.sum(np.where(sorted_probs >= pivot, sorted_probs, 0)) + sum_larger_than_pivot_exclude_min = np.sum( + np.where(filtered_sorted_probs != min_larger_than_pivot, filtered_sorted_probs, 0) + ) + + probs[probs < pivot] = 0 + renorm_prob = probs / np.sum(probs, axis=-1, keepdims=True) + try: + assert sum_larger_than_pivot >= top_p + assert sum_larger_than_pivot_exclude_min < top_p + assert abs(lsum - sum_larger_than_pivot) < 1e-6 + assert np.allclose(renorm, renorm_prob, atol=1e-6, rtol=1e-6) + except AssertionError: + print("Failed") + print("probs:", repr(probs)) + print("pivot:", pivot) + print("sorted_probs:", sorted_probs) + print("num_larger_than_pivot:", num_larger_than_pivot) + print("filtered_sorted_probs:", filtered_sorted_probs) + print("min_larger_than_pivot:", min_larger_than_pivot) + print("sum_larger_than_pivot:", sum_larger_than_pivot) + print("sum_larger_than_pivot_exclude_min:", sum_larger_than_pivot_exclude_min) + print("renom_prob:", renorm_prob) + print("renorm:", renorm) + raise + + for i in range(batch_size): + verify_pivot(p_np[i], final_pivot[i], final_lsum[i], renorm[i]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/op/test_two_stage_softmax.py b/tests/python/op/test_two_stage_softmax.py new file mode 100644 index 0000000000..1d3d55d8e3 --- /dev/null +++ b/tests/python/op/test_two_stage_softmax.py @@ -0,0 +1,47 @@ +import numpy as np +import scipy.special +import tvm +from tvm import dlight + +from mlc_llm.compiler_pass.rewrite_softmax import _get_lse_and_softmax_func + + +def test_two_stage_softmax(): + chunk_size = 4096 + target = tvm.target.Target("cuda") + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(target, chunk_size) + mod = tvm.IRModule({"chunk_lse": f_chunk_lse, "softmax_with_chunked_lse": f_softmax_with_lse}) + with target: + mod = dlight.ApplyDefaultSchedule(dlight.gpu.GeneralReduction())(mod) + + runtime_mod = tvm.build(mod, target=target) + device = tvm.cuda() + + num_runs = 5 + vocab_size = 128256 + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: + for _ in range(num_runs): + x_np = np.random.uniform(low=-10, high=10, size=(batch_size, vocab_size)).astype( + "float32" + ) + y_np = scipy.special.softmax(x_np, axis=-1) + + x_nd = tvm.nd.array(x_np, device=device) + r_nd = tvm.nd.empty( + (batch_size, (vocab_size + chunk_size - 1) // chunk_size), + x_np.dtype, + device=device, + ) + y_nd = tvm.nd.empty(x_np.shape, x_np.dtype, device=device) + + runtime_mod["chunk_lse"](x_nd, r_nd) + runtime_mod["softmax_with_chunked_lse"](x_nd, r_nd, y_nd) + + y_nd_arr = y_nd.numpy() + np.testing.assert_allclose(y_nd_arr, y_np, atol=1e-6, rtol=1e-6) + + print(f"pass batch size {batch_size}") + + +if __name__ == "__main__": + test_two_stage_softmax() diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index 4e541b7437..c89a9e2c38 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -5,7 +5,7 @@ from typing import List, Tuple from mlc_llm.serve import GenerationConfig -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine def _parse_args(): @@ -41,7 +41,7 @@ def benchmark(args: argparse.Namespace): random.seed(args.seed) # Create engine - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=args.model, device=args.device, model_lib_path=args.model_lib_path, diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index ad4fa01a82..e4f64d2ce4 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -329,23 +329,6 @@ def test_openai_v1_completions_openai_package( ) -def test_openai_v1_completions_invalid_requested_model( - launch_server, # pylint: disable=unused-argument -): - # `launch_server` is a pytest fixture defined in conftest.py. - - model = "unserved_model" - payload = { - "model": model, - "prompt": "What is the meaning of life?", - "max_tokens": 10, - } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - expect_error( - response_str=response.json(), msg_prefix=f'The requested model "{model}" is not served.' - ) - - @pytest.mark.parametrize("stream", [False, True]) def test_openai_v1_completions_echo( served_model: Tuple[str, str], @@ -620,51 +603,6 @@ class Schema(BaseModel): "response_format": {"type": "json_object", "schema": schema_str}, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_json( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = "Response with a json object:" - max_tokens = 128 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stream": stream, - "response_format": {"type": "json_object"}, - } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) if not stream: check_openai_nonstream_response( @@ -1364,7 +1302,6 @@ def test_debug_dump_event_trace( test_openai_v1_completions(MODEL, None, stream=True) test_openai_v1_completions_openai_package(MODEL, None, stream=False) test_openai_v1_completions_openai_package(MODEL, None, stream=True) - test_openai_v1_completions_invalid_requested_model(None) test_openai_v1_completions_echo(MODEL, None, stream=False) test_openai_v1_completions_echo(MODEL, None, stream=True) test_openai_v1_completions_suffix(MODEL, None, stream=False) diff --git a/tests/python/serve/test_radix_tree.py b/tests/python/serve/test_radix_tree.py new file mode 100644 index 0000000000..cea421cd95 --- /dev/null +++ b/tests/python/serve/test_radix_tree.py @@ -0,0 +1,79 @@ +from tvm import TVMError +from tvm.runtime import ShapeTuple + +from mlc_llm.serve import PagedRadixTree + + +def test_add(): + prt = PagedRadixTree(16, 128, 16) + prt.add(0) + assert prt.get(0) == [] + + +def test_remove(): + prt = PagedRadixTree(32, 128, 16) + capacity = prt.free_capacity() + prt.add(0) + prt.remove(0) + prt.add(0) + prt.extend(0, [1 for _ in range(200)]) + prt.remove(0) + assert prt.free_capacity() == capacity + + prt.add(1) + prt.extend(1, [1 for _ in range(200)]) + capacity = prt.free_capacity() + prt.add(2) + prt.extend(2, [1 for _ in range(100)] + [2 for _ in range(100)]) + prt.remove(2) + assert prt.free_capacity() == capacity + + prt.add(3) + prt.extend(3, [1 for _ in range(200)]) + prt.remove(3) + assert prt.free_capacity() == capacity + + +def test_extend(): + prt = PagedRadixTree(1024, 256, 256) + L = prt.free_capacity() // 1024 + H = L // 2 + Q = L // 4 + seq_id = 0 + for start_pos in [0, H, L, L + H]: + for length in [Q, L - H, L, 2 * L - H, 2 * L]: + prt.add(seq_id) + if start_pos: + tokens_1 = [seq_id for _ in range(start_pos)] + prt.extend(seq_id, tokens_1) + assert prt.get(seq_id) == tokens_1 + else: + tokens_1 = [] + tokens_2 = [seq_id for _ in range(length)] + prt.extend(seq_id, tokens_2) + assert prt.get(seq_id) == tokens_1 + tokens_2 + seq_id += 1 + + +def test_fork(): + prt = PagedRadixTree(1024, 256, 256) + L = prt.free_capacity() // 1024 + H = L // 2 + Q = L // 4 + seq_id = 0 + length_list = [Q, H, L, L + Q, L + H, L * 2] + for p_idx in range(1, len(length_list)): + for c_idx in range(0, p_idx + 1): + prt.add(seq_id) + tokens = [seq_id for _ in range(length_list[p_idx])] + prt.extend(seq_id, tokens) + prt.fork(seq_id + 1, seq_id, length_list[c_idx]) + assert prt.get(seq_id + 1) == tokens[: length_list[c_idx]] + seq_id += 2 + + +if __name__ == "__main__": + test_add() + test_remove() + test_extend() + test_fork() diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index 9bece30578..6e3835238a 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,7 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncLLMEngine, GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig prompts = [ "What is the meaning of life?", @@ -23,7 +23,7 @@ async def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -39,7 +39,7 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncLLMEngine, + async_engine: AsyncMLCEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, @@ -80,7 +80,7 @@ async def test_chat_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -132,7 +132,7 @@ async def test_chat_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -183,7 +183,7 @@ async def test_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -235,7 +235,7 @@ async def test_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index 6915224f81..c3963af613 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,17 +3,7 @@ import asyncio from typing import List -<<<<<<< HEAD -from mlc_llm.serve import ( - AsyncThreadedEngine, - EngineMode, - GenerationConfig, - KVCacheConfig, -) -from mlc_llm.serve.engine import ModelInfo -======= -from mlc_llm.serve import AsyncLLMEngine, GenerationConfig, SpeculativeMode ->>>>>>> upstream/main +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig, SpeculativeMode prompts = [ "What is the meaning of life?", @@ -37,7 +27,7 @@ async def test_engine_generate(): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -54,18 +44,14 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncLLMEngine, + async_engine: AsyncMLCEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, ): print(f"generate task for request {request_id}") rid = int(request_id) -<<<<<<< HEAD - async for delta_outputs in async_engine.generate( -======= async for delta_outputs in async_engine._generate( ->>>>>>> upstream/main prompt, generation_cfg, request_id=request_id ): assert len(delta_outputs) == generation_cfg.n diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 330bd4cf82..37d1833b14 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,7 +2,9 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List -from mlc_llm.serve import GenerationConfig, LLMEngine +import pytest + +from mlc_llm.serve import GenerationConfig, MLCEngine prompts = [ "What is the meaning of life?", @@ -17,17 +19,39 @@ "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", ] +test_models = [ + ( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ), + ( + "dist/rwkv-6-world-1b6-q0f16-MLC", + "dist/rwkv-6-world-1b6-q0f16-MLC/rwkv-6-world-1b6-q0f16-MLC-cuda.so", + ), +] -def test_engine_generate(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + +def create_engine(model: str, model_lib_path: str): + if "rwkv" in model: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_batch_size=8, + max_history_size=1, + ) + else: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_engine_generate(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 10 max_tokens = 256 @@ -57,16 +81,10 @@ def test_engine_generate(): del engine -def test_chat_completion(): +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion(model: str, model_lib_path: str): # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -101,16 +119,9 @@ def test_chat_completion(): del engine -def test_chat_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -144,16 +155,9 @@ def test_chat_completion_non_stream(): del engine -def test_completion(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -188,16 +192,9 @@ def test_completion(): del engine -def test_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -232,8 +229,9 @@ def test_completion_non_stream(): if __name__ == "__main__": - test_engine_generate() - test_chat_completion() - test_chat_completion_non_stream() - test_completion() - test_completion_non_stream() + for model, model_lib_path in test_models: + test_engine_generate(model, model_lib_path) + test_chat_completion(model, model_lib_path) + test_chat_completion_non_stream(model, model_lib_path) + test_completion(model, model_lib_path) + test_completion_non_stream(model, model_lib_path) diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 7f2a33b230..b764c62cd2 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -7,9 +7,9 @@ import pytest from pydantic import BaseModel -from mlc_llm.serve import AsyncLLMEngine, GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine prompts_list = [ "Generate a JSON string containing 20 objects:", @@ -22,7 +22,7 @@ def test_batch_generation_with_grammar(): # Create engine - engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt_len = len(prompts_list) prompts = prompts_list * 3 @@ -69,7 +69,7 @@ def test_batch_generation_with_grammar(): def test_batch_generation_with_schema(): # Create engine - engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt = ( "Generate a json containing three fields: an integer field named size, a " @@ -121,7 +121,7 @@ class Schema(BaseModel): async def run_async_engine(): # Create engine - async_engine = AsyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + async_engine = AsyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompts = prompts_list * 20 @@ -142,7 +142,7 @@ async def run_async_engine(): ] async def generate_task( - async_engine: AsyncLLMEngine, + async_engine: AsyncMLCEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index ff64e7235b..59e8c97196 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -2,7 +2,7 @@ from pathlib import Path from mlc_llm.serve import GenerationConfig, data -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine def get_test_image(config) -> data.ImageData: @@ -13,7 +13,7 @@ def test_engine_generate(): # Create engine model = "dist/llava-1.5-7b-hf-q4f16_1-MLC/params" model_lib_path = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 60be02ce1a..33c06b1c5e 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -11,7 +11,7 @@ SpeculativeMode, data, ) -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine prompts = [ "What is the meaning of life?", @@ -90,7 +90,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -158,7 +158,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -242,7 +242,7 @@ def step(self) -> None: "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -328,7 +328,7 @@ def step(self) -> None: "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -364,7 +364,19 @@ def step(self) -> None: # assert fin_time == request.generation_config.max_tokens - 1 -def test_engine_generate(): +def compare_output_text(output_text1, output_text2): + if isinstance(output_text1, list) and isinstance(output_text2, list): + for item1, item2 in zip(output_text1, output_text2): + if not compare_output_text(item1, item2): + return False + elif output_text1 != output_text2: + print(output_text1) + print(output_text2) + return False + return True + + +def test_engine_generate(compare_precision=False): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" @@ -372,7 +384,8 @@ def test_engine_generate(): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - engine = SyncLLMEngine( + + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -385,9 +398,31 @@ def test_engine_generate(): max_tokens = 256 # Generate output. - output_texts, _ = engine.generate( - prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=3) - ) + if compare_precision: + print("compare precision") + generation_config = GenerationConfig( + temperature=0.0, top_p=0, max_tokens=1024, stop_token_ids=[2], n=1 + ) + engine_single_model = SyncMLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + output_texts_single_model, _ = engine_single_model.generate( + prompts[:num_requests], generation_config + ) + for req_id, outputs in enumerate(output_texts_single_model): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + # TODO: Add pytorch precision + else: + generation_config = GenerationConfig(max_tokens=max_tokens, n=3) + output_texts, _ = engine.generate(prompts[:num_requests], generation_config) for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") if len(outputs) == 1: @@ -395,6 +430,12 @@ def test_engine_generate(): else: for i, output in enumerate(outputs): print(f"Output {req_id}({i}):{output}\n") + if compare_precision: + precision_flag = compare_output_text(output_texts, output_texts_single_model) + if precision_flag: + print(f"Accuracy verification succeed\n") + else: + print(f"Accuracy verification failed\n") def test_engine_eagle_generate(): @@ -405,7 +446,7 @@ def test_engine_eagle_generate(): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -453,7 +494,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -525,7 +566,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # small_model_lib_path = ( # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so" # ) - spec_engine = SyncLLMEngine( + spec_engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -595,7 +636,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" ) - spec_engine = SyncLLMEngine( + spec_engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -643,7 +684,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): test_engine_eagle_basic() test_engine_continuous_batching_1() test_engine_eagle_continuous_batching_1() - test_engine_generate() + test_engine_generate(compare_precision=True) test_engine_eagle_generate() test_engine_efficiency() test_engine_spec_efficiency() diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index c5d521b02d..f68f48b7c5 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -5,7 +5,7 @@ import numpy as np from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine prompts = [ "What is the meaning of life?", @@ -80,7 +80,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -156,7 +156,7 @@ def step(self) -> None: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -237,7 +237,7 @@ def step(self) -> None: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -323,7 +323,7 @@ def all_finished(self) -> bool: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -365,7 +365,7 @@ def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/web/emcc/mlc_wasm_runtime.cc b/web/emcc/mlc_wasm_runtime.cc index 3f05eb259f..b9a7f55bfa 100644 --- a/web/emcc/mlc_wasm_runtime.cc +++ b/web/emcc/mlc_wasm_runtime.cc @@ -29,6 +29,8 @@ // Pass in COMPILE_MLC_WASM_RUNTIME so unsupported code would not be compiled in to the .bc file #define COMPILE_MLC_WASM_RUNTIME 1 +#define __STDC_FORMAT_MACROS 1 +#define PICOJSON_USE_INT64 #define DMLC_USE_LOGGING_LIBRARY @@ -38,4 +40,5 @@ #include "serve/grammar/grammar_serializer.cc" #include "serve/grammar/grammar_simplifier.cc" #include "serve/grammar/grammar_state_matcher.cc" +#include "serve/grammar/json_schema_converter.cc" #include "support/encoding.cc"