diff --git a/tests/python/conftest.py b/conftest.py similarity index 100% rename from tests/python/conftest.py rename to conftest.py diff --git a/cpp/conversation.cc b/cpp/conversation.cc index a3a432397a..d05021dc6c 100644 --- a/cpp/conversation.cc +++ b/cpp/conversation.cc @@ -11,6 +11,130 @@ namespace llm { void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) { std::string err_templ = " in conversion template json file."; picojson::object config = config_json.get(); + + if (config.count("name")) { + CHECK(config["name"].is()) << "Invalid name" << err_templ; + this->name = config["name"].get(); + } else { + CHECK(partial_update) << "Key \"name\" not found."; + } + + if (config.count("system_template") && config.count("system_message")) { + std::string system_placeholder = "{system_message}"; + CHECK(config["system_template"].is()) << "Invalid system template" << err_templ; + CHECK(config["system_message"].is()) << "Invalid system message" << err_templ; + std::string system_template = config["system_template"].get(); + std::string system_msg = config["system_message"].get(); + std::string system = system_template.replace(system_template.find(system_placeholder), + system_placeholder.length(), system_msg); + this->system = system; + } else { + CHECK(partial_update) << "Key \"system_template\" or \"system_message\" not found."; + } + + if (config.count("system_prefix_token_ids")) { + CHECK(config["system_prefix_token_ids"].is()) + << "Invalid system_prefix_token_ids" << err_templ; + picojson::array prefix_tokens_arr = config["system_prefix_token_ids"].get(); + std::vector prefix_tokens; + for (const picojson::value& prefix_token : prefix_tokens_arr) { + CHECK(prefix_token.is()) << "Invalid prefix_tokens" << err_templ; + prefix_tokens.push_back(prefix_token.get()); + } + this->prefix_tokens = prefix_tokens; + } + + if (config.count("roles")) { + CHECK(config["roles"].is()) << "Invalid roles" << err_templ; + picojson::object roles_json = config["roles"].get(); + std::vector roles(2); + for (auto [role, role_name] : roles_json) { + CHECK(role_name.is()); + if (role == "user") { + roles.at(0) = role_name.get(); + } + if (role == "assistant") { + roles.at(1) = role_name.get(); + } + } + this->roles = roles; + } + + if (config.count("messages")) { + CHECK(config["messages"].is()) << "Invalid messages" << err_templ; + std::vector> messages; + picojson::array msgs_arr = config["messages"].get(); + for (const picojson::value& msgs_i : msgs_arr) { + CHECK(msgs_i.is()) << "Invalid messages" << err_templ; + picojson::array msgs_i_arr = msgs_i.get(); + std::vector messages_i; + for (const picojson::value& msg_v : msgs_i_arr) { + CHECK(msg_v.is()) << "Invalid messages" << err_templ; + messages_i.push_back(msg_v.get()); + } + messages.push_back(messages_i); + } + this->messages = messages; + this->offset = messages.size(); + } else { + this->offset = 0; + } + + if (config.count("seps")) { + std::vector seps; + CHECK(config["seps"].is()) << "Invalid seps" << err_templ; + picojson::array seps_arr = config["seps"].get(); + for (const picojson::value& sep : seps_arr) { + CHECK(sep.is()) << "Invalid seps" << err_templ; + seps.push_back(sep.get()); + } + this->seps = seps; + } else { + CHECK(partial_update) << "Key \"seps\" not found."; + } + + if (config.count("role_content_sep")) { + CHECK(config["role_content_sep"].is()) << "Invalid role_content_sep" << err_templ; + this->role_msg_sep = config["role_content_sep"].get(); + } else { + CHECK(partial_update) << "Key \"role_msg_sep\" not found."; + } + if (config.count("role_empty_sep")) { + CHECK(config["role_empty_sep"].is()) << "Invalid role_empty_sep" << err_templ; + this->role_empty_sep = config["role_empty_sep"].get(); + } else { + CHECK(partial_update) << "Key \"role_empty_sep\" not found."; + } + + if (config.count("stop_str")) { + CHECK(config["stop_str"].is()) << "Invalid stop_str" << err_templ; + picojson::array stop_str_arr = config["stop_str"].get(); + if (stop_str_arr.size() >= 1) { + picojson::value stop_str = stop_str_arr.at(0); + CHECK(stop_str.is()); + this->stop_str = stop_str.get(); + } + } else { + CHECK(partial_update) << "Key \"stop_str\" not found."; + } + + if (config.count("stop_token_ids")) { + CHECK(config["stop_token_ids"].is()) << "Invalid stop_token_ids" << err_templ; + picojson::array stop_tokens_arr = config["stop_token_ids"].get(); + std::vector stop_tokens; + for (const picojson::value& stop_token : stop_tokens_arr) { + CHECK(stop_token.is()) << "Invalid stop_tokens" << err_templ; + stop_tokens.push_back(stop_token.get()); + } + this->stop_tokens = stop_tokens; + } else { + CHECK(partial_update) << "Key \"stop_token_ids\" not found."; + } +} + +void Conversation::LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update) { + std::string err_templ = " in conversion template json file."; + picojson::object config = config_json.get(); if (config.count("name")) { CHECK(config["name"].is()) << "Invalid name" << err_templ; this->name = config["name"].get(); @@ -134,7 +258,13 @@ void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_ LOG(FATAL) << err; return; } - LoadJSONOverride(config_json, partial_update); + + picojson::object config = config_json.get(); + try { + LoadJSONOverride(config_json, partial_update); + } catch (...) { + LoadJSONOverrideLegacy(config_json, partial_update); + } } picojson::value Conversation::SerializeToJSON() const { diff --git a/cpp/conversation.h b/cpp/conversation.h index 14cbd44149..7a75e8748a 100644 --- a/cpp/conversation.h +++ b/cpp/conversation.h @@ -154,6 +154,18 @@ class Conversation { */ void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false); + /*! + * \brief Load legacy JSON config and overrides options. + * + * \param config_json A json config in picojson type that is partially specifies + * some of the options. + * \param partial_update Whether it's a partial update or full update, if set to true, + * we perform a partial update on some of the provided options; if set to false, all + * options must be provided. + * \note DEPRECATED. This function loads the legacy JSON config value. + */ + void LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update = false); + /*! * \brief Serialize the Conversation to JSON. * \return Serialized conversion in JSON format. diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 5577f9b87d..09c2ce9a37 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -558,16 +558,31 @@ class LLMChat { CHECK(partial_update) << "Key \"shift_fill_factor\" not found."; } if (config.count("conv_template")) { - ICHECK(config["conv_template"].is()); - std::string conv_template = config["conv_template"].get(); - this->conversation_ = Conversation::FromTemplate(conv_template); + if (config["conv_template"].is()) { + this->conversation_.LoadJSONOverride(config["conv_template"], false); + } else { + ICHECK(config["conv_template"].is()); + LOG(WARNING) + << "Legacy conversation template detected. It will be deprecated in the future. " + "Please regenerate mlc-chat-config.json with the latest version"; + std::string conv_template = config["conv_template"].get(); + this->conversation_ = Conversation::FromTemplate(conv_template); + } if (config.count("conv_config")) { // conv_config can override conv_template - this->conversation_.LoadJSONOverride(config["conv_config"], true); + try { + this->conversation_.LoadJSONOverride(config["conv_config"], true); + } catch (...) { + this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], true); + } } } else if (config.count("conv_config")) { // without conv template, conv_config needs to be a complete config - this->conversation_.LoadJSONOverride(config["conv_config"], false); + try { + this->conversation_.LoadJSONOverride(config["conv_config"], false); + } catch (...) { + this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], false); + } } else { CHECK(partial_update) << "Key \"conv_template\" and \"conv_config\" not found."; } diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst index d5edcf82aa..38cdec2f85 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python.rst @@ -184,7 +184,7 @@ We provide an example below. # Using a `ConvConfig`, we modify `system`, a field in the conversation template # `system` refers to the prompt encoded before starting the chat - conv_config = ConvConfig(system='Please show as much happiness as you can when talking to me.') + conv_config = ConvConfig(system_message='Please show as much happiness as you can when talking to me.') # We then include the `ConvConfig` instance in `ChatConfig` while overriding `max_gen_len` # Note that `conv_config` is an optional subfield of `chat_config` diff --git a/docs/get_started/mlc_chat_config.rst b/docs/get_started/mlc_chat_config.rst index ccaa97b4fc..482e68d368 100644 --- a/docs/get_started/mlc_chat_config.rst +++ b/docs/get_started/mlc_chat_config.rst @@ -52,14 +52,21 @@ Below is the ``mlc-chat-config.json`` file corresponding to Llama2 model: "tokenizer_config.json" ] - // 3. Chat related fields that affect runtime behavior + // 3. Conversation template related fields + "conv_template": { + "name": "llama-2", + "system_template": "[INST] <>\n{system_message}\n<>\n\n ", + "system_message": "You are a helpful, respectful and honest assistant.", + // more fields here... + }, + + // 4. Chat related fields that affect runtime behavior "mean_gen_len": 128, "max_gen_len": 512, "shift_fill_factor": 0.3, "temperature": 0.6, "repetition_penalty": 1.0, - "top_p": 0.9, - "conv_template": "llama-2", + "top_p": 0.9 } .. note:: @@ -70,7 +77,11 @@ Below is the ``mlc-chat-config.json`` file corresponding to Llama2 model: can be customized to change the behavior of the model.** ``conv_template`` - The name of the conversation template that this chat uses. For more information, please refer to :ref:`conversation structure `. + .. note:: + Legacy ``mlc-chat-config.json`` may specify a string for this field to look up a registered conversation + template. It will be deprecated in the future. Re-generate config using the latest version of mlc_llm + to make sure this field is a complete JSON object. + The conversation template that this chat uses. For more information, please refer to :ref:`conversation structure `. ``temperature`` The temperature applied to logits before sampling. The default value is ``0.7``. A higher temperature encourages more diverse outputs, while a lower temperature produces more deterministic outputs. @@ -99,32 +110,17 @@ can be customized to change the behavior of the model.** Conversation Structure ^^^^^^^^^^^^^^^^^^^^^^ -There are three options of loading conversation configurations: - -1. Load from pre-defined conversation templates. -2. Load from JSON format conversation configuration. -3. First load from pre-defined conversation templates, then override some fields with JSON format conversation configuration. - -.. _load-predefined-conv-template: - -Load from Pre-defined Conversation Templates --------------------------------------------- - -MLC-LLM provided a set of pre-defined conversation templates, which you can directly use by specifying the template name in ``conv_template`` field in the ``mlc-chat-config.json``, below is a list (not complete) of supported conversation templates: +MLC-LLM provided a set of pre-defined conversation templates, which you can directly use by +specifying ``--conv-template [name]`` when generating config. Below is a list (not complete) of +supported conversation templates: - ``llama-2`` -- ``vicuna_v1.1`` -- ``redpajama_chat`` -- ``rwkv`` -- ``dolly`` +- ``mistral_default`` +- ``chatml`` +- ``phi-2`` - ... -Please refer to `conv_template.cc `_ for the full list of supported templates and their implementations. - -.. _load-json-conv-config: - -Load from JSON Conversation Configuration ------------------------------------------ +Please refer to `conversation_template.py `_ for the full list of supported templates and their implementations. Below is a generic structure of a JSON conversation configuration (we use vicuna as an example): @@ -133,122 +129,81 @@ Below is a generic structure of a JSON conversation configuration (we use vicuna // mlc-chat-config.json { // ... - "conv_config": { + "conv_template": { + "name": "llama-2", + "system_template": "[INST] <>\n{system_message}\n<>\n\n ", + "system_message": "You are a helpful, respectful and honest assistant.", + "roles": { + "user": "[INST]", + "assistant": "[/INST]", + "tool": "[INST]" + }, + "role_templates": { + "user": "{user_message}", + "assistant": "{assistant_message}", + "tool": "{tool_message}" + }, + "messages": [], "seps": [ - " ", - "<\/s>" + " " ], - "stop_tokens": [ - 2 + "role_content_sep": " ", + "role_empty_sep": " ", + "stop_str": [ + "[INST]" ], - "offset": 0, - "separator_style": 0, - "messages": [], - "stop_str": "<\/s>", - "roles": [ - "USER", - "ASSISTANT" + "stop_token_ids": [ + 2 ], - "role_msg_sep": ": ", - "role_empty_sep": ": ", - "system": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", - "add_bos": true, - "name": "vicuna_v1.1" + "function_string": "", + "use_function_calling": false } } +``name`` + Name of the conversation. +``system_template`` + The system prompt template, it optionally contains the system + message placeholder, and the placeholder will be replaced with + the system message below. +``system_message`` + The content of the system prompt (without the template format). +``system_prefix_token_ids`` + The system token ids to be prepended at the beginning of tokenized + generated prompt. ``roles`` - An array that describes the role names of the user and the model. These names are specific to the model being used. -``system`` - The prompt encoded before starting the chat. It can be customized to a user-defined prompt. -``add_bos`` - Determines whether a beginning-of-string (bos) token should be added before the input tokens. -``stop_str`` - When the ``stop_str`` is encountered, the model will stop generating output. -``stop_tokens`` - A list of token IDs that act as stop tokens. -``seps`` - An array of strings indicating the separators to be used after a user message and a model message respectively. + The conversation roles +``role_templates`` + The roles prompt template, it optionally contains the defaults + message placeholders and will be replaced by actual content ``messages`` - The chat history represented as an array of string pairs in the following format: ``[[role_0, msg_0], [role_1, msg_1], ...]`` -``offset`` - The offset used to begin the chat from the chat history. When ``offset`` is not ``0``, ``messages[0:offset-1]`` will be encoded. -``separator_style`` - Specifies whether we are in chat-bot mode (``0``) or pure LM prompt mode (``1``). -``role_msg_sep`` - A string indicating the separator between a role and a message. + The conversation history messages. + Each message is a pair of strings, denoting "(role, content)". + The content can be None. +``seps`` + An array of strings indicating the separators to be used after a user + message and a model message respectively. +``role_content_sep`` + The separator between the role and the content in a message. ``role_empty_sep`` - A string indicating the separator to append to a role when there is no message yet. - - -When the value of ``separator_style`` is set to 0 (or ``kSepRoleMsg``), each round of conversation follows the format: - -.. code:: text - - {role[0]}{separator_style}{user_input}{sep[0]} - {role[1]}{separator_style}{model_output}{sep[1]} - -Here, ``{user_input}`` represents the input provided by the user, and ``{model_output}`` represents the output generated by the model. + The separator between the role and empty contents. +``stop_str`` + When the ``stop_str`` is encountered, the model will stop generating output. +``stop_token_ids`` + A list of token IDs that act as stop tokens. +``function_string`` + The function calling string. +``use_function_calling`` + Whether using function calling or not, helps check for output message format in API call. -On the other hand, if the value of ``separator_style`` is set to 1 (or ``kLM``), the model is not aware of the chat history and generates the response immediately after the user input prompt: +Given a conversation template, the corresponding prompt generated out +from it is in the following format: .. code:: text - {user_prompt}{model_output} - - -.. _customize-conv-template: - -Customize Conversation Template -------------------------------- - -In the ``mlc-chat-config.json`` file, you have the option to specify both ``conv_template`` and ``conv_config``. MLC-LLM will first load the predefined template with the name specified in ``conv_template`` and then override some of the configurations specified in ``conv_config``. It's important to note that the configurations in ``conv_config`` don't need to be complete, allowing for partial updates. - -.. _example_replace_system_prompt: - -Example 1: Replace System Prompt -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -If you're tired of the default system prompt, here's an example of how you can replace it: - -.. code:: json - - // mlc-chat-config.json - { - // ... - "conv_template": "vicuna_v1.1", - "conv_config": { - "system": "You are not Vicuna, your name is Guanaco, now let's chat!" - } - } - - -The next time you run ``mlc_llm`` CLI, you will start a chat with Vicuna using a new system prompt. - -.. _example_resume_chat_history: - -Example 2: Resume from Chat History -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The following example demonstrates how to chat with Vicuna and resume from a chat history: - -.. code:: json - - // mlc-chat-config.json - { - // ... - "conv_template": "vicuna_v1.1", - "conv_config": { - "messages": [ - ["USER", "Suppose we already have projects llama, alpaca and vicuna, what do you think would be a great name for the next project?"], - ["ASSISTANT", "Based on the previous projects, a possible name for the next project could be \"cervidae\" which is the scientific name for deer family. This name reflects the collaboration and teamwork involved in the development of the project, and also nods to the previous projects that have been developed by the team."], - ["USER", "I like cervidae, but the name is too long!"], - ["ASSISTANT", "In that case, a shorter and catchier name for the next project could be \"DeerRun\" which plays on the idea of the project being fast and efficient, just like a deer running through the woods. This name is memorable and easy to pronounce, making it a good choice for a project name."] - ], - "offset": 4 - } - } - - -The next time you start ``mlc_llm`` CLI, or use Python API, you will initiate a chat with Vicuna and resume from the provided chat history. + <><><><><> + <><><><> + ... + <><><><> + <><> diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 675e1e7c94..18c3258514 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -16,6 +16,7 @@ import tvm from tvm.runtime import disco # pylint: disable=unused-import +from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.support import logging from mlc_llm.support.auto_device import detect_device from mlc_llm.support.config import ConfigBase @@ -44,58 +45,61 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes Since the configuration is partial, everything will be ``Optional``. + The parameters are the same as :class:`mlc_llm.protocol.conversation_protocol.Conversation` + Parameters ---------- name : Optional[str] Name of the conversation. - system : Optional[str] - The prompt encoded before starting the chat. - roles : Optional[List[str]] - An array that describes the role names of the user and the model. These - names are specific to the model being used. - messages : Optional[List[List[str]]] - The chat history represented as an array of string pairs in the following - format: ``[[role_0, msg_0], [role_1, msg_1], ...]``. - offset : Optional[int] - The offset used to begin the chat from the chat history. When offset - is not ``0``, ``messages[0:offset-1]`` will be encoded. - separator_style : Optional[int] - Specifies whether we are in chat-bot mode (``0``) or pure LM prompt mode (``1``). + system_template : Optional[str] + The system prompt template, it optionally contains the system + message placeholder, and the placeholder will be replaced with + the system message below. + system_message : Optional[str] + The content of the system prompt (without the template format). + system_prefix_token_ids : Optional[List[int]] + The system token ids to be prepended at the beginning of tokenized + generated prompt. + roles : Optional[Dict[str, str]] + The conversation roles + role_templates : Optional[Dict[str, str]] + The roles prompt template, it optionally contains the defaults + message placeholders and will be replaced by actual content + messages : Optional[List[Tuple[str, Optional[str]]]] + The conversation history messages. + Each message is a pair of strings, denoting "(role, content)". + The content can be None. seps : Optional[List[str]] An array of strings indicating the separators to be used after a user message and a model message respectively. - role_msg_sep : Optional[str] - A string indicating the separator between a role and a message. + role_content_sep : Optional[str] + The separator between the role and the content in a message. role_empty_sep : Optional[str] - A string indicating the separator to append to a role when there is no message yet. - stop_str : Optional[str] + The separator between the role and empty contents. + stop_str : Optional[List[str]] When the ``stop_str`` is encountered, the model will stop generating output. - stop_tokens : Optional[List[int]] + stop_token_ids : Optional[List[int]] A list of token IDs that act as stop tokens. - prefix_tokens : Optional[List[int]] - Token list prefixing the conversation. - add_bos : Optional[bool] - Determines whether a beginning-of-string (bos) token should be added - before the input tokens. + function_string : Optional[str] + The function calling string. + use_function_calling : Optional[bool] + Whether using function calling or not, helps check for output message format in API call. """ name: Optional[str] = None - system: Optional[str] = None - roles: Optional[List[str]] = None - messages: Optional[List[List[str]]] = None - offset: Optional[int] = None - separator_style: Optional[int] = None + system_template: Optional[str] = None + system_message: Optional[str] = None + system_prefix_token_ids: Optional[List[int]] = None + roles: Optional[Dict[str, str]] = None + role_templates: Optional[Dict[str, str]] = None + messages: Optional[List[Tuple[str, Optional[str]]]] = None seps: Optional[List[str]] = None - role_msg_sep: Optional[str] = None + role_content_sep: Optional[str] = None role_empty_sep: Optional[str] = None - stop_str: Optional[str] = None - stop_tokens: Optional[List[int]] = None - prefix_tokens: Optional[List[int]] = None - add_bos: Optional[bool] = None - - def __post_init__(self): - if self.messages is not None and self.offset is None: - self.offset = len(self.messages) + stop_str: Optional[List[str]] = None + stop_token_ids: Optional[List[int]] = None + function_string: Optional[str] = None + use_function_calling: Optional[bool] = None @dataclass @@ -192,7 +196,7 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes model_lib: Optional[str] = None local_id: Optional[str] = None - conv_template: Optional[str] = None + conv_template: Optional[Union[str, Conversation]] = None temperature: Optional[float] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 @@ -217,6 +221,8 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes @classmethod def _from_json(cls, json_obj: dict): + if "conv_template" in json_obj and isinstance(json_obj["conv_template"], dict): + json_obj["conv_template"] = Conversation.from_json_dict(json_obj["conv_template"]) return cls(**{k: v for k, v in json_obj.items() if k in inspect.signature(cls).parameters}) @@ -440,6 +446,13 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi "override the full model library path instead." ) warnings.warn(warn_msg) + elif field_name == "conv_template" and isinstance(field_value, Conversation): + warn_msg = ( + 'WARNING: Do not override "conv_template" in ChatConfig. ' + 'Please override "conv_config" instead.' + "This override will be ignored." + ) + warnings.warn(warn_msg) else: setattr(final_chat_config, field_name, field_value) return final_chat_config @@ -613,6 +626,9 @@ def _convert_chat_config_to_json_str( conv_dict[conv_k] = conv_v chat_dict[key] = conv_dict continue + if key == "conv_template" and isinstance(value, Conversation): + chat_dict[key] = Conversation.to_json_dict(value) + continue if value is not None: chat_dict[key] = value diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index f4d39aa8ba..4bce52aa20 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -4,8 +4,9 @@ import json import shutil from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union +from mlc_llm.conversation_template import ConvTemplateRegistry from mlc_llm.model import Model from mlc_llm.quantization import Quantization from mlc_llm.support import convert_tiktoken, logging @@ -45,7 +46,7 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes repetition_penalty: float = None top_p: float = None # Conversation template - conv_template: str = None + conv_template: Union[str, Dict[str, Any]] = None pad_token_id: int = None bos_token_id: int = None eos_token_id: int = None @@ -89,6 +90,17 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b ): """Entrypoint of MLC Chat configuration generation.""" # Step 1. Initialize `mlc-chat-config.json` using `config.json` + conversation_reg = ConvTemplateRegistry.get_conv_template(conv_template) + if conversation_reg is None: + logger.warning( + "%s: Conversation template is not registered in ConvTemplateRegistry: %s", + red("Warning"), + conv_template, + ) + conversation = conv_template # type: ignore + else: + conversation = conversation_reg.to_json_dict() # type: ignore + model_config = ModelConfigOverride( context_window_size=context_window_size, sliding_window_size=sliding_window_size, @@ -107,7 +119,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b prefill_chunk_size=model_config.prefill_chunk_size, attention_sink_size=getattr(model_config, "attention_sink_size", -1), tensor_parallel_shards=model_config.tensor_parallel_shards, - conv_template=conv_template, + conv_template=conversation, ) # Step 2. Load `generation_config.json` and `config.json` for text-generation related configs for generation_config_filename in ["generation_config.json", "config.json"]: diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 01c145db7d..fa99b95c16 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -1,7 +1,7 @@ """The standard conversation protocol in MLC LLM""" from enum import Enum -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar from pydantic import BaseModel, Field, field_validator @@ -17,6 +17,9 @@ class MessagePlaceholders(Enum): FUNCTION = "{function_string}" +T = TypeVar("T", bound="BaseModel") + + class Conversation(BaseModel): """Class that specifies the convention template of conversation and contains the conversation history. @@ -95,6 +98,15 @@ def check_message_seps(cls, seps: List[str]) -> List[str]: raise ValueError("seps should have size 1 or 2.") return seps + def to_json_dict(self) -> Dict[str, Any]: + """Convert to a json dictionary""" + return self.model_dump(exclude_none=True) + + @classmethod + def from_json_dict(cls: Type[T], json_dict: Dict[str, Any]) -> T: + """Convert from a json dictionary""" + return Conversation.model_validate(json_dict) + def as_prompt(self) -> str: """Convert the conversation template and history messages to a single prompt. diff --git a/tests/cpp/conv_unittest.cc b/tests/cpp/conv_unittest.cc index 98d01a58ba..d49c7107cd 100644 --- a/tests/cpp/conv_unittest.cc +++ b/tests/cpp/conv_unittest.cc @@ -1,6 +1,61 @@ #include #include +void _TestConversationLoadJSON() { + std::string conv_template = + "{\n" + " \"name\": \"test\",\n" + " \"system_template\": \"abc{system_message}\",\n" + " \"system_message\": \"de\",\n" + " \"roles\": {\n" + " \"user\": \"Instruct\",\n" + " \"assistant\": \"Output\",\n" + " \"tool\": \"Instruct\"\n" + " },\n" + " \"role_templates\": {\n" + " \"user\": \"{user_message}\",\n" + " \"assistant\": \"{assistant_message}\",\n" + " \"tool\": \"{tool_message}\"\n" + " },\n" + " \"messages\": [[\"Instruct\", \"Hello\"], [\"Output\", \"Hey\"]],\n" + " \"seps\": [\n" + " \"\\n\"\n" + " ],\n" + " \"role_content_sep\": \": \",\n" + " \"role_empty_sep\": \":\",\n" + " \"stop_str\": [\n" + " \"<|endoftext|>\"\n" + " ],\n" + " \"stop_token_ids\": [\n" + " 50256\n" + " ],\n" + " \"function_string\": \"\",\n" + " \"use_function_calling\": false\n" + "}"; + mlc::llm::Conversation conv; + conv.LoadJSONOverride(conv_template, true); + ASSERT_EQ(conv.name, "test"); + ASSERT_EQ(conv.system, "abcde"); + + std::vector expected_roles{"Instruct", "Output"}; + ASSERT_EQ(conv.roles, expected_roles); + + std::vector> expected_messages = {{"Instruct", "Hello"}, + {"Output", "Hey"}}; + ASSERT_EQ(conv.messages, expected_messages); + ASSERT_EQ(conv.offset, 2); + + std::vector expected_seps = {"\n"}; + ASSERT_EQ(conv.seps, expected_seps); + + ASSERT_EQ(conv.role_msg_sep, ": "); + ASSERT_EQ(conv.role_empty_sep, ":"); + ASSERT_EQ(conv.stop_str, "<|endoftext|>"); + + std::vector expected_stop_tokens = {50256}; + ASSERT_EQ(conv.stop_tokens, expected_stop_tokens); +} + void _TestConversationJSONRoundTrip(std::string templ_name) { mlc::llm::Conversation conv = mlc::llm::Conversation::FromTemplate(templ_name); std::string conv_json = conv.GetConfigJSON(); @@ -11,12 +66,14 @@ void _TestConversationJSONRoundTrip(std::string templ_name) { void _TestConversationPartialUpdate() { mlc::llm::Conversation conv; - std::string json_str = "{\"offset\": -1}"; + std::string json_str = "{\"name\": \"test\"}"; ASSERT_ANY_THROW(conv.LoadJSONOverride(json_str, false)); conv.LoadJSONOverride(json_str, true); - ASSERT_EQ(conv.offset, -1); + ASSERT_EQ(conv.name, "test"); } +TEST(ConversationTest, ConversationLoadJSONTest) { _TestConversationLoadJSON(); } + TEST(ConversationTest, ConversationJSONRoundTripTest) { _TestConversationJSONRoundTrip("vicuna_v1.1"); _TestConversationJSONRoundTrip("conv_one_shot"); diff --git a/tests/python/protocol/test_converation_protocol.py b/tests/python/protocol/test_converation_protocol.py new file mode 100644 index 0000000000..9656eb8b18 --- /dev/null +++ b/tests/python/protocol/test_converation_protocol.py @@ -0,0 +1,20 @@ +import pytest + +from mlc_llm.conversation_template import ConvTemplateRegistry +from mlc_llm.protocol.conversation_protocol import Conversation + + +def get_conv_templates(): + return ["llama-2", "mistral_default", "gorilla", "chatml", "phi-2"] + + +@pytest.mark.parametrize("conv_template_name", get_conv_templates()) +def test_json(conv_template_name): + template = ConvTemplateRegistry.get_conv_template(conv_template_name) + j = template.to_json_dict() + template_parsed = Conversation.from_json_dict(j) + assert template == template_parsed + + +if __name__ == "__main__": + test_json()