diff --git a/README.md b/README.md index 9bea5ccc0e..88e3abd07d 100644 --- a/README.md +++ b/README.md @@ -50,98 +50,118 @@ -**Scalable.** MLC LLM scales universally on NVIDIA and AMD GPUs, cloud and gaming GPUs. Below -showcases our single batch decoding performance with prefilling = 1 and decoding = 256. +## Quick Start -Performance of 4-bit CodeLlama-34B and Llama2-70B on two NVIDIA RTX 4090 and two AMD Radeon 7900 XTX: -

- - -

+We introduce the quick start examples of chat CLI, Python API and REST server here to use MLC LLM. +We use 4-bit quantized 8B Llama-3 model for demonstration purpose. +The pre-quantized Llama-3 weights is available at https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC. +You can also try out unquantized Llama-3 model by replacing `q4f16_1` to `q0f16` in the examples below. +Please visit our [documentation](https://llm.mlc.ai/docs/index.html) for detailed quick start and introduction. -Scaling of fp16 and 4-bit CodeLlama-34 and Llama2-70B on A100-80G-PCIe and A10G-24G-PCIe, up to 8 GPUs: -

- -

+### Installation -## News +MLC LLM is available via [pip](https://llm.mlc.ai/docs/install/mlc_llm.html#install-mlc-packages). +It is always recommended to install it in an isolated conda virtual environment. -* [10/18/2023] [[Post]](https://blog.mlc.ai/2023/10/19/Scalable-Language-Model-Inference-on-Multiple-NVDIA-AMD-GPUs) Scalable multi-GPU support for CUDA and ROCm are official. -* [09/02/2023] Prebuilt ROCm 5.7 and CUDA 12.2 package is [available](https://llm.mlc.ai/docs/install/tvm.html#option-1-prebuilt-package). -* [08/25/2023] CodeLlama support is up. -* [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi. -* [08/09/2023] [[Post]](https://blog.mlc.ai/2023/08/09/Making-AMD-GPUs-competitive-for-LLM-inference) ROCm backend is mature to use. -* [08/02/2023] [Dockerfile](https://github.com/mlc-ai/llm-perf-bench/) is released for CUDA performance benchmarking. -* [07/19/2023] Support for Llama2-7B/13B/70B is up. -* [05/22/2023] [[Post]](https://blog.mlc.ai/2023/05/22/bringing-open-large-language-models-to-consumer-devices) RedPajama support is up. -* [05/08/2023] [[Post]](https://blog.mlc.ai/2023/05/08/bringing-hardware-accelerated-language-models-to-android-devices) MLC LLM is now available on Android. -* [05/01/2023] [[Post]](https://blog.mlc.ai/2023/05/01/bringing-accelerated-llm-to-consumer-hardware) MLC LLM is released with Metal, Vulkan and CUDA backends. -* [04/14/2023] [WebLLM](https://github.com/mlc-ai/web-llm) is released prior to MLC LLM with WebGPU and WebAssembly backend. +To verify the installation, activate your virtual environment, run -## Getting Started +```bash +python -c "import mlc_llm; print(mlc_llm.__path__)" +``` -Please visit our [documentation](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions. +You are expected to see the installation path of MLC LLM Python package. -## Model Support +### Chat CLI -MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can -use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list, and [Compile Models via MLC](https://llm.mlc.ai/docs/compilation/compile_models.html) to see how to use models not on this list. +We can try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ArchitecturePrebuilt Model Variants
LlamaLlama-2, Code Llama, Vicuna, WizardLM, WizardMath, OpenOrca Platypus2, FlagAlpha Llama-2 Chinese, georgesung Llama-2 Uncensored
GPT-NeoXRedPajama
GPT-J
RWKVRWKV-raven
MiniGPT
GPTBigCodeWizardCoder
ChatGLM
StableLM
Mistral
Phi
+```bash +mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC +``` + +It may take 1-2 minutes for the first time running this command. +After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. + +``` +You can use the following special commands: +/help print the special commands +/exit quit the cli +/stats print out the latest stats (token/sec) +/reset restart a fresh chat +/set [overrides] override settings in the generation config. For example, + `/set temperature=0.5;max_gen_len=100;stop=end,stop` + Note: Separate stop words in the `stop` option with commas (,). +Multi-line input: Use escape+enter to start a new line. + +user: What's the meaning of life +assistant: +What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + +The concept of the meaning of life has been debated and... +``` + +### Python API + +We can run the Llama-3 model with the chat completion Python API of MLC LLM. +You can save the code below into a Python file and run it. + +```python +from mlc_llm import MLCEngine + +# Create engine +model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" +engine = MLCEngine(model) + +# Run chat completion in OpenAI API. +for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, +): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) +print("\n") + +engine.terminate() +``` + +**The Python API of `mlc_llm.MLCEngine` fully aligns with OpenAI API**. +You can use MLCEngine in the same way of using +[OpenAI's Python package](https://github.com/openai/openai-python?tab=readme-ov-file#usage) +for both synchronous and asynchronous generation. + +If you would like to do concurrent asynchronous generation, you can use `mlc_llm.AsyncMLCEngine` instead. + +### REST Server + +We can launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. +The server has fully OpenAI API completeness. + +```bash +mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC +``` + +The server is hooked at `http://127.0.0.1:8000` by default, and you can use `--host` and `--port` +to set a different host and port. +When the server is ready (showing `INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)`), +we can open a new shell and send a cURL request via the following command: + +```bash +curl -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} + ] + }' \ + http://127.0.0.1:8000/v1/chat/completions +``` ## Universal Deployment APIs MLC LLM provides multiple sets of APIs across platforms and environments. These include -* [Python API](https://llm.mlc.ai/docs/deploy/python.html) +* [Python API](https://llm.mlc.ai/docs/deploy/python_engine.html) * [OpenAI-compatible Rest-API](https://llm.mlc.ai/docs/deploy/rest.html) * [C++ API](https://llm.mlc.ai/docs/deploy/cli.html) * [JavaScript API](https://llm.mlc.ai/docs/deploy/javascript.html) and [Web LLM](https://github.com/mlc-ai/web-llm) @@ -165,7 +185,7 @@ The underlying techniques of MLC LLM include:
References (Click to expand) - + ```bibtex @inproceedings{tensorir, author = {Feng, Siyuan and Hou, Bohan and Jin, Hongyi and Lin, Wuwei and Shao, Junru and Lai, Ruihang and Ye, Zihao and Zheng, Lianmin and Yu, Cody Hao and Yu, Yong and Chen, Tianqi}, diff --git a/android/library/prepare_libs.sh b/android/library/prepare_libs.sh index a06e9f067d..c089927d09 100755 --- a/android/library/prepare_libs.sh +++ b/android/library/prepare_libs.sh @@ -27,6 +27,7 @@ cmake .. \ -DMLC_LLM_INSTALL_STATIC_LIB=ON \ -DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON \ -DUSE_OPENCL=ON \ + -DUSE_OPENCL_ENABLE_HOST_PTR=ON \ -DUSE_CUSTOM_LOGGING=ON \ cmake --build . --target tvm4j_runtime_packed --config release diff --git a/cpp/json_ffi/config.cc b/cpp/json_ffi/config.cc new file mode 100644 index 0000000000..8f5c0e1062 --- /dev/null +++ b/cpp/json_ffi/config.cc @@ -0,0 +1,357 @@ +#include "config.h" + +#include + +#include "../metadata/json_parser.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +using namespace mlc::llm; + +/****************** Model-defined generation config ******************/ + +TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode); + +ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p, + double frequency_penalty, + double presence_penalty) { + ObjectPtr n = make_object(); + n->temperature = temperature; + n->top_p = top_p; + n->frequency_penalty = frequency_penalty; + n->presence_penalty = presence_penalty; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig") + .set_body_typed([](double temperature, double top_p, double frequency_penalty, + double presence_penalty) { + return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty); + }); + +/****************** Conversation template ******************/ + +std::map PLACEHOLDERS = { + {MessagePlaceholders::SYSTEM, "{system_message}"}, + {MessagePlaceholders::USER, "{user_message}"}, + {MessagePlaceholders::ASSISTANT, "{assistant_message}"}, + {MessagePlaceholders::TOOL, "{tool_message}"}, + {MessagePlaceholders::FUNCTION, "{function_string}"}}; + +MessagePlaceholders MessagePlaceholderFromString(const std::string& role) { + static const std::unordered_map enum_map = { + {"system", MessagePlaceholders::SYSTEM}, {"user", MessagePlaceholders::USER}, + {"assistant", MessagePlaceholders::ASSISTANT}, {"tool", MessagePlaceholders::TOOL}, + {"function", MessagePlaceholders::FUNCTION}, + }; + + return enum_map.at(role); +} + +Conversation::Conversation() + : role_templates({{"user", PLACEHOLDERS[MessagePlaceholders::USER]}, + {"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]}, + {"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {} + +std::vector Conversation::CheckMessageSeps(std::vector& seps) { + if (seps.size() == 0 || seps.size() > 2) { + throw std::invalid_argument("seps should have size 1 or 2."); + } + return seps; +} + +std::optional> Conversation::AsPrompt(std::string* err) { + // Get the system message + std::string system_msg = system_template; + size_t pos = system_msg.find(PLACEHOLDERS[MessagePlaceholders::SYSTEM]); + if (pos != std::string::npos) { + system_msg.replace(pos, PLACEHOLDERS[MessagePlaceholders::SYSTEM].length(), + this->system_message); + } + + // Get the message strings + std::vector message_list; + std::vector separators = seps; + if (separators.size() == 1) { + separators.push_back(separators[0]); + } + + if (!system_msg.empty()) { + system_msg += separators[0]; + message_list.push_back(TextData(system_message)); + } + + for (int i = 0; i < messages.size(); i++) { + std::string role = messages[i].role; + std::optional>> content = + messages[i].content; + if (roles.find(role) == roles.end()) { + *err += "\nRole " + role + " is not supported. "; + return std::nullopt; + } + + std::string separator = separators[role == "assistant"]; // check assistant role + + // If content is empty, add the role and separator + // assistant's turn to generate text + if (!content.has_value()) { + message_list.push_back(TextData(roles[role] + role_empty_sep)); + continue; + } + + std::string message = ""; + std::string role_prefix = ""; + // Do not append role prefix if this is the first message and there + // is already a system message + if (add_role_after_system_message || system_msg.empty() || i != 0) { + role_prefix = roles[role] + role_content_sep; + } + + message += role_prefix; + + for (auto& item : content.value()) { + if (item.find("type") == item.end()) { + *err += "Content item should have a type field"; + return std::nullopt; + } + if (item["type"] == "text") { + if (item.find("text") == item.end()) { + *err += "Content item should have a text field"; + return std::nullopt; + } + // replace placeholder[ROLE] with input message from role + std::string role_text = role_templates[role]; + std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)]; + size_t pos = role_text.find(placeholder); + if (pos != std::string::npos) { + role_text.replace(pos, placeholder.length(), item["text"]); + } + if (use_function_calling.has_value() && use_function_calling.value()) { + // replace placeholder[FUNCTION] with function_string + // this assumes function calling is used for a single request scenario only + if (!function_string.has_value()) { + *err += "Function string is required for function calling"; + return std::nullopt; + } + pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]); + if (pos != std::string::npos) { + role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(), + function_string.value()); + } + } + message += role_text; + } else { + *err += "Unsupported content type: " + item["type"]; + return std::nullopt; + } + } + + message += separator; + message_list.push_back(TextData(message)); + } + + return message_list; +} + +std::optional Conversation::FromJSON(const picojson::object& json, std::string* err) { + Conversation conv; + + // name + std::string name; + if (json::ParseJSONField(json, "name", name, err, false)) { + conv.name = name; + } + + std::string system_template; + if (!json::ParseJSONField(json, "system_template", system_template, err, true)) { + return std::nullopt; + } + conv.system_template = system_template; + + std::string system_message; + if (!json::ParseJSONField(json, "system_message", system_message, err, true)) { + return std::nullopt; + } + conv.system_message = system_message; + + picojson::array system_prefix_token_ids_arr; + if (json::ParseJSONField(json, "system_prefix_token_ids", system_prefix_token_ids_arr, err, + false)) { + std::vector system_prefix_token_ids; + for (const auto& token_id : system_prefix_token_ids_arr) { + if (!token_id.is()) { + *err += "system_prefix_token_ids should be an array of integers."; + return std::nullopt; + } + system_prefix_token_ids.push_back(token_id.get()); + } + conv.system_prefix_token_ids = system_prefix_token_ids; + } + + bool add_role_after_system_message; + if (!json::ParseJSONField(json, "add_role_after_system_message", add_role_after_system_message, + err, true)) { + return std::nullopt; + } + conv.add_role_after_system_message = add_role_after_system_message; + + picojson::object roles_object; + if (!json::ParseJSONField(json, "roles", roles_object, err, true)) { + return std::nullopt; + } + std::unordered_map roles; + for (const auto& role : roles_object) { + if (!role.second.is()) { + *err += "roles should be a map of string to string."; + return std::nullopt; + } + roles[role.first] = role.second.get(); + } + conv.roles = roles; + + picojson::object role_templates_object; + if (json::ParseJSONField(json, "role_templates", role_templates_object, err, false)) { + for (const auto& role : role_templates_object) { + if (!role.second.is()) { + *err += "role_templates should be a map of string to string."; + return std::nullopt; + } + conv.role_templates[role.first] = role.second.get(); + } + } + + picojson::array messages_arr; + if (!json::ParseJSONField(json, "messages", messages_arr, err, true)) { + return std::nullopt; + } + std::vector messages; + for (const auto& message : messages_arr) { + if (!message.is()) { + *err += "messages should be an array of objects."; + return std::nullopt; + } + picojson::object message_obj = message.get(); + std::string role; + if (!json::ParseJSONField(message_obj, "role", role, err, true)) { + *err += "role field is required in messages."; + return std::nullopt; + } + picojson::array content_arr; + std::vector> content; + if (json::ParseJSONField(message_obj, "content", content_arr, err, false)) { + for (const auto& item : content_arr) { + if (!item.is()) { + *err += "Content item is not an object"; + return std::nullopt; + } + std::unordered_map item_map; + picojson::object item_obj = item.get(); + for (picojson::value::object::const_iterator i = item_obj.begin(); i != item_obj.end(); + ++i) { + item_map[i->first] = i->second.to_str(); + } + content.push_back(item_map); + } + } + messages.push_back({role, content}); + } + conv.messages = messages; + + picojson::array seps_arr; + if (!json::ParseJSONField(json, "seps", seps_arr, err, true)) { + return std::nullopt; + } + std::vector seps; + for (const auto& sep : seps_arr) { + if (!sep.is()) { + *err += "seps should be an array of strings."; + return std::nullopt; + } + seps.push_back(sep.get()); + } + conv.seps = seps; + + std::string role_content_sep; + if (!json::ParseJSONField(json, "role_content_sep", role_content_sep, err, true)) { + return std::nullopt; + } + conv.role_content_sep = role_content_sep; + + std::string role_empty_sep; + if (!json::ParseJSONField(json, "role_empty_sep", role_empty_sep, err, true)) { + return std::nullopt; + } + conv.role_empty_sep = role_empty_sep; + + picojson::array stop_str_arr; + if (!json::ParseJSONField(json, "stop_str", stop_str_arr, err, true)) { + return std::nullopt; + } + std::vector stop_str; + for (const auto& stop : stop_str_arr) { + if (!stop.is()) { + *err += "stop_str should be an array of strings."; + return std::nullopt; + } + stop_str.push_back(stop.get()); + } + conv.stop_str = stop_str; + + picojson::array stop_token_ids_arr; + if (!json::ParseJSONField(json, "stop_token_ids", stop_token_ids_arr, err, true)) { + return std::nullopt; + } + std::vector stop_token_ids; + for (const auto& stop : stop_token_ids_arr) { + if (!stop.is()) { + *err += "stop_token_ids should be an array of integers."; + return std::nullopt; + } + stop_token_ids.push_back(stop.get()); + } + conv.stop_token_ids = stop_token_ids; + + std::string function_string; + if (!json::ParseJSONField(json, "function_string", function_string, err, false)) { + conv.function_string = function_string; + } + + bool use_function_calling; + if (json::ParseJSONField(json, "use_function_calling", use_function_calling, err, false)) { + conv.use_function_calling = use_function_calling; + } + + return conv; +} + +std::optional Conversation::FromJSON(const std::string& json_str, std::string* err) { + std::optional json_obj = json::LoadJSONFromString(json_str, err); + if (!json_obj.has_value()) { + return std::nullopt; + } + return Conversation::FromJSON(json_obj.value(), err); +} + +/****************** JSON FFI engine config ******************/ + +TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode); + +JSONFFIEngineConfig::JSONFFIEngineConfig( + String conv_template, Map model_generation_cfgs) { + ObjectPtr n = make_object(); + n->conv_template = conv_template; + n->model_generation_cfgs = model_generation_cfgs; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig") + .set_body_typed([](String conv_template, + Map model_generation_cfgs) { + return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs)); + }); + +} // namespace json_ffi +} // namespace llm +} // namespace mlc diff --git a/cpp/json_ffi/config.h b/cpp/json_ffi/config.h new file mode 100644 index 0000000000..fe5e4e42e2 --- /dev/null +++ b/cpp/json_ffi/config.h @@ -0,0 +1,172 @@ +#ifndef MLC_LLM_JSON_FFI_CONFIG_H +#define MLC_LLM_JSON_FFI_CONFIG_H + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../serve/data.h" +#include "picojson.h" + +using namespace mlc::llm::serve; + +namespace mlc { +namespace llm { +namespace json_ffi { + +/****************** Model-defined generation config ******************/ + +class ModelDefinedGenerationConfigNode : public Object { + public: + double temperature; + double top_p; + double frequency_penalty; + double presence_penalty; + + static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object); +}; + +class ModelDefinedGenerationConfig : public ObjectRef { + public: + explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty, + double presence_penalty); + + TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef, + ModelDefinedGenerationConfigNode); +}; + +/****************** Conversation template ******************/ + +enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; + +MessagePlaceholders messagePlaceholderFromString(const std::string& role); + +class Message { + public: + std::string role; + std::optional>> content = std::nullopt; +}; + +/** + * @brief A struct that specifies the convention template of conversation + * and contains the conversation history. + */ +struct Conversation { + // Optional name of the template. + std::optional name = std::nullopt; + + // The system prompt template, it optionally contains the system + // message placeholder, and the placeholder will be replaced with + // the system message below. + std::string system_template; + + // The content of the system prompt (without the template format). + std::string system_message; + + // The system token ids to be prepended at the beginning of tokenized + // generated prompt. + std::optional> system_prefix_token_ids = std::nullopt; + + // Whether or not to append user role and separator after the system message. + // This is mainly for [INST] [/INST] style prompt format + bool add_role_after_system_message = true; + + // The conversation roles + std::unordered_map roles; + + // The roles prompt template, it optionally contains the defaults + // message placeholders and will be replaced by actual content + std::unordered_map role_templates; + + // The conversation history messages. + // Each message is a pair of strings, denoting "(role, content)". + // The content can be None. + std::vector messages; + + // The separators between messages when concatenating into a single prompt. + // List size should be either 1 or 2. + // - When size is 1, the separator will be used between adjacent messages. + // - When size is 2, seps[0] is used after user message, and + // seps[1] is used after assistant message. + std::vector seps; + + // The separator between the role and the content in a message. + std::string role_content_sep; + + // The separator between the role and empty contents. + std::string role_empty_sep; + + // The stop criteria + std::vector stop_str; + std::vector stop_token_ids; + + // Function call fields + // whether using function calling or not, helps check for output message format in API call + std::optional function_string = std::nullopt; + std::optional use_function_calling = false; + + Conversation(); + + /** + * @brief Checks the size of the separators vector. + * This function checks if the size of the separators vector is either 1 or 2. + * If the size is not 1 or 2, it throws an invalid_argument exception. + */ + static std::vector CheckMessageSeps(std::vector& seps); + + /*! + * \brief Create the list of prompts from the messages based on the conversation template. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + std::optional> AsPrompt(std::string* err); + + /*! + * \brief Create a Conversation instance from the given JSON object. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const picojson::object& json, std::string* err); + + /*! + * \brief Parse and create a Conversation instance from the given JSON string. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const std::string& json_str, std::string* err); +}; + +/****************** JSON FFI engine config ******************/ + +class JSONFFIEngineConfigNode : public Object { + public: + String conv_template; + Map model_generation_cfgs; + + static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object); +}; + +class JSONFFIEngineConfig : public ObjectRef { + public: + explicit JSONFFIEngineConfig(String conv_template, + Map model_generation_cfgs); + + TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode); +}; + +} // namespace json_ffi +} // namespace llm +} // namespace mlc + +#endif /* MLC_LLM_JSON_FFI_CONV_TEMPLATE_H */ diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index b02a28ca89..d5fc53b8fa 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -51,33 +51,40 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request // TODO: Check if request_id is present already // inputs - // TODO: Apply conv template - Array inputs; + Conversation conv_template = this->conv_template_; + std::vector messages; for (const auto& message : request.messages) { - if (message.content.has_value()) { - for (const auto& content : message.content.value()) { - if (content.find("type") == content.end()) { - err_ += "Content should have a type field"; - return false; - } - std::string type = content.at("type"); - if (type == "text") { - if (content.find("text") == content.end()) { - err_ += "Content should have a text field"; - return false; - } - std::string text = content.at("text"); - inputs.push_back(TextData(text)); - } else { - err_ += "Content type not supported"; - return false; - } - } + std::string role; + if (message.role == Role::user) { + role = "user"; + } else if (message.role == Role::assistant) { + role = "assistant"; + } else if (message.role == Role::tool) { + role = "tool"; + } else { + role = "system"; } + messages.push_back({role, message.content}); + } + messages.push_back({"assistant", std::nullopt}); + conv_template.messages = messages; + + // check function calling + bool success_check = request.CheckFunctionCalling(conv_template, &err_); + if (!success_check) { + return false; + } + + // get prompt + std::optional> inputs_obj = conv_template.AsPrompt(&err_); + if (!inputs_obj.has_value()) { + return false; } + Array inputs = inputs_obj.value(); // generation_cfg - Optional generation_cfg = GenerationConfig::FromJSON(request_json_str, &err_); + Optional generation_cfg = GenerationConfig::Create( + request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]); if (!generation_cfg.defined()) { return false; } @@ -103,6 +110,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); TVM_MODULE_VTABLE_ENTRY("init_background_engine", &JSONFFIEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("reload", &JSONFFIEngineImpl::Reload); + TVM_MODULE_VTABLE_ENTRY("unload", &JSONFFIEngineImpl::Unload); + TVM_MODULE_VTABLE_ENTRY("reset", &JSONFFIEngineImpl::Reset); TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); @@ -112,9 +122,20 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(EngineConfig engine_config, - Optional request_stream_callback, + void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config, + Device device, Optional request_stream_callback, Optional trace_recorder) { + std::optional conv_template = + Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_); + if (!conv_template.has_value()) { + LOG(FATAL) << "Invalid conversation template JSON: " << err_; + } + this->conv_template_ = conv_template.value(); + this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs; + + // Todo(mlc-team): decouple InitBackgroundEngine into two functions + // by removing `engine_config` from arguments, after properly handling + // streamers. this->streamer_ = TextStreamer(Tokenizer::FromPath(engine_config->model)); CHECK(request_stream_callback.defined()) @@ -129,10 +150,17 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitBackgroundEngine( - std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); + this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback), + std::move(trace_recorder)); + this->engine_->Reload(std::move(engine_config)); } + void Reload(EngineConfig engine_config) { this->engine_->Reload(std::move(engine_config)); } + + void Unload() { this->engine_->Unload(); } + + void Reset() { this->engine_->Reset(); } + void RunBackgroundLoop() { this->engine_->RunBackgroundLoop(); } void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); } diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index 83013b5876..d57384abb5 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -12,6 +12,7 @@ #include "../serve/threaded_engine.h" #include "../streamer.h" +#include "config.h" #include "openai_api_protocol.h" namespace mlc { @@ -47,6 +48,8 @@ class JSONFFIEngine { std::string err_; PackedFunc request_stream_callback_; TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request + Conversation conv_template_; + Map model_generation_cfgs; }; } // namespace json_ffi diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 41378fc3e0..13f4b140ce 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -11,14 +11,166 @@ namespace mlc { namespace llm { namespace json_ffi { -std::optional ChatCompletionMessage::FromJSON(const picojson::value& json, - std::string* err) { - if (!json.is()) { - *err += "Input is not a valid JSON object"; +std::string generate_uuid_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} + +std::optional ChatFunction::FromJSON(const picojson::object& json_obj, + std::string* err) { + ChatFunction chatFunc; + + // description (optional) + std::string description; + if (json::ParseJSONField(json_obj, "description", description, err, false)) { + chatFunc.description = description; + } + + // name + std::string name; + if (!json::ParseJSONField(json_obj, "name", name, err, true)) { + return std::nullopt; + } + chatFunc.name = name; + + // parameters + picojson::object parameters_obj; + if (!json::ParseJSONField(json_obj, "parameters", parameters_obj, err, true)) { + return std::nullopt; + } + std::unordered_map parameters; + for (picojson::value::object::const_iterator i = parameters_obj.begin(); + i != parameters_obj.end(); ++i) { + parameters[i->first] = i->second.to_str(); + } + chatFunc.parameters = parameters; + + return chatFunc; +} + +picojson::object ChatFunction::ToJSON() const { + picojson::object obj; + if (this->description.has_value()) { + obj["description"] = picojson::value(this->description.value()); + } + obj["name"] = picojson::value(this->name); + picojson::object parameters_obj; + for (const auto& pair : this->parameters) { + parameters_obj[pair.first] = picojson::value(pair.second); + } + obj["parameters"] = picojson::value(parameters_obj); + return obj; +} + +std::optional ChatTool::FromJSON(const picojson::object& json_obj, std::string* err) { + ChatTool chatTool; + + // function + picojson::object function_obj; + if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { + return std::nullopt; + } + + std::optional function = ChatFunction::FromJSON(function_obj, err); + if (!function.has_value()) { return std::nullopt; } - picojson::object json_obj = json.get(); + chatTool.function = function.value(); + + return chatTool; +} +picojson::object ChatTool::ToJSON() const { + picojson::object obj; + obj["type"] = picojson::value("function"); + obj["function"] = picojson::value(this->function.ToJSON()); + return obj; +} + +std::optional ChatFunctionCall::FromJSON(const picojson::object& json_obj, + std::string* err) { + ChatFunctionCall chatFuncCall; + + // name + std::string name; + if (!json::ParseJSONField(json_obj, "name", name, err, true)) { + return std::nullopt; + } + chatFuncCall.name = name; + + // arguments + picojson::object arguments_obj; + if (json::ParseJSONField(json_obj, "arguments", arguments_obj, err, false)) { + std::unordered_map arguments; + for (picojson::value::object::const_iterator i = arguments_obj.begin(); + i != arguments_obj.end(); ++i) { + arguments[i->first] = i->second.to_str(); + } + chatFuncCall.arguments = arguments; + } + + return chatFuncCall; +} + +picojson::object ChatFunctionCall::ToJSON() const { + picojson::object obj; + picojson::object arguments_obj; + if (this->arguments.has_value()) { + for (const auto& pair : this->arguments.value()) { + arguments_obj[pair.first] = picojson::value(pair.second); + } + obj["arguments"] = picojson::value(arguments_obj); + } + + obj["name"] = picojson::value(this->name); + return obj; +} + +std::optional ChatToolCall::FromJSON(const picojson::object& json_obj, + std::string* err) { + ChatToolCall chatToolCall; + + // function + picojson::object function_obj; + if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { + return std::nullopt; + } + + std::optional function = ChatFunctionCall::FromJSON(function_obj, err); + if (!function.has_value()) { + return std::nullopt; + }; + chatToolCall.function = function.value(); + + // overwrite default id + std::string id; + if (!json::ParseJSONField(json_obj, "id", id, err, false)) { + return std::nullopt; + } + chatToolCall.id = id; + + return chatToolCall; +} + +picojson::object ChatToolCall::ToJSON() const { + picojson::object obj; + obj["id"] = picojson::value(this->id); + obj["function"] = picojson::value(this->function.ToJSON()); + obj["type"] = picojson::value("function"); + return obj; +} + +std::optional ChatCompletionMessage::FromJSON( + const picojson::object& json_obj, std::string* err) { ChatCompletionMessage message; // content @@ -65,7 +217,30 @@ std::optional ChatCompletionMessage::FromJSON(const picoj message.name = name; } - // TODO: tool_calls and tool_call_id + // tool calls + picojson::array tool_calls_arr; + if (json::ParseJSONField(json_obj, "tool_calls", tool_calls_arr, err, false)) { + std::vector tool_calls; + for (const auto& item : tool_calls_arr) { + if (!item.is()) { + *err += "Chat Tool Call item is not an object"; + return std::nullopt; + } + picojson::object item_obj = item.get(); + std::optional tool_call = ChatToolCall::FromJSON(item_obj, err); + if (!tool_call.has_value()) { + return std::nullopt; + }; + tool_calls.push_back(tool_call.value()); + } + message.tool_calls = tool_calls; + } + + // tool call id + std::string tool_call_id; + if (json::ParseJSONField(json_obj, "tool_call_id", tool_call_id, err, false)) { + message.tool_call_id = tool_call_id; + } return message; } @@ -81,7 +256,8 @@ std::optional ChatCompletionRequest::FromJSON( } std::vector messages; for (const auto& item : messages_arr) { - std::optional message = ChatCompletionMessage::FromJSON(item, err); + picojson::object item_obj = item.get(); + std::optional message = ChatCompletionMessage::FromJSON(item_obj, err); if (!message.has_value()) { return std::nullopt; } @@ -108,6 +284,32 @@ std::optional ChatCompletionRequest::FromJSON( request.presence_penalty = presence_penalty; } + // tool_choice + std::string tool_choice = "auto"; + request.tool_choice = tool_choice; + if (json::ParseJSONField(json_obj, "tool_choice", tool_choice, err, false)) { + request.tool_choice = tool_choice; + } + + // tools + picojson::array tools_arr; + if (json::ParseJSONField(json_obj, "tools", tools_arr, err, false)) { + std::vector tools; + for (const auto& item : tools_arr) { + if (!item.is()) { + *err += "Chat Tool item is not an object"; + return std::nullopt; + } + picojson::object item_obj = item.get(); + std::optional tool = ChatTool::FromJSON(item_obj, err); + if (!tool.has_value()) { + return std::nullopt; + }; + tools.push_back(tool.value()); + } + request.tools = tools; + } + // TODO: Other parameters return request; @@ -122,7 +324,7 @@ std::optional ChatCompletionRequest::FromJSON(const std:: return ChatCompletionRequest::FromJSON(json_obj.value(), err); } -picojson::object ChatCompletionMessage::ToJSON() { +picojson::object ChatCompletionMessage::ToJSON() const { picojson::object obj; picojson::array content_arr; for (const auto& item : this->content.value()) { @@ -142,13 +344,57 @@ picojson::object ChatCompletionMessage::ToJSON() { } else if (this->role == Role::tool) { obj["role"] = picojson::value("tool"); } - if (name.has_value()) { - obj["name"] = picojson::value(name.value()); + if (this->name.has_value()) { + obj["name"] = picojson::value(this->name.value()); + } + if (this->tool_call_id.has_value()) { + obj["tool_call_id"] = picojson::value(this->tool_call_id.value()); + } + if (this->tool_calls.has_value()) { + picojson::array tool_calls_arr; + for (const auto& tool_call : this->tool_calls.value()) { + tool_calls_arr.push_back(picojson::value(tool_call.ToJSON())); + } + obj["tool_calls"] = picojson::value(tool_calls_arr); } return obj; } -picojson::object ChatCompletionResponseChoice::ToJSON() { +bool ChatCompletionRequest::CheckFunctionCalling(Conversation& conv_template, std::string* err) { + if (!tools.has_value() || (tool_choice.has_value() && tool_choice.value() == "none")) { + conv_template.use_function_calling = false; + return true; + } + std::vector tools_ = tools.value(); + std::string tool_choice_ = tool_choice.value(); + + // TODO: support with tool choice as dict + for (const auto& tool : tools_) { + if (tool.function.name == tool_choice_) { + conv_template.use_function_calling = true; + picojson::value function_str(tool.function.ToJSON()); + conv_template.function_string = function_str.serialize(); + return true; + } + } + + if (tool_choice_ != "auto") { + *err += "Invalid tool_choice value: " + tool_choice_; + return false; + } + + picojson::array function_list; + for (const auto& tool : tools_) { + function_list.push_back(picojson::value(tool.function.ToJSON())); + } + + conv_template.use_function_calling = true; + picojson::value function_list_json(function_list); + conv_template.function_string = function_list_json.serialize(); + return true; +}; + +picojson::object ChatCompletionResponseChoice::ToJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -168,7 +414,7 @@ picojson::object ChatCompletionResponseChoice::ToJSON() { return obj; } -picojson::object ChatCompletionStreamResponseChoice::ToJSON() { +picojson::object ChatCompletionStreamResponseChoice::ToJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -189,11 +435,11 @@ picojson::object ChatCompletionStreamResponseChoice::ToJSON() { return obj; } -picojson::object ChatCompletionResponse::ToJSON() { +picojson::object ChatCompletionResponse::ToJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; - for (auto& choice : this->choices) { + for (const auto& choice : this->choices) { choices_arr.push_back(picojson::value(choice.ToJSON())); } obj["choices"] = picojson::value(choices_arr); @@ -204,11 +450,11 @@ picojson::object ChatCompletionResponse::ToJSON() { return obj; } -picojson::object ChatCompletionStreamResponse::ToJSON() { +picojson::object ChatCompletionStreamResponse::ToJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; - for (auto& choice : this->choices) { + for (const auto& choice : this->choices) { choices_arr.push_back(picojson::value(choice.ToJSON())); } obj["choices"] = picojson::value(choices_arr); diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 1579b5f337..429050da3c 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -8,10 +8,12 @@ #include #include +#include #include #include #include +#include "config.h" #include "picojson.h" namespace mlc { @@ -22,7 +24,8 @@ enum class Role { system, user, assistant, tool }; enum class Type { text, json_object, function }; enum class FinishReason { stop, length, tool_calls, error }; -// TODO: Implement the following class +std::string generate_uuid_string(size_t length); + class ChatFunction { public: std::optional description = std::nullopt; @@ -30,32 +33,37 @@ class ChatFunction { std::unordered_map parameters; // Assuming parameters are string key-value pairs - static std::optional FromJSON(const picojson::value& json, std::string* err); + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; -// TODO: Implement the following class class ChatTool { public: Type type = Type::function; ChatFunction function; - static std::optional FromJSON(const picojson::value& json, std::string* err); + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; -// TODO: Implement the following class class ChatFunctionCall { public: std::string name; std::optional> arguments = std::nullopt; // Assuming arguments are string key-value pairs + + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; -// TODO: Implement the following class class ChatToolCall { public: - std::string id; // TODO: python code initializes this to an random string + std::string id = "call_" + generate_uuid_string(8); Type type = Type::function; ChatFunctionCall function; + + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; class ChatCompletionMessage { @@ -64,12 +72,12 @@ class ChatCompletionMessage { std::nullopt; // Assuming content is a list of string key-value pairs Role role; std::optional name = std::nullopt; - std::optional> tool_calls = std::nullopt; // TODO: Implement this - std::optional tool_call_id = std::nullopt; // TODO: Implement this + std::optional> tool_calls = std::nullopt; + std::optional tool_call_id = std::nullopt; - static std::optional FromJSON(const picojson::value& json, + static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class RequestResponseFormat { @@ -82,8 +90,8 @@ class ChatCompletionRequest { public: std::vector messages; std::string model; - double frequency_penalty = 0.0; - double presence_penalty = 0.0; + std::optional frequency_penalty = std::nullopt; + std::optional presence_penalty = std::nullopt; bool logprobs = false; int top_logprobs = 0; std::optional> logit_bias = std::nullopt; @@ -92,8 +100,8 @@ class ChatCompletionRequest { std::optional seed = std::nullopt; std::optional> stop = std::nullopt; bool stream = false; - double temperature = 1.0; - double top_p = 1.0; + std::optional temperature = std::nullopt; + std::optional top_p = std::nullopt; std::optional> tools = std::nullopt; std::optional tool_choice = std::nullopt; std::optional user = std::nullopt; @@ -113,6 +121,7 @@ class ChatCompletionRequest { static std::optional FromJSON(const std::string& json_str, std::string* err); + bool CheckFunctionCalling(Conversation& conv_template, std::string* err); // TODO: check_penalty_range, check_logit_bias, check_logprobs }; @@ -123,7 +132,7 @@ class ChatCompletionResponseChoice { ChatCompletionMessage message; // TODO: logprobs - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class ChatCompletionStreamResponseChoice { @@ -133,7 +142,7 @@ class ChatCompletionStreamResponseChoice { ChatCompletionMessage delta; // TODO: logprobs - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class ChatCompletionResponse { @@ -146,7 +155,7 @@ class ChatCompletionResponse { std::string object = "chat.completion"; // TODO: usage_info - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class ChatCompletionStreamResponse { @@ -158,7 +167,7 @@ class ChatCompletionStreamResponse { std::string system_fingerprint; std::string object = "chat.completion.chunk"; - picojson::object ToJSON(); + picojson::object ToJSON() const; }; } // namespace json_ffi diff --git a/cpp/metadata/json_parser.h b/cpp/metadata/json_parser.h index f6ff10e1ac..99a284fc42 100644 --- a/cpp/metadata/json_parser.h +++ b/cpp/metadata/json_parser.h @@ -149,6 +149,22 @@ inline ValueType Lookup(const picojson::object& json, const std::string& key) { return it->second.get(); } +template +inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key, + const ValueType& default_value) { + auto it = json.find(key); + if (it == json.end()) { + return default_value; + } + + if (it->second.is()) { + return default_value; + } + + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} + template inline ValueType Lookup(const picojson::array& json, int index) { CHECK(index < json.size()) << "IndexError: json::array index out of range"; diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 5d647ec532..3bb809ad67 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -161,19 +161,35 @@ GenerationConfig::GenerationConfig(String config_json_str) { data_ = std::move(n); } -Optional GenerationConfig::FromJSON(const std::string& json_str, - std::string* err) { - std::optional json_obj = json::LoadJSONFromString(json_str, err); - if (!err->empty() || !json_obj.has_value()) { +Optional GenerationConfig::Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config) { + std::optional optional_json_obj = json::LoadJSONFromString(json_str, err); + if (!err->empty() || !optional_json_obj.has_value()) { return NullOpt; } + picojson::object& json_obj = optional_json_obj.value(); ObjectPtr n = make_object(); - // TODO(mlc-team): Pass the parameters from `json_obj` to `n`. + n->temperature = + json::LookupOrDefault(json_obj, "temperature", model_defined_gen_config->temperature); + n->top_p = json::LookupOrDefault(json_obj, "top_p", model_defined_gen_config->top_p); + n->frequency_penalty = json::LookupOrDefault(json_obj, "frequency_penalty", + model_defined_gen_config->frequency_penalty); + n->presence_penalty = json::LookupOrDefault(json_obj, "presence_penalty", + model_defined_gen_config->presence_penalty); + n->logprobs = json::LookupOrDefault(json_obj, "logprobs", false); + n->top_logprobs = static_cast(json::LookupOrDefault(json_obj, "top_logprobs", 0)); + n->ignore_eos = json::LookupOrDefault(json_obj, "ignore_eos", false); - if (!err->empty()) { - return NullOpt; + // Copy stop str from conversation template to generation config + for (auto& stop_str : conv_template.stop_str) { + n->stop_strs.push_back(stop_str); + } + for (auto& stop_token_id : conv_template.stop_token_ids) { + n->stop_token_ids.push_back(stop_token_id); } + GenerationConfig gen_config; gen_config.data_ = std::move(n); return gen_config; @@ -228,37 +244,85 @@ String GenerationConfigNode::AsJSONString() const { TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); EngineConfig::EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, - int max_total_sequence_length, int max_single_sequence_length, - int prefill_chunk_size, SpeculativeMode speculative_mode, - int spec_draft_length) { + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, + SpeculativeMode speculative_mode, int spec_draft_length) { ObjectPtr n = make_object(); n->model = std::move(model); n->model_lib_path = std::move(model_lib_path); n->additional_models = std::move(additional_models); n->additional_model_lib_paths = std::move(additional_model_lib_paths); - n->device = device; n->kv_cache_page_size = kv_cache_page_size; n->max_num_sequence = max_num_sequence; n->max_total_sequence_length = max_total_sequence_length; n->max_single_sequence_length = max_single_sequence_length; n->prefill_chunk_size = prefill_chunk_size; + n->max_history_size = max_history_size; + n->kv_state_kind = kv_state_kind; n->spec_draft_length = spec_draft_length; n->speculative_mode = speculative_mode; data_ = std::move(n); } +EngineConfig EngineConfig::FromJSONString(const std::string& json_str) { + picojson::value config_json; + std::string err = picojson::parse(config_json, json_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + + // Get json fields. + picojson::object config = config_json.get(); + String model = json::Lookup(config, "model"); + String model_lib_path = json::Lookup(config, "model_lib_path"); + std::vector additional_models; + std::vector additional_model_lib_paths; + int kv_cache_page_size = json::Lookup(config, "kv_cache_page_size"); + int max_num_sequence = json::Lookup(config, "max_num_sequence"); + int max_total_sequence_length = json::Lookup(config, "max_total_sequence_length"); + int max_single_sequence_length = json::Lookup(config, "max_single_sequence_length"); + int prefill_chunk_size = json::Lookup(config, "prefill_chunk_size"); + int max_history_size = json::Lookup(config, "max_history_size"); + KVStateKind kv_state_kind = + static_cast(json::Lookup(config, "kv_state_kind")); + SpeculativeMode speculative_mode = + static_cast(json::Lookup(config, "speculative_mode")); + int spec_draft_length = json::Lookup(config, "spec_draft_length"); + + picojson::array additional_models_arr = + json::Lookup(config, "additional_models"); + picojson::array additional_model_lib_paths_arr = + json::Lookup(config, "additional_model_lib_paths"); + CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size()) + << "The number of additional model lib paths does not match the number of additional models"; + int num_additional_models = additional_models_arr.size(); + additional_models.reserve(num_additional_models); + additional_model_lib_paths.reserve(num_additional_models); + for (int i = 0; i < num_additional_models; ++i) { + additional_models.push_back(json::Lookup(additional_models_arr, i)); + additional_model_lib_paths.push_back( + json::Lookup(additional_model_lib_paths_arr, i)); + } + + return EngineConfig(std::move(model), std::move(model_lib_path), additional_models, + additional_model_lib_paths, kv_cache_page_size, max_num_sequence, + max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, + max_history_size, kv_state_kind, speculative_mode, spec_draft_length); +} + TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") .set_body_typed([](String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, int speculative_mode, - int spec_draft_length) { + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, int max_history_size, + int kv_state_kind, int speculative_mode, int spec_draft_length) { return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), - std::move(additional_model_lib_paths), device, kv_cache_page_size, + std::move(additional_model_lib_paths), kv_cache_page_size, max_num_sequence, max_total_sequence_length, max_single_sequence_length, - prefill_chunk_size, SpeculativeMode(speculative_mode), spec_draft_length); + prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind), + SpeculativeMode(speculative_mode), spec_draft_length); }); } // namespace serve diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 404566fe2c..fd76dd49f0 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -11,12 +11,15 @@ #include +#include "../json_ffi/config.h" + namespace mlc { namespace llm { namespace serve { using namespace tvm; using namespace tvm::runtime; +using namespace mlc::llm::json_ffi; /****************** GenerationConfig ******************/ @@ -60,10 +63,13 @@ class GenerationConfig : public ObjectRef { explicit GenerationConfig(String config_json_str); /*! - * \brief Parse the generation config from the given JSON string. - * When parsing fails, errors are dumped to the input error string, and NullOpt is returned. + * \brief Create a generation config from a ChatCompletionRequest. + * If the request does not contain a generation config, the model-defined + * generation config will be used. */ - static Optional FromJSON(const std::string& json_str, std::string* err); + static Optional Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config); TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; @@ -80,6 +86,12 @@ enum class SpeculativeMode : int { kEagle = 2, }; +/*! \brief The kind of cache. */ +enum KVStateKind { + kAttention = 0, + kRNNState = 1, +}; + /*! \brief The configuration of engine execution config. */ class EngineConfigNode : public Object { public: @@ -94,11 +106,6 @@ class EngineConfigNode : public Object { /*! \brief The path to the additional models' libraries. */ Array additional_model_lib_paths; - /*************** Device ***************/ - - /*! \brief The device where the models run. */ - DLDevice device; - /*************** KV cache config and engine capacities ***************/ /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ @@ -117,6 +124,10 @@ class EngineConfigNode : public Object { int max_single_sequence_length; /*! \brief The maximum total sequence length in a prefill. */ int prefill_chunk_size; + /*! \brief The maximum history size for RNN state. KV cache does not need this. */ + int max_history_size; + /*! \brief The kind of cache. Whether it's KV cache or RNN state. */ + KVStateKind kv_state_kind; /*************** Speculative decoding ***************/ @@ -136,11 +147,15 @@ class EngineConfigNode : public Object { class EngineConfig : public ObjectRef { public: explicit EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, SpeculativeMode speculative_mode, int spec_draft_length); + /*! \brief Create EngineConfig from JSON string. */ + static EngineConfig FromJSONString(const std::string& json_str); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 85d1c66c2d..d82c886355 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -44,7 +45,8 @@ class EngineImpl : public Engine { public: /********************** Engine Management **********************/ - explicit EngineImpl(EngineConfig engine_config, Optional request_stream_callback, + explicit EngineImpl(EngineConfig engine_config, DLDevice device, + Optional request_stream_callback, Optional trace_recorder) { // Step 1. Initialize metadata and singleton states inside the engine this->estate_->Reset(); @@ -62,14 +64,24 @@ class EngineImpl : public Engine { this->models_.clear(); this->model_workspaces_.clear(); - auto f_create_model = [this, &engine_config, &trace_recorder](const String& model_path, - const String& model_lib_path) { - Model model = Model::Create(model_lib_path, std::move(model_path), engine_config->device, - engine_config->max_num_sequence, + std::vector model_configs; + model_configs.push_back(Model::LoadModelConfig(engine_config->model)); + for (const auto& model_path : engine_config->additional_models) { + model_configs.push_back(Model::LoadModelConfig(model_path)); + } + + Optional session = CreateDiscoSession(model_configs, device); + + auto f_create_model = [this, &engine_config, &device, &trace_recorder, &model_configs, + &session](const String& model_path, const String& model_lib_path, + int model_index) { + Model model = Model::Create(model_lib_path, std::move(model_path), model_configs[model_index], + device, engine_config->max_num_sequence, session, /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, - engine_config->prefill_chunk_size); + engine_config->prefill_chunk_size, engine_config->max_history_size, + engine_config->kv_state_kind); CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length) << "The window size of the model, " << model->GetMaxWindowSize() << ", is smaller than the pre-defined max single sequence length, " @@ -79,18 +91,18 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); }; - f_create_model(engine_config->model, engine_config->model_lib_path); + f_create_model(engine_config->model, engine_config->model_lib_path, /*model_index=*/0); CHECK_EQ(engine_config->additional_models.size(), engine_config->additional_model_lib_paths.size()) << "The additional model and lib path list has mismatched size."; for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { f_create_model(engine_config->additional_models[i], - engine_config->additional_model_lib_paths[i]); + engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); } int max_num_tokens = engine_config->max_num_sequence; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { - max_num_tokens *= engine_config->spec_draft_length; + max_num_tokens *= engine_config->spec_draft_length + 1; } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); @@ -102,18 +114,18 @@ class EngineImpl : public Engine { ICHECK_GT(this->models_.size(), 1U); switch (engine_config->speculative_mode) { case SpeculativeMode::kEagle: - this->actions_ = { - EngineAction::EagleNewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, - this->model_workspaces_, this->trace_recorder_), - EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, engine_config, - this->trace_recorder_)}; + this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::EagleBatchDraft( + this->models_, logit_processor, sampler, this->model_workspaces_, + this->trace_recorder_, engine_config->spec_draft_length), + EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, engine_config, + this->trace_recorder_)}; break; default: this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // @@ -143,6 +155,7 @@ class EngineImpl : public Engine { } void Reset() final { + AbortAllRequests(); estate_->Reset(); for (Model model : models_) { model->Reset(); @@ -167,7 +180,8 @@ class EngineImpl : public Engine { request = Request::FromUntokenized(request, tokenizer_); ICHECK_NE(request->input_total_length, -1); - if (request->input_total_length >= engine_config_->max_single_sequence_length) { + if (request->input_total_length >= engine_config_->max_single_sequence_length && + request_stream_callback_.defined()) { // If the request input length exceeds the maximum allowed single sequence length, // invoke callback and do not process the request. Array output{RequestStreamOutput( @@ -240,6 +254,28 @@ class EngineImpl : public Engine { // The request to abort is in waiting queue estate_->waiting_queue.erase(it_waiting); } + + // Send a callback to notice the abortion. + if (request_stream_callback_.defined()) { + Array output{RequestStreamOutput( + request_id, std::vector(request->generation_cfg->n), + Optional>>(), + std::vector>(request->generation_cfg->n, String("abort")))}; + request_stream_callback_.value()(std::move(output)); + } + } + + void AbortAllRequests() final { + // - Collect all the request ids. + std::vector request_ids; + request_ids.reserve(estate_->request_states.size()); + for (const auto& kv : estate_->request_states) { + request_ids.push_back(kv.first); + } + // - Abort all the requests. + for (const String& request_id : request_ids) { + AbortRequest(request_id); + } } /*********************** Engine Action ***********************/ @@ -261,6 +297,51 @@ class EngineImpl : public Engine { "action (e.g. prefill, decode, etc.) but it does not."; } + /************** Utility Functions **************/ + Optional CreateDiscoSession(std::vector model_configs, Device device) { + const auto& base_model_config = model_configs[0]; + + auto f_get_num_shards = [](const picojson::object& model_config) -> int { + constexpr auto kNumShardsKey = "tensor_parallel_shards"; + if (model_config.count(kNumShardsKey)) { + const auto& val = model_config.at(kNumShardsKey); + CHECK(val.is()); + return static_cast(val.get()); + } else { + LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; + } + throw; + }; + + int num_shards = std::transform_reduce( + model_configs.begin(), model_configs.end(), 1, [](int a, int b) { return std::max(a, b); }, + f_get_num_shards); + Optional session = NullOpt; + if (num_shards > 1) { + constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; + if (Registry::Get(f_create_process_pool) == nullptr) { + LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " + << "Multi-GPU inference depends on MLC LLM Python API to launch process."; + } + std::string ccl; + if (device.device_type == kDLCUDA) { + ccl = "nccl"; + } else if (device.device_type == kDLROCM) { + ccl = "rccl"; + } else { + LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) + << " is not supported. Currently, only NCCL and RCCL are integrated."; + } + std::vector device_ids(num_shards); + for (int i = 0; i < num_shards; ++i) { + device_ids[i] = i; + } + session = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); + session.value()->InitCCL(ccl, ShapeTuple(device_ids)); + } + return session; + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { @@ -314,10 +395,11 @@ class EngineImpl : public Engine { Optional trace_recorder_; }; -std::unique_ptr Engine::Create(EngineConfig engine_config, +std::unique_ptr Engine::Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - return std::make_unique(std::move(engine_config), std::move(request_stream_callback), + return std::make_unique(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } @@ -343,10 +425,10 @@ class EngineModule : public ModuleNode { TVM_MODULE_VTABLE_END(); /*! \brief Initialize the engine with config and other fields. */ - void Init(EngineConfig engine_config, Optional request_stream_callback, + void Init(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - this->engine_ = Engine::Create(std::move(engine_config), std::move(request_stream_callback), - std::move(trace_recorder)); + this->engine_ = Engine::Create(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } /*! \brief Construct an EngineModule. */ static tvm::runtime::Module Create() { return Module(make_object()); } diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index fc5e4205ae..2fc0a4d730 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -51,11 +51,12 @@ class Engine { /*! * \brief Create an engine in unique pointer. * \param engine_config The engine config. + * \param device The device where the run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. * \return The created Engine in pointer. */ - static std::unique_ptr Create(EngineConfig engine_config, + static std::unique_ptr Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder); @@ -82,6 +83,9 @@ class Engine { /*! \brief Abort the input request (specified by id string) from engine. */ virtual void AbortRequest(const String& request_id) = 0; + /*! \brief Abort all requests from the engine. */ + virtual void AbortAllRequests() = 0; + /*********************** Engine Action ***********************/ /*! diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index aea455a1be..78e3937d0b 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -47,7 +47,7 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector sample_results = sampler_->BatchSampleTokens( + std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( probs_on_device, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index b56f7fa9b6..c1ddeb6e4e 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -116,8 +116,10 @@ class BatchDraftActionObj : public EngineActionObj { std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 6f38292ba3..42c9bbe018 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -7,6 +7,7 @@ #include #include +#include #include "../../random.h" #include "../config.h" @@ -42,8 +43,8 @@ class BatchVerifyActionObj : public EngineActionObj { return {}; } - const auto& [rsentries, draft_lengths, total_draft_length] = GetDraftsToVerify(estate); - ICHECK_EQ(rsentries.size(), draft_lengths.size()); + const auto& [rsentries, verify_lengths, total_verify_length] = GetDraftsToVerify(estate); + ICHECK_EQ(rsentries.size(), verify_lengths.size()); if (rsentries.empty()) { return {}; } @@ -62,7 +63,7 @@ class BatchVerifyActionObj : public EngineActionObj { std::vector> draft_output_tokens; std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_rsentries); - all_tokens_to_verify.reserve(total_draft_length); + all_tokens_to_verify.reserve(total_verify_length); verify_request_mstates.reserve(num_rsentries); rngs.reserve(num_rsentries); generation_cfg.reserve(num_rsentries); @@ -73,12 +74,12 @@ class BatchVerifyActionObj : public EngineActionObj { RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; RequestModelState draft_mstate = rsentries[i]->mstates[draft_model_id_]; request_internal_ids.push_back(verify_mstate->internal_id); - ICHECK(!draft_lengths.empty()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_prob_dist.size()); - // the last committed token + all the draft tokens but the last one. + ICHECK(!verify_lengths.empty()); + ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_tokens.size() + 1); + ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_prob_dist.size() + 1); + // the last committed token + all the draft tokens. all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); - for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()) - 1; ++j) { + for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); } verify_request_mstates.push_back(verify_mstate); @@ -95,19 +96,19 @@ class BatchVerifyActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "start verify"); NDArray logits = - models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids, draft_lengths); + models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids, verify_lengths); RECORD_EVENT(trace_recorder_, request_ids, "finish verify"); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], total_draft_length); + ICHECK_EQ(logits->shape[1], total_verify_length); // - Update logits. std::vector cum_verify_lengths = {0}; cum_verify_lengths.reserve(num_rsentries + 1); for (int i = 0; i < num_rsentries; ++i) { - cum_verify_lengths.push_back(cum_verify_lengths.back() + draft_lengths[i]); + cum_verify_lengths.push_back(cum_verify_lengths.back() + verify_lengths[i]); } - logits = logits.CreateView({total_draft_length, logits->shape[2]}, logits->dtype); + logits = logits.CreateView({total_verify_length, logits->shape[2]}, logits->dtype); logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates, request_ids, &cum_verify_lengths, &draft_output_tokens); @@ -115,9 +116,14 @@ class BatchVerifyActionObj : public EngineActionObj { NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits( logits, generation_cfg, request_ids, &cum_verify_lengths); - std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( - probs_on_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, - draft_output_prob_dist); + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector> sample_results_arr = + sampler_->BatchVerifyDraftTokensWithProbAfterTopP( + renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, + draft_output_tokens, draft_output_prob_dist); ICHECK_EQ(sample_results_arr.size(), num_rsentries); for (int i = 0; i < num_rsentries; ++i) { @@ -128,10 +134,8 @@ class BatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); } estate->stats.total_accepted_length += accept_length; - // - Minus one because the last draft token has no kv cache entry - // - Take max with 0 in case of all accepted. int rollback_length = - std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length - 1, 0); + std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0); // rollback kv cache // NOTE: when number of small models is more than 1 (in the future), // it is possible to re-compute prefill for the small models. @@ -158,10 +162,10 @@ class BatchVerifyActionObj : public EngineActionObj { struct DraftRequestStateEntries { /*! \brief The request state entries to verify. */ Array draft_rsentries; - /*! \brief The draft length of each request state. */ - std::vector draft_lengths; + /*! \brief The length to verify for each request state. */ + std::vector verify_lengths; /*! \brief The total draft length. */ - int total_draft_length; + int total_verify_length; }; /*! @@ -171,8 +175,8 @@ class BatchVerifyActionObj : public EngineActionObj { * state and input length. */ DraftRequestStateEntries GetDraftsToVerify(EngineState estate) { - std::vector draft_lengths; - int total_draft_length = 0; + std::vector verify_lengths; + int total_verify_length = 0; int total_required_pages = 0; int num_available_pages = models_[verify_model_id_]->GetNumAvailablePages(); @@ -184,24 +188,24 @@ class BatchVerifyActionObj : public EngineActionObj { int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size(); int num_require_pages = (draft_length + engine_config_->kv_cache_page_size - 1) / engine_config_->kv_cache_page_size; - draft_lengths.push_back(draft_length); + verify_lengths.push_back(draft_length + 1); num_page_requirement.push_back(num_require_pages); - total_draft_length += draft_length; + total_verify_length += draft_length + 1; total_required_pages += num_require_pages; } while (!CanVerify(total_required_pages)) { RequestStateEntry preempted = PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { - total_draft_length -= draft_lengths.back(); + total_verify_length -= verify_lengths.back(); total_required_pages -= num_page_requirement.back(); - draft_lengths.pop_back(); + verify_lengths.pop_back(); num_page_requirement.pop_back(); running_rsentries.pop_back(); } } - return {running_rsentries, draft_lengths, total_draft_length}; + return {running_rsentries, verify_lengths, total_verify_length}; } bool CanVerify(int num_required_pages) { diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc index 50393c38a2..fde314a5c5 100644 --- a/cpp/serve/engine_actions/eagle_batch_draft.cc +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -145,8 +145,10 @@ class EagleBatchDraftActionObj : public EngineActionObj { std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 043f68b9c2..b259417050 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -88,7 +88,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { generation_cfg.push_back(rsentries[i]->request->generation_cfg); rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - CHECK(draft_mstate->draft_output_prob_dist[0]->device.device_type == kDLCPU); draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } @@ -129,10 +128,14 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Compute probability distributions. NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits( logits, generation_cfg, request_ids, &cum_verify_lengths); - - std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( - probs_on_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, - draft_output_prob_dist); + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector> sample_results_arr = + sampler_->BatchVerifyDraftTokensWithProbAfterTopP( + renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, + draft_output_tokens, draft_output_prob_dist); ICHECK_EQ(sample_results_arr.size(), num_rsentries); std::vector last_hidden_states; @@ -230,8 +233,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj { std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 133c23e8a1..a687e7eb7f 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -277,8 +277,10 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } } std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); // - Update the committed tokens of states. @@ -459,7 +461,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // No exceeding of the maximum allowed requests that can // run simultaneously. int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? engine_config_->spec_draft_length + ? (engine_config_->spec_draft_length + 1) : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index c3f7491960..b4192a04f1 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -229,7 +229,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { rsentry_activated.push_back(true); } } - std::vector sample_results = sampler_->BatchSampleTokens( + std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( probs_on_device, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); @@ -396,10 +396,15 @@ class NewRequestPrefillActionObj : public EngineActionObj { int num_running_rsentries) { ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); + // For RNN State, it can prefill as long as it can be instantiated. + if (engine_config_->kv_state_kind == KVStateKind::kRNNState) { + return true; + } + // No exceeding of the maximum allowed requests that can // run simultaneously. int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? engine_config_->spec_draft_length + ? (engine_config_->spec_draft_length + 1) : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { diff --git a/cpp/serve/event_trace_recorder.h b/cpp/serve/event_trace_recorder.h index fd98cc844a..76e87ca710 100644 --- a/cpp/serve/event_trace_recorder.h +++ b/cpp/serve/event_trace_recorder.h @@ -22,7 +22,7 @@ using namespace tvm::runtime; class EventTraceRecorderObj : public Object { public: /*! - * \brief Record a event for the the input request in the trace recorder. + * \brief Record a event for the input request in the trace recorder. * \param request_id The subject request of the event. * \param event The event in a string name. * It can have one of the following patterns: diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index fa24828399..3267f1dd38 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -69,7 +69,8 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, }); } -void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config) { +void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session) { local_gpu_device = device; Device null_device{DLDeviceType(0), 0}; int num_shards; @@ -85,29 +86,10 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->cached_buffers = Map(); if (num_shards > 1) { - constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; - if (Registry::Get(f_create_process_pool) == nullptr) { - LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " - << "Multi-GPU inference depends on MLC LLM Python API to launch process."; - } - std::string ccl; - if (device.device_type == kDLCUDA) { - ccl = "nccl"; - } else if (device.device_type == kDLROCM) { - ccl = "rccl"; - } else { - LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) - << " is not supported. Currently, only NCCL and RCCL are integrated."; - } - std::vector device_ids(num_shards); - for (int i = 0; i < num_shards; ++i) { - device_ids[i] = i; - } + this->sess = session.value(); this->use_disco = true; - this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); - this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), - std::move(reload_lib_path), null_device); + reload_lib_path, null_device); this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { @@ -130,14 +112,23 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->_InitFunctions(); } else { Module executable{nullptr}; - if (false) { - // Todo(mlc-team): system lib reload // reload_lib_path starts with "system://" + PackedFunc fload_exec{nullptr}; + if (StartsWith(reload_lib_path, "system://")) { + const PackedFunc* f_load_system_lib = Registry::Get("runtime.SystemLib"); + ICHECK_NOTNULL(f_load_system_lib); + std::string system_lib_prefix = std::string(reload_lib_path).substr(9); + std::replace(system_lib_prefix.begin(), system_lib_prefix.end(), /*old=*/'-', /*new=*/'_'); + executable = (*f_load_system_lib)(system_lib_prefix + "_"); + fload_exec = executable->GetFunction("vm_load_executable"); + ICHECK(fload_exec.defined()) + << "Cannot find system lib with " << system_lib_prefix + << ", please make sure you set model_lib field consistently with the compilation "; } else { executable = tvm::runtime::Module::LoadFromFile(reload_lib_path); + fload_exec = executable->GetFunction("vm_load_executable"); + ICHECK(fload_exec.defined()) << "TVM runtime cannot find vm_load_executable"; } this->use_disco = false; - auto fload_exec = executable->GetFunction("vm_load_executable"); - ICHECK(fload_exec.defined()) << "TVM runtime cannot find vm_load_executable"; this->local_vm = fload_exec(); this->local_vm->GetFunction("vm_initialization")( static_cast(device.device_type), device.device_id, @@ -225,8 +216,8 @@ void FunctionTable::_InitFunctions() { this->verify_to_last_hidden_func_ = mod_get_func("batch_verify_to_last_hidden_states"); this->fuse_embed_hidden_func_ = mod_get_func("fuse_embed_hidden_states"); Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; - this->get_logits_func_ = mod->GetFunction("get_logits", true); - this->batch_get_logits_func_ = mod->GetFunction("batch_get_logits", true); + this->get_logits_func_ = mod_get_func("get_logits"); + this->batch_get_logits_func_ = mod_get_func("batch_get_logits"); this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true); this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); @@ -235,7 +226,12 @@ void FunctionTable::_InitFunctions() { this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor"); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (!this->create_kv_cache_func_.defined()) { - this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state"); + if (f_create_rnn_state.defined()) { + this->create_kv_cache_func_ = f_create_rnn_state; + } else { + this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + } ICHECK(this->create_kv_cache_func_.defined()); } this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); @@ -256,6 +252,8 @@ void FunctionTable::_InitFunctions() { gpu_argsort_probs_func_ = mod->GetFunction("argsort_probs", true); gpu_sample_with_top_p_func_ = mod->GetFunction("sample_with_top_p", true); gpu_sampler_take_probs_func_ = mod->GetFunction("sampler_take_probs", true); + gpu_verify_draft_tokens_func_ = mod->GetFunction("sampler_verify_draft_tokens", true); + gpu_renormalize_by_top_p_func_ = mod->GetFunction("renormalize_by_top_p", true); } this->nd_view_func_ = get_global_func("vm.builtin.reshape"); this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index f6a156b8a3..bc2b4f21c8 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -41,7 +41,8 @@ using namespace tvm::runtime; struct FunctionTable { static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name); - void Init(String reload_lib_path, Device device, picojson::object model_config); + void Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session); ObjectRef LoadParams(const std::string& model_path, Device device); @@ -104,6 +105,8 @@ struct FunctionTable { PackedFunc gpu_argsort_probs_func_; PackedFunc gpu_sample_with_top_p_func_; PackedFunc gpu_sampler_take_probs_func_; + PackedFunc gpu_verify_draft_tokens_func_; + PackedFunc gpu_renormalize_by_top_p_func_; PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; diff --git a/cpp/serve/grammar/grammar_serializer.h b/cpp/serve/grammar/grammar_serializer.h index 8746b1f6ae..4ad5c2103b 100644 --- a/cpp/serve/grammar/grammar_serializer.h +++ b/cpp/serve/grammar/grammar_serializer.h @@ -77,7 +77,7 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { }; /*! - * \brief Serialize the the raw representation of the BNF AST to a string with JSON format. + * \brief Serialize the raw representation of the BNF AST to a string with JSON format. * \sa BNFJSONParser::Parse for parsing the JSON string. * \details JSON format: * { diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index d9954f1e28..5c4ef98efe 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -469,9 +469,10 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenizer") TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenTable") .set_body([](TVMArgs args, TVMRetValue* rv) { BNFGrammar grammar = args[0]; + Array token_table_arr = args[1]; std::vector token_table; - for (int i = 1; i < args.size() - 1; ++i) { - token_table.push_back(args[i]); + for (int i = 0; i < token_table_arr.size(); ++i) { + token_table.push_back(token_table_arr[i]); } int max_rollback_steps = args[args.size() - 1]; auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table); diff --git a/cpp/serve/grammar/json_schema_converter.cc b/cpp/serve/grammar/json_schema_converter.cc index 93d693f3c6..83be710cf5 100644 --- a/cpp/serve/grammar/json_schema_converter.cc +++ b/cpp/serve/grammar/json_schema_converter.cc @@ -23,6 +23,14 @@ namespace serve { using namespace tvm::runtime; +// EMCC somehow cannot pickup operator overload from picojson.h, so we copy here. +#ifdef COMPILE_MLC_WASM_RUNTIME +inline std::ostream& operator<<(std::ostream& os, const picojson::value& x) { + x.serialize(std::ostream_iterator(os)); + return os; +} +#endif + /*! * \brief Manage the indent and separator for the generation of EBNF grammar. * \param indent The number of spaces for each indent. If it is std::nullopt, there will be no diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 17121d8e28..6f34220219 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -13,6 +13,7 @@ #include +#include "config.h" #include "logit_processor.h" namespace mlc { @@ -25,10 +26,27 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); -Model Model::Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) { - return Model( - make_object(reload_lib_path, model_path, device, max_num_sequence, trace_enabled)); +Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) { + return Model(make_object(reload_lib_path, model_path, model_config, device, + max_num_sequence, session, trace_enabled)); +} + +picojson::object Model::LoadModelConfig(const String& model_path) { + picojson::object model_config; + std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); + std::ostringstream config_ostream; + ICHECK(config_istream); + config_ostream << config_istream.rdbuf(); + std::string config_str = config_ostream.str(); + picojson::value config_json; + std::string err = picojson::parse(config_json, config_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + picojson::object config = config_json.get(); + return config; } class ModelImpl : public ModelObj { @@ -37,23 +55,16 @@ class ModelImpl : public ModelObj { * \brief Constructor of ModelImpl. * \sa Model::Create */ - explicit ModelImpl(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) + explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) : device_(device) { // Step 1. Process model config json string. - picojson::object model_config; - { - std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); - std::ostringstream config_ostream; - ICHECK(config_istream); - config_ostream << config_istream.rdbuf(); - std::string config_str = config_ostream.str(); - model_config = LoadModelConfigJSON(config_str); - } + LoadModelConfigJSON(model_config); // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. - this->ft_.Init(reload_lib_path, device_, model_config); + this->ft_.Init(reload_lib_path, device_, model_config, session); // Step 3. Load params in nd-array cache. this->params_ = ft_.LoadParams(model_path, device_); // Step 4. Set max_num_sequence @@ -68,6 +79,12 @@ class ModelImpl : public ModelObj { token_ids_storage_ = memory::Storage( allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); + // Step 7. Set model type + if (model_config["model_type"].get().find("rwkv") != std::string::npos) { + this->kind = KVStateKind::kRNNState; + } else { + this->kind = KVStateKind::kAttention; + } } /*********************** Model Computation ***********************/ @@ -136,16 +153,23 @@ class ModelImpl : public ModelObj { ICHECK_EQ(hidden_states->device.device_type, device_.device_type); ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - hidden_states_dref_or_nd = + hidden_states = hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); + // This copy can be avoided by not copying the hidden states to engine. + hidden_states_dref_or_nd = ft_.CopyToWorker0( + hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - NDArray logits; - logits = Downcast(ret); + NDArray logits{nullptr}; + if (ret->IsInstance()) { + logits = Downcast(ret)->DebugGetFromRemote(0); + } else { + logits = Downcast(ret); + } CHECK(logits.defined()); // logits: (b * s, v) ICHECK_EQ(logits->ndim, 2); @@ -185,8 +209,11 @@ class ModelImpl : public ModelObj { ICHECK_EQ(hidden_states->device.device_type, device_.device_type); ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - hidden_states_dref_or_nd = - hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + hidden_states = hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + + // This copy can be avoided by not copying the hidden states to engine. + hidden_states_dref_or_nd = ft_.CopyToWorker0( + hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); ObjectRef ret = ft_.batch_get_logits_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); @@ -218,8 +245,15 @@ class ModelImpl : public ModelObj { p_logit_pos[i] = total_length - 1; } NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); + + // This step runs on the engine thread. + // By temporarily turning off the disco flag, this copies the logit_pos_nd to the cached device + // tensor without actually copying to the worker. + bool use_disco = ft_.use_disco; + ft_.use_disco = false; ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); + ft_.use_disco = use_disco; CHECK(ft_.batch_select_last_hidden_func_.defined()) << "`batch_select_last_hidden_states` function is not found in the model."; @@ -240,7 +274,7 @@ class ModelImpl : public ModelObj { hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); ObjectRef ret = - ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); + ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } @@ -265,10 +299,17 @@ class ModelImpl : public ModelObj { // No ICHECK_EQ(hidden->shape[0], hidden_size_) here to allow different hidden_sizes. hidden = hidden.CreateView({1, hidden_size_}, hidden->dtype); // Reuse the copy embedding function - ft_.nd_copy_embedding_to_offset_func_(hidden, *dst, cum_length); + ObjectRef hidden_dref_or_nd = + ft_.CopyToWorker0(hidden, "hidden_for_concat", {1, hidden_size_}); + ft_.nd_copy_embedding_to_offset_func_(hidden_dref_or_nd, *dst, cum_length); cum_length += 1; } - NDArray ret = Downcast(*dst); + NDArray ret{nullptr}; + if ((*dst)->IsInstance()) { + ret = Downcast(*dst)->DebugGetFromRemote(0); + } else { + ret = Downcast(*dst); + } ret = ret.CreateView({cum_length, hidden_size_}, hidden_states[0]->dtype); return ret; } @@ -295,7 +336,7 @@ class ModelImpl : public ModelObj { return embeddings_nd.CreateView({batch_size, seq_len, hidden_size_}, embeddings_nd->dtype); } } else { - ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; + ShapeTuple embedding_shape{batch_size * seq_len, hidden_size_}; embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { @@ -715,16 +756,26 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) final { - IntTuple max_num_sequence_tuple{max_num_sequence}; - IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; - IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; - IntTuple page_size_tuple{page_size}; - IntTuple support_sliding_window{sliding_window_size_ != -1}; - kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, - prefill_chunk_size_tuple, page_size_tuple, - support_sliding_window); - local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) final { + if (kv_state_kind == KVStateKind::kAttention) { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; + IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; + IntTuple page_size_tuple{page_size}; + IntTuple support_sliding_window{sliding_window_size_ != -1}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, + prefill_chunk_size_tuple, page_size_tuple, + support_sliding_window); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } else { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_history_size_tuple = {std::max(max_history_size, 1)}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_history_size_tuple); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } } void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } @@ -751,11 +802,21 @@ class ModelImpl : public ModelObj { /************** Raw Info Query **************/ int GetNumAvailablePages() const final { - return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not introduce new page at runtime + return std::numeric_limits::max(); + } else { + return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + } } int GetCurrentTotalSequenceLength() const final { - return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not have a total sequence length limit + return 0; + } else { + return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + } } /*********************** Utilities ***********************/ @@ -768,9 +829,7 @@ class ModelImpl : public ModelObj { Sampler CreateSampler(int max_num_sample, int num_models, Optional trace_recorder) { - if (num_models > 1) { // speculative decoding uses cpu sampler - return Sampler::CreateCPUSampler(std::move(trace_recorder)); - } else if (Sampler::SupportGPUSampler(device_)) { + if (Sampler::SupportGPUSampler(device_)) { return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, std::move(trace_recorder)); } else { @@ -842,15 +901,7 @@ class ModelImpl : public ModelObj { private: /*! \brief Load model configuration from JSON. */ - picojson::object LoadModelConfigJSON(const std::string& config_str) { - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); + picojson::object LoadModelConfigJSON(picojson::object config) { if (config.count("context_window_size")) { CHECK(config["context_window_size"].is()); this->max_window_size_ = config["context_window_size"].get(); @@ -924,6 +975,8 @@ class ModelImpl : public ModelObj { NDArray logit_pos_arr_{nullptr}; // A boolean indicating if tracing is enabled. bool trace_enabled_; + // An enum indicating whether it's RNN-based. + KVStateKind kind; }; TVM_REGISTER_GLOBAL("mlc.copy_embedding_to_offset") diff --git a/cpp/serve/model.h b/cpp/serve/model.h index da532f83e8..bc63840a74 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -234,9 +234,13 @@ class ModelObj : public Object { * in the engine. * \param prefill_chunk_size The maximum total number of tokens whose KV data * are allowed to exist in the KV cache at any time. + * \param max_history_size The maximum history size for RNN state to roll back. + * The KV cache does not need this. + * \param kv_state_kind The kind of cache. It can be KV cache or RNN state. */ virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) = 0; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; @@ -315,13 +319,24 @@ class Model : public ObjectRef { * \brief Create the runtime module for LLM functions. * \param reload_lib_path The model library path. * \param model_path The path to the model weight parameters. + * \param model_config The model config json object. * \param device The device to run the model on. * \param max_num_sequence The maximum number of sequences to be processed + * \param session The session to run the model on. * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ - TVM_DLL static Model Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled); + TVM_DLL static Model Create(String reload_lib_path, String model_path, + const picojson::object& model_config, DLDevice device, + int max_num_sequence, const Optional& session, + bool trace_enabled); + + /*! + * Load the model config from the given model path. + * \param model_path The path to the model weight parameters. + * \return The model config json object. + */ + static picojson::object LoadModelConfig(const String& model_path); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); }; diff --git a/cpp/serve/radix_tree.cc b/cpp/serve/radix_tree.cc new file mode 100644 index 0000000000..5d5c311593 --- /dev/null +++ b/cpp/serve/radix_tree.cc @@ -0,0 +1,718 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/radix_tree.cc + */ +#include "radix_tree.h" + +#include + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief The sequence ID linked list structure in paged radix tree node. + */ +struct SequenceIDNode { + /*! \brief The stored sequence ID. */ + int64_t id = 0; + /*! \brief The pointer to the next sequence ID. */ + SequenceIDNode* next = nullptr; +}; + +/*! + * \brief The sequence Id node pool. + * + * The sequence Id node pool allocates all sequence ID nodes when construction and frees when + * destruction, to avoid frequent memory operation. + */ +class SequenceIDNodePool { + public: + /*! \brief The constructor of sequence Id node pool, allocating memory for each node. */ + SequenceIDNodePool(size_t num_nodes) : num_nodes_(num_nodes) { + nodes_.reserve(num_nodes); + free_node_indicess_.reserve(num_nodes); + used_nodes_.clear(); + raw_pool_ = new SequenceIDNode[num_nodes_]; + for (size_t i = 0; i < num_nodes; ++i) { + nodes_.push_back(&raw_pool_[i]); + free_node_indicess_.push_back(i); + } + } + + /*! + * \brief Get a radix page from pool, and assign the fields. + * \param seq_id The assigned sequence ID of allocated sequence ID node. + * \param node The next sequence ID node pointer of allocated sequence ID node. + * \return The allocated radix page. + * \throw Error if no free radix page available in pool. + */ + SequenceIDNode* Allocate(int64_t seq_id, SequenceIDNode* next) { + CHECK(!free_node_indicess_.empty()) << "Sequence ID node pool has no free sequence ID nodes."; + size_t id = free_node_indicess_.back(); + free_node_indicess_.pop_back(); + SequenceIDNode* node = nodes_[id]; + used_nodes_[node] = id; + node->id = seq_id; + node->next = next; + return node; + } + + /*! + * \brief Free a sequence ID node to pool. + * \param node The sequence ID node to free. + */ + void Free(SequenceIDNode* node) { + CHECK(used_nodes_.find(node) != used_nodes_.end()); + free_node_indicess_.push_back(used_nodes_[node]); + used_nodes_.erase(node); + } + + /*! \brief The destructor of sequence Id node pool, freeing memory for each node. */ + ~SequenceIDNodePool() { delete[] raw_pool_; } + + private: + /*! \brief The number of nodes in sequence ID node pool. */ + size_t num_nodes_; + /*! \brief The raw sequence ID node pool. */ + SequenceIDNode* raw_pool_; + /*! \brief The sequence ID node pool. */ + std::vector nodes_; + /*! \brief The indices of free sequence ID node in node pool. */ + std::vector free_node_indicess_; + /*! \brief The map from used paged sequence ID node to its index in node pool. */ + std::unordered_map used_nodes_; +}; + +/*! + * \brief The paged radix tree node data structure. + * + * The paged radix tree node is similar to original radix tree node, but with the limited length for + * prefix in page, so that the memory usage in each page is the same and is fixed once allocated. + * Since the page only consists of pointers and int tokens, the page memory layout is int array + * indeed. The lower offset is the pointers and page information, while the higher offset is the + * stored prefix tokens. + * + * And since the vocabulary size may be very large, the paged Radix tree is represented + * as left-child, right-sibling binary tree. + * + * Also, due to possible pop/push front/back tokens in page, the page is designed as circular + * buffer, to make full use of each page. + * + * Each page records the sequence excatly ends with the prefix tokens stored in page. In other word, + * all sequences locate in the boundary of each page, or the end of each page. + */ +struct RedixPage { + /*! \brief The parent page. */ + RedixPage* parent; + /*! \brief The first child page. */ + RedixPage* first_child; + /*! \brief The sibling page shareing the same parent page. */ + RedixPage* next_sibiling; + /*! \brief The head of sequence ID linked list. */ + SequenceIDNode* seq_ids; + /*! \brief The capacity of maximum stored prefix tokens. */ + size_t capacity; + /*! \brief The start offset of stored prefix tokens. The legal value is of [0, capacity). */ + size_t offset; + /*! \brief The length of stored prefix tokens. The legal value is of [0, capacity). */ + size_t length; + /*! \brief The offset of first prefix token in memory layout. */ + static constexpr int DATA_OFFSET = (sizeof(RedixPage*) * 3 + sizeof(SequenceIDNode*) + + sizeof(size_t) * 3 + sizeof(int32_t) - 1) / + sizeof(int32_t); + + /*! + * \brief Overload opeartor [] to get the prefix tokens by index as simple int array. + * \param i The prefix token index. + * \return The value of i-th prefix token. + */ + int32_t& operator[](size_t i) { + return reinterpret_cast(this)[DATA_OFFSET + (i + offset) % capacity]; + } + + /*! + * \brief Extend or push back a suffix tokens in page. + * \param suffix The suffix tokens array. + * \param suffix_length The suffix length to extend. + * \throw Error if suffix length is larger than current vacant space. + */ + void Extend(const int64_t* suffix, size_t suffix_length) { + CHECK_LE(suffix_length + length, capacity); + for (int i = 0; i < suffix_length; ++i) { + (*this)[i + length] = (int32_t)suffix[i]; + } + length += suffix_length; + } + + /*! + * \brief Add a sequence ID in page. + * \param pool The sequence ID node pool to allocate new node. + * \param id The sequence ID to add. + */ + void AddSequence(SequenceIDNodePool* pool, int64_t id) { seq_ids = pool->Allocate(id, seq_ids); } + + /*! + * \brief Pop a sequence ID in page. + * \param pool The sequence ID node pool to free popped node. + * \param id The sequence ID to pop. + * \throw Error if no such sequence ID in page. + */ + void PopSequence(SequenceIDNodePool* pool, int64_t id) { + if (seq_ids->id == id) { + // If the popped sequencs ID is the first node in linked list, + // directly skip from head and free it. + SequenceIDNode* next = seq_ids->next; + pool->Free(seq_ids); + seq_ids = next; + } else { + // If the popped sequencs ID is not the first node in linked list, + // skip it from previous node and free it. + SequenceIDNode* last = seq_ids; + SequenceIDNode* cur = seq_ids->next; + while (cur) { + if (cur->id == id) { + last->next = cur->next; + pool->Free(cur); + return; + } + } + LOG(FATAL) << "Sequence ID = " << id << " not found."; + } + } + + /*! + * \brief Get all sequence ID in page. + * \return The std::vector of sequence ID in page. + */ + std::vector GetLocalSequence() { + std::vector output; + for (SequenceIDNode* node = seq_ids; node; node = node->next) { + output.push_back(node->id); + } + return output; + } + + /*! + * \brief Get any sequence ID in current page or child pages. + * Since there is always a sequence in leaf pages, it only check first child if no sequence ID in + * current page. + * \return The any sequence ID in current page or child pages. + */ + int32_t FindAnyChildSequence() { + if (seq_ids) return seq_ids->id; + return first_child->FindAnyChildSequence(); + } + + /*! + * \brief Get all sequence ID in current page and child pages, using Iterate method with lambda + * expression as callback to avoid frequently memory allocation of std::vector. + * \return The std::vector of all sequence ID in current page and child pages. + */ + std::vector FindAllChildSequence() { + std::vector output = GetLocalSequence(); + if (first_child) { + first_child->Iterate([&output](const RedixPage* page) { + for (SequenceIDNode* node = page->seq_ids; node; node = node->next) { + output.push_back(node->id); + } + }); + } + return output; + } + + /*! + * \brief The iteration method for tree or sub-tree traverse. + * \param f The callback function to invoke at each radix page visited. + */ + template + void Iterate(CallbackFunc f) { + f(this); + if (next_sibiling) next_sibiling->Iterate(f); + if (first_child) first_child->Iterate(f); + } + + /*! + * \brief Get the last sibling of current page. + * \return The page whose next_sibling is current page, or nullptr if current is the fisrt_child + * of its parent page. + */ + RedixPage* GetLastSibling() { + if (parent == nullptr) return nullptr; + if (parent->first_child == this) return nullptr; + for (RedixPage* child = parent->first_child; child; child = child->next_sibiling) { + if (child->next_sibiling == this) return child; + } + return nullptr; + } + + /*! + * \brief Find the child indexed by first token. + * \return The child page started with first token, or nullptr if no such child page. + */ + RedixPage* FindChild(int64_t first_token) { + int32_t casted = first_token; + // Iterate all child radix pages, as the child radix pages are stored unorderly. + for (RedixPage* child = first_child; child; child = child->next_sibiling) { + if ((*child)[0] == casted) return child; + } + return nullptr; + } + + /*! \brief Insert a new child page. */ + void InsertChild(RedixPage* child) { + child->parent = this; + child->next_sibiling = first_child; + first_child = child; + } + + /*! + * \brief Remove a child page. + * \throw Error if page to be removed is not child page. + */ + void RemoveChild(RedixPage* child) { + CHECK(child->parent == this); + if (first_child == child) { + first_child = child->next_sibiling; + } else { + child->GetLastSibling()->next_sibiling = child->next_sibiling; + } + } + + /*! + * \brief Check current page is mergable with its child page. + * The page is mergable if and only if + * 1. No sequence ID in current page, as sequence ID is not allowed to exist within page. + * 2. The current page has child page. + * 3. The current page has only one child page. + * 4. The current page perfix and the child page prefix can be concatenated into one page. + * \return True if current page is mergable, or false. + */ + bool Mergeable() { + if (seq_ids) return false; + if (!first_child) return false; + if (first_child->next_sibiling) return false; + if (length + first_child->length > capacity) return false; + return true; + } + + /*! + * \brief Match the given prefix within page. + * \param prefix The prefix token array. + * \param prefix_length The length of prefix token array. + * \return The matched prefix offset within page, or the first mismatched token position. The + * possible return value is [0, page->length], where page->length means the page is completely the + * prefix of given prefix. + */ + size_t MatchPrefix(const int64_t* prefix, size_t prefix_length) { + size_t n = std::min(length, prefix_length); + for (int i = 0; i < n; ++i) { + if ((*this)[i] != prefix[i]) return i; + } + return n; + } +}; + +/*! + * \brief The paged radix tree page pool. + * + * The paged radix tree page pool allocates all radix tree pages when construction and frees when + * destruction, to avoid frequent memory operation. + */ +class RadixPagePool { + public: + /*! \brief The constructor of paged radix tree page pool, allocating memory for each page. */ + RadixPagePool(size_t page_size, size_t num_pages) : page_size_(page_size), num_pages_(num_pages) { + pages_.reserve(num_pages); + free_page_indices_.reserve(num_pages); + raw_pool_ = new int32_t[num_pages * page_size / sizeof(int32_t)]; + int32_t num_int = page_size / sizeof(int32_t); + for (size_t i = 0; i < num_pages; ++i) { + pages_.push_back(reinterpret_cast(raw_pool_ + i * num_int)); + free_page_indices_.push_back(i); + } + } + + /*! + * \brief Get a radix page from pool. + * \return The allocated radix page. + * \throw Error if no free radix page available in pool. + */ + RedixPage* Allocate() { + CHECK(!free_page_indices_.empty()) << "Radix page pool has no free radix tree pages."; + int id = free_page_indices_.back(); + free_page_indices_.pop_back(); + RedixPage* page = pages_[id]; + used_pages_[page] = id; + page->parent = page->first_child = page->next_sibiling = nullptr; + page->capacity = page_size_ / sizeof(int32_t) - RedixPage::DATA_OFFSET; + page->offset = page->length = 0; + page->seq_ids = nullptr; + return page; + } + + /*! + * \brief Free a radix page to pool. + * \param page The radix page to free. + */ + void Free(RedixPage* page) { + CHECK_EQ(page->seq_ids, nullptr); + CHECK(used_pages_.find(page) != used_pages_.end()); + free_page_indices_.push_back(used_pages_[page]); + CHECK(used_pages_.erase(page)); + } + + /*! + * \brief Get the token capacity of free pages. + * \return The the token capacity of free pages. + */ + size_t FreeCapacity() { + return free_page_indices_.size() * (page_size_ / sizeof(int32_t) - RedixPage::DATA_OFFSET); + } + + /*! \brief The destructor of paged radix tree page pool, freeing memory for each page. */ + ~RadixPagePool() { delete[] raw_pool_; } + + private: + /*! \brief The page size of each paged radix tree page. */ + size_t page_size_; + /*! \brief The number of pages in paged radix tree page pool. */ + size_t num_pages_; + /*! \brief The raw paged radix tree page pool. */ + int32_t* raw_pool_; + /*! \brief The paged radix tree page pool. */ + std::vector pages_; + /*! \brief The indices of free paged radix page in page pool. */ + std::vector free_page_indices_; + /*! \brief The map from used paged radix tree page to its index in page pool. */ + std::unordered_map used_pages_; +}; + +// PagedRadixTree + +/*! + * \brief The paged radix tree data structure. + */ +class PagedRadixTreeImpl : public PagedRadixTreeObj { + public: + /*! \brief The page size of each paged radix tree node. */ + size_t page_size; + /*! \brief The number of pages in paged radix tree page pool. */ + size_t num_pages; + /*! \brief The maximum number of sequence ID in paged radix tree page pool. */ + size_t num_seqs; + /*! \brief The map from sequence to paged radix tree node it is stored. */ + std::unordered_map seq2page; + /*! \brief The sequence ID node pool. */ + SequenceIDNodePool* seq_id_node_pool = nullptr; + /*! \brief The radix page pool. */ + RadixPagePool* radix_page_pool = nullptr; + /*! \brief The root page of paged radix tree. */ + RedixPage* root = nullptr; + + explicit PagedRadixTreeImpl(size_t num_pages, size_t page_size, size_t num_seqs) { + num_pages = num_pages; + page_size = page_size; + num_seqs = num_seqs; + + seq_id_node_pool = new SequenceIDNodePool(num_seqs); + radix_page_pool = new RadixPagePool(page_size, num_pages); + + root = reinterpret_cast(new int32_t[RedixPage::DATA_OFFSET]); + root->parent = root->first_child = root->next_sibiling = nullptr; + root->offset = root->length = root->capacity = 0; + root->seq_ids = nullptr; + } + + /*! + * \brief Get a sequence's all tokens. + * \param seq_id The sequence ID for index. + * \return The sequence tokens. + * \throw Error if sequence ID is not valid. + */ + IntTuple GetSequence(int64_t seq_id) { + CHECK(seq2page.find(seq_id) != seq2page.end()); + size_t length = GetSequenceLength(seq_id); + std::vector output(length); + size_t offset = length; + for (RedixPage* page = seq2page[seq_id]; page; page = page->parent) { + offset -= page->length; + for (int i = 0; i < page->length; ++i) { + output[offset + i] = (*page)[i]; + } + } + return IntTuple(output); + } + + /*! + * \brief Get all sequences with longest common prefix with give prefix tokens. + * \param tokens The prefix tokens for reference. + * \return The pair of matched prefix length and the array of matched sequences indices. + */ + std::pair> MatchPrefix(IntTuple tokens) { + const int64_t* prefix = tokens.data(); + size_t length = tokens.size(); + auto [page, offset, in_page_offset] = MatchSequence(root, prefix, length); + if (!offset) return std::make_pair(0, std::vector()); + return std::make_pair(offset, page->FindAllChildSequence()); + } + + /*! + * \brief Get a sequence's length. + * \param seq_id The sequence ID for index. + * \return The sequence length. + * \throw Error if sequence ID is not valid. + */ + size_t GetSequenceLength(int64_t seq_id) { + CHECK(seq2page.find(seq_id) != seq2page.end()); + size_t length = 0; + for (RedixPage* page = seq2page[seq_id]; page; page = page->parent) { + length += page->length; + } + return length; + } + + /*! + * \brief Fork a sequence from parent sequence at given position. + * \param seq_id The new sequence ID. + * \param parent_seq_id The parent sequence ID to fork from. + * \param forked_offset The position of parent sequence to fork at. + * The valid value is [1, length of forked sequence]. If the position equals the length of forked + * sequence, the new sequence will copy the entire forked sequence. + * \throw Error if sequence ID or + * forked postion is not valid. + */ + void ForkSequence(int64_t seq_id, int64_t parent_seq_id, size_t forked_offset) { + CHECK(seq2page.find(seq_id) == seq2page.end()); + CHECK(seq2page.find(parent_seq_id) != seq2page.end()); + CHECK_GT(forked_offset, 0); + size_t length = GetSequenceLength(parent_seq_id); + CHECK_LE(forked_offset, length); + for (RedixPage* page = seq2page[parent_seq_id]; page; page = page->parent) { + if (forked_offset >= length - page->length) { + if (forked_offset < length) { + // Split radix page if forked position is within page + page = SplitPage(page, forked_offset + page->length - length); + } + page->AddSequence(seq_id_node_pool, seq_id); + seq2page[seq_id] = page; + return; + } + length -= page->length; + } + } + + /*! + * \brief Add an empty sequence at root. + * \param seq_id The new sequence ID. + * \throw Error if sequence ID is not valid. + */ + void AddSequence(int64_t seq_id) { + CHECK(seq2page.find(seq_id) == seq2page.end()); + root->AddSequence(seq_id_node_pool, seq_id); + seq2page[seq_id] = root; + } + + /*! + * \brief Extend a sequence with given tokens. + * \param seq_id The sequence ID for index. + * \param tokens The given tokens to extend. + * \throw Error if sequence ID is not valid. + */ + void ExtendSequence(int64_t seq_id, IntTuple tokens) { + CHECK(seq2page.find(seq_id) != seq2page.end()); + const int64_t* suffix = tokens.data(); + size_t length = tokens.size(); + RedixPage* original_page = seq2page[seq_id]; + original_page->PopSequence(seq_id_node_pool, seq_id); + auto [page, offset, in_page_offset] = MatchSequence(original_page, suffix, length); + if (in_page_offset < page->length) { + // Split page if extended sequence mismatches within page + page = SplitPage(page, in_page_offset); + } + if (offset < length && !page->seq_ids && !page->first_child && page->capacity > page->length) { + // Extend in the existing leaf page first if possible. + size_t suffix_length = std::min(page->capacity - page->length, length - offset); + page->Extend(suffix + offset, suffix_length); + offset += suffix_length; + } + while (offset < length) { + // Allocate new radix page and extend tokens + RedixPage* new_page = radix_page_pool->Allocate(); + page->InsertChild(new_page); + page = new_page; + size_t suffix_length = std::min(page->capacity - page->length, length - offset); + page->Extend(suffix + offset, suffix_length); + offset += suffix_length; + } + page->AddSequence(seq_id_node_pool, seq_id); + seq2page[seq_id] = page; + if (original_page->Mergeable()) { + // The original page may be mergeable, as the sequence ID changes + MergePage(original_page); + } + } + + /*! + * \brief Remove a sequence. + * \param seq_id The sequence ID to remove. + * \throw Error if sequence ID is not valid. + */ + void RemoveSequence(int64_t seq_id) { + RedixPage* page = seq2page[seq_id]; + page->PopSequence(seq_id_node_pool, seq_id); + seq2page.erase(seq_id); + while (page->parent && !page->seq_ids && !page->first_child) { + RedixPage* parent = page->parent; + parent->RemoveChild(page); + radix_page_pool->Free(page); + page = parent; + } + if (page && page->Mergeable()) { + // The remaining page may be mergeable, as the sequence ID changes + MergePage(page); + } + } + + /*! + * \brief Get the remaining token capacity of the paged radix tree. + * \return The the remaining token capacity of the paged radix tree. + */ + size_t FreeCapacity() { return radix_page_pool->FreeCapacity(); } + + /*! \brief The destructor to free root page. */ + ~PagedRadixTreeImpl() { + delete[] reinterpret_cast(root); + delete seq_id_node_pool; + delete radix_page_pool; + } + + private: + /*! + * \brief Merge a radix tree page with its child radix tree page, to save radix tree page. + * e.g. MergePage([1, 2, _, _, _] -> [3, 4, 5, _, _]) = [1, 2, 3, 4, 5]. + * And the page to be merged should be page->Mergeable(). + * \param page The parent radix tree page. + */ + void MergePage(RedixPage* page) { + CHECK(page->Mergeable()); + RedixPage* child = page->first_child; + for (int i = 0; i < child->length; ++i) { + (*page)[i + page->length] = (*child)[i]; + } + page->length += child->length; + page->first_child = child->first_child; + for (RedixPage* p = child->first_child; p; p = p->next_sibiling) { + p->parent = page; + } + page->seq_ids = child->seq_ids; + std::vector seq_ids = page->GetLocalSequence(); + for (int64_t id : seq_ids) seq2page[id] = page; + child->seq_ids = nullptr; + radix_page_pool->Free(child); + } + + /*! + * \brief Split a radix tree page at given postition, to accept new sequence. + * e.g. SplitPage([1, 2, 3, 4, 5], 2) = [1, 2, _, _, _] -> [3, 4, 5, _, _]. + * \param page The radix tree page to split. + * \param offset The position to split the radix tree page. + * \return The splitted radix tree page. It can be different from the input radix tree page, as + * there may be implicit radix tree page merge. + */ + RedixPage* SplitPage(RedixPage* page, size_t offset) { + CHECK_LT(offset, page->length); + RedixPage* child = radix_page_pool->Allocate(); + child->parent = page; + child->first_child = page->first_child; + for (RedixPage* p = page->first_child; p; p = p->next_sibiling) { + p->parent = child; + } + page->first_child = child; + for (int i = offset; i < page->length; ++i) { + (*child)[i - offset] = (*page)[i]; + } + child->length = page->length - offset; + page->length = offset; + if (child->Mergeable()) { + // The child page may be mergeable + MergePage(child); + } + if (page->parent && page->parent->Mergeable()) { + // The parent page may be mergeable + page = page->parent; + MergePage(page); + } + return page; + } + + /*! + * \brief Match with given token from a radix tree page, stopping at first mismatch. + * \param page The radix tree page to start matching. + * \param tokens The given tokens to match. + * \param length The length of given tokens. + */ + std::tuple MatchSequence(RedixPage* page, const int64_t* tokens, + size_t length) { + size_t offset = 0; + while (offset < length) { + if (RedixPage* child = page->FindChild(tokens[offset])) { + // If child page starts with offset-th token, common prefix at least ends with child page + size_t matched_offset = child->MatchPrefix(tokens + offset, length - offset); + offset += matched_offset; + if (matched_offset < child->length) { + // Common prefix ends within child page + return std::make_tuple(child, offset, matched_offset); + } + page = child; + } else { + // No child page starts with offset-th token, common prefix ends with current page + return std::make_tuple(page, offset, page->length); + } + } + return std::make_tuple(page, length, page->length); + } +}; + +TVM_REGISTER_OBJECT_TYPE(PagedRadixTreeImpl); + +PagedRadixTree::PagedRadixTree(size_t num_pages, size_t page_size, size_t num_seqs) { + data_ = std::move(make_object(num_pages, page_size, num_pages)); +} + +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTree") + .set_body_typed([](uint64_t num_pages, uint64_t page_size, uint64_t num_seqs) { + return PagedRadixTree(num_pages, page_size, num_seqs); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeMatchPrefix") + .set_body_typed([](PagedRadixTree paged_radix_tree, IntTuple tokens) { + auto [offset, seq_ids] = paged_radix_tree->MatchPrefix(tokens); + seq_ids.insert(seq_ids.begin(), offset); + return IntTuple(seq_ids); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeExtendSequence") + .set_body_method(&PagedRadixTreeObj::ExtendSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeForkSequence") + .set_body_typed([](PagedRadixTree paged_radix_tree, int64_t seq_id, int64_t parent_seq_id, + uint64_t forked_offset) { + paged_radix_tree->ForkSequence(seq_id, parent_seq_id, forked_offset); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeAddSequence") + .set_body_method(&PagedRadixTreeObj::AddSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeRemoveSequence") + .set_body_method(&PagedRadixTreeObj::RemoveSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeGetSequence") + .set_body_method(&PagedRadixTreeObj::GetSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeGetSequenceLength") + .set_body_typed([](PagedRadixTree paged_radix_tree, int64_t seq_id) { + return (int64_t)paged_radix_tree->GetSequenceLength(seq_id); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeFreeCapacity") + .set_body_typed([](PagedRadixTree paged_radix_tree) { + return (int64_t)paged_radix_tree->FreeCapacity(); + }); +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/radix_tree.h b/cpp/serve/radix_tree.h new file mode 100644 index 0000000000..ed831c17b1 --- /dev/null +++ b/cpp/serve/radix_tree.h @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/radix_tree.h + */ +#ifndef MLC_LLM_SERVE_RADIX_TREE_H_ +#define MLC_LLM_SERVE_RADIX_TREE_H_ +#include +#include + +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief The paged radix tree data structure. + */ +class PagedRadixTreeObj : public Object { + public: + /*! + * \brief Get a sequence's all tokens. + * \param seq_id The sequence ID for index. + * \return The sequence tokens. + * \throw Error if sequence ID is not valid. + */ + virtual IntTuple GetSequence(int64_t seq_id) = 0; + + /*! + * \brief Get all sequences with longest common prefix with give prefix tokens. + * \param tokens The prefix tokens for reference. + * \return The pair of matched prefix length and the array of matched sequences indices. + */ + virtual std::pair> MatchPrefix(IntTuple tokens) = 0; + + /*! + * \brief Get a sequence's length. + * \param seq_id The sequence ID for index. + * \return The sequence length. + * \throw Error if sequence ID is not valid. + */ + virtual size_t GetSequenceLength(int64_t seq_id) = 0; + + /*! + * \brief Fork a sequence from parent sequence at given position. + * \param seq_id The new sequence ID. + * \param parent_seq_id The parent sequence ID to fork from. + * \param forked_offset The position of parent sequence to fork at. + * The valid value is [1, length of forked sequence]. If the position equals the length of forked + * sequence, the new sequence will copy the entire forked sequence. + * \throw Error if sequence ID or + * forked postion is not valid. + */ + virtual void ForkSequence(int64_t seq_id, int64_t parent_seq_id, size_t forked_offset) = 0; + + /*! + * \brief Add an empty sequence at root. + * \param seq_id The new sequence ID. + * \throw Error if sequence ID is not valid. + */ + virtual void AddSequence(int64_t seq_id) = 0; + + /*! + * \brief Extend a sequence with given tokens. + * \param seq_id The sequence ID for index. + * \param tokens The given tokens to extend. + * \throw Error if sequence ID is not valid. + */ + virtual void ExtendSequence(int64_t seq_id, IntTuple tokens) = 0; + + /*! + * \brief Remove a sequence. + * \param seq_id The sequence ID to remove. + * \throw Error if sequence ID is not valid. + */ + virtual void RemoveSequence(int64_t seq_id) = 0; + + /*! + * \brief Get the remaining token capacity of the paged radix tree. + * \return The the remaining token capacity of the paged radix tree. + */ + virtual size_t FreeCapacity() = 0; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "mlc.serve.PagedRadixTree"; + TVM_DECLARE_BASE_OBJECT_INFO(PagedRadixTreeObj, Object) +}; + +TVM_REGISTER_OBJECT_TYPE(PagedRadixTreeObj); + +class PagedRadixTree : public ObjectRef { + public: + /*! + * \brief Constructor of paged radix tree. + * \param num_pages The number of radix tree pages. + * \param page_size The page size of each radix tree page. + * \param num_seqs The maximum number of sequence ID. + */ + PagedRadixTree(size_t num_pages, size_t page_size, size_t num_seqs); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PagedRadixTree, ObjectRef, PagedRadixTreeObj); +}; +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_RADIX_TREE_H_ diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index 02b7e2a81d..98080c979d 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include "../../random.h" @@ -43,12 +44,7 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o ICHECK(prob.IsContiguous()); ICHECK(prob.DataType() == DataType::Float(32)); - - if (prob->device.device_type != kDLCPU) { - prob = prob.CopyTo(DLDevice{kDLCPU, 0}); - } - - ICHECK(prob->device.device_type == kDLCPU); + ICHECK_EQ(prob->device.device_type, DLDeviceType::kDLCPU); int64_t ndata = prob->shape[prob->ndim - 1]; const float* __restrict p_prob = @@ -186,6 +182,98 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o return {sampled_index.second, sampled_index.first}; } +/*! + * \brief Renormalize the probability distribution by the top p value. + * \param prob The input batch of probability distributions. + * \param unit_offset The offset specifying which distribution to output + * \param top_p The top p value for renormalization. + * \param eps A small epsilon value for comparison stability. + */ +void RenormalizeProbByTopP(NDArray prob, int unit_offset, double top_p, double eps) { + // prob: (*, v) + // The prob array may have arbitrary ndim and shape. + // The last dimension corresponds to the prob distribution size. + // We use the `unit_offset` parameter to determine which slice + // of the prob array we will renormalize. + ICHECK(prob.IsContiguous()); + ICHECK(prob.DataType() == DataType::Float(32)); + ICHECK_EQ(prob->device.device_type, DLDeviceType::kDLCPU); + + int vocab_size = prob->shape[prob->ndim - 1]; + float* __restrict p_prob = + static_cast(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * vocab_size); + + // We manually choice the cutoff values of "top_p / 256" and "top_p / 8192". + // In most of the cases, only one round is needed. + std::vector cutoff_values{top_p / 256, top_p / 8192, 0.0f}; + + // Create the upper partition vector and the lower partition rolling vectors. + std::vector upper_partition; + std::vector lower_partitions[2]; + upper_partition.reserve(vocab_size); + lower_partitions[0].reserve(vocab_size); + lower_partitions[1].reserve(vocab_size); + float upper_partition_sum = 0.0; + for (int round = 0; round < static_cast(cutoff_values.size()); ++round) { + const float* lower_partition_begin; + const float* lower_partition_end; + if (round == 0) { + lower_partition_begin = p_prob; + lower_partition_end = p_prob + vocab_size; + } else { + int idx = (round - 1) & 1; + lower_partition_begin = lower_partitions[idx].data(); + lower_partition_end = lower_partitions[idx].data() + lower_partitions[idx].size(); + } + + // - Partition the last round lower partition into upper and lower + // based on the new cutoff value. + std::vector& lower_partition = lower_partitions[round & 1]; + lower_partition.clear(); + for (const float* ptr = lower_partition_begin; ptr != lower_partition_end; ++ptr) { + if (*ptr >= cutoff_values[round]) { + upper_partition.push_back(*ptr); + upper_partition_sum += *ptr; + } else { + lower_partition.push_back(*ptr); + } + } + // - If the upper partition sum is at least top p, exit the loop. + if (upper_partition_sum >= top_p - eps) { + break; + } + } + + // - Sort the upper partition in descending order. + std::sort(upper_partition.begin(), upper_partition.end(), std::greater<>()); + // - Find the top p boundary prob value. + float boundary_value = -1.0; + upper_partition_sum = 0.0; + for (float upper_value : upper_partition) { + upper_partition_sum += upper_value; + if (upper_partition_sum >= top_p - eps) { + boundary_value = upper_value; + break; + } + } + // - Mask all values smaller than the boundary to 0. + float renormalize_sum = 0.0; + std::vector upper_partition_indices; + upper_partition_indices.reserve(vocab_size); + for (int i = 0; i < vocab_size; ++i) { + if (p_prob[i] >= boundary_value) { + upper_partition_indices.push_back(i); + renormalize_sum += p_prob[i]; + } else { + p_prob[i] = 0.0; + } + } + // - Renormalize. + for (int idx : upper_partition_indices) { + p_prob[idx] /= renormalize_sum; + } +} + namespace detail { /*! \brief Implementation of getting top probs on CPU. */ @@ -266,68 +354,87 @@ class CPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray probs_on_device, // - const std::vector& sample_indices, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // - std::vector* output_prob_dist) final { + NDArray BatchRenormalizeProbsByTopP(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg) final { // probs_on_device: (n, v) - RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); CHECK_EQ(probs_on_device->ndim, 2); // - Copy probs to CPU RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - NDArray probs_host = CopyProbsToCPU(probs_on_device); + NDArray probs_on_host = CopyProbsToCPU(probs_on_device); RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); - - // - Sample tokens from probabilities. - int n = request_ids.size(); - ICHECK_EQ(generation_cfg.size(), n); - ICHECK_EQ(rngs.size(), n); - - std::vector sample_results; - sample_results.resize(n); - if (output_prob_dist) { - output_prob_dist->resize(n); + int num_samples = sample_indices.size(); + int num_probs = probs_on_device->shape[0]; + int vocab_size = probs_on_device->shape[1]; + ICHECK_EQ(request_ids.size(), num_samples); + ICHECK_EQ(generation_cfg.size(), num_samples); + + std::vector top_p_indices; + std::vector top_p_values; + for (int i = 0; i < num_samples; ++i) { + if (top_p_indices.empty() || top_p_indices.back() != sample_indices[i]) { + top_p_indices.push_back(sample_indices[i]); + top_p_values.push_back(generation_cfg[i]->top_p); + } else { + CHECK(fabs(top_p_values.back() - generation_cfg[i]->top_p) < eps_) + << "Sampler requires the top_p values for each prob distribution are the same."; + } + } + if (top_p_indices.empty()) { + // Return if no top p needs to apply. + return probs_on_host; } tvm::runtime::parallel_for_with_threading_backend( - [this, &sample_results, &probs_host, &generation_cfg, &rngs, &request_ids, sample_indices, - output_prob_dist](int i) { - RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); - // Sample top p from probability. - sample_results[i].sampled_token_id = SampleTopPFromProb( - probs_host, i, sample_indices[i], - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber(), output_prob_dist); - if (output_prob_dist == nullptr) { - // When `output_prob_dist` is not nullptr, it means right now - // we are sampling for a small model in speculation, in which - // case we do not need to get the top probs. - sample_results[i].top_prob_tokens = - ComputeTopProbs(probs_host, i, generation_cfg[i]->top_logprobs); - } - RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); + [this, &probs_on_host, &request_ids, &top_p_indices, &top_p_values](int i) { + RECORD_EVENT(this->trace_recorder_, request_ids[i], "start renormalize by top p"); + RenormalizeProbByTopP(probs_on_host, top_p_indices[i], top_p_values[i], eps_); + RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish renormalize by top p"); }, - 0, n); - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); - return sample_results; + 0, static_cast(top_p_indices.size())); + + return probs_on_host; } - std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, - const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { + std::vector BatchSampleTokensWithProbBeforeTopP( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs) final { // probs_on_device: (n, v) - RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); CHECK_EQ(probs_on_device->ndim, 2); // - Copy probs to CPU RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - NDArray probs_host = CopyProbsToCPU(probs_on_device); + NDArray probs_on_host = CopyProbsToCPU(probs_on_device); RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); + return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs, + /*top_p_applied=*/false); + } + + std::vector BatchSampleTokensWithProbAfterTopP( + NDArray probs_on_host, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + std::vector* output_prob_dist) final { + return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs, + /*top_p_applied=*/true, output_prob_dist); + } + + std::vector> BatchVerifyDraftTokensWithProbAfterTopP( + NDArray probs_on_host, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, + const std::vector>& draft_output_prob_dist) final { + // probs_on_host: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); + CHECK_EQ(probs_on_host->ndim, 2); + int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); @@ -337,8 +444,8 @@ class CPUSampler : public SamplerObj { sample_results.resize(num_sequence); float* __restrict global_p_probs = - static_cast(__builtin_assume_aligned(probs_host->data, 4)); - int vocab_size = probs_host->shape[1]; + static_cast(__builtin_assume_aligned(probs_on_host->data, 4)); + int vocab_size = probs_on_host->shape[1]; tvm::runtime::parallel_for_with_threading_backend( [&](int i) { @@ -355,7 +462,7 @@ class CPUSampler : public SamplerObj { if (p_value >= q_value) { sample_results[i].push_back( SampleResult{{cur_token, p_value}, - ComputeTopProbs(probs_host, verify_start + cur_token_idx, + ComputeTopProbs(probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs)}); continue; } @@ -363,7 +470,7 @@ class CPUSampler : public SamplerObj { if (r < p_value / (q_value + eps_)) { sample_results[i].push_back( SampleResult{{cur_token, p_value}, - ComputeTopProbs(probs_host, verify_start + cur_token_idx, + ComputeTopProbs(probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs)}); continue; } @@ -388,11 +495,10 @@ class CPUSampler : public SamplerObj { // sample a new token from the new distribution SampleResult sample_result; sample_result.sampled_token_id = SampleTopPFromProb( - probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber()); + probs_on_host, verify_start + cur_token_idx, verify_start + cur_token_idx, + /*top_p=*/1.0f, rngs[i]->GetRandomNumber()); sample_result.top_prob_tokens = ComputeTopProbs( - probs_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); + probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); sample_results[i].push_back(sample_result); break; } @@ -403,11 +509,10 @@ class CPUSampler : public SamplerObj { SampleResult sample_result; // sample a new token from the original distribution sample_result.sampled_token_id = SampleTopPFromProb( - probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber()); + probs_on_host, verify_start + cur_token_idx, verify_start + cur_token_idx, + /*top_p=*/1.0f, rngs[i]->GetRandomNumber()); sample_result.top_prob_tokens = ComputeTopProbs( - probs_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); + probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); sample_results[i].push_back(sample_result); } }, @@ -417,6 +522,56 @@ class CPUSampler : public SamplerObj { } private: + std::vector BatchSampleTokensImpl( + NDArray probs_on_host, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + bool top_p_applied, // + std::vector* output_prob_dist = nullptr) { + // probs_on_host: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + ICHECK_EQ(probs_on_host->ndim, 2); + ICHECK_EQ(probs_on_host->device.device_type, DLDeviceType::kDLCPU); + + // - Sample tokens from probabilities. + int n = request_ids.size(); + ICHECK_EQ(generation_cfg.size(), n); + ICHECK_EQ(rngs.size(), n); + + std::vector sample_results; + sample_results.resize(n); + if (output_prob_dist) { + output_prob_dist->resize(n); + } + + tvm::runtime::parallel_for_with_threading_backend( + [this, &sample_results, &probs_on_host, &generation_cfg, &rngs, &request_ids, top_p_applied, + sample_indices, output_prob_dist](int i) { + RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); + // Sample top p from probability. + double top_p = + top_p_applied + ? 1.0f + : (generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p); + sample_results[i].sampled_token_id = + SampleTopPFromProb(probs_on_host, i, sample_indices[i], top_p, + rngs[i]->GetRandomNumber(), output_prob_dist); + if (output_prob_dist == nullptr) { + // When `output_prob_dist` is not nullptr, it means right now + // we are sampling for a small model in speculation, in which + // case we do not need to get the top probs. + sample_results[i].top_prob_tokens = + ComputeTopProbs(probs_on_host, i, generation_cfg[i]->top_logprobs); + } + RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); + }, + 0, n); + RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + return sample_results; + } + /*! \brief Copy prob distributions from device to CPU. */ NDArray CopyProbsToCPU(NDArray probs_on_device) { // probs_on_device: (n, v) diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index b376523dac..58a27c24f7 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -43,12 +43,17 @@ class GPUSampler : public SamplerObj { gpu_argsort_probs_func_(ft->gpu_argsort_probs_func_), gpu_sample_with_top_p_func_(ft->gpu_sample_with_top_p_func_), gpu_sampler_take_probs_func_(ft->gpu_sampler_take_probs_func_), + gpu_verify_draft_tokens_func_(ft->gpu_verify_draft_tokens_func_), + gpu_renormalize_by_top_p_func_(ft->gpu_renormalize_by_top_p_func_), trace_recorder_(std::move(trace_recorder)) { ICHECK(gpu_multinomial_from_uniform_func_.defined()); ICHECK(gpu_argsort_probs_func_.defined()); ICHECK(gpu_sample_with_top_p_func_.defined()); ICHECK(gpu_sampler_take_probs_func_.defined()); + flashinfer_multinomial_sample_func_ = + Registry::Get("flashinfer.sampling.parallel_sampling_from_prob"); + DLDevice device_cpu{DLDeviceType::kDLCPU, /*device_id=*/0}; // We support at most 5 top prob results for each sequence. // Initialize auxiliary arrays on CPU. @@ -56,6 +61,10 @@ class GPUSampler : public SamplerObj { sample_indices_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); top_p_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); top_prob_offsets_host_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device_cpu); + draft_tokens_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + token_tree_first_child_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + token_tree_next_sibling_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + token_tree_parent_ptr_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); sampled_token_ids_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); sampled_probs_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); top_prob_probs_host_ = NDArray::Empty({max_num_sample * 5}, dtype_f32_, device_cpu); @@ -65,6 +74,12 @@ class GPUSampler : public SamplerObj { sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); + draft_probs_device_ = NDArray::Empty({max_num_sample, vocab_size}, dtype_f32_, device); + draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + token_tree_next_sibling_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + token_tree_parent_ptr_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + sampled_token_ids_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); // If the device is CUDA/ROCm, we create a standalone copy stream, in // purpose to hide the latency of auxiliary stream copy. @@ -83,20 +98,237 @@ class GPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray probs_on_device, // - const std::vector& sample_indices, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // - std::vector* output_prob_dist) final { - NVTXScopedRange nvtx_scope("BatchSampleTokens"); + NDArray BatchRenormalizeProbsByTopP(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg) final { + NVTXScopedRange nvtx_scope("BatchRenormalizeProbsByTopP"); + // probs_on_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start renormalization by top p"); + CHECK_EQ(probs_on_device->ndim, 2); + int num_samples = sample_indices.size(); + int num_probs = probs_on_device->shape[0]; + int vocab_size = probs_on_device->shape[1]; + ICHECK_LE(num_probs, max_num_sample_); + ICHECK_EQ(request_ids.size(), num_samples); + ICHECK_EQ(generation_cfg.size(), num_samples); + + // - Check if there is need for applying top p. + bool need_top_p = CheckTopP(generation_cfg, sample_indices, num_probs, num_samples, vocab_size); + if (!need_top_p) { + return probs_on_device; + } + + // - Argsort the probability. + Array argsort_results = gpu_argsort_probs_func_(probs_on_device); + ICHECK_EQ(argsort_results.size(), 2); + NDArray sorted_probs_on_device = argsort_results[0]; + NDArray sorted_indices_on_device = argsort_results[1]; + + // - Copy auxiliary array for top-p. + NDArray top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_); + NDArray top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_); + CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_); + SyncCopyStream(device_, compute_stream_, copy_stream_); + + // - Renormalize the prob with top p. + NDArray renormed_probs_on_device = + gpu_renormalize_by_top_p_func_(probs_on_device, sorted_probs_on_device, top_p_device); + + RECORD_EVENT(trace_recorder_, request_ids, "finish renormalization by top p"); + return renormed_probs_on_device; + } + + std::vector BatchSampleTokensWithProbBeforeTopP( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs) final { + NVTXScopedRange nvtx_scope("BatchSampleTokensWithProbBeforeTopP"); + return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids, + generation_cfg, rngs, /*top_p_applied=*/false); + } + + std::vector BatchSampleTokensWithProbAfterTopP( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + std::vector* output_prob_dist = nullptr) final { + NVTXScopedRange nvtx_scope("BatchSampleTokensWithProbAfterTopP"); + return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids, + generation_cfg, rngs, /*top_p_applied=*/true, output_prob_dist); + } + + std::vector> BatchVerifyDraftTokensWithProbAfterTopP( + NDArray probs_on_device, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, + const std::vector>& draft_output_prob_dist) final { + NVTXScopedRange nvtx_scope("BatchVerifyDraftTokensWithProbAfterTopP"); + std::vector> sample_results; + // probs_on_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); + CHECK_EQ(probs_on_device->ndim, 2); + + int num_sequence = static_cast(cum_verify_lengths.size()) - 1; + CHECK_EQ(rngs.size(), num_sequence); + CHECK_EQ(draft_output_tokens.size(), num_sequence); + CHECK_EQ(draft_output_prob_dist.size(), num_sequence); + sample_results.resize(num_sequence); + + int num_nodes = cum_verify_lengths.back(); + NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_nodes}, dtype_f32_); + NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_nodes}, dtype_f32_); + NDArray draft_probs_device = + draft_probs_device_.CreateView({num_nodes, vocab_size_}, dtype_f32_); + NDArray draft_tokens_host = draft_tokens_host_.CreateView({num_nodes}, dtype_i32_); + NDArray draft_tokens_device = draft_tokens_device_.CreateView({num_nodes}, dtype_i32_); + + // Concat draft prob distributions to a ragged tensor (num_nodes, vocab_size) + for (int i = 0; i < num_sequence; i++) { + const std::vector& draft_output_tokens_i = draft_output_tokens[i]; + const std::vector& draft_output_prob_dist_i = draft_output_prob_dist[i]; + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + // start/end is the range of the sequence i in probs_on_device, which includes the prob dist + // of the draft tokens and the last committed token + ICHECK_EQ(draft_output_tokens_i.size() + 1, end - start); + ICHECK_EQ(draft_output_prob_dist_i.size() + 1, end - start); + for (int j = 0; j < end - start - 1; j++) { + // Copy prob dist + ICHECK_EQ(draft_probs_device->dtype.bits, 32); + float* p_draft_probs = + static_cast(draft_probs_device->data) + + (j + start + 1) * + vocab_size_; // shift by one, q of the last committed token is undefined + // Copy sampled token id + draft_output_prob_dist_i[j].CopyToBytes(p_draft_probs, vocab_size_ * sizeof(float)); + *(static_cast(draft_tokens_host->data) + j + start + 1) = + draft_output_tokens_i[j].sampled_token_id.first; + } + } + CopyArray(draft_tokens_host, draft_tokens_device, copy_stream_); + + float* p_uniform_samples = static_cast(uniform_samples_host->data); + for (int i = 0; i < num_sequence; ++i) { + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + for (int j = start; j < end; j++) { + p_uniform_samples[j] = rngs[i]->GetRandomNumber(); + } + } + CopyArray(uniform_samples_host, uniform_samples_device, copy_stream_); + + NDArray token_tree_first_child_host = + token_tree_first_child_host_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_first_child_device = + token_tree_first_child_device_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_next_sibling_host = + token_tree_next_sibling_host_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_next_sibling_device = + token_tree_next_sibling_device_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_parent_ptr_host = + token_tree_parent_ptr_host_.CreateView({num_sequence}, dtype_i32_); + NDArray token_tree_parent_ptr_device = + token_tree_parent_ptr_device_.CreateView({num_sequence}, dtype_i32_); + std::vector token_tree_child_to_parent(/*n=*/num_nodes); + + // Build the tree structure on CPU + for (int i = 0; i < num_sequence; i++) { + // Assuming no tree structure for now + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + ICHECK_GE(end - start, 2); + token_tree_child_to_parent[start] = -1; // root has no parent + for (int j = 0; j < end - start; j++) { + int cur_node = j + start; + int child_node = j + 1 >= end - start ? -1 : cur_node + 1; + static_cast(token_tree_first_child_host->data)[cur_node] = child_node; + if (child_node != -1) { + token_tree_child_to_parent[child_node] = cur_node; + } + static_cast(token_tree_next_sibling_host->data)[cur_node] = -1; + } + static_cast(token_tree_parent_ptr_host->data)[i] = start; // point to the root + } + // Copy token tree structure to GPU + CopyArray(token_tree_first_child_host, token_tree_first_child_device, copy_stream_); + CopyArray(token_tree_next_sibling_host, token_tree_next_sibling_device, copy_stream_); + CopyArray(token_tree_parent_ptr_host, token_tree_parent_ptr_device, copy_stream_); + + SyncCopyStream(device_, compute_stream_, copy_stream_); + + gpu_verify_draft_tokens_func_(draft_probs_device, draft_tokens_device, probs_on_device, + token_tree_first_child_device, token_tree_next_sibling_device, + uniform_samples_device, token_tree_parent_ptr_device); + + CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, compute_stream_); + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); + + std::vector sample_indices; + + for (int i = 0; i < num_sequence; i++) { + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + int last_accepted = static_cast(token_tree_parent_ptr_host->data)[i]; + int num_accepted = 0; + for (int cur_node = last_accepted; cur_node != start; + cur_node = token_tree_child_to_parent[cur_node]) { + sample_results[i].push_back(draft_output_tokens[i][cur_node - start - 1]); + num_accepted++; + } + std::reverse(sample_results[i].rbegin(), sample_results[i].rbegin() + num_accepted); + sample_indices.push_back(last_accepted); + } + std::vector additional_sample_result; + additional_sample_result = this->BatchSampleTokensWithProbAfterTopP( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs); + ICHECK_EQ(additional_sample_result.size(), num_sequence); + for (int i = 0; i < num_sequence; i++) { + sample_results[i].push_back(additional_sample_result[i]); + } + + RECORD_EVENT(trace_recorder_, request_ids, "finish draft verification"); + return sample_results; + } + + private: + std::vector BatchSampleTokensImpl( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + bool top_p_applied, // + std::vector* output_prob_dist = nullptr) { // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); - CHECK(output_prob_dist == nullptr) << "GPU sampler does not support collecting output probs."; CHECK_EQ(probs_on_device->ndim, 2); + CHECK_EQ(probs_on_device->device.device_id, device_.device_id); + CHECK_EQ(probs_on_device->device.device_type, device_.device_type); int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; int vocab_size = probs_on_device->shape[1]; + if (output_prob_dist != nullptr) { + ICHECK(output_prob_dist->empty()); + output_prob_dist->reserve(num_samples); + for (int i = 0; i < num_samples; ++i) { + NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_); + float* p_prob = static_cast(probs_on_device->data) + sample_indices[i] * vocab_size; + prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float)); + output_prob_dist->push_back(std::move(prob_dist)); + } + } + if (num_samples == 0) { + // This synchronization is necessary for making sure that this round + // of model forward is finished. + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); + return {}; + } ICHECK_EQ(request_ids.size(), num_samples); ICHECK_EQ(generation_cfg.size(), num_samples); ICHECK_EQ(rngs.size(), num_samples); @@ -105,7 +337,8 @@ class GPUSampler : public SamplerObj { // we apply chunking to support large `num_samples`. std::vector sample_results; if (num_samples <= max_num_sample_) { - sample_results = ChunkSampleTokensImpl(probs_on_device, sample_indices, generation_cfg, rngs); + sample_results = ChunkSampleTokensImpl(probs_on_device, sample_indices, generation_cfg, rngs, + top_p_applied); } else { for (int chunk_start = 0; chunk_start < num_samples; chunk_start += max_num_sample_) { int chunk_end = std::min(chunk_start + max_num_sample_, num_samples); @@ -116,7 +349,7 @@ class GPUSampler : public SamplerObj { std::vector rngs_chunk(rngs.begin() + chunk_start, rngs.begin() + chunk_end); std::vector sample_results_chunk = ChunkSampleTokensImpl( - probs_on_device, sample_indices_chunk, generation_cfg_chunk, rngs_chunk); + probs_on_device, sample_indices_chunk, generation_cfg_chunk, rngs_chunk, top_p_applied); sample_results.insert(sample_results.end(), sample_results_chunk.begin(), sample_results_chunk.end()); } @@ -126,20 +359,11 @@ class GPUSampler : public SamplerObj { return sample_results; } - std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, - const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { - LOG(FATAL) << "GPU sampler does not support batch verification for now."; - } - - private: std::vector ChunkSampleTokensImpl(NDArray probs_on_device, // const std::vector& sample_indices, // const Array& generation_cfg, // - const std::vector& rngs) { + const std::vector& rngs, // + bool top_p_applied) { // probs_on_device: (n, v) int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; @@ -153,11 +377,13 @@ class GPUSampler : public SamplerObj { // - Check if there is need for applying top p or prob values, // so that argsort is needed. bool need_top_p = false; - bool need_prob_values = false; + if (!top_p_applied) { + need_top_p = CheckTopP(generation_cfg, sample_indices, num_probs, num_samples, vocab_size); + } // The indptr array of the number of top probs for each sample. std::vector top_prob_offset_indptr; - CheckTopPAndProbValues(generation_cfg, sample_indices, num_probs, num_samples, vocab_size, - &need_top_p, &need_prob_values, &top_prob_offset_indptr); + bool need_prob_values = CheckProbValues(generation_cfg, sample_indices, num_probs, num_samples, + vocab_size, &top_prob_offset_indptr); // - Sample tokens on GPU, and take out the probability values if needed. std::vector device_arrays = @@ -217,30 +443,39 @@ class GPUSampler : public SamplerObj { return {uniform_samples_device, sample_indices_device}; } - /*! \brief Check if top p and prob values are needed, and collect info when necessary. */ - void CheckTopPAndProbValues(const Array& generation_cfg, - const std::vector& sample_indices, int num_probs, - int num_samples, int vocab_size, bool* need_top_p, - bool* need_prob_values, std::vector* top_prob_offset_indptr) { - top_prob_offset_indptr->reserve(num_samples + 1); - top_prob_offset_indptr->push_back(0); + /*! \brief Check if top p is needed. Update host top p array in place. */ + bool CheckTopP(const Array& generation_cfg, + const std::vector& sample_indices, int num_probs, int num_samples, + int vocab_size) { // Initialize top p values with -1. float* p_top_p = static_cast(top_p_host_->data); for (int i = 0; i < num_probs; ++i) { p_top_p[i] = -1.0; } - int* p_top_prob_offsets = static_cast(top_prob_offsets_host_->data); - int num_top_probs = 0; + bool need_top_p = false; for (int i = 0; i < num_samples; ++i) { if (p_top_p[sample_indices[i]] == -1.0) { p_top_p[sample_indices[i]] = generation_cfg[i]->top_p; - *need_top_p |= generation_cfg[i]->top_p != 1.0; + need_top_p |= generation_cfg[i]->top_p != 1.0; } else { CHECK(fabs(p_top_p[sample_indices[i]] - generation_cfg[i]->top_p) < eps_) << "GPU sampler requires the top_p values for each prob distribution are the same."; } + } + return need_top_p; + } - *need_prob_values |= generation_cfg[i]->logprobs; + /*! \brief Check whether prob values are needed, and collect info when necessary. */ + bool CheckProbValues(const Array& generation_cfg, + const std::vector& sample_indices, int num_probs, int num_samples, + int vocab_size, std::vector* top_prob_offset_indptr) { + top_prob_offset_indptr->reserve(num_samples + 1); + top_prob_offset_indptr->push_back(0); + int* p_top_prob_offsets = static_cast(top_prob_offsets_host_->data); + int num_top_probs = 0; + bool need_prob_values = false; + for (int i = 0; i < num_samples; ++i) { + need_prob_values |= generation_cfg[i]->logprobs; for (int j = 0; j < generation_cfg[i]->top_logprobs; ++j) { p_top_prob_offsets[num_top_probs++] = sample_indices[i] * vocab_size + j; } @@ -248,6 +483,7 @@ class GPUSampler : public SamplerObj { generation_cfg[i]->top_logprobs); } ICHECK_EQ(num_top_probs, top_prob_offset_indptr->back()); + return need_prob_values; } /*! \brief Sample tokens on GPU. Take out the probability values when needed. */ @@ -263,8 +499,15 @@ class GPUSampler : public SamplerObj { if (!need_top_p && !need_prob_values) { // - Short path: If top_p and prob values are not needed, we directly sample from multinomial. SyncCopyStream(device_, compute_stream_, copy_stream_); - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device, top_prob_indices_device}; } @@ -299,8 +542,15 @@ class GPUSampler : public SamplerObj { uniform_samples_device, sample_indices_device, top_p_device); } else { // - Sample without top_p. - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } } if (need_prob_values) { @@ -354,7 +604,7 @@ class GPUSampler : public SamplerObj { } // Synchronize for CPU to get the correct array results. - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); return {sampled_token_ids_host, sampled_probs_host, top_prob_probs_host, top_prob_indices_host}; } @@ -370,11 +620,18 @@ class GPUSampler : public SamplerObj { PackedFunc gpu_argsort_probs_func_; PackedFunc gpu_sample_with_top_p_func_; PackedFunc gpu_sampler_take_probs_func_; + PackedFunc gpu_verify_draft_tokens_func_; + PackedFunc gpu_renormalize_by_top_p_func_; + const PackedFunc* flashinfer_multinomial_sample_func_; // Auxiliary NDArrays on CPU NDArray uniform_samples_host_; NDArray sample_indices_host_; NDArray top_p_host_; NDArray top_prob_offsets_host_; + NDArray draft_tokens_host_; + NDArray token_tree_first_child_host_; + NDArray token_tree_next_sibling_host_; + NDArray token_tree_parent_ptr_host_; NDArray sampled_token_ids_host_; NDArray sampled_probs_host_; NDArray top_prob_probs_host_; @@ -384,6 +641,12 @@ class GPUSampler : public SamplerObj { NDArray sample_indices_device_; NDArray top_p_device_; NDArray top_prob_offsets_device_; + NDArray draft_probs_device_; + NDArray draft_tokens_device_; + NDArray token_tree_first_child_device_; + NDArray token_tree_next_sibling_device_; + NDArray token_tree_parent_ptr_device_; + NDArray sampled_token_ids_device_; // The event trace recorder for requests. */ Optional trace_recorder_; // The device stream for the default computation operations. diff --git a/cpp/serve/sampler/sampler.h b/cpp/serve/sampler/sampler.h index 03d031bdb7..7943231e55 100644 --- a/cpp/serve/sampler/sampler.h +++ b/cpp/serve/sampler/sampler.h @@ -26,14 +26,33 @@ using namespace tvm::runtime; /*! * \brief The base class of runtime sampler. - * Its main function is `BatchSampleTokens`, which takes a batch of + * Its main function is `BatchSampleTokensWithProbBeforeTopP`, which takes a batch of * logits and corresponding configuration, and sample one token * for each instance of the batch. */ class SamplerObj : public Object { public: + /*! + * \brief Renormalize the input batch of probability distributions with top p values. + * \param probs_on_device The batch of prob distributions before normalization. + * \param sample_indices Specifying which request we will sample for + * in i-th output for the sampling later on. + * The output result of the sampling will be as follow: + * result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i])); + * For renormalization, the sample indices are used for determine the top-p grouping. + * \param request_ids The id of each request. + * \param generation_cfg The generation config of each request in the input batch. + * \return The renormalized probability distributions, residing on device + * if the sampler is GPU sampler, or on host if the sampler is CPU sampler. + */ + virtual NDArray BatchRenormalizeProbsByTopP(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg) = 0; + /*! * \brief Sample tokens from the input batch of prob distribution on device. + * The input prob distributions are not yet applied with top-p. * \param probs_on_device The prob distributions on GPU to sample tokens from. * \param sample_indices Specifying which request we should sample for * in i-th output. The output result is sample as follow: @@ -42,22 +61,46 @@ class SamplerObj : public Object { * \param generation_cfg The generation config of each request * in the input batch. * \param rngs The random number generator of each sequence. - * \param output_prob_dist The output probability distribution * \return The batch of sampling results, which contain the sampled token id * and other probability info. */ - virtual std::vector BatchSampleTokens( + virtual std::vector BatchSampleTokensWithProbBeforeTopP( NDArray probs_on_device, // const std::vector& sample_indices, // const Array& request_ids, // const Array& generation_cfg, // + const std::vector& rngs) = 0; + + /*! + * \brief Sample tokens from the input batch of prob distribution on device. + * The input prob distributions are already applied with top-p. + * \param probs The prob distributions. + * It resides on GPU if the sampler is GPU sampler, or on host if hte sampler is CPU sampler. + * \param sample_indices Specifying which request we should sample for + * in i-th output. The output result is sample as follow: + * result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i])); + * \param request_ids The id of each request. + * \param generation_cfg The generation config of each request + * in the input batch. + * \param rngs The random number generator of each sequence. + * \param output_prob_dist The output probability distribution + * \return The batch of sampling results, which contain the sampled token id + * and other probability info. + */ + virtual std::vector BatchSampleTokensWithProbAfterTopP( + NDArray probs, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // const std::vector& rngs, // std::vector* output_prob_dist = nullptr) = 0; /*! * \brief Verify draft tokens generated by small models in the large model * in speculative decoding. The input corresponds to a batch of sequences. - * \param probs_on_device The prob distributions on GPU to sample tokens from. + * The input prob distributions are already applied with top-p. + * \param probs The prob distributions on GPU to sample tokens from. + * It resides on GPU if the sampler is GPU sampler, or on host if hte sampler is CPU sampler. * \param request_ids The id of each request. * \param cum_verify_lengths The cumulative draft lengths to verify of all sequences. * \param generation_cfg The generation config of each request @@ -69,10 +112,9 @@ class SamplerObj : public Object { * small model for each sequence. * \return The list of accepted tokens for each request. */ - virtual std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, + virtual std::vector> BatchVerifyDraftTokensWithProbAfterTopP( + NDArray probs, const Array& request_ids, const std::vector& cum_verify_lengths, + const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_prob_dist) = 0; diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index 458d2ae5d7..2f6f77a3a0 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -29,37 +29,59 @@ enum class InstructionKind : int { kAbortRequest = 1, kUnloadEngine = 2, kReloadEngine = 3, - kDebugCallFuncOnAllAllWorker = 4, + kResetEngine = 4, + kDebugCallFuncOnAllAllWorker = 5, }; /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(EngineConfig engine_config, - Optional request_stream_callback, + void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) final { + device_ = device; CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; request_stream_callback_ = request_stream_callback.value(); + trace_recorder_ = trace_recorder; + } - auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 1); - Array delta_outputs = args[0]; - bool need_notify = false; - { - std::lock_guard lock(request_stream_callback_mutex_); - request_stream_callback_inputs_.push_back(std::move(delta_outputs)); - ++pending_request_stream_callback_cnt_; - need_notify = stream_callback_waiting_; - } - if (need_notify) { - request_stream_callback_cv_.notify_one(); - } - }; + void Reload(EngineConfig engine_config) final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kReloadEngine, std::move(engine_config)); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } + + void Unload() final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kUnloadEngine, ObjectRef(nullptr)); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } - request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - background_engine_ = Engine::Create( - std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); + void Reset() final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kResetEngine, ObjectRef(nullptr)); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } } void AddRequest(Request request) final { @@ -97,7 +119,8 @@ class ThreadedEngineImpl : public ThreadedEngine { std::unique_lock lock(background_loop_mutex_); engine_waiting_ = true; background_loop_cv_.wait(lock, [this] { - return !background_engine_->Empty() || pending_request_operation_cnt_.load() > 0 || + return (background_engine_ != nullptr && !background_engine_->Empty()) || + pending_request_operation_cnt_.load() > 0 || exit_now_.load(std::memory_order_relaxed); }); engine_waiting_ = false; @@ -108,22 +131,30 @@ class ThreadedEngineImpl : public ThreadedEngine { } for (const auto& [kind, arg] : local_instruction_queue) { if (kind == InstructionKind::kAddRequest) { + CHECK(background_engine_ != nullptr) << "Background engine is not loaded."; background_engine_->AddRequest(Downcast(arg)); } else if (kind == InstructionKind::kAbortRequest) { + CHECK(background_engine_ != nullptr) << "Background engine is not loaded."; background_engine_->AbortRequest(Downcast(arg)); } else if (kind == InstructionKind::kUnloadEngine) { - // Todo(mlc-team): implement engine unload - LOG(FATAL) << "Not implemented yet."; + EngineUnloadImpl(); } else if (kind == InstructionKind::kReloadEngine) { - // Todo(mlc-team): implement engine reload - LOG(FATAL) << "Not implemented yet."; + EngineUnloadImpl(); + EngineReloadImpl(Downcast(arg)); + } else if (kind == InstructionKind::kResetEngine) { + if (background_engine_ != nullptr) { + background_engine_->Reset(); + } } else if (kind == InstructionKind::kDebugCallFuncOnAllAllWorker) { + CHECK(background_engine_ != nullptr) << "Background engine is not loaded."; background_engine_->DebugCallFuncOnAllAllWorker(Downcast(arg)); } else { LOG(FATAL) << "Cannot reach here"; } } - background_engine_->Step(); + if (background_engine_ != nullptr) { + background_engine_->Step(); + } } } @@ -184,10 +215,47 @@ class ThreadedEngineImpl : public ThreadedEngine { } private: + void EngineReloadImpl(EngineConfig engine_config) { + auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + bool need_notify = false; + { + std::lock_guard lock(request_stream_callback_mutex_); + request_stream_callback_inputs_.push_back(std::move(delta_outputs)); + ++pending_request_stream_callback_cnt_; + need_notify = stream_callback_waiting_; + } + if (need_notify) { + request_stream_callback_cv_.notify_one(); + } + }; + + Optional request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); + background_engine_ = Engine::Create(std::move(engine_config), device_, + std::move(request_stream_callback), trace_recorder_); + } + + void EngineUnloadImpl() { + if (background_engine_ != nullptr) { + background_engine_->AbortAllRequests(); + background_engine_ = nullptr; + // Clear the allocated memory in cached memory pool. + const PackedFunc* fclear_memory_manager = + tvm::runtime::Registry::Get("vm.builtin.memory_manager.clear"); + ICHECK(fclear_memory_manager) << "Cannot find env function vm.builtin.memory_manager.clear"; + (*fclear_memory_manager)(); + } + } + + /*! \brief The device to run models on. */ + Device device_; /*! \brief The background normal engine for request processing. */ std::unique_ptr background_engine_; /*! \brief The request stream callback. */ PackedFunc request_stream_callback_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; /*! \brief The mutex ensuring only one thread can access critical regions. */ std::mutex background_loop_mutex_; @@ -237,6 +305,7 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.serve.async_threaded_engine"); TVM_MODULE_VTABLE_ENTRY("init_background_engine", &ThreadedEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("reload", &ThreadedEngineImpl::Reload); TVM_MODULE_VTABLE_ENTRY("add_request", &ThreadedEngineImpl::AddRequest); TVM_MODULE_VTABLE_ENTRY("abort_request", &ThreadedEngineImpl::AbortRequest); TVM_MODULE_VTABLE_ENTRY("run_background_loop", &ThreadedEngineImpl::RunBackgroundLoop); diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index 3d11ba36f1..49ba8f2175 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -35,14 +35,25 @@ class ThreadedEngine { /*! * \brief Initialize the threaded engine from packed arguments in TVMArgs. - * \param engine_config The engine config. + * \param device The device where to run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. */ - virtual void InitBackgroundEngine(EngineConfig engine_config, - Optional request_stream_callback, + virtual void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) = 0; + /*! + * \brief Reload the engine with the new engine config. + * \param engine_config The engine config. + */ + virtual void Reload(EngineConfig engine_config) = 0; + + /*! \brief Unload the background engine. */ + virtual void Unload() = 0; + + /*! \brief Reset the engine to the initial state. */ + virtual void Reset() = 0; + /*! \brief Starts the background request processing loop. */ virtual void RunBackgroundLoop() = 0; diff --git a/cpp/support/utils.h b/cpp/support/utils.h index 5360f0496c..6c53e35715 100644 --- a/cpp/support/utils.h +++ b/cpp/support/utils.h @@ -10,6 +10,7 @@ namespace mlc { namespace llm { +/*! \brief Split the input string by the given delimiter character. */ inline std::vector Split(const std::string& str, char delim) { std::string item; std::istringstream is(str); @@ -20,5 +21,21 @@ inline std::vector Split(const std::string& str, char delim) { return ret; } +/*! + * \brief Check whether the string starts with a given prefix. + * \param str The given string. + * \param prefix The given prefix. + * \return Whether the prefix matched. + */ +inline bool StartsWith(const std::string& str, const char* prefix) { + size_t n = str.length(); + for (size_t i = 0; i < n; i++) { + if (prefix[i] == '\0') return true; + if (str.data()[i] != prefix[i]) return false; + } + // return true if the str is equal to the prefix + return prefix[n] == '\0'; +} + } // namespace llm } // namespace mlc diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 00beb5cc4d..4706e09811 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -235,7 +235,7 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con RuntimeError: Cannot find libraries: wasm_runtime.bc .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. Otherwise, you may run into issues like: .. code:: text @@ -664,7 +664,7 @@ generalized to any model variant, as long as mlc-llm supports the architecture. RuntimeError: Cannot find libraries: wasm_runtime.bc .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. Otherwise, you may run into issues like: .. code:: text @@ -793,7 +793,7 @@ generalized to any model variant, as long as mlc-llm supports the architecture. RuntimeError: Cannot find libraries: wasm_runtime.bc .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. Otherwise, you may run into issues like: .. code:: text diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index f341e31e71..a7ebe28d6d 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -54,15 +54,15 @@ To run a model with MLC LLM in any platform, you can either: **Option 1: Use model prebuilts** To run ``mlc_llm``, you can specify the Huggingface MLC prebuilt model repo path with the prefix ``HF://``. -For example, to run the MLC Llama 2 7B Q4F16_1 model (`Repo link `_), -simply use ``HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC``. The model weights and library will be downloaded +For example, to run the MLC Llama 3 8B Q4F16_1 model (`Repo link `_), +simply use ``HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC``. The model weights and library will be downloaded automatically from Huggingface. .. code:: shell - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 -.. code:: shell +.. code:: You can use the following special commands: /help print the special commands @@ -74,13 +74,11 @@ automatically from Huggingface. Note: Separate stop words in the `stop` option with commas (,). Multi-line input: Use escape+enter to start a new line. - [INST]: What's the meaning of life - [/INST]: - Ah, a question that has puzzled philosophers and theologians for centuries! The meaning - of life is a deeply personal and subjective topic, and there are many different - perspectives on what it might be. However, here are some possible answers that have been - proposed by various thinkers and cultures: - ... + user: What's the meaning of life + assistant: + What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + + The concept of the meaning of life has been debated and... **Option 2: Use locally compiled model weights and libraries** diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index c0217db9e9..75a5cdbdc7 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -341,10 +341,24 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con mlc_llm gen_config ./dist/models/phi-2/ \ --quantization q4f16_1 --conv-template phi-2 \ -o dist/phi-2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json + # 2. mkdir: create a directory to store the compiled model library + mkdir -p dist/libs + # 3. compile: compile model library with specification in mlc-chat-config.json mlc_llm compile ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json \ --device iphone -o dist/libs/phi-2-q4f16_1-iphone.tar +Given the compiled library, it is possible to calculate an upper bound for the VRAM +usage during runtime. This useful to better understand if a model is able to fit particular +hardware. +That information will be displayed at the end of the console log when the ``compile`` is executed. +It might look something like this: + +.. code:: shell + + [2024-04-25 03:19:56] INFO model_metadata.py:96: Total memory usage: 1625.73 MB (Parameters: 1492.45 MB. KVCache: 0.00 MB. Temporary buffer: 133.28 MB) + [2024-04-25 03:19:56] INFO model_metadata.py:105: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` + [2024-04-25 03:19:56] INFO compile.py:198: Generated: dist/libs/phi-2-q4f16_1-iphone.tar + .. note:: When compiling larger models like ``Llama-2-7B``, you may want to add a lower chunk size while prefilling prompts ``--prefill_chunk_size 128`` or even lower ``context_window_size``\ @@ -388,21 +402,7 @@ This would result in something like `phi-2-q4f16_1-MLC `_. -**Step 4. Calculate estimated VRAM usage** - -Given the compiled library, it is possible to calculate an upper bound for the VRAM -usage during runtime. This useful to better understand if a model is able to fit particular -hardware. We can calculate this estimate using the following command: - -.. code:: shell - - ~/mlc-llm > python -m mlc_llm.cli.model_metadata ./dist/libs/phi-2-q4f16_1-iphone.tar \ - > --memory-only --mlc-chat-config ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json - INFO model_metadata.py:90: Total memory usage: 3042.96 MB (Parameters: 1492.45 MB. KVCache: 640.00 MB. Temporary buffer: 910.51 MB) - INFO model_metadata.py:99: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` - - -**Step 5. Register as a ModelRecord** +**Step 4. Register as a ModelRecord** Finally, we update the code snippet for `app-config.json `__ diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst index c5d9a072a7..89c60ac422 100644 --- a/docs/deploy/python_engine.rst +++ b/docs/deploy/python_engine.rst @@ -4,12 +4,261 @@ Python API ========== .. note:: - This page introduces the Python API with LLMEngine in MLC LLM. - If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, - please go to :ref:`deploy-python-chat-module` + This page introduces the Python API with MLCEngine in MLC LLM. + If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, + please go to :ref:`deploy-python-chat-module` .. contents:: Table of Contents - :local: - :depth: 2 + :local: + :depth: 2 -🚧 Under construction... + +MLC LLM provides Python API through classes :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine` +which **support full OpenAI API completeness** for easy integration into other Python projects. + +This page introduces how to use the engines in MLC LLM. +The Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via +the :ref:`installation page `. + + +Verify Installation +------------------- + +.. code:: bash + + python -c "from mlc_llm import MLCEngine; print(MLCEngine)" + +You are expected to see the output of ````. + +If the command above results in error, follow :ref:`install-mlc-packages` to install prebuilt pip +packages or build MLC LLM from source. + + +Run MLCEngine +------------- + +:class:`mlc_llm.MLCEngine` provides the interface of OpenAI chat completion synchronously. +:class:`mlc_llm.MLCEngine` does not batch concurrent request due to the synchronous design, +and please use :ref:`AsyncMLCEngine ` for request batching process. + +**Stream Response.** In :ref:`quick-start` and :ref:`introduction-to-mlc-llm`, +we introduced the basic use of :class:`mlc_llm.MLCEngine`. + +.. code:: python + + from mlc_llm import MLCEngine + + # Create engine + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + engine = MLCEngine(model) + + # Run chat completion in OpenAI API. + for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + print("\n") + + engine.terminate() + +This code example first creates an :class:`mlc_llm.MLCEngine` instance with the 8B Llama-3 model. +**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.MLCEngine` in the same way of using +`OpenAI's Python package `_ +for both synchronous and asynchronous generation. + +**Non-stream Response.** The code example above uses the synchronous chat completion +interface and iterate over all the stream responses. +If you want to run without streaming, you can run + +.. code:: python + + response = engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=False, + ) + print(response) + +Please refer to `OpenAI's Python package `_ +and `OpenAI chat completion API `_ +for the complete chat completion interface. + + +.. _python-engine-async-llm-engine: + +Run AsyncMLCEngine +------------------ + +:class:`mlc_llm.AsyncMLCEngine` provides the interface of OpenAI chat completion with +asynchronous features. +**We recommend using** :class:`mlc_llm.AsyncMLCEngine` **to batch concurrent request for better throughput.** + +**Stream Response.** The core use of :class:`mlc_llm.AsyncMLCEngine` for stream responses is as follows. + +.. code:: python + + async for response in await engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + +.. collapse:: The collapsed is a complete runnable example of AsyncMLCEngine in Python. + + .. code:: python + + import asyncio + from typing import Dict + + from mlc_llm.serve import AsyncMLCEngine + + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + prompts = [ + "Write a three-day travel plan to Pittsburgh.", + "What is the meaning of life?", + ] + + + async def test_completion(): + # Create engine + async_engine = AsyncMLCEngine(model=model) + + num_requests = len(prompts) + output_texts: Dict[str, str] = {} + + async def generate_task(prompt: str): + async for response in await async_engine.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + model=model, + stream=True, + ): + if response.id not in output_texts: + output_texts[response.id] = "" + output_texts[response.id] += response.choices[0].delta.content + + tasks = [asyncio.create_task(generate_task(prompts[i])) for i in range(num_requests)] + await asyncio.gather(*tasks) + + # Print output. + for request_id, output in output_texts.items(): + print(f"Output of request {request_id}:\n{output}\n") + + async_engine.terminate() + + + asyncio.run(test_completion()) + +| + +**Non-stream Response.** Similarly, :class:`mlc_llm.AsyncEngine` provides the non-stream response +interface. + +.. code:: python + + response = await engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=False, + ) + print(response) + +Please refer to `OpenAI's Python package `_ +and `OpenAI chat completion API `_ +for the complete chat completion interface. + + +Engine Mode +----------- + +To ease the engine configuration, the constructors of :class:`mlc_llm.MLCEngine` and +:class:`mlc_llm.AsyncMLCEngine` have an optional argument ``mode``, +which falls into one of the three options ``"local"``, ``"interactive"`` or ``"server"``. +The default mode is ``"local"``. + +Each mode denotes a pre-defined configuration of the engine to satisfy different use cases. +The choice of the mode controls the request concurrency of the engine, +as well as engine's KV cache token capacity (or in other words, the maximum +number of tokens that the engine's KV cache can hold), +and further affects the GPU memory usage of the engine. + +In short, + +- mode ``"local"`` uses low request concurrency and low KV cache capacity, which is suitable for cases where **concurrent requests are not too many, and the user wants to save GPU memory usage**. +- mode ``"interactive"`` uses 1 as the request concurrency and low KV cache capacity, which is designed for **interactive use cases** such as chats and conversations. +- mode ``"server"`` uses as much request concurrency and KV cache capacity as possible. This mode aims to **fully utilize the GPU memory for large server scenarios** where concurrent requests may be many. + +**For system benchmark, please select mode** ``"server"``. +Please refer to :ref:`python-engine-api-reference` for detailed documentation of the engine mode. + + +Deploy Your Own Model with Python API +------------------------------------- + +The :ref:`introduction page ` introduces how we can deploy our +own models with MLC LLM. +This section introduces how you can use the model weights you convert and the model library you build +in :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine`. + +We use the `Phi-2 `_ as the example model. + +**Specify Model Weight Path.** Assume you have converted the model weights for your own model, +you can construct a :class:`mlc_llm.MLCEngine` as follows: + +.. code:: python + + from mlc_llm import MLCEngine + + model = "models/phi-2" # Assuming the converted phi-2 model weights are under "models/phi-2" + engine = MLCEngine(model) + + +**Specify Model Library Path.** Further, if you build the model library on your own, +you can use it in :class:`mlc_llm.MLCEngine` by passing the library path through argument ``model_lib_path``. + +.. code:: python + + from mlc_llm import MLCEngine + + model = "models/phi-2" + model_lib_path = "models/phi-2/lib.so" # Assuming the phi-2 model library is built at "models/phi-2/lib.so" + engine = MLCEngine(model, model_lib_path=model_lib_path) + + +The same applies to :class:`mlc_llm.AsyncMLCEngine`. + + +.. _python-engine-api-reference: + +API Reference +------------- + +The :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine` classes provide the following constructors. + +The MLCEngine and AsyncMLCEngine have full OpenAI API completeness. +Please refer to `OpenAI's Python package `_ +and `OpenAI chat completion API `_ +for the complete chat completion interface. + +.. currentmodule:: mlc_llm + +.. autoclass:: MLCEngine + :members: + :exclude-members: evaluate + :undoc-members: + :show-inheritance: + + .. automethod:: __init__ + +.. autoclass:: AsyncMLCEngine + :members: + :exclude-members: evaluate + :undoc-members: + :show-inheritance: + + .. automethod:: __init__ diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index e59abc1257..07d39dbfad 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -1,10 +1,6 @@ .. _deploy-rest-api: -<<<<<<< HEAD -Rest API -======= REST API ->>>>>>> upstream/main ======== .. contents:: Table of Contents @@ -17,18 +13,6 @@ for a user to interact with MLC-LLM in their own programs. Install MLC-LLM Package ------------------------ -<<<<<<< HEAD -SERVE is a part of the MLC-Chat package, installation instruction for which we be found here :doc:`<../install/mlc_llm>`. - -Verify Installation -^^^^^^^^^^^^^^^^^^^ - -.. code:: bash - - python -m mlc_llm.serve.server --help - -You are expected to see the help information of the MLC SERVE. -======= SERVE is a part of the MLC-LLM package, installation instruction for which can be found :ref:`here `. Once you have install the MLC-LLM package, you can run the following command to check if the installation was successful: .. code:: bash @@ -36,13 +20,10 @@ SERVE is a part of the MLC-LLM package, installation instruction for which can b mlc_llm serve --help You should see serve help message if the installation was successful. ->>>>>>> upstream/main Quick start ------------ -<<<<<<< HEAD -======= This section provides a quick start guide to work with MLC-LLM REST API. To launch a server, run the following command: .. code:: bash @@ -77,24 +58,15 @@ Once you have launched the Server, you can use the API in your own program to se .. _rest_launch_server: ->>>>>>> upstream/main Launch the Server ----------------- -<<<<<<< HEAD -To launch the MLC Server for MLC-Chat, run the following command in your terminal. - -.. code:: bash - - python -m mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] -======= To launch the MLC Server for MLC-LLM, run the following command in your terminal. .. code:: bash mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] ->>>>>>> upstream/main MODEL The model folder after compiling with MLC-LLM build process. The parameter can either be the model name with its quantization scheme @@ -131,11 +103,7 @@ The REST API provides the following endpoints: ------------------------------------------------ -<<<<<<< HEAD - Get a list of models available for MLC-Chat. -======= Get a list of models available for MLC-LLM. ->>>>>>> upstream/main **Example** @@ -153,118 +121,8 @@ The REST API provides the following endpoints: print(response.json()) else: print("Error:", response.status_code) -<<<<<<< HEAD -.. http:post:: /v1/chat/completions -======= ->>>>>>> upstream/main - - -<<<<<<< HEAD - Get a response from MLC-Chat using a prompt, either with or without streaming. - -**Chat Completion Request Object** - -- **messages** (*List[ChatCompletionMessage]*, required): A sequence of messages that have been exchanged in the conversation so far. Each message in the conversation is represented by a `ChatCompletionMessage` object, which includes the following fields: - - **content** (*Optional[Union[str, List[Dict[str, str]]]]*): The text content of the message or structured data in case of tool-generated messages. - - **role** (*Literal["system", "user", "assistant", "tool"]*): The role of the message sender, indicating whether the message is from the system, user, assistant, or a tool. - - **name** (*Optional[str]*): An optional name for the sender of the message. - - **tool_calls** (*Optional[List[ChatToolCall]]*): A list of calls to external tools or functions made within this message, applicable when the role is `tool`. - - **tool_call_id** (*Optional[str]*): A unique identifier for the tool call, relevant when integrating external tools or services. - -- **model** (*str*, required): The model to be used for generating responses. - -- **frequency_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model’s likelihood to repeat tokens. - -- **presence_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens if they are already present in the text so far, decreasing the model’s likelihood to repeat tokens. - -- **logprobs** (*bool*, optional, default=False): Indicates whether to include log probabilities for each token in the response. - -- **top_logprobs** (*int*, optional, default=0): An integer ranging from 0 to 5. It determines the number of tokens, most likely to appear at each position, to be returned. Each token is accompanied by a log probability. If this parameter is used, 'logprobs' must be set to true. - -- **logit_bias** (*Optional[Dict[int, float]]*): Allows specifying biases for or against specific tokens during generation. - -- **max_tokens** (*Optional[int]*): The maximum number of tokens to generate in the response(s). - -- **n** (*int*, optional, default=1): Number of responses to generate for the given prompt. - -- **seed** (*Optional[int]*): A seed for deterministic generation. Using the same seed and inputs will produce the same output. - -- **stop** (*Optional[Union[str, List[str]]]*): One or more strings that, if encountered, will cause generation to stop. - -- **stream** (*bool*, optional, default=False): If `True`, responses are streamed back as they are generated. - -- **temperature** (*float*, optional, default=1.0): Controls the randomness of the generation. Lower values lead to less random completions. - -- **top_p** (*float*, optional, default=1.0): Nucleus sampling parameter that controls the diversity of the generated responses. - -- **tools** (*Optional[List[ChatTool]]*): Specifies external tools or functions that can be called as part of the chat. - -- **tool_choice** (*Optional[Union[Literal["none", "auto"], Dict]]*): Controls how tools are selected for use in responses. - -- **user** (*Optional[str]*): An optional identifier for the user initiating the request. - -- **ignore_eos** (*bool*, optional, default=False): If `True`, the model will ignore the end-of-sequence token for generating responses. - -- **response_format** (*RequestResponseFormat*, optional): Specifies the format of the response. Can be either "text" or "json_object", with optional schema definition for JSON responses. - -**Returns** - -- If `stream` is `False`, a `ChatCompletionResponse` object containing the generated response(s). -- If `stream` is `True`, a stream of `ChatCompletionStreamResponse` objects, providing a real-time feed of generated responses. - - -**ChatCompletionResponseChoice** - -- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls", "error"]]*, optional): The reason the completion process was terminated. It can be due to reaching a stop condition, the maximum length, output of tool calls, or an error. - -- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. - -- **message** (*ChatCompletionMessage*, required): The message part of the chat completion, containing the content of the chat response. - -- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token - -**ChatCompletionStreamResponseChoice** - -- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls"]]*, optional): Specifies why the streaming completion process ended. Valid reasons are "stop", "length", and "tool_calls". - -- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. - -- **delta** (*ChatCompletionMessage*, required): Represents the incremental update or addition to the chat completion message in the stream. - -- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token - -**ChatCompletionResponse** - -- **id** (*str*, required): A unique identifier for the chat completion session. - -- **choices** (*List[ChatCompletionResponseChoice]*, required): A collection of `ChatCompletionResponseChoice` objects, representing the potential responses generated by the model. - -- **created** (*int*, required, default=current time): The UNIX timestamp representing when the response was generated. - -- **model** (*str*, required): The name of the model used to generate the chat completions. - -- **system_fingerprint** (*str*, required): A system-generated fingerprint that uniquely identifies the computational environment. - -- **object** (*Literal["chat.completion"]*, required, default="chat.completion"): A string literal indicating the type of object, here always "chat.completion". - -- **usage** (*UsageInfo*, required, default=empty `UsageInfo` object): Contains information about the API usage for this specific request. - -**ChatCompletionStreamResponse** - -- **id** (*str*, required): A unique identifier for the streaming chat completion session. - -- **choices** (*List[ChatCompletionStreamResponseChoice]*, required): A list of `ChatCompletionStreamResponseChoice` objects, each representing a part of the streaming chat response. - -- **created** (*int*, required, default=current time): The creation time of the streaming response, represented as a UNIX timestamp. - -- **model** (*str*, required): Specifies the model that was used for generating the streaming chat completions. - -- **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions. - -- **object** (*Literal["chat.completion.chunk"]*, required, default="chat.completion.chunk"): A literal indicating that this object represents a chunk of a streaming chat completion. -======= .. http:post:: /v1/chat/completions ------------------------------------------------ @@ -343,7 +201,6 @@ The REST API provides the following endpoints: - **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token **ChatCompletionResponse** ->>>>>>> upstream/main - **id** (*str*, required): A unique identifier for the chat completion session. @@ -359,10 +216,8 @@ The REST API provides the following endpoints: - **usage** (*UsageInfo*, required, default=empty `UsageInfo` object): Contains information about the API usage for this specific request. +**ChatCompletionStreamResponse** -<<<<<<< HEAD -**Example** -======= - **id** (*str*, required): A unique identifier for the streaming chat completion session. - **choices** (*List[ChatCompletionStreamResponseChoice]*, required): A list of `ChatCompletionStreamResponseChoice` objects, each representing a part of the streaming chat response. @@ -374,69 +229,14 @@ The REST API provides the following endpoints: - **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions. - **object** (*Literal["chat.completion.chunk"]*, required, default="chat.completion.chunk"): A literal indicating that this object represents a chunk of a streaming chat completion. ->>>>>>> upstream/main - -Once you have launched the Server, you can use the API in your own program. Below is an example of using the API to interact with MLC-Chat in Python without Streaming (suppose the server is running on ``http://127.0.0.1:8080/``): -.. code:: bash +------------------------------------------------ -<<<<<<< HEAD - import requests - - # Get a response using a prompt without streaming - payload = { - "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", - "messages": [ - {"role": "user", "content": "Hello! Our project is MLC LLM."}, - { - "role": "assistant", - "content": "Hello! It's great to hear about your project, MLC LLM.", - }, - {"role": "user", "content": "What is the name of our project?"}, - ], - "stream": False, - # "n": 1, - "max_tokens": 300, - } - r = requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload) - choices = r.json()["choices"] - for choice in choices: - print(f"{choice['message']['content']}\n") -======= **Example** ->>>>>>> upstream/main Below is an example of using the API to interact with MLC-LLM in Python with Streaming. -<<<<<<< HEAD -Below is an example of using the API to interact with MLC-Chat in Python with Streaming. - -.. code:: bash - - import requests - import json - - # Get a response using a prompt with streaming - payload = { - "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", - "messages": [{"role": "user", "content": "Write a haiku"}], - "stream": True, - } - with requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload, stream=True) as r: - for chunk in r.iter_content(chunk_size=None): - chunk = chunk.decode("utf-8") - if "[DONE]" in chunk[6:]: - break - response = json.loads(chunk[6:]) - content = response["choices"][0]["delta"].get("content", "") - print(content, end="", flush=True) - print("\n") - ------------------------------------------------- - - -======= .. code:: bash import requests @@ -460,7 +260,6 @@ Below is an example of using the API to interact with MLC-Chat in Python with St ------------------------------------------------ ->>>>>>> upstream/main There is also support for function calling similar to OpenAI (https://platform.openai.com/docs/guides/function-calling). Below is an example on how to use function calling in Python. .. code:: bash diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index 282b4764c2..29060d5a60 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -32,12 +32,12 @@ You are expected to see the installation path of MLC LLM Python package. Chat CLI -------- -As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 7B Llama-2 model. +As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. You can run MLC chat through a one-liner command: .. code:: bash - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC It may take 1-2 minutes for the first time running this command. After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. @@ -54,17 +54,19 @@ After waiting, this command launch a chat interface where you can enter your pro Note: Separate stop words in the `stop` option with commas (,). Multi-line input: Use escape+enter to start a new line. - [INST]: What's the meaning of life? - [/INST]: - Ah, a question that has puzzled philosophers and theologians for centuries! ... + user: What's the meaning of life + assistant: + What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + + The concept of the meaning of life has been debated and... The figure below shows what run under the hood of this chat CLI command. For the first time running the command, there are three major phases. -- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-2 model from `Hugging Face `_ and saves it to your local cache directory. -- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-2 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM `_ compiler, and generate the binary model library that enables the execution language models on your local GPU. -- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-2 model. +- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-3 model from `Hugging Face `_ and saves it to your local cache directory. +- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-3 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM `_ compiler, and generate the binary model library that enables the execution language models on your local GPU. +- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-3 model. We cache the pre-quantized model weights and compiled model library locally. Therefore, phase 1 and 2 will only execute **once** over multiple runs. @@ -83,16 +85,16 @@ Therefore, phase 1 and 2 will only execute **once** over multiple runs. Python API ---------- -In the second example, we run the Llama-2 model with the chat completion Python API of MLC LLM. +In the second example, we run the Llama-3 model with the chat completion Python API of MLC LLM. You can save the code below into a Python file and run it. .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # Create engine - model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" - engine = LLMEngine(model) + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( @@ -112,9 +114,9 @@ You can save the code below into a Python file and run it. MLC LLM Python API -This code example first creates an :class:`mlc_llm.LLMEngine` instance with the the 4-bit quantized Llama-2 model. -**We design the Python API** :class:`mlc_llm.LLMEngine` **to align with OpenAI API**, -which means you can use :class:`mlc_llm.LLMEngine` in the same way of using +This code example first creates an :class:`mlc_llm.MLCEngine` instance with the 4-bit quantized Llama-3 model. +**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.MLCEngine` in the same way of using `OpenAI's Python package `_ for both synchronous and asynchronous generation. @@ -132,17 +134,17 @@ If you want to run without streaming, you can run print(response) You can also try different arguments supported in `OpenAI chat completion API `_. -If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncLLMEngine` instead. +If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncMLCEngine` instead. REST Server ----------- -For the third example, we launch a REST server to serve the 4-bit quantized Llama-2 model +For the third example, we launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. The server can be launched in command line with .. code:: bash - mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC The server is hooked at ``http://127.0.0.1:8000`` by default, and you can use ``--host`` and ``--port`` to set a different host and port. @@ -154,7 +156,7 @@ we can open a new shell and send a cURL request via the following command: curl -X POST \ -H "Content-Type: application/json" \ -d '{ - "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", "messages": [ {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} ] @@ -165,6 +167,7 @@ The server will process this request and send back the response. Similar to :ref:`introduction-to-mlc-llm-python-api`, you can pass argument ``"stream": true`` to request for stream responses. +.. _introduction-deploy-your-own-model: Deploy Your Own Model --------------------- @@ -226,7 +229,7 @@ You can also use this model in Python API, MLC serve and other use scenarios. (Optional) Compile Model Library ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In previous sections, model libraries are compiled when the :class:`mlc_llm.LLMEngine` launches, +In previous sections, model libraries are compiled when the :class:`mlc_llm.MLCEngine` launches, which is what we call "JIT (Just-in-Time) model compilation". In some cases, it is beneficial to explicitly compile the model libraries. We can deploy LLMs with reduced dependencies by shipping the library for deployment without going through compilation. @@ -254,12 +257,12 @@ At runtime, we need to specify this model library path to use it. For example, .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # For Python API model = "models/phi-2" model_lib_path = "models/phi-2/lib.so" - engine = LLMEngine(model, model_lib_path=model_lib_path) + engine = MLCEngine(model, model_lib_path=model_lib_path) :ref:`compile-model-libraries` introduces the model compilation command in detail, where you can find instructions and example commands to compile model to different @@ -280,7 +283,7 @@ environments (e.g. SteamDeck). .. code:: bash - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device vulkan + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device vulkan The same core LLM runtime engine powers all the backends, enabling the same model to be deployed across backends as long as they fit within the memory and computing budget of the corresponding hardware backend. @@ -298,7 +301,7 @@ To briefly summarize this page, - We went through three examples (chat CLI, Python API, and REST server) of MLC LLM, - we introduced how to convert model weights for your own models to run with MLC LLM, and (optionally) how to compile your models. -- We also discussed the the universal deployment capability of MLC LLM. +- We also discussed the universal deployment capability of MLC LLM. Next, please feel free to check out the pages below for quick start examples and more detailed information on specific platforms diff --git a/docs/get_started/quick_start.rst b/docs/get_started/quick_start.rst index bd3b41218e..8349197eda 100644 --- a/docs/get_started/quick_start.rst +++ b/docs/get_started/quick_start.rst @@ -6,7 +6,7 @@ Quick Start Examples -------- -To begin with, try out MLC LLM support for int4-quantized Llama2 7B. +To begin with, try out MLC LLM support for int4-quantized Llama3 8B. It is recommended to have at least 6GB free VRAM to run it. .. tabs:: @@ -20,11 +20,11 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # Create engine - model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" - engine = LLMEngine(model) + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( @@ -57,7 +57,7 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: shell - mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC **Send requests to server.** When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), open a new shell and send a request via the following command: @@ -67,7 +67,7 @@ It is recommended to have at least 6GB free VRAM to run it. curl -X POST \ -H "Content-Type: application/json" \ -d '{ - "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", "messages": [ {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} ] @@ -94,7 +94,7 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: bash - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC If you are using windows/linux/steamdeck and would like to use vulkan, @@ -133,7 +133,7 @@ It is recommended to have at least 6GB free VRAM to run it. | - **Requirement**. Llama2-7B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + **Requirement**. Llama3-8B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. **Tutorial and source code**. The source code of the iOS app is fully `open source `__, and a :ref:`tutorial ` is included in documentation. @@ -154,7 +154,7 @@ It is recommended to have at least 6GB free VRAM to run it. | - **Requirement**. Llama2-7B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + **Requirement**. Llama3-8B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. The demo is tested on - Samsung S23 with Snapdragon 8 Gen 2 chip diff --git a/docs/index.rst b/docs/index.rst index e9835e152d..2d5597d18e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,7 +46,6 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a compilation/convert_weights.rst compilation/compile_models.rst compilation/define_new_models.rst - compilation/configure_quantization.rst .. toctree:: :maxdepth: 1 diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index c6602559ae..ce15616957 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -118,6 +118,13 @@ Select your operating system/compute platform and run the command in your termin python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly .. note:: + Make sure you also install vulkan loader and clang to avoid vulkan + not found error or clang not found(needed for jit compile) + + .. code-block:: bash + + conda install -c conda-forge clang libvulkan-loader + If encountering the error below: .. code-block:: bash @@ -207,7 +214,9 @@ There are two ways to do so: .. code-tab :: bash Install via environment variable - export PYTHONPATH=/path-to-mlc-llm/python:$PYTHONPATH + export MLC_LLM_HOME=/path-to-mlc-llm + export PYTHONPATH=$MLC_LLM_HOME/python:$PYTHONPATH + alias mlc_llm="python -m mlc_llm" .. code-tab :: bash Install via pip local project diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index 849152cce6..ed4977e5e3 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -112,6 +112,13 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. note:: + Make sure you also install vulkan loader and clang to avoid vulkan + not found error or clang not found(needed for jit compile) + + .. code-block:: bash + + conda install -c conda-forge clang libvulkan-loader + If encountering the error below: .. code-block:: bash @@ -213,7 +220,7 @@ While it is generally recommended to always use the prebuilt TVM Unity, if you r If you are using CUDA and your compute capability is above 80, then it is require to build with ``set(USE_FLASHINFER ON)``. Otherwise, you may run into ``Cannot find PackedFunc`` issue during runtime. - + To check your CUDA compute capability, you can use ``nvidia-smi --query-gpu=compute_cap --format=csv``. Once ``config.cmake`` is edited accordingly, kick off build with the commands below: diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst index f97909a515..2f772a5d7e 100644 --- a/docs/prebuilt_models.rst +++ b/docs/prebuilt_models.rst @@ -68,7 +68,7 @@ For more, please see :ref:`the CLI page `, and the :ref:`the Python .. code:: shell - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC To run the model with Python API, see :ref:`the Python page ` (all other downloading steps are the same as CLI). diff --git a/docs/requirements.txt b/docs/requirements.txt index bc020bc662..0156a180b0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,5 +6,9 @@ tlcpack-sphinx-addon==0.2.2 sphinxcontrib_httpdomain==1.8.1 sphinxcontrib-napoleon==0.7 sphinx-reredirects==0.1.2 +shortuuid +pydantic +uvicorn +fastapi --find-links https://mlc.ai/wheels mlc-ai-nightly diff --git a/examples/python/sample_mlc_engine.py b/examples/python/sample_mlc_engine.py index e26e17f1e2..e4f869930f 100644 --- a/examples/python/sample_mlc_engine.py +++ b/examples/python/sample_mlc_engine.py @@ -1,8 +1,8 @@ -from mlc_llm import LLMEngine +from mlc_llm import MLCEngine # Create engine -model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" -engine = LLMEngine(model) +model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" +engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 943f98c7e2..24ad8faecf 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -664,7 +664,7 @@ def _inspect_model_lib_metadata_memory_usage(model_lib_path, config_file_path): "--mlc-chat-config", config_file_path, ] - subprocess.run(cmd, check=False) + subprocess.run(cmd, check=False, env=os.environ) class ChatModule: # pylint: disable=too-many-instance-attributes @@ -768,7 +768,7 @@ def __init__( # pylint: disable=too-many-arguments self.chat_config = _get_chat_config(self.config_file_path, chat_config) # 4. Look up model library - try: + if model_lib_path is not None: self.model_lib_path = _get_lib_module_path( model, self.model_path, @@ -777,8 +777,8 @@ def __init__( # pylint: disable=too-many-arguments self.device.MASK2STR[self.device.device_type], self.config_file_path, ) - except FileNotFoundError: - logger.info("Model lib not found. Now compiling model lib on device...") + else: + logger.info("Now compiling model lib on device...") from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel self.model_lib_path = str( diff --git a/python/mlc_llm/cli/delivery.py b/python/mlc_llm/cli/delivery.py index 50b9c7e170..a7dd6408b0 100644 --- a/python/mlc_llm/cli/delivery.py +++ b/python/mlc_llm/cli/delivery.py @@ -1,7 +1,9 @@ """Continuous model delivery for MLC LLM models.""" + import argparse import dataclasses import json +import os import shutil import subprocess import sys @@ -131,7 +133,9 @@ def _run_quantization( cmd += ["--" + optional_arg.replace("_", "-"), str(optional_arg_val)] print(" ".join(cmd), file=log_file, flush=True) - subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + subprocess.run( + cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ + ) cmd = [ sys.executable, "-m", @@ -146,7 +150,9 @@ def _run_quantization( output_dir, ] print(" ".join(cmd), file=log_file, flush=True) - subprocess.run(cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT) + subprocess.run( + cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ + ) logger.info("[MLC] Complete!") if not (Path(output_dir) / "ndarray-cache.json").exists(): logger.error( diff --git a/python/mlc_llm/cli/lib_delivery.py b/python/mlc_llm/cli/lib_delivery.py new file mode 100644 index 0000000000..a5d678fbe2 --- /dev/null +++ b/python/mlc_llm/cli/lib_delivery.py @@ -0,0 +1,200 @@ +"""Continuous model delivery for MLC LLM models.""" + +import argparse +import dataclasses +import json +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, Callable, Dict, List + +from mlc_llm.support import logging +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.constants import MLC_TEMP_DIR +from mlc_llm.support.style import bold, green, red + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ModelInfo: # pylint: disable=too-many-instance-attributes + """Necessary information for the model delivery""" + + model_id: str + model: Path + quantization: str + device: str + # overrides the `context_window_size`, `prefill_chunk_size`, + # `sliding_window_size`, `attention_sink_size`, `max_batch_size` + # and `tensor_parallel_shards in mlc-chat-config.json + overrides: Dict[str, int] + + +class DeferredScope: + """A context manager that defers execution of functions until exiting the scope.""" + + def __init__(self): + self.deferred_functions = [] + + def add(self, func: Callable[[], None]): + """Add a function to be executed when exiting the scope.""" + self.deferred_functions.append(func) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + for func in reversed(self.deferred_functions): + func() + return False + + def create_temp_dir(self) -> Path: + """Create a temporary directory that will be deleted when exiting the scope.""" + temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR) + self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) + return Path(temp_dir) + + +def _run_compilation(model_info: ModelInfo, repo_dir: Path) -> bool: + """Run the compilation of the model library.""" + + def get_lib_ext(device: str) -> str: + if device in ["cuda", "vulkan", "metal"]: + return ".so" + if device in ["android", "ios"]: + return ".tar" + if device in ["webgpu"]: + return ".wasm" + + return "" + + succeeded = True + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as temp_dir: + log_path = Path(temp_dir) / "logs.txt" + model_lib_name = f"{model_info.model_id}-{model_info.quantization}-{model_info.device}" + lib_ext = get_lib_ext(model_info.device) + if lib_ext == "": + raise ValueError(f"Unsupported device: {model_info.device}") + model_lib_name += lib_ext + with log_path.open("a", encoding="utf-8") as log_file: + overrides = ";".join(f"{key}={value}" for key, value in model_info.overrides.items()) + cmd = [ + sys.executable, + "-m", + "mlc_llm", + "compile", + str(model_info.model), + "--device", + model_info.device, + "--quantization", + model_info.quantization, + "--overrides", + overrides, + "--output", + os.path.join(temp_dir, model_lib_name), + ] + print(" ".join(cmd), file=log_file, flush=True) + subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + logger.info("[MLC] Compilation Complete!") + if not (Path(temp_dir) / model_lib_name).exists(): + logger.error( + "[%s] Model %s. Device %s. No compiled library found.", + red("FAILED"), + model_info.model_id, + model_info.device, + ) + succeeded = False + return succeeded + + # overwrite git repo file with the compiled library + repo_filepath = repo_dir / model_info.model_id / model_lib_name + if not repo_filepath.parent.exists(): + repo_filepath.parent.mkdir(parents=True, exist_ok=True) + # copy lib from Path(temp_dir) / model_lib_name to repo_filepath + shutil.copy(Path(temp_dir) / model_lib_name, repo_filepath) + logger.info("Saved library %s at %s", model_lib_name, repo_filepath) + return succeeded + + +def _main( # pylint: disable=too-many-locals + spec: Dict[str, Any], +): + """Compile the model libs in the spec and save them to the binary_libs_dir.""" + failed_cases: List[Any] = [] + for task_index, task in enumerate(spec["tasks"], 1): + logger.info( + bold("[{task_index}/{total_tasks}] Processing model: ").format( + task_index=task_index, + total_tasks=len(spec["tasks"]), + ) + + green(task["model_id"]) + ) + model_info = { + "model_id": task["model_id"], + "model": task["model"], + } + for compile_opt in spec["default_compile_options"] + task.get("compile_options", []): + for quantization in spec["default_quantization"] + task.get("quantization", []): + model_info["quantization"] = quantization + model_info["device"] = compile_opt["device"] + model_info["overrides"] = compile_opt.get("overrides", {}) + logger.info( + "[Config] " + + bold("model_id: ") + + model_info["model_id"] + + bold(", quantization: ") + + model_info["quantization"] + + bold(", device: ") + + model_info["device"] + + bold(", overrides: ") + + json.dumps(model_info["overrides"]) + ) + + result = _run_compilation( + ModelInfo(**model_info), + repo_dir=Path(spec["binary_libs_dir"]), + ) + if not result: + failed_cases.append(model_info) + + if failed_cases: + logger.info("Total %s %s:", len(failed_cases), red("failures")) + for case in failed_cases: + logger.info( + "model_id %s, quantization %s, device %s, overrides %s", + case["model_id"], + case["quantization"], + case["device"], + json.dumps(case["overrides"]), + ) + + +def main(): + """Entry point.""" + + def _load_spec(path_spec: str) -> Dict[str, Any]: + path = Path(path_spec) + if not path.exists(): + raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}") + with path.open("r", encoding="utf-8") as i_f: + return json.load(i_f) + + parser = ArgumentParser("MLC LLM continuous library delivery") + parser.add_argument( + "--spec", + type=_load_spec, + required=True, + help="Path to the spec file", + ) + parsed = parser.parse_args() + _main( + spec=parsed.spec, + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_llm/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py index 9b45561665..81473b1ec7 100644 --- a/python/mlc_llm/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any, Dict, List, Union -import numpy as np +from tvm.runtime import DataType from mlc_llm.support import logging from mlc_llm.support.argparse import ArgumentParser @@ -81,7 +81,7 @@ def _compute_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBa else: # Contains dynamic shape; use config to look up concrete values param_shape = _read_dynamic_shape(param["shape"], config) - params_bytes += math.prod(param_shape) * np.dtype(param["dtype"]).itemsize + params_bytes += math.prod(param_shape) * DataType(param["dtype"]).itemsize() temp_func_bytes = 0.0 for _func_name, func_bytes in metadata["memory_usage"].items(): temp_func_bytes = max(temp_func_bytes, func_bytes) diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 9f7c1c3580..6663a0c230 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -44,6 +44,9 @@ def main(argv): "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] ) parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) + parser.add_argument( + "--max-history-size", type=int, default=1, help=HELP["max_history_size_serve"] + ) parser.add_argument( "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] ) @@ -100,6 +103,7 @@ def main(argv): max_batch_size=parsed.max_batch_size, max_total_sequence_length=parsed.max_total_seq_length, prefill_chunk_size=parsed.prefill_chunk_size, + max_history_size=parsed.max_history_size, gpu_memory_utilization=parsed.gpu_memory_utilization, speculative_mode=SpeculativeMode[parsed.speculative_mode], spec_draft_length=parsed.spec_draft_length, diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 1b7b0328a9..46dc40c106 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -7,6 +7,8 @@ from tvm.relax.frontend import nn from tvm.script import tir as T +from ..op.batch_spec_verify import batch_spec_verify + @tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc") class AttachGPUSamplingFunc: # pylint: disable=too-few-public-methods @@ -46,6 +48,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR _attach_argsort_func(bb, vocab_size), _attach_sample_with_top_p(bb, vocab_size), _attach_take_probs_func(bb, vocab_size), + _attach_batch_verifier(bb, vocab_size), + _attach_renormalize_by_top_p(bb, vocab_size), ] ] @@ -126,6 +130,17 @@ def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): return gv +@T.prim_func +def full(var_result: T.handle, value: T.int32): + """The filling function for top k.""" + batch_size = T.int32(is_size_var=True) + result = T.match_buffer(var_result, (batch_size, 1), "int32") + for i in T.serial(batch_size): + with T.block("block"): + vi = T.axis.spatial(batch_size, i) + result[vi, 0] = value + + def _attach_sample_with_top_p( # pylint: disable=too-many-locals bb: relax.BlockBuilder, vocab_size: tir.PrimExpr ): @@ -143,15 +158,6 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals sample_indices = relax.Var("sample_indices", relax.TensorStructInfo((num_samples,), "int32")) top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) - @T.prim_func - def full(var_result: T.handle, value: T.int32): - batch_size = T.int32(is_size_var=True) - result = T.match_buffer(var_result, (batch_size, 1), "int32") - for i in T.serial(batch_size): - with T.block("block"): - vi = T.axis.spatial(batch_size, i) - result[vi, 0] = value - with bb.function( "sample_with_top_p", [sorted_probs, sorted_indices, uniform_samples, sample_indices, top_p], @@ -221,6 +227,44 @@ def full(var_result: T.handle, value: T.int32): return gv +def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): + batch_size = tir.Var("batch_size", "int64") + probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) + sorted_probs = relax.Var( + "sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") + ) + top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) + with bb.function("renormalize_by_top_p", [probs, sorted_probs, top_p]): + with bb.dataflow(): + probs_tensor = nn.wrap_nested(probs, name="probs") + sorted_probs_tensor = nn.wrap_nested(sorted_probs, name="sorted_probs") + top_p_shape = relax.ShapeExpr([batch_size, 1]) + top_p_tensor = nn.wrap_nested( + relax.call_pure_packed( + "vm.builtin.reshape", + top_p, + top_p_shape, + sinfo_args=relax.TensorStructInfo(top_p_shape, "float32"), + ), + name="sample_indices", + ) + top_k_tensor = nn.tensor_ir_op( + full, + name_hint="full", + args=[vocab_size], + out=nn.Tensor.placeholder( + [batch_size, 1], + "int32", + ), + ) + renormalized_probs = nn.renormalize_top_p_top_k_prob( + probs_tensor, sorted_probs_tensor, top_p_tensor, top_k_tensor + ) + bb.emit_output(renormalized_probs._expr) # pylint: disable=protected-access + gv = bb.emit_func_output(renormalized_probs._expr) # pylint: disable=protected-access + return gv + + def _attach_take_probs_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): batch_size = tir.Var("batch_size", "int64") num_samples = tir.Var("num_samples", "int64") @@ -289,3 +333,50 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument bb.emit_output(taken_probs_indices) gv = bb.emit_func_output(taken_probs_indices) return gv + + +def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): + num_nodes = tir.Var("num_nodes", "int64") + nbatch = tir.Var("nbatch", "int64") + draft_probs = relax.Var( + "draft_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32") + ) + draft_tokens = relax.Var("draft_tokens", relax.TensorStructInfo((num_nodes,), "int32")) + model_probs = relax.Var( + "model_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32") + ) + token_tree_first_child = relax.Var( + "token_tree_first_child", relax.TensorStructInfo((num_nodes,), "int32") + ) + token_tree_next_sibling = relax.Var( + "token_tree_next_sibling", relax.TensorStructInfo((num_nodes,), "int32") + ) + uniform_samples = relax.Var("uniform_samples", relax.TensorStructInfo((num_nodes,), "float32")) + token_tree_parent_ptr = relax.Var( + "token_tree_parent_ptr", relax.TensorStructInfo((nbatch,), "int32") + ) + args = [ + draft_probs, + draft_tokens, + model_probs, + token_tree_first_child, + token_tree_next_sibling, + uniform_samples, + token_tree_parent_ptr, + ] + with bb.function("sampler_verify_draft_tokens", args): + with bb.dataflow(): + res = bb.emit( + relax.call_tir_inplace( + bb.add_func(batch_spec_verify(vocab_size), "batch_verify_on_gpu_single_kernel"), + args, + inplace_indices=[args.index(model_probs), args.index(token_tree_parent_ptr)], + out_sinfo=[ + model_probs.struct_info, # pylint: disable=no-member + token_tree_parent_ptr.struct_info, # pylint: disable=no-member + ], + ) + ) + bb.emit_output(res) + gv = bb.emit_func_output(res) + return gv diff --git a/python/mlc_llm/compiler_pass/estimate_memory_usage.py b/python/mlc_llm/compiler_pass/estimate_memory_usage.py index d69d99109d..83007fde66 100644 --- a/python/mlc_llm/compiler_pass/estimate_memory_usage.py +++ b/python/mlc_llm/compiler_pass/estimate_memory_usage.py @@ -25,6 +25,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR func_name = "_metadata" + func_name = "_metadata" + def _emit_metadata(metadata): bb = relax.BlockBuilder() # pylint: disable=invalid-name with bb.function(func_name, params=[]): diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index b85a6a2cf6..57b68f742d 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -33,6 +33,7 @@ from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc from .low_batch_specialization import LowBatchGemvSpecialize +from .rewrite_softmax import RewriteTwoStageSoftmax from .scatter_tuple_get_item import ScatterTupleGetItem logger = logging.getLogger(__name__) @@ -117,6 +118,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline _LogProgress("Lowering to TVM TIR kernels"), tvm.relax.backend.DispatchSortScan(), + RewriteTwoStageSoftmax(target=target), tvm.relax.transform.LegalizeOps(), tvm.relax.transform.AnnotateTIROpPattern(), tvm.relax.transform.FoldConstant(), diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py new file mode 100644 index 0000000000..1a6e41eafc --- /dev/null +++ b/python/mlc_llm/compiler_pass/rewrite_softmax.py @@ -0,0 +1,190 @@ +"""A compiler pass that rewrites one-shot softmax into two-stage softmax.""" + +import math + +import tvm +from tvm import relax +from tvm.ir.module import IRModule +from tvm.relax.expr import Expr +from tvm.relax.expr_functor import PyExprMutator, mutator +from tvm.script import tir as T + +from ..support.max_thread_check import get_max_num_threads_per_block + + +@tvm.transform.module_pass(opt_level=0, name="RewriteTwoStageSoftmax") +class RewriteTwoStageSoftmax: # pylint: disable=too-few-public-methods + """Rewrites one-shot softmax into two-stage softmax.""" + + def __init__(self, target: tvm.target.Target) -> None: + self.target = target + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _Rewriter(mod, self.target).transform() + + +@mutator +class _Rewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: + super().__init__(mod) + self.mod = mod + self.target = target + self.chunk_size = 4096 + + def transform(self) -> IRModule: + """Entry point""" + gv = self.mod.get_global_var("softmax_with_temperature") + updated_func = self.visit_expr(self.mod[gv]) + self.builder_.update_func(gv, updated_func) + return self.builder_.get() + + def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-renamed + if call.op != tvm.ir.Op.get("relax.nn.softmax"): + return call + x = call.args[0] + if call.attrs.axis not in [-1, x.struct_info.ndim - 1]: + return call + # Currently the softmax input is 3-dim, and dtype is float32. + assert x.struct_info.ndim == 3 + assert x.struct_info.dtype == "float32" + x_shape = x.struct_info.shape + new_shape = relax.ShapeExpr([x_shape[0] * x_shape[1], x_shape[2]]) + x_reshaped = relax.call_pure_packed( + "vm.builtin.reshape", + x, + new_shape, + sinfo_args=relax.TensorStructInfo(new_shape, x.struct_info.dtype), + ) + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(self.target, self.chunk_size) + chunked_lse = relax.call_tir( + self.builder_.add_func(f_chunk_lse, "chunk_lse"), + args=[x_reshaped], + out_sinfo=relax.TensorStructInfo( + (new_shape[0], (new_shape[1] + self.chunk_size - 1) // self.chunk_size), + x.struct_info.dtype, + ), + ) + softmax = relax.call_tir( + self.builder_.add_func(f_softmax_with_lse, "softmax_with_chunked_lse"), + args=[x_reshaped, chunked_lse], + out_sinfo=relax.TensorStructInfo(new_shape, x.struct_info.dtype), + ) + return relax.call_pure_packed( + "vm.builtin.reshape", softmax, x_shape, sinfo_args=x.struct_info + ) + + +def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements + target: tvm.target.Target, chunk_size: int +): + log2e = math.log2(math.exp(1)) + + # pylint: disable=invalid-name + @T.prim_func + def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") + A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(chunk_size)), dtype="float32") + temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + A_pad[v0, v1, v2] = T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + A[v0, v1 * T.int64(chunk_size) + v2], + T.min_value("float32"), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("max"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_max[v0, v1] = T.min_value("float32") + temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("sum_exp"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_sum[v0, v1] = T.float32(0) + temp_sum[v0, v1] += T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + T.exp2((A_pad[v0, v1, v2] - temp_max[v0, v1]) * log2e), + T.float32(0), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): + with T.block("log"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + chunked_lse[v0, v1] = T.log2(temp_sum[v0, v1]) + temp_max[v0, v1] * log2e + + @T.prim_func + def softmax_with_chunked_lse(var_A: T.handle, var_chunked_lse: T.handle, var_softmax: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") + softmax = T.match_buffer(var_softmax, (batch_size, vocab_size), dtype="float32") + temp_max = T.alloc_buffer((batch_size,), dtype="float32") + temp_sum = T.alloc_buffer((batch_size,), dtype="float32") + lse = T.alloc_buffer((batch_size,), dtype="float32") + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("max"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_max[v0] = T.min_value("float32") + temp_max[v0] = T.max(temp_max[v0], chunked_lse[v0, v1]) + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("sum_exp"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_sum[v0] = T.float32(0) + temp_sum[v0] += T.exp2(chunked_lse[v0, v1] - temp_max[v0]) + for l0 in T.serial(0, batch_size): + with T.block("log"): + v0 = T.axis.remap("S", [l0]) + lse[v0] = T.log2(temp_sum[v0]) + temp_max[v0] + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + if v1 * T.int64(chunk_size) + v2 < vocab_size: + softmax[v0, v1 * T.int64(chunk_size) + v2] = T.exp2( + A[v0, v1 * T.int64(chunk_size) + v2] * log2e - lse[v0] + ) + + sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_lse": softmax_with_chunked_lse})) + max_threads = get_max_num_threads_per_block(target) + TX = 32 + TY = max_threads // TX + unroll_depth = 64 + # pylint: enable=invalid-name + + sch.work_on("softmax_with_chunked_lse") + sch.compute_inline("log") + l0, l1, l2 = sch.get_loops("pad") + bx = sch.fuse(l0, l1) + sch.bind(bx, "blockIdx.x") + unroll, ty, tx = sch.split(l2, [None, TY, TX]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1) + + for block_name in ["sum_exp", "max"]: + block = sch.get_block(block_name) + sch.set_scope(block, buffer_index=0, storage_scope="shared") + sch.compute_at(block, bx) + r_loop = sch.get_loops(block)[-1] + r_loop, tx = sch.split(r_loop, [None, TX]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + return chunk_lse, sch.mod["softmax_with_chunked_lse"] diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 1b2a06feab..1c599fa875 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -36,6 +36,27 @@ def get_conv_template(name: str) -> Optional[Conversation]: ############## Preset Conversation Templates ############## +# Llama3 +# See https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models +# and https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py +ConvTemplateRegistry.register_conv_template( + Conversation( + name="llama-3", + system_template=( + f"<|start_header_id|>system<|end_header_id|>\n\n{MessagePlaceholders.SYSTEM.value}" + ), + system_message="You are a helpful, respectful and honest assistant.", + roles={"user": "user", "assistant": "assistant"}, + seps=["<|eot_id|><|start_header_id|>"], + role_content_sep="<|end_header_id|>\n\n", + role_empty_sep="<|end_header_id|>\n\n", + stop_str=["<|end_of_text|>", "<|eot_id|>"], + stop_token_ids=[128001, 128009], # "<|end_of_text|>", "<|eot_id|>" + system_prefix_token_ids=[128000], # "<|begin_of_text|>" + add_role_after_system_message=True, + ) +) + # Llama2 ConvTemplateRegistry.register_conv_template( Conversation( @@ -344,7 +365,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: # RWKV World ConvTemplateRegistry.register_conv_template( Conversation( - name="rwkv-world", + name="rwkv_world", system_template=f"User: hi\n\nAssistant: {MessagePlaceholders.SYSTEM.value}", system_message=( "Hi. I am your assistant and I will provide expert full response " diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index b4321ebdec..86930fa5ea 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -152,6 +152,11 @@ The maximum number of tokens the model passes for prefill each time. It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config. +""".strip(), + "max_history_size_serve": """ +The maximum history length for rolling back the RNN state. +If unspecified, the default value is 1. +KV cache does not need this. """.strip(), "enable_tracing_serve": """ Enable Chrome Tracing for the server. @@ -188,7 +193,7 @@ "gpu_memory_utilization_serve": """ A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. -When it is unspecified, it defaults to 0.90. +When it is unspecified, it defaults to 0.85. Under mode "local" or "interactive", the actual memory usage may be significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. """, @@ -203,7 +208,7 @@ The number of draft tokens to generate in speculative proposal. The default values is 4. """, "engine_config_serve": """ -The LLMEngine execution configuration. +The MLCEngine execution configuration. Currently speculative decoding mode is specified via engine config. For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to specify the eagle-style speculative decoding. diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 2d0d668672..77b611d139 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -2,7 +2,6 @@ import dataclasses import enum -import re from io import StringIO from typing import Optional diff --git a/python/mlc_llm/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py index b54318ef4c..179c872e50 100644 --- a/python/mlc_llm/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -7,10 +7,9 @@ from pathlib import Path from typing import Any, Dict, Iterator, Tuple -import numpy as np from tvm import tir from tvm.contrib import tvmjs -from tvm.runtime import Device, NDArray +from tvm.runtime import DataType, Device, NDArray from tvm.runtime import cpu as cpu_device from tvm.target import Target diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index d22aa7d231..8e617fc3d2 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -274,6 +274,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b # FIXME: Copy RWKV tokenizer file # pylint: disable=fixme CONV_TEMPLATES = { + "llama-3", "chatml", "open_hermes_mistral", "neural_hermes_mistral", diff --git a/python/mlc_llm/interface/jit.py b/python/mlc_llm/interface/jit.py index 25548e0e4a..e999a36468 100644 --- a/python/mlc_llm/interface/jit.py +++ b/python/mlc_llm/interface/jit.py @@ -93,7 +93,11 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): ] logger.info("Compiling using commands below:") logger.info("%s", blue(shlex.join(cmd))) - subprocess.run(cmd, check=True) + subprocess.run(cmd, check=False, env=os.environ) + # note on windows: compilation can succeed but return code is still nonzero + # check whether file exists instead + if not os.path.isfile(dso_path): + raise RuntimeError("Cannot find compilation output, compilation failed") shutil.move(dso_path, dst) logger.info("Using compiled model lib: %s", bold(dst)) diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index c5696ef473..40fa9fdda8 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -22,6 +22,7 @@ def serve( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -35,7 +36,7 @@ def serve( ): # pylint: disable=too-many-arguments, too-many-locals """Serve the model with the specified configuration.""" # Create engine and start the background loop - async_engine = engine.AsyncLLMEngine( + async_engine = engine.AsyncMLCEngine( model=model, device=device, model_lib_path=model_lib_path, @@ -44,6 +45,7 @@ def serve( max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, diff --git a/python/mlc_llm/json_ffi/__init__.py b/python/mlc_llm/json_ffi/__init__.py new file mode 100644 index 0000000000..8a7059153d --- /dev/null +++ b/python/mlc_llm/json_ffi/__init__.py @@ -0,0 +1,8 @@ +"""JSON FFI is a pure string based interface of MLC LLM Engine. + +We build interfacing with JSON FFI for both testing purposes +and internal use. For most python API usage, please use MLCEngine +and MLCAsyncEngine +""" + +from .engine import JSONFFIEngine diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py new file mode 100644 index 0000000000..0c604a2ef3 --- /dev/null +++ b/python/mlc_llm/json_ffi/engine.py @@ -0,0 +1,310 @@ +# pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import json +import queue +import threading +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union + +import tvm + +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import engine_utils +from mlc_llm.serve.engine_base import ( + EngineConfig, + SpeculativeMode, + _infer_kv_cache_config, + _parse_models, + _process_model_args, + detect_device, +) +from mlc_llm.tokenizer import Tokenizer + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# construction to not depend on any config and directly pass in JSON +# model defined generation config should be read from the JSONFFIEngine via Reload +def create_model_defined_generation_config( + temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( + temperature, + top_p, + frequency_penalty, + presence_penalty, + ) + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# Engine config should be passed as json str +# and backend should have good default +# only model and model_lib should be mandatory +def create_json_ffi_engine_config( + conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( + conv_template, model_generation_cfgs + ) + + +class EngineState: + sync_queue: queue.Queue + + def get_request_stream_callback(self) -> Callable[[List[str]], None]: + # ChatCompletionStreamResponse + + def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + self._sync_request_stream_callback(chat_completion_stream_responses_json_str) + + return _callback + + def _sync_request_stream_callback( + self, chat_completion_stream_responses_json_str: List[str] + ) -> None: + # Put the delta outputs to the queue in the unblocking way. + self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) + + +class JSONFFIEngine: + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + max_history_size: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, + gpu_memory_utilization: Optional[float] = None, + ) -> None: + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) + ( + model_args, + model_config_paths, + self.conv_template, + ) = _process_model_args(models, device) + + # TODO(mlc-team) Remove the model config parsing, estimation below + # in favor of a simple direct passing of parameters into backend. + # JSONFFIEngine do not have to support automatic mode + # + # Instead, its config should default to interactive mode always + # and allow overrides of parameters through json config via reload + # + # This is to simplify the logic of users of JSONFFI + # since we won't have similar logics in android/iOS + # + # - Load the raw model config into dict + self.model_config_dicts = [] + for i, model_info in enumerate(models): + model_info.model_lib_path = model_args[i][1] + with open(model_config_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + # - Decide the KV cache config based on mode and user input. + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + + # - Initialize engine state and engine. + self.state = EngineState() + module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "init_background_engine", + "reload", + "unload", + "reset", + "chat_completion", + "abort", + "get_last_error", + "run_background_loop", + "run_background_stream_back_loop", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(model_args[0][0]) + + self.engine_config = EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ) + + self.json_ffi_engine_config = create_json_ffi_engine_config( + conv_template=self.conv_template.model_dump_json(), + model_generation_cfgs={ + model.model: create_model_defined_generation_config( + temperature=model_config["temperature"], + top_p=model_config["top_p"], + frequency_penalty=model_config["frequency_penalty"], + presence_penalty=model_config["presence_penalty"], + ) + for model, model_config in zip(models, self.model_config_dicts) + }, + ) + + self._ffi["init_background_engine"]( + self.json_ffi_engine_config, + self.engine_config, + device, + self.state.get_request_stream_callback(), + None, + ) + + def _background_loop(): + self._ffi["run_background_loop"]() + + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + + # Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) + self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() + self._terminated = False + + def terminate(self): + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() + + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ).model_dump_json(), + n=n, + request_id=request_id, + ) + for response in chatcmpl_generator: + yield response + + def _handle_chat_completion( + self, request_json_str: str, n: int, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + self.state.sync_queue = queue.Queue() + num_unfinished_requests = n + + success = bool(self._ffi["chat_completion"](request_json_str, request_id)) + + try: + while num_unfinished_requests > 0: + chat_completion_stream_responses_json_str = self.state.sync_queue.get() + for chat_completion_response_json_str in chat_completion_stream_responses_json_str: + chat_completion_response = ( + openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( + chat_completion_response_json_str + ) + ) + for choice in chat_completion_response.choices: + if choice.finish_reason is not None: + num_unfinished_requests -= 1 + yield chat_completion_response + except Exception as exception: # pylint: disable=broad-exception-caught + self._ffi["abort"](request_id) + raise exception + + def _test_reload(self): + self._ffi["reload"](self.engine_config) + + def _test_reset(self): + self._ffi["reset"]() + + def _test_unload(self): + self._ffi["unload"]() diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index 28c34353e2..ede9dc350f 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -28,7 +28,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes n_embd: int n_layer: int n_head: int - layer_norm_epsilon: int + layer_norm_epsilon: float n_inner: int = -1 context_window_size: int = 0 prefill_chunk_size: int = 0 diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 2ae5500c6d..18238f688e 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -224,15 +224,41 @@ def batch_forward( hidden_states = self.model(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) + return self.get_logits(hidden_states) + + def batch_forward_to_last_hidden_states( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + return hidden_states + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.model.embed_tokens(input_ids) + + def get_logits(self, hidden_states: Tensor): + op_ext.configure() logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") return logits - def embed(self, input_ids: Tensor): + def batch_get_logits(self, hidden_states: Tensor, logit_positions: Tensor): + op_ext.configure() if self.tensor_parallel_shards > 1: - input_ids = op.ccl_broadcast_from_worker0(input_ids) - return self.model.embed_tokens(input_ids) + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + hidden_states = op.take(hidden_states, logit_positions, axis=0) + return self.get_logits(hidden_states) + + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): + op_ext.configure() + hidden_states = op.take(hidden_states, logit_positions, axis=0) + return hidden_states def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() @@ -243,20 +269,28 @@ def _index(x: te.Tensor): # x[:-1,:] hidden_states = self.model(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) - logits = self.lm_head(hidden_states) - if logits.dtype != "float32": - logits = logits.astype("float32") + logits = self.get_logits(hidden_states) return logits, paged_kv_cache def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) - logits = self.lm_head(hidden_states) - if logits.dtype != "float32": - logits = logits.astype("float32") + logits = self.get_logits(hidden_states) return logits, paged_kv_cache + def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + return hidden_states, paged_kv_cache + + def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + return hidden_states, paged_kv_cache + def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): @@ -273,6 +307,24 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache + def batch_prefill_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_decode_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_verify_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) @@ -309,6 +361,29 @@ def get_default_spec(self): "effect_mode": "none", }, }, + "get_logits": { + "hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_get_logits": { + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_select_last_hidden_states": { + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), @@ -325,6 +400,22 @@ def get_default_spec(self): "effect_mode": "none", }, }, + "prefill_to_last_hidden_states": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode_to_last_hidden_states": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), @@ -350,6 +441,30 @@ def get_default_spec(self): "effect_mode": "none", }, }, + "batch_prefill_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "softmax_with_temperature": { "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), "temperature": nn.spec.Tensor(["batch_size"], "float32"), diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 1c513e15d3..272cffdc80 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -85,7 +85,7 @@ class Model: "group-quant": llama_quantization.group_quant, "ft-quant": llama_quantization.ft_quant, "awq": llama_quantization.awq_quant, - "smoothquant": llama_quantization.smooth_quant, + "smoothquant": llama_quantization.smooth_quant }, ), "mistral": Model( diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index 3bfe1cb891..41abf0292c 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -660,4 +660,54 @@ "eos_token_id": 2, "pad_token_id": 0, }, + "llama3_8b": { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "use_cache": True, + "vocab_size": 128256, + }, + "llama3_70b": { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "use_cache": True, + "vocab_size": 128256, + }, } diff --git a/python/mlc_llm/model/rwkv5/rwkv5_model.py b/python/mlc_llm/model/rwkv5/rwkv5_model.py index 49386720da..81c9e9aa7f 100644 --- a/python/mlc_llm/model/rwkv5/rwkv5_model.py +++ b/python/mlc_llm/model/rwkv5/rwkv5_model.py @@ -40,6 +40,7 @@ class RWKV5Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -129,23 +130,18 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - # x.shape = (batch, seq_len, hidden_size) - # state.shape = (batch, hidden_size) - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): # x.shape = (batch, seq_len, hidden_size) batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -350,10 +346,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -367,11 +367,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -386,7 +402,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -396,9 +411,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -406,7 +419,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -414,8 +452,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/rwkv6/rwkv6_model.py b/python/mlc_llm/model/rwkv6/rwkv6_model.py index 0e1887310d..a8faf48a6b 100644 --- a/python/mlc_llm/model/rwkv6/rwkv6_model.py +++ b/python/mlc_llm/model/rwkv6/rwkv6_model.py @@ -40,6 +40,7 @@ class RWKV6Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -126,20 +127,17 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -390,10 +388,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -407,11 +409,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -426,7 +444,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -436,9 +453,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -446,7 +461,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -454,8 +494,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 4a058c6e03..e4cbf1c047 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -887,7 +887,7 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 if target.kind.name == "opencl" and "android" in str(target.host): - THREAD_LIMIT = 64 + THREAD_LIMIT = 256 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) diff --git a/python/mlc_llm/op/__init__.py b/python/mlc_llm/op/__init__.py index 342568639d..850312a8a7 100644 --- a/python/mlc_llm/op/__init__.py +++ b/python/mlc_llm/op/__init__.py @@ -1,6 +1,9 @@ """Extern module for compiler.""" + from . import moe_matmul, moe_misc from .attention import attention +from .batch_spec_verify import batch_spec_verify from .extern import configure, enable, get_store from .ft_gemm import faster_transformer_dequantize_gemm from .position_embedding import llama_rope +from .top_p_pivot import top_p_pivot, top_p_renorm diff --git a/python/mlc_llm/op/batch_spec_verify.py b/python/mlc_llm/op/batch_spec_verify.py new file mode 100644 index 0000000000..d1a57fc71c --- /dev/null +++ b/python/mlc_llm/op/batch_spec_verify.py @@ -0,0 +1,177 @@ +"""Operators for batch verify in speculative decoding.""" + +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type,name-defined" +# pylint: disable=too-many-locals,invalid-name,too-many-arguments, +# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches + + +def batch_spec_verify(vocab_size): + """Batch draft verify function. This function verifies the token tree. + + Before calling the function + + - token_tree_parent_ptr[b] should store the root of the tree + + - draft_probs[node_id, :] stores the prob that samples the correspond tree node + - model_probs[node_id, :] stores the prob that should be used to sample its children + - Please note that the storage convention difference between model_probs and draft_probs + draft_probs was stored on the token node, while model_probs stores on the parent. + This is an intentional design since we can sample different child token with different + proposal draft probabilities, but the ground truth model_prob is unique per parent. + + After calling the function + - token_tree_parent_ptr[b] points to the last token accepted + - There should be a followup sample step that samples from model_probs[token_tree_parent_ptr[b], :] + This token will be appended to the token generated. + + This function will inplace update model_probs if a token was rejected and renormalization is needed. + + Parameters + ---------- + draft_probs: + The draft probability attached to each tree node + + draft_tokens: + The draft token in each node + + model_probs: + The model proability attached to each parent + + token_tree_first_child: + The first child of each tree node, if there is no child, it should be -1 + + token_tree_next_sibling + The next sibling of each tree node, if there is no next sibling, it should be -1 + + uniform_samples + Per node uniform sample used to check rejection + + token_tree_parent_ptr: + Current parent ptr state + """ + TX = 1024 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + # fmt: off + @T.prim_func(private=True) + def _func( + var_draft_probs: T.handle, + var_draft_tokens: T.handle, + var_model_probs: T.handle, + var_token_tree_first_child: T.handle, + var_token_tree_next_sibling: T.handle, + var_uniform_samples: T.handle, + var_token_tree_parent_ptr: T.handle, + ): + """ + [ + blockIdx.x on batch, + threadIdx.x on vocab_size, + for loop over excessive amounts + ] + """ + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + num_nodes = T.int32(is_size_var=True) + nbatch = T.int32(is_size_var=True) + + draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size), "float32") + draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32") + model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size), "float32") + token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32") + token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32") + uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,), "float32") + token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32") + + with T.block("kernel"): + child_ptr = _var() + parent_ptr = _var() + child_token = _var() + done = _var("bool") + psum = _var("float32") + t0 = _var("float32") + model_prob_local = _var("float32") + draft_prob_local = _var("float32") + p_child = _var("float32") + q_child = _var("float32") + uniform_sample = _var("float32") + + pred_shared = T.alloc_buffer((1,), "bool", scope="shared") + pred_local = T.alloc_buffer((1,), "bool", scope="local") + + for _bx in T.thread_binding(0, nbatch, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + # batch size + b = T.axis.S(nbatch, _bx) + tx = T.axis.S(TX, _tx) + + parent_ptr[0] = token_tree_parent_ptr[b] + child_ptr[0] = token_tree_first_child[parent_ptr[0]] + done[0] = False + + while T.Not(done[0]): + T.tvm_storage_sync("shared") # ensure all effects last round are visible + if child_ptr[0] == -1: + done[0] = True + T.tvm_storage_sync("shared") # sync before exit + else: + # decide to validate current ptr + if tx == 0: + child_token[0] = draft_tokens[child_ptr[0]] + p_child[0] = model_probs[parent_ptr[0], child_token[0]] + q_child[0] = draft_probs[child_ptr[0], child_token[0]] + uniform_sample[0] = uniform_samples[child_ptr[0]] + pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0] # use multiplication to avoid division by zero + T.tvm_storage_sync("shared") # make sure all read of model_probs are done + pred_local[0] = pred_shared[0] + + # accept the proposal, we move to child + if pred_local[0]: + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + psum[0] = 0.0 + # renormalize probability, predicated by stopped_expansion[b]: + for i in T.serial(T.ceildiv(vocab_size, TX)): + k = T.meta_var(i * TX + tx) + if k < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], k] + draft_prob_local[0] = draft_probs[child_ptr[0], k] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) + psum[0] += model_prob_local[0] + + with T.block("block_cross_thread"): + T.reads(psum[0]) + T.writes(t0[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), psum[0], True, t0[0], tx, dtype="handle") + + if t0[0] < 1e-7: + # accept the proposal, we move to child + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + # renormalize + for i in T.serial(T.ceildiv(vocab_size, TX)): + k = T.meta_var(i * TX + tx) + if k < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], k] + draft_prob_local[0] = draft_probs[child_ptr[0], k] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) + model_probs[parent_ptr[0], k] = model_prob_local[0] / t0[0] + + child_ptr[0] = token_tree_next_sibling[child_ptr[0]] + + if tx == 0: + token_tree_parent_ptr[b] = parent_ptr[0] + # fmt: on + + return _func diff --git a/python/mlc_llm/op/moe_matmul.py b/python/mlc_llm/op/moe_matmul.py index 6978d8ba0e..b4ebb5b630 100644 --- a/python/mlc_llm/op/moe_matmul.py +++ b/python/mlc_llm/op/moe_matmul.py @@ -2,7 +2,7 @@ from typing import Literal, Optional -from tvm import DataType, tir +from tvm import DataType, DataTypeCode, tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T @@ -335,6 +335,7 @@ def _dequantize(w, s, e, i, j): if num_elem_per_storage == 1: w = tir.reinterpret(quantize_dtype, w[e, i, j]) else: + assert DataType(storage_dtype).type_code == DataTypeCode.UINT tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) w = w[e, i, j // num_elem_per_storage] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) diff --git a/python/mlc_llm/op/top_p_pivot.py b/python/mlc_llm/op/top_p_pivot.py new file mode 100644 index 0000000000..9c97959bff --- /dev/null +++ b/python/mlc_llm/op/top_p_pivot.py @@ -0,0 +1,315 @@ +"""Operators for choosing the pivot to cut-off top-p percentile """ + +import tvm +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type,name-defined" +# pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda +# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches + + +def top_p_pivot(pN): + """Top-p pivot function. This function finds the pivot to cut-off top-p percentile. + + A valide pivot should satisfy the following conditions: + - lsum >= top_p + - top_p > lsum - cmin * lmin + where lsum is the sum of elements that are larger or equal to the pivot, + lmin is the minimum elements that is larger or equal to the pivot, + cmin is the count of elements that are equal to lmin, + + Parameters + ---------- + prob: + The probability vector + + top_p_global: + The top-p threshold + + init_pivots: + The initial pivot candidates + + final_pivot: + The final pivot to cut-off top-p percentile + """ + TX = 1024 + K = 32 + eps_LR = 1e-7 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + def valid(lsum, lmin, cmin, top_p): + return tvm.tir.all(lsum >= top_p, top_p > lsum - cmin * lmin) + + # fmt: off + @T.prim_func(private=True) + def _func( + var_prob: T.handle, + top_p_global: T.buffer([1], dtype="float32"), + var_init_pivots: T.handle, + var_final_pivot: T.handle, + var_final_lsum: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32() + N = T.int32() + prob = T.match_buffer(var_prob, (B, N,), "float32") + init_pivots = T.match_buffer(var_init_pivots, (pN,), "float32") + final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") + final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") + + with T.block("kernel"): + pivot = T.alloc_buffer((pN,), "float32", scope="local") + top_p = _var("float32") + + L = T.alloc_buffer((1,), "float32", scope="shared") + R = T.alloc_buffer((1,), "float32", scope="shared") + L_local = _var("float32") + R_local = _var("float32") + + q = _var("float32") + lsum = T.alloc_buffer((pN,), "float32", scope="local") + lmin_broadcast = T.alloc_buffer((1), "float32", scope="shared") + lmin_broadcast_local = _var("float32") + lmin = T.alloc_buffer((pN,), "float32", scope="local") + cmin = T.alloc_buffer((pN,), "int32", scope="local") + total_sum = _var("float32") + + it = _var("int32") + es_local = _var("bool") + es = T.alloc_buffer((1,), "bool", scope="shared") + find_pivot_local = _var("bool") + find_pivot = T.alloc_buffer((1,), "bool", scope="shared") + + total_sum_reduce = _var("float32") + lsum_reduce = _var("float32") + lmin_reduce = _var("float32") + cmin_reduce = _var("int32") + + for _bx in T.thread_binding(0, B, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + b, tx = T.axis.remap("SS", [_bx, _tx]) + + top_p[0] = top_p_global[0] + + if tx == 0: + # leader thread initializes L, R + L[0] = 1.0 - top_p[0] + R[0] = eps_LR + find_pivot[0] = False + T.tvm_storage_sync("shared") + + L_local[0] = L[0] + R_local[0] = R[0] + for i in T.unroll(0, pN): + # pivots are in descending order + pivot[i] = init_pivots[i] + find_pivot_local[0] = False + + while T.tvm_thread_invariant( + L_local[0] - R_local[0] > eps_LR + and T.Not(find_pivot_local[0]) + ): + # sync before each iteration + T.tvm_storage_sync("shared") + + ### get lsum, lmin, total_sum + for pidx in T.unroll(0, pN): + lsum[pidx] = 0.0 + lmin[pidx] = 1.0 + cmin[pidx] = 0 + total_sum[0] = 0.0 + it[0] = 0 + es_local[0] = False + while it[0] < T.ceildiv(N, TX) and T.Not(es_local[0]): + idx = T.meta_var(it[0] * TX + tx) + q[0] = T.if_then_else(idx < N, prob[b, idx], 0.0) + total_sum[0] += q[0] + for pidx in T.unroll(0, pN): + if q[0] >= pivot[pidx]: + lsum[pidx] += q[0] + if lmin[pidx] > q[0]: + lmin[pidx] = q[0] + cmin[pidx] = 1 + elif lmin[pidx] == q[0]: + cmin[pidx] += 1 + it[0] += 1 + + # early stop every K iterations + if it[0] % K == 0: + # reduce total_sum over tx + # T.tvm_storage_sync("shared") + with T.block("block_cross_thread"): + T.reads(total_sum[0]) + T.writes(total_sum_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), total_sum[0], True, total_sum_reduce[0], tx, dtype="handle") + # T.tvm_storage_sync("shared") + + if tx == 0: + # leader thread checks if we can stop early + es[0] = 1 - total_sum_reduce[0] < pivot[pN - 1] + T.tvm_storage_sync("shared") + es_local[0] = es[0] + + T.tvm_storage_sync("shared") + + # reduce lsum, lmin, cmin, over tx + for pidx in T.serial(0, pN): + # reduce lsum over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(lsum[pidx]) + T.writes(lsum_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], True, lsum_reduce[0], tx, dtype="handle") + + # reduce lmin over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(lmin[pidx]) + T.writes(lmin_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], True, lmin_reduce[0], tx, dtype="handle") + + if tx == 0: + # broadcast lmin to all threads + lmin_broadcast[0] = lmin_reduce[0] + T.tvm_storage_sync("shared") + lmin_broadcast_local[0] = lmin_broadcast[0] + if lmin[pidx] > lmin_broadcast_local[0]: + cmin[pidx] = 0 + if tx == 0: + # only the leader thread updates lsum, lmin + lsum[pidx] = lsum_reduce[0] + lmin[pidx] = lmin_reduce[0] + + # reduce cmin over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(cmin[pidx]) + T.writes(cmin_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.int32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], True, cmin_reduce[0], tx, dtype="handle") + + if tx == 0: + # only the leader thread updates cmin + cmin[pidx] = cmin_reduce[0] + + T.tvm_storage_sync("shared") + + if tx == 0: + # leader thread checks if we have found the pivot, or updates L, R + it[0] = 0 + while it[0] < pN and T.Not(find_pivot_local[0]): + pidx = T.meta_var(it[0]) + if valid(lsum[pidx], lmin[pidx], cmin[pidx], top_p[0]): + find_pivot[0] = True + find_pivot_local[0] = True + # write back the pivot and lsum + final_pivot[b] = pivot[pidx] + final_lsum[b] = lsum[pidx] + elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]: + R[0] = pivot[pidx] + elif lsum[pidx] < top_p[0]: + L[0] = pivot[pidx] + it[0] += 1 + + T.tvm_storage_sync("shared") + + L_local[0] = L[0] + R_local[0] = R[0] + find_pivot_local[0] = find_pivot[0] + # new pivots for next iteration + # uniform spacing between L and R + for pidx in T.unroll(0, pN): + pivot[pidx] = L[0] - (pidx + 1) * (L_local[0] - R_local[0]) / (pN + 1) + + if tx == 0: + # leader thread writes back the pivot + if T.Not(find_pivot_local[0]): + final_pivot[b] = -1e5 + # fmt: on + + return _func + + +def top_p_renorm(): + """Top-p renormalization function. This function renormalizes the probability vector. + + Given the pivot, the probability vector is renormalized as follows: + - if prob >= pivot, renorm_prob = prob / lsum + - otherwise, renorm_prob = 0 + + Parameters + ---------- + prob: + The probability vector + + final_pivot: + The final pivot to cut-off top-p percentile + + final_lsum: + The sum of elements that are larger or equal to the pivot + + renorm_prob: + The renormalized probability vector + """ + TX = 1024 + CTA_COUNT = 512 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + # fmt: off + @T.prim_func(private=True) + def _func( + var_prob: T.handle, + var_final_pivot: T.handle, + var_final_lsum: T.handle, + var_renorm_prob: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32() + N = T.int32() + prob = T.match_buffer(var_prob, (B, N,), "float32") + final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") + final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") + renorm_prob = T.match_buffer(var_renorm_prob, (B, N,), "float32") + + with T.block("kernel"): + pivot = _var("float32") + lsum = _var("float32") + BX = T.meta_var(T.ceildiv(CTA_COUNT, B)) + + for _by in T.thread_binding(0, B, thread="blockIdx.y"): + for _bx in T.thread_binding(0, BX, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx]) + + pivot[0] = final_pivot[by] + lsum[0] = final_lsum[by] + + for i in T.serial(T.ceildiv(N, BX * TX)): + idx = T.meta_var(i * BX * TX + bx * TX + tx) + if idx < N: + renorm_prob[by, idx] = T.if_then_else(prob[by, idx] >= pivot[0], prob[by, idx] / lsum[0], 0.0) + # fmt: on + + return _func diff --git a/python/mlc_llm/protocol/protocol_utils.py b/python/mlc_llm/protocol/protocol_utils.py index f4273d0302..3005909bbd 100644 --- a/python/mlc_llm/protocol/protocol_utils.py +++ b/python/mlc_llm/protocol/protocol_utils.py @@ -23,13 +23,14 @@ def get_unsupported_fields(request: RequestProtocol) -> List[str]: def get_generation_config( request: RequestProtocol, + model_config: Dict[str, Any], extra_stop_token_ids: Optional[List[int]] = None, extra_stop_str: Optional[List[str]] = None, ) -> GenerationConfig: """Create the generation config in MLC LLM out from the input request protocol.""" kwargs: Dict[str, Any] if isinstance(request, (OpenAICompletionRequest, OpenAIChatCompletionRequest)): - kwargs = openai_api_get_generation_config(request) + kwargs = openai_api_get_generation_config(request, model_config) else: raise RuntimeError("Cannot reach here") diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 8e06de7b54..59358c1646 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,11 +2,10 @@ # Load MLC LLM library by importing base from .. import base -from .async_engine import AsyncThreadedEngine -from .config import EngineMode, GenerationConfig, KVCacheConfig +from .config import EngineConfig, GenerationConfig, SpeculativeMode from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData -from .engine import Engine +from .engine import AsyncMLCEngine, MLCEngine from .grammar import BNFGrammar, GrammarStateMatcher -from .json_schema_converter import json_schema_to_ebnf +from .radix_tree import PagedRadixTree from .request import Request from .server import PopenServer diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index e539ec7e56..6b808ac37b 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -1,9 +1,14 @@ """Configuration dataclasses used in MLC LLM serving""" +import enum import json from dataclasses import asdict, dataclass, field from typing import Dict, List, Literal, Optional +import tvm + +from . import _ffi_api + @dataclass class ResponseFormat: @@ -123,66 +128,101 @@ def from_json(json_str: str) -> "GenerationConfig": return GenerationConfig(**json.loads(json_str)) -@dataclass -class KVCacheConfig: - """The KV cache initialization configuration. +class KVStateKind(enum.IntEnum): # pylint: disable=too-few-public-methods + """Possible kinds of KV state.""" + + ATTENTION = 0 + RNNSTATE = 1 + + +class SpeculativeMode(enum.IntEnum): + """The speculative mode.""" + + # Disable speculative decoding. + DISABLE = 0 + # The normal speculative decoding (small draft) mode. + SMALL_DRAFT = 1 + # The eagle-style speculative decoding. + EAGLE = 2 + + +@tvm._ffi.register_object("mlc.serve.EngineConfig") # pylint: disable=protected-access +class EngineConfig(tvm.runtime.Object): + """The class of MLCEngine execution configuration. Parameters ---------- - page_size : int - The number of consecutive tokens handled in each page in paged KV cache. + model : str + The path to the model directory. - max_num_sequence : int - The maximum number of sequences that are allowed to processed by the KV - cache at any time. + model_lib_path : str + The path to the model library. - max_total_sequence_length : Optional[int] - The maximum total number of tokens whose KV data are allowed to exist - in the KV cache at any time. - Set it to None to enable automatic computation of the max total - sequence length. + additional_models : List[str] + The path to the additional models' directories. - prefill_chunk_size : Optional[int] - The maximum total sequence length in a prefill. - If not specified, it will be automatically inferred from model config. - """ + additional_model_lib_paths : List[str] + The path to the additional models' libraries. - page_size: int = 16 - max_num_sequence: int = 32 - max_total_sequence_length: Optional[int] = None - prefill_chunk_size: Optional[int] = None + kv_cache_page_size : int + The number of consecutive tokens handled in each page in paged KV cache. - def asjson(self) -> str: - """Return the config in string of JSON format.""" - return json.dumps(asdict(self)) + max_num_sequence : int + The maximum number of sequences that are allowed to be + processed by the KV cache at any time. - @staticmethod - def from_json(json_str: str) -> "KVCacheConfig": - """Construct a config from JSON string.""" - return KVCacheConfig(**json.loads(json_str)) + max_total_sequence_length : int + The maximum length allowed for a single sequence in the engine. + max_single_sequence_length : int + The maximum total number of tokens whose KV data are allowed + to exist in the KV cache at any time. -@dataclass -class EngineMode: - """The Engine execution mode. + prefill_chunk_size : int + The maximum total sequence length in a prefill. - Parameters - ---------- - enable_speculative : bool - Whether the speculative decoding mode is enabled, default False. + max_history_size: int + The maximum history size for RNN state to rool back. - spec_draft_length : int - The number of tokens to generate in speculative proposal (draft), default 4. - """ + kv_state_kind: KVStateKind + The kind of cache. - enable_speculative: bool = False - spec_draft_length: int = 4 + speculative_mode : SpeculativeMode + The speculative mode. - def asjson(self) -> str: - """Return the config in string of JSON format.""" - return json.dumps(asdict(self)) + spec_draft_length : int + The number of tokens to generate in speculative proposal (draft). + """ - @staticmethod - def from_json(json_str: str) -> "EngineMode": - """Construct a config from JSON string.""" - return EngineMode(**json.loads(json_str)) + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + model_lib_path: str, + additional_models: List[str], + additional_model_lib_paths: List[str], + kv_cache_page_size: int, + max_num_sequence: int, + max_total_sequence_length: int, + max_single_sequence_length: int, + prefill_chunk_size: int, + max_history_size: int, + kv_state_kind: KVStateKind, + speculative_mode: SpeculativeMode, + spec_draft_length: int, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.EngineConfig, # type: ignore # pylint: disable=no-member + model, + model_lib_path, + additional_models, + additional_model_lib_paths, + kv_cache_page_size, + max_num_sequence, + max_total_sequence_length, + max_single_sequence_length, + prefill_chunk_size, + max_history_size, + kv_state_kind, + speculative_mode, + spec_draft_length, + ) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 607f970a1e..413c856db1 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1,306 +1,838 @@ """The MLC LLM Serving Engine.""" -import json -import os -import subprocess +# pylint: disable=too-many-lines + +import asyncio +import queue import sys -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import weakref +from typing import ( + Any, + AsyncGenerator, + Dict, + Iterator, + List, + Literal, + Optional, + Union, + overload, +) -import tvm from tvm.runtime import Device -from mlc_llm.protocol.conversation_protocol import Conversation -from mlc_llm.serve import data +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import data, engine_utils +from mlc_llm.serve.config import GenerationConfig, SpeculativeMode +from mlc_llm.serve.request import Request +from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging -from mlc_llm.support.auto_device import detect_device -from mlc_llm.support.style import green -from ..chat_module import _get_chat_config, _get_lib_module_path, _get_model_path -from ..streamer import TextStreamer -from ..tokenizer import Tokenizer -from . import data -from .config import EngineMode, GenerationConfig, KVCacheConfig -from .event_trace_recorder import EventTraceRecorder -from .request import Request +from . import engine_base logging.enable_logging() logger = logging.getLogger(__name__) -@dataclass -class ModelInfo: - """The model info dataclass. +class Chat: # pylint: disable=too-few-public-methods + """The proxy class to direct to chat completions.""" - Parameters - ---------- - model : str - The identifier of the input model. - It may be a compiled model's id (e.g., "Llama-2-7b-chat-hf-q4f16_1"), - or a full path to a model directory - (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") - - device : str - The device where to run the model. - It can be "auto", "device_name" (e.g., "cuda") or - "device_name:device_id" (e.g., "cuda:1"). - - model_lib_path : str - The path to the compiled library of the model. - E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - """ + def __init__(self, engine: weakref.ReferenceType) -> None: + assert isinstance(engine(), (AsyncMLCEngine, MLCEngine)) + self.completions = ( + AsyncChatCompletion(engine) # type: ignore + if isinstance(engine(), AsyncMLCEngine) + else ChatCompletion(engine) # type: ignore + ) - model: str - model_lib_path: str - device: Device = "auto" # type: ignore - - def __post_init__(self): - if isinstance(self.device, str): - self.device = detect_device(self.device) - assert isinstance(self.device, Device) - - -def _create_tvm_module( - creator: str, ffi_funcs: Sequence[str], creator_args: Optional[List[Any]] = None -) -> Dict[str, Callable]: - """Internal method to create a module.""" - if creator_args is None: - creator_args = [] - module = tvm.get_global_func(creator, allow_missing=False)(*creator_args) - return {key: module[key] for key in ffi_funcs} - - -def _process_model_args( - models: List[ModelInfo], -) -> Tuple[List[Any], List[str], str, int, int, Optional[str]]: - """Process the input ModelInfo to get the engine initialization arguments.""" - max_single_sequence_length = int(1e9) - prefill_chunk_size = int(1e9) - tokenizer_path: Optional[str] = None - conv_template_name: Optional[str] = None - config_file_paths: List[str] = [] - - def _convert_model_info(model: ModelInfo) -> List[Any]: - nonlocal max_single_sequence_length, prefill_chunk_size, tokenizer_path, conv_template_name - - device = model.device - model_path, config_file_path = _get_model_path(model.model) - config_file_paths.append(config_file_path) - chat_config = _get_chat_config(config_file_path, user_chat_config=None) - if chat_config.context_window_size and chat_config.context_window_size != -1: - max_single_sequence_length = min( - max_single_sequence_length, - chat_config.context_window_size, - ) - if chat_config.prefill_chunk_size: - prefill_chunk_size = min(prefill_chunk_size, chat_config.prefill_chunk_size) - if tokenizer_path is None: - tokenizer_path = model_path - if conv_template_name is None: - assert isinstance(chat_config.conv_template, Conversation) - conv_template_name = chat_config.conv_template.name - # Try look up model library, and do JIT compile if model library not found. - try: - model_lib_path = _get_lib_module_path( - model=model.model, - model_path=model_path, - chat_config=chat_config, - model_lib_path=model.model_lib_path, - device_name=device.MASK2STR[device.device_type], - config_file_path=config_file_path, - ) - except FileNotFoundError: - from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - - model_lib_path = str( - jit.jit( - model_path=Path(model_path), - chat_config=asdict(chat_config), - device=device, - ) - ) - return [model_lib_path, model_path, device.device_type, device.device_id] - - model_args: List[Any] = sum( - (_convert_model_info(model) for model in models), - start=[], - ) - - assert prefill_chunk_size != int(1e9) - return ( - model_args, - config_file_paths, - tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, - conv_template_name, - ) - - -def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals - models: List[ModelInfo], config_file_paths: List[str], max_num_sequence: int -) -> int: - """Estimate the max total sequence length (capacity) of the KV cache.""" - assert len(models) != 0 - - kv_bytes_per_token = 0 - kv_aux_workspace_bytes = 0 - model_workspace_bytes = 0 - logit_processor_workspace_bytes = 0 - params_bytes = 0 - temp_func_bytes = 0 - - for model, config_file_path in zip(models, config_file_paths): - # Read metadata for the parameter size and the temporary memory size. - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-memory-usage-in-json", - "--mlc-chat-config", - config_file_path, - ] - usage_str = subprocess.check_output(cmd, universal_newlines=True) - usage_json = json.loads(usage_str) - params_bytes += usage_json["params_bytes"] - temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) - - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-kv-cache-metadata-in-json", - ] - kv_cache_metadata_str = subprocess.check_output(cmd, universal_newlines=True) - kv_cache_metadata = json.loads(kv_cache_metadata_str) - - # Read model config and compute the kv size per token. - with open(config_file_path, mode="rt", encoding="utf-8") as file: - json_object = json.load(file) - model_config = json_object["model_config"] - vocab_size = model_config["vocab_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - num_layers = kv_cache_metadata["num_hidden_layers"] - head_dim = kv_cache_metadata["head_dim"] - num_qo_heads = kv_cache_metadata["num_attention_heads"] - num_kv_heads = kv_cache_metadata["num_key_value_heads"] - hidden_size = head_dim * num_qo_heads - kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 - kv_aux_workspace_bytes += ( - (max_num_sequence + 1) * 88 - + prefill_chunk_size * (num_qo_heads + 1) * 8 - + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 - + 48 * 1024 * 1024 + +class AsyncChatCompletion: # pylint: disable=too-few-public-methods + """The proxy class to direct to async chat completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["AsyncMLCEngine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + stream: Literal[True], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]: + """Asynchronous streaming chat completion interface with OpenAI API compatibility. + The method is a coroutine that streams ChatCompletionStreamResponse + that conforms to OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.ChatCompletionResponse: + """Asynchronous non-streaming chat completion interface with OpenAI API compatibility. + The method is a coroutine that streams ChatCompletionStreamResponse + that conforms to OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : ChatCompletionResponse + The chat completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any], + openai_api_protocol.ChatCompletionResponse, + ]: + """Asynchronous chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return await self.engine()._chat_completion( # pylint: disable=protected-access + messages=messages, + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, ) - model_workspace_bytes += ( - prefill_chunk_size * 4 - + max_num_sequence * 4 - + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + + +class ChatCompletion: # pylint: disable=too-few-public-methods + """The proxy class to direct to chat completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["MLCEngine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + stream: Literal[True], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + """Synchronous streaming chat completion interface with OpenAI API compatibility. + The method streams back ChatCompletionStreamResponse that conforms to + OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.ChatCompletionResponse: + """Synchronous non-streaming chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : ChatCompletionResponse + The chat completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + Iterator[openai_api_protocol.ChatCompletionStreamResponse], + openai_api_protocol.ChatCompletionResponse, + ]: + """Synchronous chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return self.engine()._chat_completion( # pylint: disable=protected-access + messages=messages, + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, ) - logit_processor_workspace_bytes += ( - max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + + +class AsyncCompletion: # pylint: disable=too-few-public-methods + """The proxy class to direct to async completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["AsyncMLCEngine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + stream: Literal[True], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + """Asynchronous streaming completion interface with OpenAI API compatibility. + The method is a coroutine that streams CompletionResponse + that conforms to OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.CompletionResponse: + """Asynchronous non-streaming completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : CompletionResponse + The completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.CompletionResponse, Any], + openai_api_protocol.CompletionResponse, + ]: + """Asynchronous completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return await self.engine()._completion( # pylint: disable=protected-access + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, ) - # Get single-card GPU size. - gpu_size_bytes = os.environ.get("MLC_GPU_SIZE_BYTES", default=None) - if gpu_size_bytes is None: - gpu_size_bytes = models[0].device.total_global_memory - if gpu_size_bytes is None: - raise ValueError( - "Cannot read total GPU global memory from device. " - 'Please the GPU memory size in bytes through "MLC_GPU_SIZE_BYTES" env variable.' - ) - max_total_sequence_length = int( - ( - int(gpu_size_bytes) * 0.90 - - params_bytes - - temp_func_bytes - - kv_aux_workspace_bytes - - model_workspace_bytes - - logit_processor_workspace_bytes +class Completion: # pylint: disable=too-few-public-methods + """The proxy class to direct to completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["MLCEngine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + stream: Literal[True], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.CompletionResponse: + """Synchronous streaming completion interface with OpenAI API compatibility. + The method streams back CompletionResponse that conforms to + OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous non-streaming completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : CompletionResponse + The completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return self.engine()._completion( # pylint: disable=protected-access + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, ) - / kv_bytes_per_token - ) - assert max_total_sequence_length > 0, ( - "Cannot estimate KV cache capacity. " - f"The model weight size {params_bytes} may be larger than GPU memory size {gpu_size_bytes}" - ) - - if models[0].device.device_type == Device.kDLMetal: - # NOTE: Metal runtime has severe performance issues with large buffers. - # To work around the issue, we limit the KV cache capacity to 32768. - max_total_sequence_length = min(max_total_sequence_length, 32768) - - total_size = ( - params_bytes - + temp_func_bytes - + kv_aux_workspace_bytes - + model_workspace_bytes - + logit_processor_workspace_bytes - + kv_bytes_per_token * max_total_sequence_length - ) - logger.info( - "%s: %d.", - green('Estimated KVCacheConfig "max_total_sequence_length"'), - max_total_sequence_length, - ) - logger.info( - "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB)", - green("Estimated total single GPU memory usage"), - total_size / 1024 / 1024, - params_bytes / 1024 / 1024, - (kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes) / 1024 / 1024, - (model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes) / 1024 / 1024, - ) - return int(max_total_sequence_length) - - -class Engine: - """The Python interface of request serving engine for MLC LLM. - - The engine can run one or multiple LLM models internally for - text generation. Usually, when there are multiple models, - speculative inference will be activated, where the first model - (index 0) is the main "large model" that has better generation - quality, and all other models are "small" models that used for - speculation. - - The engine receives requests from the "add_request" method. For - an given request, the engine will keep generating new tokens for - the request until finish (under certain criterion). After finish, - the engine will return the generation result through the callback - function provided by the request. + + +class AsyncMLCEngine(engine_base.MLCEngineBase): + """The AsyncMLCEngine in MLC LLM that provides the asynchronous + interfaces with regard to OpenAI API. Parameters ---------- - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. - - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. - - request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]] - The provided callback function to handle the generation - output. It has the signature of `(str, data.TokenData, bool) -> None`, - where - - the first string is the request id, - - the TokenData contains the generated **delta** token ids since - the last invocation of the callback on the specific request, - - the optional string value denotes the finish reason if the - generation of the request is finished, or None if it has not finished. - - The callback function is optional at construction, but it needs to - be set before the engine executing requests. This can be done via - the `set_request_stream_callback` method. Otherwise, the engine will raise - exception. - - engine_mode : Optional[EngineMode] - The Engine execution mode. + models : str + A path to ``mlc-chat-config.json``, or an MLC model directory that contains + `mlc-chat-config.json`. + It can also be a link to a HF repository pointing to an MLC compiled model. + + device: Union[str, Device] + The device used to deploy the model such as "cuda" or "cuda:0". + Will default to "auto" and detect from local available GPUs if not specified. + + model_lib_path : Optional[str] + The full path to the model library file to use (e.g. a ``.so`` file). + If unspecified, we will use the provided ``model`` to search over possible paths. + It the model lib path is not found, it will be compiled in a JIT manner. + + mode : Literal["local", "interactive", "server"] + The engine mode in MLC LLM. + We provide three preset modes: "local", "interactive" and "server". + The default mode is "local". + The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + and "prefill_chunk_size" when they are not explicitly specified. + 1. Mode "local" refers to the local server deployment which has low + request concurrency. So the max batch size will be set to 4, and max + total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 2. Mode "interactive" refers to the interactive use of server, which + has at most 1 concurrent request. So the max batch size will be set to 1, + and max total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 3. Mode "server" refers to the large server use case which may handle + many concurrent request and want to use GPU memory as much as possible. + In this mode, we will automatically infer the largest possible max batch + size and max total sequence length. + + You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + "prefill_chunk_size" to override the automatic inferred values. + + additional_models : Optional[List[str]] + The model paths and (optional) model library paths of additional models + (other than the main model). + When engine is enabled with speculative decoding, additional models are needed. + Each string in the list is either in form "model_path" or "model_path:model_lib_path". + When the model lib path of a model is not given, JIT model compilation will + be activated to compile the model automatically. + + max_batch_size : Optional[int] + The maximum allowed batch size set for the KV cache to concurrently support. + + max_total_sequence_length : Optional[int] + The KV cache total token capacity, i.e., the maximum total number of tokens that + the KV cache support. This decides the GPU memory size that the KV cache consumes. + If not specified, system will automatically estimate the maximum capacity based + on the vRAM size on GPU. + + prefill_chunk_size : Optional[int] + The maximum number of tokens the model passes for prefill each time. + It should not exceed the prefill chunk size in model config. + If not specified, this defaults to the prefill chunk size in model config. + + max_history_size : Optional[int] + The maximum history for RNN state. + + gpu_memory_utilization : Optional[float] + A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + It is used to infer to maximum possible KV cache capacity. + When it is unspecified, it defaults to 0.85. + Under mode "local" or "interactive", the actual memory usage may be + significantly smaller than this number. Under mode "server", the actual + memory usage may be slightly larger than this number. + + engine_config : Optional[EngineConfig] + The MLCEngine execution configuration. + Currently speculative decoding mode is specified via engine config. + For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" + to specify the eagle-style speculative decoding. + Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. enable_tracing : bool A boolean indicating if to enable event logging for requests. @@ -308,245 +840,1021 @@ class Engine: def __init__( # pylint: disable=too-many-arguments self, - models: Union[ModelInfo, List[ModelInfo]], - kv_cache_config: KVCacheConfig, - engine_mode: Optional[EngineMode] = None, - request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, + model: str, + device: Union[str, Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, enable_tracing: bool = False, - ): - if isinstance(models, ModelInfo): - models = [models] + ) -> None: + super().__init__( + "async", + model=model, + device=device, + model_lib_path=model_lib_path, + mode=mode, + additional_models=additional_models, + max_batch_size=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + gpu_memory_utilization=gpu_memory_utilization, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + enable_tracing=enable_tracing, + ) + self.chat = Chat(weakref.ref(self)) + self.completions = AsyncCompletion(weakref.ref(self)) + + async def abort(self, request_id: str) -> None: + """Generation abortion interface. + + Parameter + --------- + request_id : str + The id of the request to abort. + """ + self._abort(request_id) + + async def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any], + openai_api_protocol.ChatCompletionResponse, + ]: + """Asynchronous chat completion internal interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id=request_id, + ) + if stream: + # Stream response. + return chatcmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) + async for response in chatcmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) + + async def _completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.CompletionResponse, Any], + openai_api_protocol.CompletionResponse, + ]: + """Asynchronous completion internal interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + if request_id is None: + request_id = f"cmpl-{engine_utils.random_uuid()}" + cmpl_generator = self._handle_completion( + openai_api_protocol.CompletionRequest( + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id, + ) + if stream: + # Stream response. + return cmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) + + async for response in cmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + return engine_base.wrap_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) + + async def _handle_chat_completion( + self, request: openai_api_protocol.ChatCompletionRequest, request_id: str + ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]: + """The implementation fo asynchronous ChatCompletionRequest handling. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ ( - model_args, - config_file_paths, - tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, - self.conv_template_name, - ) = _process_model_args(models) - self._ffi = _create_tvm_module( - "mlc.serve.create_engine", - ffi_funcs=[ - "init", - "add_request", - "abort_request", - "step", - "stats", - "reset", - "get_request_stream_callback", - "set_request_stream_callback", - ], + prompts, + generation_cfg, + use_function_calling, + prompt_length, + ) = engine_base.process_chat_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer.encode, + self.max_input_sequence_length, + self.conv_template.model_copy(deep=True), ) - self.trace_recorder = EventTraceRecorder() if enable_tracing else None - self.max_input_sequence_length = max_single_sequence_length - if kv_cache_config.max_total_sequence_length is None: - kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( - models, config_file_paths, kv_cache_config.max_num_sequence + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + num_completion_tokens = 0 + self.state.record_event(request_id, event="invoke generate") + async for delta_outputs in self._generate( + prompts, generation_cfg, request_id # type: ignore + ): + response, num_completion_tokens = engine_base.process_chat_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + use_function_calling, + prompt_length, + finish_reasons, + num_completion_tokens, ) - if kv_cache_config.prefill_chunk_size is None: - kv_cache_config.prefill_chunk_size = prefill_chunk_size - elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: - raise ValueError( - f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " - f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " - "models. Please specify a smaller prefill chunk size." + if response is not None: + yield response + self.state.record_event(request_id, event="finish") + + async def _handle_completion( + self, request: openai_api_protocol.CompletionRequest, request_id: str + ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + """The implementation fo asynchronous CompletionRequest handling. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + ( + prompt, + generation_cfg, + prompt_length, + echo_response, + ) = engine_base.process_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer, + self.max_input_sequence_length, + ) + if echo_response is not None: + yield echo_response + + num_completion_tokens = 0 + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + self.state.record_event(request_id, event="invoke generate") + async for delta_outputs in self._generate( + prompt, generation_cfg, request_id # type: ignore + ): + response, num_completion_tokens = engine_base.process_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + prompt_length, + finish_reasons, + num_completion_tokens, ) + if response is not None: + yield response - if engine_mode is None: - # The default engine mode: non-speculative - engine_mode = EngineMode() - - self._ffi["init"]( - max_single_sequence_length, - tokenizer_path, - kv_cache_config.asjson(), - engine_mode.asjson(), - request_stream_callback, - self.trace_recorder, - *model_args, + suffix_response = engine_base.create_completion_suffix_response( + request, request_id, prompt_length, finish_reasons, num_completion_tokens ) - self.tokenizer = Tokenizer(tokenizer_path) + if suffix_response is not None: + yield suffix_response + self.state.record_event(request_id, event="finish") - def generate( # pylint: disable=too-many-locals + async def _generate( self, - prompts: Union[str, List[str], List[int], List[List[int]], List[List[data.Data]]], - generation_config: Union[GenerationConfig, List[GenerationConfig]], - ) -> Tuple[List[List[str]], List[Optional[List[List[str]]]]]: - """Generate texts for a list of input prompts. - Each prompt can be a string or a list of token ids. - The generation for each prompt is independent. - Return the generation results, one for each prompt. + prompt: Union[str, List[int], List[Union[str, List[int], data.Data]]], + generation_config: GenerationConfig, + request_id: str, + ) -> AsyncGenerator[List[engine_base.CallbackStreamOutput], Any]: + """Internal asynchronous text generation interface of AsyncMLCEngine. + The method is a coroutine that streams a list of CallbackStreamOutput + at a time via yield. The returned list length is the number of + parallel generations specified by `generation_config.n`. Parameters ---------- - prompts : Union[str, List[str], List[int], List[List[int]]] - One or a list of input prompts for text generation. - Each prompt can be a string or a list of token ids. + prompt : Union[str, List[int], List[Union[str, List[int], data.Data]]] + The input prompt in forms of text strings, lists of token ids or data. - generation_config : Union[GenerationConfig, List[GenerationConfig]] - The generation config for each requests. - If the it is a single GenerationConfig instance, - this config will be shared by all the prompts. - Otherwise, one generation config is required for every - prompt. + generation_config : GenerationConfig + The generation config of the request. - Returns - ------- - output_text : List[List[str]] - The text generation results, one list of strings for each input prompt. - The length of each list is the parallel generation `n` in - generation config. - - output_logprobs_str : List[Optional[List[List[str]]]] - The logprob strings of each token for each input prompt, or None - if an input prompt does not require logprobs. + request_id : str + The unique identifier (in string) or this generation request. + + Yields + ------ + request_output : List[engine_base.CallbackStreamOutput] + The delta generated outputs in a list. + The number of list elements equals to `generation_config.n`, + and each element corresponds to the delta output of a parallel + generation. """ - if isinstance(prompts, str): - # `prompts` is a single string. - prompts = [prompts] - else: - assert isinstance(prompts, list), ( - "Input `prompts` is expected to be a string, a list of " - "str, a list of token ids or multiple lists of token ids. " - ) - if len(prompts) == 0: - return [], [] - if isinstance(prompts[0], int): - # `prompts` is a list of token ids - prompts = [prompts] # type: ignore - - num_requests = len(prompts) - if not isinstance(generation_config, list): - generation_config = [generation_config] * num_requests - - assert ( - len(generation_config) == num_requests - ), "Number of generation config and number of prompts mismatch" - - num_finished_generations = 0 - output_texts: List[List[str]] = [] - output_logprobs_str: List[Optional[List[List[str]]]] = [] - text_streamers: List[List[TextStreamer]] = [] - for i in range(num_requests): - output_texts.append([]) - output_logprobs_str.append([] if generation_config[i].logprobs else None) - text_streamers.append([]) - for _ in range(generation_config[i].n): - output_texts[i].append("") - text_streamers[i].append(TextStreamer(self.tokenizer)) - if output_logprobs_str[i] is not None: - output_logprobs_str[i].append([]) - - num_total_generations = sum(cfg.n for cfg in generation_config) - - # Save a copy of the original function callback since `generate` - # overrides the callback function. - # The original callback will be set back later on. - original_callback = self._ffi["get_request_stream_callback"]() - - # Define the callback function for request generation results - def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): - nonlocal num_finished_generations - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - rid = int(request_id) - - assert len(stream_outputs) == generation_config[rid].n - for i, (stream_output, text_streamer) in enumerate( - zip(stream_outputs, text_streamers[rid]) - ): - if output_logprobs_str[rid] is not None: - assert stream_output.delta_logprob_json_strs is not None - output_logprobs_str[rid][i] += stream_output.delta_logprob_json_strs - - delta_text = ( - text_streamer.put(stream_output.delta_token_ids) - if len(stream_output.delta_token_ids) > 0 - else "" - ) - if stream_output.finish_reason is not None: - delta_text += text_streamer.finish() - - output_texts[rid][i] += delta_text - if stream_output.finish_reason is not None: - num_finished_generations += 1 - - # Override the callback function in engine. - self._ffi["set_request_stream_callback"](request_stream_callback) - - def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data.Data]: - if isinstance(prompt, str): - return [data.TextData(prompt)] - if isinstance(prompt[0], int): - return [data.TokenData(prompt)] # type: ignore - return prompt # type: ignore - - # Add requests to engine. - for req_id, (prompt, generation_cfg) in enumerate(zip(prompts, generation_config)): - input_data = convert_to_data(prompt) # type: ignore - self.add_request( - Request( - request_id=str(req_id), - inputs=input_data, - generation_config=generation_cfg, + if self._terminated: + raise ValueError("The AsyncThreadedEngine has terminated.") + self.state.async_lazy_init_event_loop() + + # Create the request with the given id, input data, generation + # config and the created callback. + input_data = engine_utils.convert_prompts_to_data(prompt) + request = Request(request_id, input_data, generation_config) + + # Create the unique async request stream of the request. + stream = engine_base.AsyncRequestStream() + if request_id in self.state.async_streamers: + # Report error in the stream if the request id already exists. + stream.push( + RuntimeError( + f'The request id "{request_id} already exists. ' + 'Please make sure the request id is unique."' ) ) + else: + # Record the stream in the tracker + self.state.async_streamers[request_id] = ( + stream, + [TextStreamer(self.tokenizer) for _ in range(generation_config.n)], + ) + self.state.async_num_unfinished_generations[request_id] = generation_config.n + self._ffi["add_request"](request) + + # Iterate the stream asynchronously and yield the output. + try: + async for request_output in stream: + yield request_output + except ( + Exception, + asyncio.CancelledError, + ) as exception: # pylint: disable=broad-exception-caught + await self.abort(request_id) + raise exception + + def _abort(self, request_id: str): + """Internal implementation of request abortion.""" + self.state.async_streamers.pop(request_id, None) + self.state.async_num_unfinished_generations.pop(request_id, None) + self._ffi["abort_request"](request_id) - while num_finished_generations != num_total_generations: - self.step() - # Restore the callback function in engine. - self._ffi["set_request_stream_callback"](original_callback) - return output_texts, output_logprobs_str +class MLCEngine(engine_base.MLCEngineBase): + """The MLCEngine in MLC LLM that provides the synchronous + interfaces with regard to OpenAI API. - def add_request(self, request: Request) -> None: - """Add a new request to the engine. + Parameters + ---------- + models : str + A path to ``mlc-chat-config.json``, or an MLC model directory that contains + `mlc-chat-config.json`. + It can also be a link to a HF repository pointing to an MLC compiled model. + + device: Union[str, Device] + The device used to deploy the model such as "cuda" or "cuda:0". + Will default to "auto" and detect from local available GPUs if not specified. + + model_lib_path : Optional[str] + The full path to the model library file to use (e.g. a ``.so`` file). + If unspecified, we will use the provided ``model`` to search over possible paths. + It the model lib path is not found, it will be compiled in a JIT manner. + + mode : Literal["local", "interactive", "server"] + The engine mode in MLC LLM. + We provide three preset modes: "local", "interactive" and "server". + The default mode is "local". + The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + and "prefill_chunk_size" when they are not explicitly specified. + 1. Mode "local" refers to the local server deployment which has low + request concurrency. So the max batch size will be set to 4, and max + total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 2. Mode "interactive" refers to the interactive use of server, which + has at most 1 concurrent request. So the max batch size will be set to 1, + and max total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 3. Mode "server" refers to the large server use case which may handle + many concurrent request and want to use GPU memory as much as possible. + In this mode, we will automatically infer the largest possible max batch + size and max total sequence length. + + You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + "prefill_chunk_size" to override the automatic inferred values. + + additional_models : Optional[List[str]] + The model paths and (optional) model library paths of additional models + (other than the main model). + When engine is enabled with speculative decoding, additional models are needed. + Each string in the list is either in form "model_path" or "model_path:model_lib_path". + When the model lib path of a model is not given, JIT model compilation will + be activated to compile the model automatically. + + max_batch_size : Optional[int] + The maximum allowed batch size set for the KV cache to concurrently support. + + max_total_sequence_length : Optional[int] + The KV cache total token capacity, i.e., the maximum total number of tokens that + the KV cache support. This decides the GPU memory size that the KV cache consumes. + If not specified, system will automatically estimate the maximum capacity based + on the vRAM size on GPU. + + prefill_chunk_size : Optional[int] + The maximum number of tokens the model passes for prefill each time. + It should not exceed the prefill chunk size in model config. + If not specified, this defaults to the prefill chunk size in model config. + + gpu_memory_utilization : Optional[float] + A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + It is used to infer to maximum possible KV cache capacity. + When it is unspecified, it defaults to 0.85. + Under mode "local" or "interactive", the actual memory usage may be + significantly smaller than this number. Under mode "server", the actual + memory usage may be slightly larger than this number. + + engine_config : Optional[EngineConfig] + The MLCEngine execution configuration. + Currently speculative decoding mode is specified via engine config. + For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" + to specify the eagle-style speculative decoding. + Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. + + enable_tracing : bool + A boolean indicating if to enable event logging for requests. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + device: Union[str, Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, + enable_tracing: bool = False, + ) -> None: + super().__init__( + "sync", + model=model, + device=device, + model_lib_path=model_lib_path, + mode=mode, + additional_models=additional_models, + max_batch_size=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + gpu_memory_utilization=gpu_memory_utilization, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + enable_tracing=enable_tracing, + ) + self.chat = Chat(weakref.ref(self)) + self.completions = Completion(weakref.ref(self)) + + def abort(self, request_id: str) -> None: + """Generation abortion interface. + + Parameter + --------- + request_id : str + The id of the request to abort. + """ + self._ffi["abort_request"](request_id) + + def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + Iterator[openai_api_protocol.ChatCompletionStreamResponse], + openai_api_protocol.ChatCompletionResponse, + ]: + """Synchronous chat completion internal interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. Parameters ---------- - request : Request - The request to add. + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. """ - self._ffi["add_request"](request) + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id=request_id, + ) + if stream: + # Stream response. + return chatcmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) + for response in chatcmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) - def abort_request(self, request_id: str) -> None: - """Abort the generation of the request corresponding to the input request id. + def _completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + prompt: Union[str, List[int]], + model: Optional[str] = None, + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous completion internal interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. Parameters ---------- - request_id : str - The unique id of the request to abort. + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. """ - self._ffi["abort_request"](request_id) + if request_id is None: + request_id = f"cmpl-{engine_utils.random_uuid()}" + cmpl_generator = self._handle_completion( + openai_api_protocol.CompletionRequest( + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id, + ) + if stream: + # Stream response. + return cmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) - def step(self) -> None: - """The main function that the engine takes a step of action. + for response in cmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + return engine_base.wrap_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) - At each step, the engine may decide to - - run prefill for one (or more) requests, - - run one-step decode for the all existing requests - ... + def _handle_chat_completion( + self, request: openai_api_protocol.ChatCompletionRequest, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + """The implementation fo synchronous ChatCompletionRequest handling. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + ( + prompts, + generation_cfg, + use_function_calling, + prompt_length, + ) = engine_base.process_chat_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer.encode, + self.max_input_sequence_length, + self.conv_template.model_copy(deep=True), + ) - In the end of certain actions (e.g., decode), the engine will - check if any request has finished, and will return the - generation results for those finished requests. + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + num_completion_tokens = 0 + self.state.record_event(request_id, event="invoke generate") + for delta_outputs in self._generate(prompts, generation_cfg, request_id): # type: ignore + response, num_completion_tokens = engine_base.process_chat_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + use_function_calling, + prompt_length, + finish_reasons, + num_completion_tokens, + ) + if response is not None: + yield response + self.state.record_event(request_id, event="finish") + + def _handle_completion( + self, request: openai_api_protocol.CompletionRequest, request_id: str + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """The implementation fo synchronous CompletionRequest handling. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. """ - self._ffi["step"]() - - def reset(self) -> None: - """Reset the engine, clean up all running data and statistics.""" - self._ffi["reset"]() - - def stats(self) -> Dict[str, float]: - """The engine runtime statistics. - We collect the following entries: - - single token prefill latency (s/tok): avg latency of processing one token in prefill - - single token decode latency (s/tok): avg latency of processing one token in decode - - engine time for prefill (sec) - - engine time for decode (sec) - - total number of processed tokens in prefill. - - total number of processed tokens in decode. + ( + prompt, + generation_cfg, + prompt_length, + echo_response, + ) = engine_base.process_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer, + self.max_input_sequence_length, + ) + if echo_response is not None: + yield echo_response + + num_completion_tokens = 0 + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + self.state.record_event(request_id, event="invoke generate") + for delta_outputs in self._generate(prompt, generation_cfg, request_id): # type: ignore + response, num_completion_tokens = engine_base.process_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + prompt_length, + finish_reasons, + num_completion_tokens, + ) + if response is not None: + yield response + + suffix_response = engine_base.create_completion_suffix_response( + request, request_id, prompt_length, finish_reasons, num_completion_tokens + ) + if suffix_response is not None: + yield suffix_response + self.state.record_event(request_id, event="finish") + + def _generate( # pylint: disable=too-many-locals + self, + prompt: Union[str, List[int], List[Union[str, List[int], data.Data]]], + generation_config: GenerationConfig, + request_id: str, + ) -> Iterator[List[engine_base.CallbackStreamOutput]]: + """Internal synchronous text generation interface of AsyncMLCEngine. + The method is a coroutine that streams a list of CallbackStreamOutput + at a time via yield. The returned list length is the number of + parallel generations specified by `generation_config.n`. + + Parameters + ---------- + prompt : Union[str, List[int], List[Union[str, List[int], data.Data]]] + The input prompt in forms of text strings, lists of token ids or data. + + generation_config : GenerationConfig + The generation config of the request. + + request_id : str + The unique identifier (in string) or this generation request. + + Yields + ------ + request_output : List[engine_base.CallbackStreamOutput] + The delta generated outputs in a list. + The number of list elements equals to `generation_config.n`, + and each element corresponds to the delta output of a parallel + generation. """ - stats_json_str = self._ffi["stats"]() - return json.loads(stats_json_str) + if self._terminated: + raise ValueError("The engine has terminated.") + + # Create the request with the given id, input data, generation + # config and the created callback. + input_data = engine_utils.convert_prompts_to_data(prompt) + request = Request(request_id, input_data, generation_config) + + # Record the stream in the tracker + self.state.sync_output_queue = queue.Queue() + self.state.sync_text_streamers = [ + TextStreamer(self.tokenizer) for _ in range(generation_config.n) + ] + self.state.sync_num_unfinished_generations = generation_config.n + self._ffi["add_request"](request) + + # Iterate the stream asynchronously and yield the token. + try: + while self.state.sync_num_unfinished_generations > 0: + delta_outputs = self.state.sync_output_queue.get() + request_outputs = self._request_stream_callback_impl(delta_outputs) + for request_output in request_outputs: + yield request_output + except Exception as exception: # pylint: disable=broad-exception-caught + self.abort(request_id) + raise exception + + def _request_stream_callback_impl( + self, delta_outputs: List[data.RequestStreamOutput] + ) -> List[List[engine_base.CallbackStreamOutput]]: + """The underlying implementation of request stream callback of MLCEngine.""" + batch_outputs: List[List[engine_base.CallbackStreamOutput]] = [] + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + self.state.record_event(request_id, event="start callback") + outputs: List[engine_base.CallbackStreamOutput] = [] + for stream_output, text_streamer in zip(stream_outputs, self.state.sync_text_streamers): + self.state.record_event(request_id, event="start detokenization") + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + self.state.record_event(request_id, event="finish detokenization") + + outputs.append( + engine_base.CallbackStreamOutput( + delta_text=delta_text, + num_delta_tokens=len(stream_output.delta_token_ids), + delta_logprob_json_strs=stream_output.delta_logprob_json_strs, + finish_reason=stream_output.finish_reason, + ) + ) + if stream_output.finish_reason is not None: + self.state.sync_num_unfinished_generations -= 1 + batch_outputs.append(outputs) + self.state.record_event(request_id, event="finish callback") + return batch_outputs diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 4c95f6e612..65b41a66ac 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -20,7 +20,12 @@ from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import ( + EngineConfig, + GenerationConfig, + KVStateKind, + SpeculativeMode, +) from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -89,8 +94,10 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: if conversation is None: assert isinstance(chat_config.conv_template, Conversation) conversation = chat_config.conv_template - # Try look up model library, and do JIT compile if model library not found. - try: + + if model.model_lib_path is not None: + # do model lib search if the model lib path is provided + # error out if file not found model_lib_path = _get_lib_module_path( model=model.model, model_path=model_path, @@ -99,7 +106,9 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: device_name=device.MASK2STR[device.device_type], config_file_path=config_file_path, ) - except FileNotFoundError: + else: + # TODO(mlc-team) add logging information + # Run jit if model_lib_path is not provided from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel model_lib_path = str( @@ -117,7 +126,7 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: return model_args, config_file_paths, conversation -def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-many-locals,too-many-arguments +def _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( # pylint: disable=too-many-locals,too-many-arguments models: List[ModelInfo], device: tvm.runtime.Device, model_config_paths: List[str], @@ -195,7 +204,7 @@ def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-ma if gpu_size_bytes is None: raise ValueError("Cannot read total GPU global memory from device.") if gpu_memory_utilization is None: - gpu_memory_utilization = 0.90 + gpu_memory_utilization = 0.85 model_max_total_sequence_length = int( ( @@ -236,6 +245,90 @@ def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-ma ) +def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=too-many-arguments, too-many-locals, unused-argument + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_paths: List[str], + model_config_dicts: List[Dict[str, Any]], + max_num_sequence: int, + gpu_memory_utilization: Optional[float], +) -> Tuple[float, float, float, int]: + # Get single-card GPU size. + gpu_size_bytes = device.total_global_memory + if gpu_size_bytes is None: + raise ValueError("Cannot read total GPU global memory from device.") + if gpu_memory_utilization is None: + gpu_memory_utilization = 0.90 + + rnn_state_base_bytes = 0.0 # the memory usage for rnn state when history = 1 + param_bytes = 0.0 + temp_func_bytes = 0.0 + model_workspace_bytes = 0.0 + logit_processor_workspace_bytes = 0.0 + for model, model_config_path, model_config_dict in zip( + models, model_config_paths, model_config_dicts + ): + # Read metadata for the parameter size and the temporary memory size. + cmd = [ + sys.executable, + "-m", + "mlc_llm.cli.model_metadata", + model.model_lib_path, + "--print-memory-usage-in-json", + "--mlc-chat-config", + model_config_path, + ] + usage_str = subprocess.check_output(cmd, universal_newlines=True) + usage_json = json.loads(usage_str) + param_bytes += usage_json["params_bytes"] + temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) + + model_config = model_config_dict["model_config"] + vocab_size = model_config_dict["vocab_size"] + head_size = model_config["head_size"] + num_heads = model_config["num_heads"] + num_layers = model_config["num_hidden_layers"] + hidden_size = model_config["hidden_size"] + prefill_chunk_size = model_config["prefill_chunk_size"] + logit_processor_workspace_bytes += ( + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + ) + + model_workspace_bytes += ( + prefill_chunk_size * 4 + + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + ) + + rnn_state_base_bytes += ( + max_num_sequence * hidden_size * num_layers * 2 * 2 + + max_num_sequence * num_heads * head_size * head_size * num_layers * 2 + ) + + max_history_size = int( + ( + gpu_size_bytes * gpu_memory_utilization + - logit_processor_workspace_bytes + - model_workspace_bytes + - param_bytes + - temp_func_bytes + ) + / rnn_state_base_bytes + ) + if max_history_size < 1: + raise ValueError( + f"Memory required by models may be larger than available GPU memory " + f"size {gpu_size_bytes * gpu_memory_utilization} bytes." + ) + + return ( + param_bytes, + model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, + rnn_state_base_bytes, + max_history_size, + ) + + def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: """Read the model config dictionaries, and return the maximum single sequence length the models can support, the maximum prefill chunk @@ -290,7 +383,7 @@ def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[i return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size -def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements +def _infer_kv_cache_config_for_kv_cache( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements mode: Literal["local", "interactive", "server"], max_batch_size: Optional[int], max_total_sequence_length: Optional[int], @@ -300,12 +393,13 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local device: tvm.runtime.Device, model_config_dicts: List[Dict[str, Any]], model_config_paths: List[str], -) -> Tuple[int, int, int, int]: +) -> Tuple[int, int, int, KVStateKind, int]: """Initialize the KV cache config with user input and GPU memory usage estimation. The returned four integers are: - max_batch_size - max_total_sequence_length - prefill_chunk_size + - kv_state_kind - model_max_single_sequence_length """ ( @@ -319,7 +413,7 @@ def infer_args_under_mode( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], - ) -> Tuple[Tuple[int, int, int], List[float]]: + ) -> Tuple[Tuple[int, int, int, KVStateKind], List[float]]: logging_msg = "" # - max_batch_size if max_batch_size is None: @@ -339,7 +433,7 @@ def infer_args_under_mode( kv_aux_workspace_bytes, temp_workspace_bytes, model_max_total_sequence_length, - ) = _estimate_mem_usage_and_max_total_sequence_length( + ) = _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( models, device, model_config_paths, @@ -396,7 +490,12 @@ def infer_args_under_mode( # - Construct the KV cache config # - Estimate total GPU memory usage on single GPU. - return (max_batch_size, max_total_sequence_length, prefill_chunk_size), [ + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + KVStateKind.ATTENTION, + ), [ total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, model_params_bytes, kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, @@ -458,9 +557,192 @@ def infer_args_under_mode( return *kv_cache_config, model_max_single_sequence_length +def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, KVStateKind, int]: + """Initialize the RNN state config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - kv_state_kind + - max_history_size + """ + logging_msg = "" + prefill_chunk_size = 0 + + if prefill_chunk_size is None: + prefill_chunk_size = min( + config["prefill_chunk_size"] if "prefill_chunk_size" in config else 4096 + for config in model_config_dicts + ) + logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " + else: + logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " + if max_batch_size is None: + max_batch_size = 1 if mode == "interactive" else 4 + logging_msg += f"max batch size is set to {max_batch_size}, " + else: + logging_msg += f"max batch size {max_batch_size} is specified by user, " + + if mode == "local": + logging_msg += ( + "We choose small max batch size and RNN state capacity to use less GPU memory." + ) + elif mode == "interactive": + logging_msg += "We fix max batch size to 1 for interactive single sequence use." + else: + logging_msg += ( + "We use as much GPU memory as possible (within the" " limit of gpu_memory_utilization)." + ) + logger.info('Under mode "%s", %s', mode, logging_msg) + + ( + model_param_bytes, + model_temp_bytes, + model_rnn_state_base_bytes, + model_max_history_size, + ) = _estimate_mem_usage_and_max_history_size_for_rnn_state( + models, + device, + model_config_paths, + model_config_dicts, + max_batch_size, + gpu_memory_utilization, + ) + if max_history_size is None: + max_history_size = model_max_history_size + else: + max_history_size = min(max_history_size, model_max_history_size) + max_total_sequence_length = 32768 + prefill_chunk_size = 0 + kind = KVStateKind.RNNSTATE + + logger.info( + "%s: %.2f MB (Parameters: %.2f MB. RNNState: %.2f MB. Temporary buffer: %.2f MB). " + "The actual usage might be slightly larger than the estimated number.", + green("Estimated total single GPU memory usage"), + (model_param_bytes + model_temp_bytes + model_rnn_state_base_bytes) / 1024 / 1024, + model_param_bytes / 1024 / 1024, + max_history_size * model_rnn_state_base_bytes / 1024 / 1024, + model_temp_bytes / 1024 / 1024, + ) + + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kind, + max_history_size, + ) + + +def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, int, int, KVStateKind]: + """Initialize the cache config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - max_single_sequence_length + - max_history_size + - kv_state_kind + """ + if all("rwkv" not in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_single_sequence_length, + ) = _infer_kv_cache_config_for_kv_cache( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_history_size = 0 # KV cache doesn't need this + elif all("rwkv" in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_history_size, + ) = _infer_kv_cache_config_for_rnn_state( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_single_sequence_length = max_total_sequence_length # RNN state doesn't need this + else: + raise ValueError("The models should be either all KV cache models or all RNN state models.") + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) + + +def _infer_generation_config( + model_config_dicts: List[Dict[str, Any]] +) -> List[Tuple[float, float, float, float]]: + """Infer the generation config from the model config dictionaries. + The returned four floats are: + - temperature + - top_p + - frequency_penalty + - presence_penalty + """ + generation_configs = [] + + for model_config in model_config_dicts: + temperature = model_config.get("temperature", 1.0) + top_p = model_config.get("top_p", 1.0) + frequency_penalty = model_config.get("frequency_penalty", 0.0) + presence_penalty = model_config.get("presence_penalty", 0.0) + generation_configs.append((temperature, top_p, frequency_penalty, presence_penalty)) + + return generation_configs + + @dataclass class CallbackStreamOutput: - """The output of LLMEngine._generate and AsyncLLMEngine._generate + """The output of MLCEngine._generate and AsyncMLCEngine._generate Attributes ---------- @@ -485,7 +767,7 @@ class CallbackStreamOutput: class AsyncRequestStream: - """The asynchronous stream for requests in AsyncLLMEngine. + """The asynchronous stream for requests in AsyncMLCEngine. Each request has its own unique stream. The stream exposes the method `push` for engine to push new generated @@ -544,29 +826,29 @@ async def __anext__(self) -> List[CallbackStreamOutput]: class EngineState: """The engine states that the request stream callback function may use. - This class is used for both AsyncLLMEngine and LLMEngine. - AsyncLLMEngine uses the fields and methods starting with "async", - and LLMEngine uses the ones starting with "sync". + This class is used for both AsyncMLCEngine and MLCEngine. + AsyncMLCEngine uses the fields and methods starting with "async", + and MLCEngine uses the ones starting with "sync". - - For AsyncLLMEngine, the state contains an asynchronous event loop, + - For AsyncMLCEngine, the state contains an asynchronous event loop, the streamers and the number of unfinished generations for each request being processed. - - For LLMEngine, the state contains a callback output blocking queue, + - For MLCEngine, the state contains a callback output blocking queue, the text streamers and the number of unfinished requests. We use this state class to avoid the callback function from capturing - the AsyncLLMEngine. + the AsyncMLCEngine. The state also optionally maintains an event trace recorder, which can provide Chrome tracing when enabled. """ trace_recorder = None - # States used for AsyncLLMEngine + # States used for AsyncMLCEngine async_event_loop: Optional[asyncio.AbstractEventLoop] = None async_streamers: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} async_num_unfinished_generations: Dict[str, int] = {} - # States used for LLMEngine + # States used for MLCEngine sync_output_queue: queue.Queue = queue.Queue() sync_text_streamers: List[TextStreamer] = [] sync_num_unfinished_generations: int = 0 @@ -577,7 +859,7 @@ def __init__(self, enable_tracing: bool) -> None: self.trace_recorder = EventTraceRecorder() def record_event(self, request_id: str, event: str) -> None: - """Record a event for the the input request in the trace + """Record a event for the input request in the trace recorder when the recorder exists. Parameters @@ -628,7 +910,7 @@ def async_lazy_init_event_loop(self) -> None: self.async_event_loop = asyncio.get_event_loop() def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for AsyncLLMEngine to stream back + """The request stream callback function for AsyncMLCEngine to stream back the request generation results. Note @@ -648,7 +930,7 @@ def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamO def _async_request_stream_callback_impl( self, delta_outputs: List[data.RequestStreamOutput] ) -> None: - """The underlying implementation of request stream callback for AsyncLLMEngine.""" + """The underlying implementation of request stream callback for AsyncMLCEngine.""" for delta_output in delta_outputs: request_id, stream_outputs = delta_output.unpack() streamers = self.async_streamers.get(request_id, None) @@ -689,28 +971,28 @@ def _async_request_stream_callback_impl( self.record_event(request_id, event="finish callback") def _sync_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for LLMEngine to stream back + """The request stream callback function for MLCEngine to stream back the request generation results. """ # Put the delta outputs to the queue in the unblocking way. self.sync_output_queue.put_nowait(delta_outputs) -class LLMEngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods +class MLCEngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods """The base engine class, which implements common functions that - are shared by LLMEngine and AsyncLLMEngine. + are shared by MLCEngine and AsyncMLCEngine. This class wraps a threaded engine that runs on a standalone thread inside and streams back the delta generated results via callback functions. The internal threaded engine keeps running an loop that drives the engine. - LLMEngine and AsyncLLMEngine inherits this LLMEngineBase class, and implements + MLCEngine and AsyncMLCEngine inherits this MLCEngineBase class, and implements their own methods to process the delta generated results received from callback functions and yield the processed delta results in the forms of standard API protocols. - Checkout subclasses AsyncLLMEngine/LLMEngine for the docstring of constructor parameters. + Checkout subclasses AsyncMLCEngine/MLCEngine for the docstring of constructor parameters. """ def __init__( # pylint: disable=too-many-arguments,too-many-locals @@ -724,6 +1006,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -753,11 +1036,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -776,32 +1062,37 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "abort_request", "run_background_loop", "run_background_stream_back_loop", + "reload", "init_background_engine", "exit_background_loop", "debug_call_func_on_all_worker", ] } self.tokenizer = Tokenizer(model_args[0][0]) + self._ffi["init_background_engine"]( + device, + self.state.get_request_stream_callback(kind), + self.state.trace_recorder, + ) + self._ffi["reload"]( + EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ) + ) def _background_loop(): - self._ffi["init_background_engine"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ), - self.state.get_request_stream_callback(kind), - self.state.trace_recorder, - ) self._ffi["run_background_loop"]() def _background_stream_back_loop(): @@ -919,6 +1210,7 @@ def process_chat_completion_request( # pylint: disable=too-many-arguments # Process generation config. Create request id. generation_cfg = protocol_utils.get_generation_config( request, + model_config, extra_stop_token_ids=conv_template.stop_token_ids, extra_stop_str=conv_template.stop_str, ) @@ -1039,10 +1331,11 @@ def process_chat_completion_stream_output( # pylint: disable=too-many-arguments return response, num_completion_tokens -def process_completion_request( +def process_completion_request( # pylint: disable=too-many-arguments request: openai_api_protocol.CompletionRequest, request_id: str, engine_state: EngineState, + model_config: Dict[str, Any], tokenizer: Tokenizer, max_input_sequence_length: int, ) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]: @@ -1094,7 +1387,7 @@ def process_completion_request( assert isinstance(prompt, list) # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config(request) + generation_cfg = protocol_utils.get_generation_config(request, model_config) # - Echo back the prompt. echo_response = None diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index b95fd4faae..af1613c027 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -5,8 +5,8 @@ import fastapi -from ..server import ServerContext -from . import entrypoint_utils +from mlc_llm.protocol import error_protocol +from mlc_llm.serve.server import ServerContext app = fastapi.APIRouter() @@ -26,11 +26,11 @@ async def debug_dump_event_trace(request: fastapi.Request): # Parse the JSON string request_dict = json.loads(request_json_str) except json.JSONDecodeError: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" ) if "model" not in request_dict: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" ) @@ -41,12 +41,41 @@ async def debug_dump_event_trace(request: fastapi.Request): async_engine = server_context.get_engine(model) if async_engine is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" is not served.' ) if async_engine.state.trace_recorder is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" does not enable tracing' ) return json.loads(async_engine.state.trace_recorder.dump_json()) + + +################ /debug/cuda_profiler_start/end ################ + + +@app.post("/debug/cuda_profiler_start") +async def debug_cuda_profiler_start(_request: fastapi.Request): + """Start the cuda profiler for the engine. Only for debug purpose.""" + server_context: ServerContext = ServerContext.current() + # Since the CUDA profiler is process-wise, call the function for one model is sufficient. + for model in server_context.get_model_list(): + async_engine = server_context.get_engine(model) + async_engine._debug_call_func_on_all_worker( # pylint: disable=protected-access + "mlc.debug_cuda_profiler_start" + ) + break + + +@app.post("/debug/cuda_profiler_stop") +async def debug_cuda_profiler_stop(_request: fastapi.Request): + """Stop the cuda profiler for the engine. Only for debug purpose.""" + server_context: ServerContext = ServerContext.current() + # Since the CUDA profiler is process-wise, call the function for one model is sufficient. + for model in server_context.get_model_list(): + async_engine = server_context.get_engine(model) + async_engine._debug_call_func_on_all_worker( # pylint: disable=protected-access + "mlc.debug_cuda_profiler_stop" + ) + break diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index ac8503d5df..23a279021f 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -1,37 +1,21 @@ """OpenAI API-compatible server entrypoints in MLC LLM""" # pylint: disable=too-many-locals,too-many-return-statements,too-many-statements -import ast -import json from http import HTTPStatus -from typing import AsyncGenerator, Dict, List, Optional, Sequence, Union +from typing import AsyncGenerator, List, Optional import fastapi -from mlc_llm.serve import data - -from ...protocol import protocol_utils -from ...protocol.conversation_protocol import Conversation -from ...protocol.openai_api_protocol import ( - ChatCompletionMessage, +from mlc_llm.protocol import error_protocol +from mlc_llm.protocol.openai_api_protocol import ( ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionStreamResponse, - ChatCompletionStreamResponseChoice, - ChatFunctionCall, - ChatToolCall, CompletionRequest, - CompletionResponse, - CompletionResponseChoice, ListResponse, - LogProbs, LogProbsContent, ModelResponse, - UsageInfo, ) -from ..server import ServerContext -from . import entrypoint_utils +from mlc_llm.serve import engine_base, engine_utils +from mlc_llm.serve.server import ServerContext app = fastapi.APIRouter() @@ -59,130 +43,30 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re server_context: ServerContext = ServerContext.current() async_engine = server_context.get_engine(request.model) if async_engine is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' ) - request_id = f"cmpl-{entrypoint_utils.random_uuid()}" - async_engine.state.record_event(request_id, event="receive request") - - # - Check if unsupported arguments are specified. - error = entrypoint_utils.check_unsupported_fields(request) - if error is not None: - return error - - # - Process prompt and check validity. - async_engine.state.record_event(request_id, event="start tokenization") - prompts = entrypoint_utils.process_prompts(request.prompt, async_engine.tokenizer.encode) - async_engine.state.record_event(request_id, event="finish tokenization") - if isinstance(prompts, fastapi.responses.JSONResponse): - # Errored when processing the prompts - return prompts - if len(prompts) > 1: - return entrypoint_utils.create_error_response( - HTTPStatus.BAD_REQUEST, - message="Entrypoint /v1/completions only accept single prompt. " - f"However, {len(prompts)} prompts {prompts} are received.", - ) - error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) - if error is not None: - return error - prompt = prompts[0] - - # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config(request) + request_id = f"cmpl-{engine_utils.random_uuid()}" # Streaming response. if request.stream: + # We manually get the first response from generator to + # capture potential exceptions in this scope, rather then + # the StreamingResponse scope. + stream_generator = async_engine._handle_completion( # pylint: disable=protected-access + request, request_id + ) + first_response = await anext( # type: ignore # pylint: disable=undefined-variable + stream_generator + ) async def completion_stream_generator() -> AsyncGenerator[str, None]: - # - Echo back the prompt. - if request.echo: - text = async_engine.tokenizer.decode(prompt) - response = CompletionResponse( - id=request_id, - choices=[ - CompletionResponseChoice(index=i, text=text) - for i in range(generation_cfg.n) - ], - model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=0, - ), - ) - yield f"data: {response.model_dump_json()}\n\n" - - # - Generate new tokens. - num_completion_tokens = 0 - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - async_engine.state.record_event(request_id, event="invoke generate") - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): - assert len(delta_outputs) == generation_cfg.n - choices = [] - for i, delta_output in enumerate(delta_outputs): - finish_reason_updated = False - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = delta_output.finish_reason - finish_reason_updated = True - num_completion_tokens += delta_output.num_delta_tokens - if not finish_reason_updated and delta_output.delta_text == "": - # Ignore empty delta text when finish reason is not updated. - continue - - choices.append( - CompletionResponseChoice( - index=i, - finish_reason=finish_reasons[i], - text=delta_output.delta_text, - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in delta_output.delta_logprob_json_strs - ] - ) - if delta_output.delta_logprob_json_strs is not None - else None - ), - ) - ) - - if len(choices) == 0: - # Skip yield when there is no delta output. - continue - response = CompletionResponse( - id=request_id, - choices=choices, - model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=num_completion_tokens, - ), - ) - yield f"data: {response.model_dump_json()}\n\n" - async_engine.state.record_event(request_id, event="finish") - - # - Echo the suffix. - if request.suffix is not None: - assert all(finish_reason is not None for finish_reason in finish_reasons) - response = CompletionResponse( - id=request_id, - choices=[ - CompletionResponseChoice( - index=i, - finish_reason=finish_reason, - text=request.suffix, - ) - for i, finish_reason in enumerate(finish_reasons) - ], - model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=num_completion_tokens, - ), - ) + if isinstance(first_response, StopAsyncIteration): + yield "data: [DONE]\n\n" + return + yield f"data: {first_response.model_dump_json()}\n\n" + async for response in stream_generator: yield f"data: {response.model_dump_json()}\n\n" - yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( @@ -190,165 +74,51 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - init_output_text = "" if not request.echo else async_engine.tokenizer.decode(prompt) - output_texts = [init_output_text for _ in range(generation_cfg.n)] + num_prompt_tokens = 0 num_completion_tokens = 0 - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - logprob_json_strs_list: Optional[List[List[str]]] = ( - [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None + output_texts = ["" for _ in range(request.n)] + finish_reasons: List[Optional[str]] = [None for _ in range(request.n)] + logprob_results: Optional[List[List[LogProbsContent]]] = ( + [[] for _ in range(request.n)] if request.logprobs else None ) - async_engine.state.record_event(request_id, event="invoke generate") - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): + + async for response in async_engine._handle_completion( # pylint: disable=protected-access + request, request_id + ): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, # and abort the request from engine if so. await async_engine.abort(request_id) - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[choice.index] += choice.logprobs.content - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = delta_output.finish_reason - output_texts[i] += delta_output.delta_text - num_completion_tokens += delta_output.num_delta_tokens - if logprob_json_strs_list is not None: - assert delta_output.delta_logprob_json_strs is not None - logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs assert all(finish_reason is not None for finish_reason in finish_reasons) - suffix = request.suffix if request.suffix is not None else "" - async_engine.state.record_event(request_id, event="finish") - response = CompletionResponse( - id=request_id, - choices=[ - CompletionResponseChoice( - index=i, - finish_reason=finish_reason, - text=output_text + suffix, - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in logprob_json_strs_list[ # pylint: disable=unsubscriptable-object - i - ] - ] - ) - if logprob_json_strs_list is not None - else None - ), - ) - for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons)) - ], + return engine_base.wrap_completion_response( + request_id=request_id, model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=num_completion_tokens, - ), + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, ) - return response ################ v1/chat/completions ################ -def chat_completion_check_message_validity( - messages: List[ChatCompletionMessage], -) -> Optional[str]: - """Check if the given chat messages are valid. Return error message if invalid.""" - for i, message in enumerate(messages): - if message.role == "system" and i != 0: - return f"System prompt at position {i} in the message list is invalid." - if message.role == "tool": - return "Tool as the message author is not supported yet." - if message.tool_call_id is not None: - if message.role != "tool": - return "Non-tool message having `tool_call_id` is invalid." - if isinstance(message.content, list): - if message.role != "user": - return "Non-user message having a list of content is invalid." - if message.tool_calls is not None: - if message.role != "assistant": - return "Non-assistant message having `tool_calls` is invalid." - return "Assistant message having `tool_calls` is not supported yet." - return None - - -def check_function_call_usage( - request: ChatCompletionRequest, conv_template: Conversation -) -> Optional[str]: - """Check if function calling is used and update the conversation template. - Return error message if invalid request format for function calling. - """ - - # return if no tools are provided or tool_choice is set to none - if request.tools is None or ( - isinstance(request.tool_choice, str) and request.tool_choice == "none" - ): - conv_template.use_function_calling = False - return None - - # select the tool based on the tool_choice if specified - if isinstance(request.tool_choice, dict): - if request.tool_choice["type"] != "function": - return "Only 'function' tool choice is supported" - - if len(request.tool_choice["function"]) > 1: - return "Only one tool is supported when tool_choice is specified" - - for tool in request.tools: - if tool.function.name == request.tool_choice["function"]["name"]: - conv_template.use_function_calling = True - conv_template.function_string = tool.function.model_dump_json() - return None - - return ( - f"The tool_choice function {request.tool_choice['function']['name']}" - " is not found in the tools list" - ) - - if isinstance(request.tool_choice, str) and request.tool_choice != "auto": - return f"Invalid tool_choice value: {request.tool_choice}" - - function_list = [] - for tool in request.tools: - if tool.type != "function": - return "Only 'function' tool type is supported" - function_list.append(tool.function.model_dump()) - - conv_template.use_function_calling = True - conv_template.function_string = json.dumps(function_list) - return None - - -def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: - """Convert a (possibly list) of function call string to a list of json objects. - Return None for invalid function call string.""" - - def parse_function_call(call_str: str): - node = ast.parse(call_str, mode="eval") - call_node = node.body - if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): - name = call_node.func.id - arguments = {} - for keyword in call_node.keywords: - arguments[keyword.arg] = ast.literal_eval(keyword.value) - return {"name": name, "arguments": arguments} - return None - - if ( - stringified_calls[0] == "[" and stringified_calls[-1] == "]" - ): # hacky way to check if string list - calls = ast.literal_eval(stringified_calls) - else: - calls = [stringified_calls] - function_calls_json = [parse_function_call(call_str) for call_str in calls] - return function_calls_json - - @app.post("/v1/chat/completions") async def request_chat_completion( request: ChatCompletionRequest, raw_request: fastapi.Request @@ -360,132 +130,30 @@ async def request_chat_completion( server_context: ServerContext = ServerContext.current() async_engine = server_context.get_engine(request.model) if async_engine is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' ) - request_id = f"chatcmpl-{entrypoint_utils.random_uuid()}" - async_engine.state.record_event(request_id, event="receive request") - - # - Check if the model supports chat conversation. - conv_template = server_context.get_conv_template(request.model) - if conv_template is None: - return entrypoint_utils.create_error_response( - HTTPStatus.BAD_REQUEST, - message=f'The requested model "{request.model}" does not support chat.', - ) - - # - Check if unsupported arguments are specified. - error = entrypoint_utils.check_unsupported_fields(request) - if error is not None: - return error - - # - Process messages and update the conversation template in three steps: - # i. Check the message validity. - # ii. Add the input messages to the conversation template. - # iii. Add the additional message for the assistant. - error_msg = chat_completion_check_message_validity(request.messages) - if error_msg is not None: - return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - - # Check for function calling usage and update the conversation template - error_msg = check_function_call_usage(request, conv_template) - if error_msg is not None: - return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - - for message in request.messages: - role = message.role - content = message.content - if role == "system": - assert isinstance(content, str) - conv_template.system_message = content if content is not None else "" - continue - - assert role != "tool", "Internal error: tool role." - conv_template.messages.append((role, content)) - conv_template.messages.append(("assistant", None)) - - # - Get the prompt from template, and encode to token ids. - # - Check prompt length - async_engine.state.record_event(request_id, event="start tokenization") - - model_config = server_context.get_model_config(request.model) - prompts = entrypoint_utils.process_prompts( - conv_template.as_prompt(model_config), - async_engine.tokenizer.encode, - ) - - async_engine.state.record_event(request_id, event="finish tokenization") - - if conv_template.system_prefix_token_ids is not None: - prompts[0] = conv_template.system_prefix_token_ids + prompts[0] - error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) - if error is not None: - return error - - prompt: Sequence[Union[List[int], data.ImageData]] = prompts - - # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config( - request, - extra_stop_token_ids=conv_template.stop_token_ids, - extra_stop_str=conv_template.stop_str, - ) + request_id = f"chatcmpl-{engine_utils.random_uuid()}" # Streaming response. if request.stream: + # We manually get the first response from generator to + # capture potential exceptions in this scope, rather then + # the StreamingResponse scope. + stream_generator = async_engine._handle_chat_completion( # pylint: disable=protected-access + request, request_id + ) + first_response = await anext( # type: ignore # pylint: disable=undefined-variable + stream_generator + ) async def completion_stream_generator() -> AsyncGenerator[str, None]: - async_engine.state.record_event(request_id, event="invoke generate") - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): - assert len(delta_outputs) == generation_cfg.n - choices = [] - for i, delta_output in enumerate(delta_outputs): - finish_reason_updated = False - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = ( - delta_output.finish_reason - if not conv_template.use_function_calling - else "tool_calls" - ) - finish_reason_updated = True - if not finish_reason_updated and delta_output.delta_text == "": - # Ignore empty delta text when finish reason is not updated. - async_engine.state.record_event(request_id, event="skip empty delta text") - continue - - choices.append( - ChatCompletionStreamResponseChoice( - index=i, - finish_reason=finish_reasons[i], - delta=ChatCompletionMessage( - content=delta_output.delta_text, role="assistant" - ), - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in delta_output.delta_logprob_json_strs - ] - ) - if delta_output.delta_logprob_json_strs is not None - else None - ), - ) - ) - - if len(choices) == 0: - # Skip yield when there is no delta output. - continue - response = ChatCompletionStreamResponse( - id=request_id, - choices=choices, - model=request.model, - system_fingerprint="", - ) - async_engine.state.record_event(request_id, event="yield delta output") + if isinstance(first_response, StopAsyncIteration): + yield "data: [DONE]\n\n" + return + yield f"data: {first_response.model_dump_json()}\n\n" + async for response in stream_generator: yield f"data: {response.model_dump_json()}\n\n" - async_engine.state.record_event(request_id, event="finish") yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( @@ -493,93 +161,49 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - output_texts = ["" for _ in range(generation_cfg.n)] + num_prompt_tokens = 0 num_completion_tokens = 0 - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - logprob_json_strs_list: Optional[List[List[str]]] = ( - [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None + output_texts = ["" for _ in range(request.n)] + finish_reasons: List[Optional[str]] = [None for _ in range(request.n)] + logprob_results: Optional[List[List[LogProbsContent]]] = ( + [[] for _ in range(request.n)] if request.logprobs else None ) - async_engine.state.record_event(request_id, event="invoke generate") - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): + + async for response in async_engine._handle_chat_completion( # pylint: disable=protected-access + request, request_id + ): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, # and abort the request from engine if so. await async_engine.abort(request_id) - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[choice.index] += choice.logprobs.content - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = delta_output.finish_reason - output_texts[i] += delta_output.delta_text - num_completion_tokens += delta_output.num_delta_tokens - if logprob_json_strs_list is not None: - assert delta_output.delta_logprob_json_strs is not None - logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs assert all(finish_reason is not None for finish_reason in finish_reasons) - - async_engine.state.record_event(request_id, event="finish") - - tool_calls_list: List[List[ChatToolCall]] = [[] for _ in range(generation_cfg.n)] - if conv_template.use_function_calling: - for i, output_text in enumerate(output_texts): - try: - fn_json_list = convert_function_str_to_json(output_text) - except (SyntaxError, ValueError): - output_text = "Got an invalid function call output from model" - finish_reasons[i] = "error" - else: - tool_calls_list[i] = [ - ChatToolCall( - type="function", - function=ChatFunctionCall( - name=fn_json_obj["name"], arguments=fn_json_obj["arguments"] - ), - ) - for fn_json_obj in fn_json_list - if fn_json_obj is not None - ] - if len(tool_calls_list[i]) == 0: - output_texts[i] = "Got an invalid function call output from model" - finish_reasons[i] = "error" - else: - finish_reasons[i] = "tool_calls" - - return ChatCompletionResponse( - id=request_id, - choices=[ - ChatCompletionResponseChoice( - index=i, - finish_reason=finish_reasons[i], - message=( - ChatCompletionMessage(role="assistant", content=output_text) - if (not conv_template.use_function_calling or finish_reason == "error") - else ChatCompletionMessage(role="assistant", tool_calls=tool_calls) - ), - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in logprob_json_strs_list[ # pylint: disable=unsubscriptable-object - i - ] - ] - ) - if logprob_json_strs_list is not None - else None - ), - ) - for i, (output_text, finish_reason, tool_calls) in enumerate( - zip(output_texts, finish_reasons, tool_calls_list) - ) - ], + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, model=request.model, - system_fingerprint="", - usage=UsageInfo( - prompt_tokens=sum(len(item) for item in prompt), completion_tokens=num_completion_tokens - ), + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, ) diff --git a/python/mlc_llm/serve/event_trace_recorder.py b/python/mlc_llm/serve/event_trace_recorder.py index 7a8a8177fe..457918d598 100644 --- a/python/mlc_llm/serve/event_trace_recorder.py +++ b/python/mlc_llm/serve/event_trace_recorder.py @@ -17,7 +17,7 @@ def __init__(self) -> None: ) def add_event(self, request_id: str, event: str) -> None: - """Record a event for the the input request in the trace recorder. + """Record a event for the input request in the trace recorder. Parameters ---------- diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index d5ad862a42..cf491884c2 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -247,7 +247,7 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.GrammarStateMatcherFromTokenTable, # type: ignore # pylint: disable=no-member grammar, - *tokenizer, + tokenizer, max_rollback_steps, ) else: diff --git a/python/mlc_llm/serve/radix_tree.py b/python/mlc_llm/serve/radix_tree.py new file mode 100644 index 0000000000..102cdac675 --- /dev/null +++ b/python/mlc_llm/serve/radix_tree.py @@ -0,0 +1,150 @@ +"""The Paged Radix Tree class.""" + +from typing import List, Tuple, Union + +import tvm +import tvm._ffi +from tvm.runtime import Object, ShapeTuple + +from . import _ffi_api + + +@tvm._ffi.register_object("mlc.serve.PagedRadixTree") # pylint: disable=protected-access +class PagedRadixTree(Object): + """The paged radix tree to manage prefix and sequence.""" + + def __init__(self, num_pages: int, page_size: int, num_seqs: int): + """ + Constructor of paged radix tree. + + Parameters + ---------- + num_pages : int + The number of radix tree pages. + page_size : int + The page size of each radix tree page. + num_seqs : int + The maximum number of sequence ID. + """ + self.__init_handle_by_constructor__(_ffi_api.PagedRadixTree, num_pages, page_size, num_seqs) # type: ignore # pylint: disable=no-member + + def match(self, tokens: Union[ShapeTuple, List, Tuple]) -> Tuple[int, ShapeTuple]: + """ + Get all sequences with longest common prefix with given prefix tokens. + + Parameters + ---------- + tokens : Union[ShapeTuple, List, Tuple] + The prefix tokens for reference. + + Returns + ------ + matched_offset : int + The matched prefix length. + seq_ids : ShapeTuple + The array of matched sequence indice. + """ + if isinstance(tokens, (list, tuple)): + tokens = ShapeTuple(tokens) + output = _ffi_api.PagedRadixTreeMatchPrefix(self, tokens) # type: ignore # pylint: disable=no-member + if len(output) == 1: + return output[0], [] + return output[0], output[1:] + + def add(self, seq_id: int) -> None: + """ + Get all sequences with longest common prefix with give prefix tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + """ + _ffi_api.PagedRadixTreeAddSequence(self, seq_id) # type: ignore # pylint: disable=no-member + + def remove(self, seq_id: int) -> None: + """ + Remove a sequence. + + Parameters + ---------- + seq_id : int + The sequence ID to remove. + """ + _ffi_api.PagedRadixTreeRemoveSequence(self, seq_id) # type: ignore # pylint: disable=no-member + + def extend(self, seq_id: int, tokens: Union[ShapeTuple, List, Tuple]) -> None: + """ + Get all sequences with longest common prefix with give prefix tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + tokens : Union[ShapeTuple, List, Tuple] + The given tokens to extend. + """ + if isinstance(tokens, (list, tuple)): + tokens = ShapeTuple(tokens) + _ffi_api.PagedRadixTreeExtendSequence(self, seq_id, tokens) # type: ignore # pylint: disable=no-member + + def fork(self, seq_id: int, parent_seq_id: int, forked_offset: int) -> None: + """ + Fork a sequence from parent sequence at given position. + + Parameters + ---------- + seq_id : int + The new sequence ID. + parent_seq_id : int + The parent sequence ID to fork from. + forked_offset : int + The position of parent sequence to fork at. + The valid value is [1, length of forked sequence]. + If the position equals the length of forked sequence, + the new sequence will copy the entire forked sequence. + """ + _ffi_api.PagedRadixTreeForkSequence(self, seq_id, parent_seq_id, forked_offset) # type: ignore # pylint: disable=no-member + + def get(self, seq_id: int) -> ShapeTuple: + """ + Get a sequence's all tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + + Returns + ------ + tokens : ShapeTuple + The sequence tokens. + """ + return _ffi_api.PagedRadixTreeGetSequence(self, seq_id) # type: ignore # pylint: disable=no-member + + def get_length(self, seq_id: int) -> int: + """ + Get a sequence's length. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + + Returns + ------ + length : int + The sequence length. + """ + return _ffi_api.PagedRadixTreeGetSequenceLength(self, seq_id) # type: ignore # pylint: disable=no-member + + def free_capacity(self) -> int: + """ + Get the remaining token capacity of the paged radix tree. + + Returns + ------ + capacity : int + The remaining token capacity of the paged radix tree. + """ + return _ffi_api.PagedRadixTreeFreeCapacity(self) # type: ignore # pylint: disable=no-member diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index 0a9a1b0b1f..d6acd4a2be 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional -from ..engine import AsyncLLMEngine +from ..engine import AsyncMLCEngine class ServerContext: @@ -13,7 +13,7 @@ class ServerContext: server_context: Optional["ServerContext"] = None def __init__(self): - self._models: Dict[str, AsyncLLMEngine] = {} + self._models: Dict[str, AsyncMLCEngine] = {} def __enter__(self): if ServerContext.server_context is not None: @@ -31,14 +31,17 @@ def current(): """Returns the current ServerContext.""" return ServerContext.server_context - def add_model(self, hosted_model: str, engine: AsyncLLMEngine) -> None: + def add_model(self, hosted_model: str, engine: AsyncMLCEngine) -> None: """Add a new model to the server context together with the engine.""" if hosted_model in self._models: raise RuntimeError(f"Model {hosted_model} already running.") self._models[hosted_model] = engine - def get_engine(self, model: str) -> Optional[AsyncLLMEngine]: - """Get the async engine of the requested model.""" + def get_engine(self, model: Optional[str]) -> Optional[AsyncMLCEngine]: + """Get the async engine of the requested model, or the unique async engine + if only one engine is served.""" + if len(self._models) == 1: + return next(iter(self._models.values())) return self._models.get(model, None) def get_model_list(self) -> List[str]: diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 23b151d5c7..1be841cb08 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -41,7 +41,7 @@ def _create_tvm_module( return {key: module[key] for key in ffi_funcs} -class SyncLLMEngine: +class SyncMLCEngine: """The Python interface of synchronize request serving engine for MLC LLM. The engine receives requests from the "add_request" method. For @@ -98,6 +98,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, enable_tracing: bool = False, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, @@ -128,11 +129,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -162,15 +166,17 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_lib_path=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ), + device, request_stream_callback, self.trace_recorder, ) diff --git a/python/mlc_llm/support/auto_config.py b/python/mlc_llm/support/auto_config.py index f0247a6ef9..be0ee8af98 100644 --- a/python/mlc_llm/support/auto_config.py +++ b/python/mlc_llm/support/auto_config.py @@ -62,7 +62,7 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path: # search mlc-chat-config.json under path mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json" if not mlc_chat_config_json_path.exists(): - raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.") + raise ValueError(f"Fail to find mlc-chat-config.json under {mlc_chat_config_path}.") else: mlc_chat_config_json_path = mlc_chat_config_path diff --git a/python/mlc_llm/support/auto_device.py b/python/mlc_llm/support/auto_device.py index cf6d09495a..bddb9954c6 100644 --- a/python/mlc_llm/support/auto_device.py +++ b/python/mlc_llm/support/auto_device.py @@ -1,4 +1,6 @@ """Automatic detection of the device available on the local machine.""" + +import os import subprocess import sys from typing import Dict, Optional @@ -65,6 +67,7 @@ def _device_exists(device: Device) -> bool: capture_output=True, text=True, check=False, + env=os.environ, ) .stdout.strip() .splitlines() diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 5c61af6f07..4c32feb6ff 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -295,12 +295,18 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): def detect_cuda_arch_list(target: Target) -> List[str]: """Detect the CUDA architecture list from the target.""" + + def convert_to_num(arch_str): + arch_num_str = "".join(filter(str.isdigit, arch_str)) + assert arch_num_str, f"'{arch_str}' does not contain any digits" + return int(arch_num_str) + assert target.kind.name == "cuda", f"Expect target to be CUDA, but got {target}" if MLC_MULTI_ARCH is not None: - multi_arch = [x.strip() for x in MLC_MULTI_ARCH.split(",")] + multi_arch = [convert_to_num(x) for x in MLC_MULTI_ARCH.split(",")] else: assert target.arch.startswith("sm_") - multi_arch = [target.arch[3:]] + multi_arch = [convert_to_num(target.arch[3:])] multi_arch = list(set(multi_arch)) return multi_arch diff --git a/python/mlc_llm/support/download.py b/python/mlc_llm/support/download.py index a109c967bc..770833e9af 100644 --- a/python/mlc_llm/support/download.py +++ b/python/mlc_llm/support/download.py @@ -36,11 +36,13 @@ def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None: command = ["git", "clone", url, repo_name] _ensure_directory_not_exist(destination, force_redo=False) try: + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: logger.info("[Git] Cloning %s to %s", bold(url), destination) subprocess.run( command, - env={"GIT_LFS_SKIP_SMUDGE": "1"}, + env=env, cwd=tmp_dir, check=True, stdout=subprocess.DEVNULL, diff --git a/python/mlc_llm/support/max_thread_check.py b/python/mlc_llm/support/max_thread_check.py index 6c078c3bbf..6711fb5c55 100644 --- a/python/mlc_llm/support/max_thread_check.py +++ b/python/mlc_llm/support/max_thread_check.py @@ -3,7 +3,7 @@ from tvm.target import Target -def get_max_num_threads_per_block(target: Target): +def get_max_num_threads_per_block(target: Target) -> int: """ max(max_num_threads, max_threads_per_block); if latter does not exist, return max_num_threads. We add this method since some targets have both fields and `max_threads_per_block` is larger. diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 2a70154bba..4f1cfe103d 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -118,7 +118,7 @@ def __call__(self, func, name, before_run, ret_val, *args): print(f"{red(f'{func_name} has INF')}: {num_infs}") self.first_inf_occurred = True - # Save the the arguments to npz + # Save the arguments to npz arg_dict = {} for i, arg in enumerate(args): if isinstance(arg, tvm.nd.NDArray): diff --git a/scripts/build_mlc_for_docs.sh b/scripts/build_mlc_for_docs.sh new file mode 100755 index 0000000000..50eee3231a --- /dev/null +++ b/scripts/build_mlc_for_docs.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -euxo pipefail + +mkdir -p build +cd build +cmake .. +make -j$(nproc) +cd - diff --git a/scripts/build_site.sh b/scripts/build_site.sh index 6340ee838e..062f8094de 100755 --- a/scripts/build_site.sh +++ b/scripts/build_site.sh @@ -1,6 +1,7 @@ #!/bin/bash set -euxo pipefail +export PYTHONPATH=$PWD/python cd docs && make html && cd .. cd site && jekyll b && cd .. diff --git a/scripts/gh_deploy_site.sh b/scripts/gh_deploy_site.sh index 1b21c52d16..326c280484 100755 --- a/scripts/gh_deploy_site.sh +++ b/scripts/gh_deploy_site.sh @@ -4,6 +4,7 @@ set -euxo pipefail +scripts/build_mlc_for_docs.sh scripts/build_site.sh git fetch diff --git a/site/index.md b/site/index.md index 44befd4abc..ac0367cdb2 100644 --- a/site/index.md +++ b/site/index.md @@ -6,62 +6,41 @@ notitle: true # MLC LLM -MLC LLM is a universal solution that allows any language model to be deployed natively on a diverse set of hardware backends and native applications. - -Please visit [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) for detailed instructions. - -## Demos - -- [iOS](#ios) -- [Android](#android) -- [Windows Linux Mac](#windows-linux-mac) -- [Web browser](#web-browser) - -### iOS - -Our iOS app, MLCChat, is available on [App Store](https://apps.apple.com/us/app/mlc-chat/id6448482937) for iPhone and iPad. -You can try out the [Testflight app](https://testflight.apple.com/join/57zd7oxa) that sometimes contains beta release of latest models. -This app is tested on iPhone 15 Pro Max, iPhone 14 Pro Max, iPhone 14 Pro and iPhone 12 Pro. -Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, -[documentation](https://llm.mlc.ai/docs/deploy/ios.html) is available for building iOS apps with MLC LLM. +Documentation: [https://llm.mlc.ai/docs](https://llm.mlc.ai/docs) +**M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques.

- +

-Note: Llama-7B takes 4GB of RAM and RedPajama-3B takes 2.2GB to run. We recommend a latest device with 6GB RAM for Llama-7B, or 4GB RAM for RedPajama-3B, to run the app. The text generation speed could vary from time to time, for example, slow in the beginning but recover to a normal speed then. +## Installation -### Android +MLC LLM is available via [pip](https://llm.mlc.ai/docs/install/mlc_llm.html#install-mlc-packages). +It is always recommended to install it in an isolated conda virtual environment. -The demo APK is available to [download](https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk). The demo is tested on Samsung S23 with Snapdragon 8 Gen 2 chip, Redmi Note 12 Pro with Snapdragon 685 and Google Pixel phones. -Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, -[documentation](https://llm.mlc.ai/docs/deploy/android.html) is available for building android apps with MLC LLM. +To verify the installation, activate your virtual environment, run -

- -

+```bash +python -c "import mlc_llm; print(mlc_llm.__path__)" +``` -### Windows Linux Mac +You are expected to see the installation path of MLC LLM Python package. -Our cpp interface runs on AMD, Intel, Apple and NVIDIA GPUs. -Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, -[documentation](https://llm.mlc.ai/docs/deploy/cli.html) is available for building C++ apps with MLC LLM. +## Quick Start -

- -

+Please check out our documentation for the [quick start](https://llm.mlc.ai/docs/get_started/quick_start.html). -### Web Browser +## Introduction -[WebLLM](https://webllm.mlc.ai/) is our companion project that deploys MLC LLM natively to browsers using WebGPU and WebAssembly. Still everything runs inside the browser without server resources, and accelerated by local GPUs (e.g. AMD, Intel, Apple or NVIDIA). +Please check out our documentation for the [introduction](https://llm.mlc.ai/docs/get_started/introduction.html). ## Links -* Our official [GitHub repo](https://github.com/mlc-ai/mlc-llm); -* Our companion project [WebLLM](https://webllm.mlc.ai/) that enables running LLMs purely in browser. -* [Web Stable Diffusion](https://websd.mlc.ai/) is another MLC-series that runs the diffusion models purely in the browser. -* [Machine Learning Compilation course](https://mlc.ai) is available for a systematic walkthrough of our approach to universal deployment. +- You might want to check out our online public [Machine Learning Compilation course](https://mlc.ai) for a systematic +walkthrough of our approaches. +- [WebLLM](https://webllm.mlc.ai/) is a companion project using MLC LLM's WebGPU and WebAssembly backend. +- [WebStableDiffusion](https://websd.mlc.ai/) is a companion project for diffusion models with the WebGPU backend. ## Disclaimer diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index b86fd423a9..c52571b522 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -1,25 +1,8 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -import json -import queue -import threading from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union -import tvm +from mlc_llm.json_ffi import JSONFFIEngine -from mlc_llm.protocol import openai_api_protocol -from mlc_llm.serve import engine_utils -from mlc_llm.serve.engine_base import ( - EngineConfig, - SpeculativeMode, - _infer_kv_cache_config, - _parse_models, - _process_model_args, - detect_device, -) -from mlc_llm.tokenizer import Tokenizer - -prompts = [ +chat_completion_prompts = [ "What is the meaning of life?", "Introduce the history of Pittsburgh to me. Please elaborate in detail.", "Write a three-day Seattle travel plan. Please elaborate in detail.", @@ -32,227 +15,40 @@ "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", ] +function_calling_prompts = [ + "What is the temperature in Pittsburgh, PA?", + "What is the temperature in Tokyo, JP?", + "What is the temperature in Pittsburgh, PA and Tokyo, JP?", +] -class EngineState: - sync_queue: queue.Queue - - def get_request_stream_callback(self) -> Callable[[List[str]], None]: - # ChatCompletionStreamResponse - - def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: - self._sync_request_stream_callback(chat_completion_stream_responses_json_str) - - return _callback - - def _sync_request_stream_callback( - self, chat_completion_stream_responses_json_str: List[str] - ) -> None: - # Put the delta outputs to the queue in the unblocking way. - self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) - - -class JSONFFIEngine: - def __init__( # pylint: disable=too-many-arguments,too-many-locals - self, - model: str, - device: Union[str, tvm.runtime.Device] = "auto", - *, - model_lib_path: Optional[str] = None, - mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, - gpu_memory_utilization: Optional[float] = None, - ) -> None: - # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) - if isinstance(device, str): - device = detect_device(device) - assert isinstance(device, tvm.runtime.Device) - ( - model_args, - model_config_paths, - self.conv_template, - ) = _process_model_args(models, device) - - # - Load the raw model config into dict - self.model_config_dicts = [] - for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] - with open(model_config_paths[i], "r", encoding="utf-8") as file: - self.model_config_dicts.append(json.load(file)) - - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) - - # - Initialize engine state and engine. - self.state = EngineState() - module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() - self._ffi = { - key: module[key] - for key in [ - "init_background_engine", - "chat_completion", - "abort", - "get_last_error", - "run_background_loop", - "run_background_stream_back_loop", - "exit_background_loop", - ] - } - self.tokenizer = Tokenizer(model_args[0][0]) - - def _background_loop(): - self._ffi["init_background_engine"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ), - self.state.get_request_stream_callback(), - None, - ) - self._ffi["run_background_loop"]() - - def _background_stream_back_loop(): - self._ffi["run_background_stream_back_loop"]() - - # Create the background engine-driving thread and start the loop. - self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) - self._background_stream_back_loop_thread: threading.Thread = threading.Thread( - target=_background_stream_back_loop - ) - self._background_loop_thread.start() - self._background_stream_back_loop_thread.start() - self._terminated = False - - def terminate(self): - self._terminated = True - self._ffi["exit_background_loop"]() - self._background_loop_thread.join() - self._background_stream_back_loop_thread.join() - - def chat_completion( # pylint: disable=too-many-arguments,too-many-locals - self, - *, - messages: List[Dict[str, Any]], - model: str, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - logprobs: bool = False, - top_logprobs: int = 0, - logit_bias: Optional[Dict[int, float]] = None, - max_tokens: Optional[int] = None, - n: int = 1, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, - user: Optional[str] = None, - ignore_eos: bool = False, - response_format: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - if request_id is None: - request_id = f"chatcmpl-{engine_utils.random_uuid()}" - - chatcmpl_generator = self._handle_chat_completion( - openai_api_protocol.ChatCompletionRequest( - messages=[ - openai_api_protocol.ChatCompletionMessage.model_validate(message) - for message in messages - ], - model=model, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - logprobs=logprobs, - top_logprobs=top_logprobs, - logit_bias=logit_bias, - max_tokens=max_tokens, - n=n, - seed=seed, - stop=stop, - stream=stream, - temperature=temperature, - top_p=top_p, - tools=( - [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] - if tools is not None - else None - ), - tool_choice=tool_choice, - user=user, - ignore_eos=ignore_eos, - response_format=( - openai_api_protocol.RequestResponseFormat.model_validate(response_format) - if response_format is not None - else None - ), - ).model_dump_json(), - n=n, - request_id=request_id, - ) - for response in chatcmpl_generator: - yield response - - def _handle_chat_completion( - self, request_json_str: str, n: int, request_id: str - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - self.state.sync_queue = queue.Queue() - num_unfinished_requests = n - - success = bool(self._ffi["chat_completion"](request_json_str, request_id)) - - try: - while num_unfinished_requests > 0: - chat_completion_stream_responses_json_str = self.state.sync_queue.get() - for chat_completion_response_json_str in chat_completion_stream_responses_json_str: - chat_completion_response = ( - openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( - chat_completion_response_json_str - ) - ) - for choice in chat_completion_response.choices: - if choice.finish_reason is not None: - num_unfinished_requests -= 1 - yield chat_completion_response - except Exception as exception: # pylint: disable=broad-exception-caught - self._ffi["abort"](request_id) - raise exception +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } +] -def test_chat_completion(engine: JSONFFIEngine): +def run_chat_completion( + engine: JSONFFIEngine, + model: str, + prompts: List[str] = chat_completion_prompts, + tools: Optional[List[Dict]] = None, +): num_requests = 2 max_tokens = 64 n = 1 @@ -266,6 +62,7 @@ def test_chat_completion(engine: JSONFFIEngine): max_tokens=max_tokens, n=n, request_id=str(rid), + tools=tools, ): for choice in response.choices: assert choice.delta.role == "assistant" @@ -284,24 +81,61 @@ def test_chat_completion(engine: JSONFFIEngine): print(f"Output {req_id}({i}):{output}\n") -def test_malformed_request(engine: JSONFFIEngine): +def test_chat_completion(): + # Create engine. + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + engine = JSONFFIEngine( + model, + max_total_sequence_length=1024, + ) + + run_chat_completion(engine, model) + + # Test malformed requests. for response in engine._handle_chat_completion("malformed_string", n=1, request_id="123"): assert len(response.choices) == 1 assert response.choices[0].finish_reason == "error" + engine.terminate() -if __name__ == "__main__": + +def test_reload_reset_unload(): # Create engine. - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + engine = JSONFFIEngine( + model, + max_total_sequence_length=1024, + ) + + # Run chat completion before and after reload/reset. + run_chat_completion(engine, model) + engine._test_reload() + run_chat_completion(engine, model) + engine._test_reset() + run_chat_completion(engine, model) + engine._test_unload() + + engine.terminate() + + +def test_function_calling(): + model = "dist/gorilla-openfunctions-v1-q4f16_1-MLC" + model_lib_path = ( + "dist/gorilla-openfunctions-v1-q4f16_1-MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so" + ) engine = JSONFFIEngine( model, model_lib_path=model_lib_path, max_total_sequence_length=1024, ) - test_chat_completion(engine) - test_malformed_request(engine) + # run function calling + run_chat_completion(engine, model, function_calling_prompts, tools) engine.terminate() - del engine + + +if __name__ == "__main__": + test_chat_completion() + test_reload_reset_unload() + test_function_calling() diff --git a/tests/python/op/test_batch_spec_verify.py b/tests/python/op/test_batch_spec_verify.py new file mode 100644 index 0000000000..f35a39d71e --- /dev/null +++ b/tests/python/op/test_batch_spec_verify.py @@ -0,0 +1,160 @@ +import numpy as np +import pytest +import tvm +import tvm.testing + +from mlc_llm.op.batch_spec_verify import batch_spec_verify + + +@pytest.mark.parametrize("nbatch", [32, 64]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001, 128000]) +@pytest.mark.parametrize("plist", [[0.5, 0.5], [1, 0], [0, 1]]) +def test_batch_spec_verify(nbatch, vocab, plist): + def numpy_reference( + draft_probs, + draft_tokens, + model_probs, + token_tree_first_child, + token_tree_next_sibling, + uniform_samples, + token_tree_parent_ptr, + ): + nbatch = token_tree_parent_ptr.shape[0] + for b in range(nbatch): + parent_ptr = token_tree_parent_ptr[b] + child_ptr = token_tree_first_child[parent_ptr] + while child_ptr != -1: + child_token = draft_tokens[child_ptr] + p_child = model_probs[parent_ptr, child_token] + q_child = draft_probs[child_ptr, child_token] + uniform_sample = uniform_samples[child_ptr] + if p_child / q_child >= uniform_sample: + parent_ptr = child_ptr + child_ptr = token_tree_first_child[child_ptr] + else: + model_probs[parent_ptr, :] = np.maximum( + model_probs[parent_ptr, :] - draft_probs[child_ptr, :], 0.0 + ) + psum = np.sum(model_probs[parent_ptr, :]) + model_probs[parent_ptr, :] /= psum + child_ptr = token_tree_next_sibling[child_ptr] + token_tree_parent_ptr[b] = parent_ptr + + np.random.seed(0) + + def gen_chain(num_nodes, base): + token_tree_first_child = list() + token_tree_next_sibling = list() + for i in range(num_nodes): + token_tree_first_child.append(base + i + 1 if i + 1 < num_nodes else -1) + token_tree_next_sibling.append(-1) + return token_tree_first_child, token_tree_next_sibling, base, base + 1 + + def gen_full_binary_tree(height, base): + token_tree_first_child = list() + token_tree_next_sibling = list() + num_nodes = 2**height - 1 + for i in range(num_nodes): + token_tree_first_child.append(base + i * 2 + 1 if i * 2 + 1 < num_nodes else -1) + token_tree_next_sibling.append(base + i * 2 + 2 if i * 2 + 2 < num_nodes else -1) + return token_tree_first_child, token_tree_next_sibling, base, base + 1 + + ### Inputs + num_nodes = 0 + token_tree_first_child = list() + token_tree_next_sibling = list() + token_tree_parent_ptr = list() + + for _ in range(nbatch): + choice = np.random.choice(2, 1, p=plist) + if choice == 0: + nodes_batch = np.random.randint(3, 32) + res = gen_chain(nodes_batch, num_nodes) + num_nodes += nodes_batch + else: + height = np.random.randint(3, 5) + res = gen_full_binary_tree(height, num_nodes) + num_nodes += 2**height - 1 + token_tree_first_child.extend(res[0]) + token_tree_next_sibling.extend(res[1]) + token_tree_parent_ptr.append(res[2]) + + token_tree_first_child = np.array(token_tree_first_child).astype("int32") + token_tree_next_sibling = np.array(token_tree_next_sibling).astype("int32") + token_tree_parent_ptr = np.array(token_tree_parent_ptr).astype("int32") + + draft_probs = np.random.rand(num_nodes, vocab).astype("float32") + draft_probs /= np.sum(draft_probs, axis=1, keepdims=True) + draft_tokens = np.random.randint(0, vocab, num_nodes).astype("int32") + model_probs = np.random.rand(num_nodes, vocab).astype("float32") + model_probs /= np.sum(model_probs, axis=1, keepdims=True) + uniform_samples = np.random.rand(num_nodes).astype("float32") + + ### TVM Inputs + dev = tvm.cuda(0) + draft_probs_tvm = tvm.nd.array(draft_probs, dev) + draft_tokens_tvm = tvm.nd.array(draft_tokens, dev) + model_probs_tvm = tvm.nd.array(model_probs, dev) + token_tree_first_child_tvm = tvm.nd.array(token_tree_first_child, dev) + token_tree_next_sibling_tvm = tvm.nd.array(token_tree_next_sibling, dev) + uniform_samples_tvm = tvm.nd.array(uniform_samples, dev) + token_tree_parent_ptr_tvm = tvm.nd.array(token_tree_parent_ptr, dev) + + # print("draft_probs", draft_probs) + # print("draft_tokens", draft_tokens) + # print("model_probs", model_probs) + # print("token_tree_first_child", token_tree_first_child) + # print("token_tree_next_sibling", token_tree_next_sibling) + # print("uniform_samples", uniform_samples) + # print("token_tree_parent_ptr", token_tree_parent_ptr) + + ### Numpy reference + numpy_reference( + draft_probs, + draft_tokens, + model_probs, + token_tree_first_child, + token_tree_next_sibling, + uniform_samples, + token_tree_parent_ptr, + ) + # print("model_probs", model_probs) + # print("token_tree_parent_ptr", token_tree_parent_ptr) + + ### TVM + kernel = batch_spec_verify(vocab) + mod = tvm.build(kernel, target="cuda") + mod( + draft_probs_tvm, + draft_tokens_tvm, + model_probs_tvm, + token_tree_first_child_tvm, + token_tree_next_sibling_tvm, + uniform_samples_tvm, + token_tree_parent_ptr_tvm, + ) + # print("model_probs", model_probs_tvm.asnumpy()) + # print("token_tree_parent_ptr", token_tree_parent_ptr_tvm.asnumpy()) + + tvm.testing.assert_allclose(model_probs, model_probs_tvm.asnumpy()) + tvm.testing.assert_allclose( + token_tree_parent_ptr, token_tree_parent_ptr_tvm.asnumpy(), rtol=0, atol=0 + ) + + time_evaluator = mod.time_evaluator(mod.entry_name, dev, number=10, repeat=3) + print(f"batch_size: {nbatch}, vocab_size: {vocab}, tree_structure: {plist}") + print( + time_evaluator( + draft_probs_tvm, + draft_tokens_tvm, + model_probs_tvm, + token_tree_first_child_tvm, + token_tree_next_sibling_tvm, + uniform_samples_tvm, + token_tree_parent_ptr_tvm, + ) + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/op/test_top_p_pivot.py b/tests/python/op/test_top_p_pivot.py new file mode 100644 index 0000000000..7cfeb60e9c --- /dev/null +++ b/tests/python/op/test_top_p_pivot.py @@ -0,0 +1,83 @@ +import numpy as np +import pytest +import tvm +import tvm.testing + +from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm + +# mypy: disable-error-code="var-annotated" + + +@pytest.mark.parametrize("batch_size", [32, 64]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 128]) +def test_top_p_renorm(batch_size, vocab): + top_p = 0.95 + init_pivots_np = np.array([1 - top_p, 0.02, 0.01]).astype(np.float32) + top_p_np = np.array([top_p]).astype(np.float32) + + p_np = np.random.exponential(3, size=(batch_size, vocab)).astype(np.float32) + p_np /= np.sum(p_np, axis=-1, keepdims=True) + final_pivot_np = np.zeros(batch_size).astype(np.float32) + final_lsum_np = np.zeros(batch_size).astype(np.float32) + + dev = tvm.cuda(0) + var_prob = tvm.nd.array(p_np, dev) + var_init_pivots = tvm.nd.array(init_pivots_np, dev) + top_p_global = tvm.nd.array(top_p_np, dev) + var_final_pivot = tvm.nd.array(final_pivot_np, dev) + var_final_lsum = tvm.nd.array(final_lsum_np, dev) + + kernel = top_p_pivot(init_pivots_np.shape[0]) + mod = tvm.build(kernel, target="cuda") + mod(var_prob, top_p_global, var_init_pivots, var_final_pivot, var_final_lsum) + + final_pivot = var_final_pivot.asnumpy() + final_lsum = var_final_lsum.asnumpy() + + renorm_np = p_np.copy() + var_renorm = tvm.nd.array(renorm_np, dev) + + kernel_renorm = top_p_renorm() + mod_renorm = tvm.build(kernel_renorm, target="cuda") + mod_renorm(var_prob, var_final_pivot, var_final_lsum, var_renorm) + + renorm = var_renorm.asnumpy() + + def verify_pivot(probs: np.ndarray, pivot: float, lsum: float, renorm: np.ndarray): + sorted_probs = np.sort(probs, axis=-1)[::-1] + num_larger_than_pivot = np.sum(sorted_probs >= pivot) + filtered_sorted_probs = sorted_probs[:num_larger_than_pivot] + min_larger_than_pivot = min(filtered_sorted_probs) + + sum_larger_than_pivot = np.sum(np.where(sorted_probs >= pivot, sorted_probs, 0)) + sum_larger_than_pivot_exclude_min = np.sum( + np.where(filtered_sorted_probs != min_larger_than_pivot, filtered_sorted_probs, 0) + ) + + probs[probs < pivot] = 0 + renorm_prob = probs / np.sum(probs, axis=-1, keepdims=True) + try: + assert sum_larger_than_pivot >= top_p + assert sum_larger_than_pivot_exclude_min < top_p + assert abs(lsum - sum_larger_than_pivot) < 1e-6 + assert np.allclose(renorm, renorm_prob, atol=1e-6, rtol=1e-6) + except AssertionError: + print("Failed") + print("probs:", repr(probs)) + print("pivot:", pivot) + print("sorted_probs:", sorted_probs) + print("num_larger_than_pivot:", num_larger_than_pivot) + print("filtered_sorted_probs:", filtered_sorted_probs) + print("min_larger_than_pivot:", min_larger_than_pivot) + print("sum_larger_than_pivot:", sum_larger_than_pivot) + print("sum_larger_than_pivot_exclude_min:", sum_larger_than_pivot_exclude_min) + print("renom_prob:", renorm_prob) + print("renorm:", renorm) + raise + + for i in range(batch_size): + verify_pivot(p_np[i], final_pivot[i], final_lsum[i], renorm[i]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/op/test_two_stage_softmax.py b/tests/python/op/test_two_stage_softmax.py new file mode 100644 index 0000000000..1d3d55d8e3 --- /dev/null +++ b/tests/python/op/test_two_stage_softmax.py @@ -0,0 +1,47 @@ +import numpy as np +import scipy.special +import tvm +from tvm import dlight + +from mlc_llm.compiler_pass.rewrite_softmax import _get_lse_and_softmax_func + + +def test_two_stage_softmax(): + chunk_size = 4096 + target = tvm.target.Target("cuda") + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(target, chunk_size) + mod = tvm.IRModule({"chunk_lse": f_chunk_lse, "softmax_with_chunked_lse": f_softmax_with_lse}) + with target: + mod = dlight.ApplyDefaultSchedule(dlight.gpu.GeneralReduction())(mod) + + runtime_mod = tvm.build(mod, target=target) + device = tvm.cuda() + + num_runs = 5 + vocab_size = 128256 + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: + for _ in range(num_runs): + x_np = np.random.uniform(low=-10, high=10, size=(batch_size, vocab_size)).astype( + "float32" + ) + y_np = scipy.special.softmax(x_np, axis=-1) + + x_nd = tvm.nd.array(x_np, device=device) + r_nd = tvm.nd.empty( + (batch_size, (vocab_size + chunk_size - 1) // chunk_size), + x_np.dtype, + device=device, + ) + y_nd = tvm.nd.empty(x_np.shape, x_np.dtype, device=device) + + runtime_mod["chunk_lse"](x_nd, r_nd) + runtime_mod["softmax_with_chunked_lse"](x_nd, r_nd, y_nd) + + y_nd_arr = y_nd.numpy() + np.testing.assert_allclose(y_nd_arr, y_np, atol=1e-6, rtol=1e-6) + + print(f"pass batch size {batch_size}") + + +if __name__ == "__main__": + test_two_stage_softmax() diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index 4e541b7437..c89a9e2c38 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -5,7 +5,7 @@ from typing import List, Tuple from mlc_llm.serve import GenerationConfig -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine def _parse_args(): @@ -41,7 +41,7 @@ def benchmark(args: argparse.Namespace): random.seed(args.seed) # Create engine - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=args.model, device=args.device, model_lib_path=args.model_lib_path, diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index ad4fa01a82..e4f64d2ce4 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -329,23 +329,6 @@ def test_openai_v1_completions_openai_package( ) -def test_openai_v1_completions_invalid_requested_model( - launch_server, # pylint: disable=unused-argument -): - # `launch_server` is a pytest fixture defined in conftest.py. - - model = "unserved_model" - payload = { - "model": model, - "prompt": "What is the meaning of life?", - "max_tokens": 10, - } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - expect_error( - response_str=response.json(), msg_prefix=f'The requested model "{model}" is not served.' - ) - - @pytest.mark.parametrize("stream", [False, True]) def test_openai_v1_completions_echo( served_model: Tuple[str, str], @@ -620,51 +603,6 @@ class Schema(BaseModel): "response_format": {"type": "json_object", "schema": schema_str}, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_json( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = "Response with a json object:" - max_tokens = 128 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stream": stream, - "response_format": {"type": "json_object"}, - } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) if not stream: check_openai_nonstream_response( @@ -1364,7 +1302,6 @@ def test_debug_dump_event_trace( test_openai_v1_completions(MODEL, None, stream=True) test_openai_v1_completions_openai_package(MODEL, None, stream=False) test_openai_v1_completions_openai_package(MODEL, None, stream=True) - test_openai_v1_completions_invalid_requested_model(None) test_openai_v1_completions_echo(MODEL, None, stream=False) test_openai_v1_completions_echo(MODEL, None, stream=True) test_openai_v1_completions_suffix(MODEL, None, stream=False) diff --git a/tests/python/serve/test_radix_tree.py b/tests/python/serve/test_radix_tree.py new file mode 100644 index 0000000000..cea421cd95 --- /dev/null +++ b/tests/python/serve/test_radix_tree.py @@ -0,0 +1,79 @@ +from tvm import TVMError +from tvm.runtime import ShapeTuple + +from mlc_llm.serve import PagedRadixTree + + +def test_add(): + prt = PagedRadixTree(16, 128, 16) + prt.add(0) + assert prt.get(0) == [] + + +def test_remove(): + prt = PagedRadixTree(32, 128, 16) + capacity = prt.free_capacity() + prt.add(0) + prt.remove(0) + prt.add(0) + prt.extend(0, [1 for _ in range(200)]) + prt.remove(0) + assert prt.free_capacity() == capacity + + prt.add(1) + prt.extend(1, [1 for _ in range(200)]) + capacity = prt.free_capacity() + prt.add(2) + prt.extend(2, [1 for _ in range(100)] + [2 for _ in range(100)]) + prt.remove(2) + assert prt.free_capacity() == capacity + + prt.add(3) + prt.extend(3, [1 for _ in range(200)]) + prt.remove(3) + assert prt.free_capacity() == capacity + + +def test_extend(): + prt = PagedRadixTree(1024, 256, 256) + L = prt.free_capacity() // 1024 + H = L // 2 + Q = L // 4 + seq_id = 0 + for start_pos in [0, H, L, L + H]: + for length in [Q, L - H, L, 2 * L - H, 2 * L]: + prt.add(seq_id) + if start_pos: + tokens_1 = [seq_id for _ in range(start_pos)] + prt.extend(seq_id, tokens_1) + assert prt.get(seq_id) == tokens_1 + else: + tokens_1 = [] + tokens_2 = [seq_id for _ in range(length)] + prt.extend(seq_id, tokens_2) + assert prt.get(seq_id) == tokens_1 + tokens_2 + seq_id += 1 + + +def test_fork(): + prt = PagedRadixTree(1024, 256, 256) + L = prt.free_capacity() // 1024 + H = L // 2 + Q = L // 4 + seq_id = 0 + length_list = [Q, H, L, L + Q, L + H, L * 2] + for p_idx in range(1, len(length_list)): + for c_idx in range(0, p_idx + 1): + prt.add(seq_id) + tokens = [seq_id for _ in range(length_list[p_idx])] + prt.extend(seq_id, tokens) + prt.fork(seq_id + 1, seq_id, length_list[c_idx]) + assert prt.get(seq_id + 1) == tokens[: length_list[c_idx]] + seq_id += 2 + + +if __name__ == "__main__": + test_add() + test_remove() + test_extend() + test_fork() diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index 9bece30578..6e3835238a 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,7 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncLLMEngine, GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig prompts = [ "What is the meaning of life?", @@ -23,7 +23,7 @@ async def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -39,7 +39,7 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncLLMEngine, + async_engine: AsyncMLCEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, @@ -80,7 +80,7 @@ async def test_chat_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -132,7 +132,7 @@ async def test_chat_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -183,7 +183,7 @@ async def test_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -235,7 +235,7 @@ async def test_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index 6915224f81..c3963af613 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,17 +3,7 @@ import asyncio from typing import List -<<<<<<< HEAD -from mlc_llm.serve import ( - AsyncThreadedEngine, - EngineMode, - GenerationConfig, - KVCacheConfig, -) -from mlc_llm.serve.engine import ModelInfo -======= -from mlc_llm.serve import AsyncLLMEngine, GenerationConfig, SpeculativeMode ->>>>>>> upstream/main +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig, SpeculativeMode prompts = [ "What is the meaning of life?", @@ -37,7 +27,7 @@ async def test_engine_generate(): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -54,18 +44,14 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncLLMEngine, + async_engine: AsyncMLCEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, ): print(f"generate task for request {request_id}") rid = int(request_id) -<<<<<<< HEAD - async for delta_outputs in async_engine.generate( -======= async for delta_outputs in async_engine._generate( ->>>>>>> upstream/main prompt, generation_cfg, request_id=request_id ): assert len(delta_outputs) == generation_cfg.n diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 330bd4cf82..37d1833b14 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,7 +2,9 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List -from mlc_llm.serve import GenerationConfig, LLMEngine +import pytest + +from mlc_llm.serve import GenerationConfig, MLCEngine prompts = [ "What is the meaning of life?", @@ -17,17 +19,39 @@ "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", ] +test_models = [ + ( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ), + ( + "dist/rwkv-6-world-1b6-q0f16-MLC", + "dist/rwkv-6-world-1b6-q0f16-MLC/rwkv-6-world-1b6-q0f16-MLC-cuda.so", + ), +] -def test_engine_generate(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + +def create_engine(model: str, model_lib_path: str): + if "rwkv" in model: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_batch_size=8, + max_history_size=1, + ) + else: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_engine_generate(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 10 max_tokens = 256 @@ -57,16 +81,10 @@ def test_engine_generate(): del engine -def test_chat_completion(): +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion(model: str, model_lib_path: str): # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -101,16 +119,9 @@ def test_chat_completion(): del engine -def test_chat_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -144,16 +155,9 @@ def test_chat_completion_non_stream(): del engine -def test_completion(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -188,16 +192,9 @@ def test_completion(): del engine -def test_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -232,8 +229,9 @@ def test_completion_non_stream(): if __name__ == "__main__": - test_engine_generate() - test_chat_completion() - test_chat_completion_non_stream() - test_completion() - test_completion_non_stream() + for model, model_lib_path in test_models: + test_engine_generate(model, model_lib_path) + test_chat_completion(model, model_lib_path) + test_chat_completion_non_stream(model, model_lib_path) + test_completion(model, model_lib_path) + test_completion_non_stream(model, model_lib_path) diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 7f2a33b230..b764c62cd2 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -7,9 +7,9 @@ import pytest from pydantic import BaseModel -from mlc_llm.serve import AsyncLLMEngine, GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine prompts_list = [ "Generate a JSON string containing 20 objects:", @@ -22,7 +22,7 @@ def test_batch_generation_with_grammar(): # Create engine - engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt_len = len(prompts_list) prompts = prompts_list * 3 @@ -69,7 +69,7 @@ def test_batch_generation_with_grammar(): def test_batch_generation_with_schema(): # Create engine - engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt = ( "Generate a json containing three fields: an integer field named size, a " @@ -121,7 +121,7 @@ class Schema(BaseModel): async def run_async_engine(): # Create engine - async_engine = AsyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + async_engine = AsyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompts = prompts_list * 20 @@ -142,7 +142,7 @@ async def run_async_engine(): ] async def generate_task( - async_engine: AsyncLLMEngine, + async_engine: AsyncMLCEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index ff64e7235b..59e8c97196 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -2,7 +2,7 @@ from pathlib import Path from mlc_llm.serve import GenerationConfig, data -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine def get_test_image(config) -> data.ImageData: @@ -13,7 +13,7 @@ def test_engine_generate(): # Create engine model = "dist/llava-1.5-7b-hf-q4f16_1-MLC/params" model_lib_path = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 60be02ce1a..33c06b1c5e 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -11,7 +11,7 @@ SpeculativeMode, data, ) -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine prompts = [ "What is the meaning of life?", @@ -90,7 +90,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -158,7 +158,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -242,7 +242,7 @@ def step(self) -> None: "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -328,7 +328,7 @@ def step(self) -> None: "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -364,7 +364,19 @@ def step(self) -> None: # assert fin_time == request.generation_config.max_tokens - 1 -def test_engine_generate(): +def compare_output_text(output_text1, output_text2): + if isinstance(output_text1, list) and isinstance(output_text2, list): + for item1, item2 in zip(output_text1, output_text2): + if not compare_output_text(item1, item2): + return False + elif output_text1 != output_text2: + print(output_text1) + print(output_text2) + return False + return True + + +def test_engine_generate(compare_precision=False): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" @@ -372,7 +384,8 @@ def test_engine_generate(): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - engine = SyncLLMEngine( + + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -385,9 +398,31 @@ def test_engine_generate(): max_tokens = 256 # Generate output. - output_texts, _ = engine.generate( - prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=3) - ) + if compare_precision: + print("compare precision") + generation_config = GenerationConfig( + temperature=0.0, top_p=0, max_tokens=1024, stop_token_ids=[2], n=1 + ) + engine_single_model = SyncMLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + output_texts_single_model, _ = engine_single_model.generate( + prompts[:num_requests], generation_config + ) + for req_id, outputs in enumerate(output_texts_single_model): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + # TODO: Add pytorch precision + else: + generation_config = GenerationConfig(max_tokens=max_tokens, n=3) + output_texts, _ = engine.generate(prompts[:num_requests], generation_config) for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") if len(outputs) == 1: @@ -395,6 +430,12 @@ def test_engine_generate(): else: for i, output in enumerate(outputs): print(f"Output {req_id}({i}):{output}\n") + if compare_precision: + precision_flag = compare_output_text(output_texts, output_texts_single_model) + if precision_flag: + print(f"Accuracy verification succeed\n") + else: + print(f"Accuracy verification failed\n") def test_engine_eagle_generate(): @@ -405,7 +446,7 @@ def test_engine_eagle_generate(): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -453,7 +494,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -525,7 +566,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # small_model_lib_path = ( # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so" # ) - spec_engine = SyncLLMEngine( + spec_engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -595,7 +636,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" ) - spec_engine = SyncLLMEngine( + spec_engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -643,7 +684,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): test_engine_eagle_basic() test_engine_continuous_batching_1() test_engine_eagle_continuous_batching_1() - test_engine_generate() + test_engine_generate(compare_precision=True) test_engine_eagle_generate() test_engine_efficiency() test_engine_spec_efficiency() diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index c5d521b02d..f68f48b7c5 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -5,7 +5,7 @@ import numpy as np from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine prompts = [ "What is the meaning of life?", @@ -80,7 +80,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -156,7 +156,7 @@ def step(self) -> None: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -237,7 +237,7 @@ def step(self) -> None: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -323,7 +323,7 @@ def all_finished(self) -> bool: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -365,7 +365,7 @@ def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/web/emcc/mlc_wasm_runtime.cc b/web/emcc/mlc_wasm_runtime.cc index 3f05eb259f..b9a7f55bfa 100644 --- a/web/emcc/mlc_wasm_runtime.cc +++ b/web/emcc/mlc_wasm_runtime.cc @@ -29,6 +29,8 @@ // Pass in COMPILE_MLC_WASM_RUNTIME so unsupported code would not be compiled in to the .bc file #define COMPILE_MLC_WASM_RUNTIME 1 +#define __STDC_FORMAT_MACROS 1 +#define PICOJSON_USE_INT64 #define DMLC_USE_LOGGING_LIBRARY @@ -38,4 +40,5 @@ #include "serve/grammar/grammar_serializer.cc" #include "serve/grammar/grammar_simplifier.cc" #include "serve/grammar/grammar_state_matcher.cc" +#include "serve/grammar/json_schema_converter.cc" #include "support/encoding.cc"