Skip to content

Commit

Permalink
Unify schema for conversation template and embed into mlc-chat-config…
Browse files Browse the repository at this point in the history
….json
  • Loading branch information
rickzx committed Mar 15, 2024
1 parent c0b2ccd commit d31099b
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 45 deletions.
File renamed without changes.
131 changes: 130 additions & 1 deletion cpp/conversation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,129 @@ namespace llm {
void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();

if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"name\" not found.";
}

if (config.count("system_template") && config.count("system_message")) {
std::string system_placeholder = "{system_message}";
CHECK(config["system_template"].is<std::string>()) << "Invalid system template" << err_templ;
CHECK(config["system_message"].is<std::string>()) << "Invalid system message" << err_templ;
std::string system_template = config["system_template"].get<std::string>();
std::string system_msg = config["system_message"].get<std::string>();
std::string system = system_template.replace(system_template.find(system_placeholder),
system_placeholder.length(), system_msg);
this->system = system;
} else {
CHECK(partial_update) << "Key \"system_template\" or \"system_message\" not found.";
}

if (config.count("system_prefix_token_ids")) {
CHECK(config["system_prefix_token_ids"].is<picojson::array>()) << "Invalid system_prefix_token_ids" << err_templ;
picojson::array prefix_tokens_arr = config["system_prefix_token_ids"].get<picojson::array>();
std::vector<int32_t> prefix_tokens;
for (const picojson::value& prefix_token : prefix_tokens_arr) {
CHECK(prefix_token.is<int64_t>()) << "Invalid prefix_tokens" << err_templ;
prefix_tokens.push_back(prefix_token.get<int64_t>());
}
this->prefix_tokens = prefix_tokens;
}

if (config.count("roles")) {
CHECK(config["roles"].is<picojson::object>()) << "Invalid roles" << err_templ;
picojson::object roles_json = config["roles"].get<picojson::object>();
std::vector<std::string> roles(2);
for (auto [role, role_name] : roles_json) {
CHECK(role_name.is<std::string>());
if (role == "user") {
roles.at(0) = role_name.get<std::string>();
}
if (role == "assistant") {
roles.at(1) = role_name.get<std::string>();
}
}
this->roles = roles;
}

if (config.count("messages")) {
CHECK(config["messages"].is<picojson::array>()) << "Invalid messages" << err_templ;
std::vector<std::vector<std::string>> messages;
picojson::array msgs_arr = config["messages"].get<picojson::array>();
for (const picojson::value& msgs_i : msgs_arr) {
CHECK(msgs_i.is<picojson::array>()) << "Invalid messages" << err_templ;
picojson::array msgs_i_arr = msgs_i.get<picojson::array>();
std::vector<std::string> messages_i;
for (const picojson::value& msg_v : msgs_i_arr) {
CHECK(msg_v.is<std::string>()) << "Invalid messages" << err_templ;
messages_i.push_back(msg_v.get<std::string>());
}
messages.push_back(messages_i);
}
this->messages = messages;
this->offset = messages.size();
} else {
this->offset = 0;
}

if (config.count("seps")) {
std::vector<std::string> seps;
CHECK(config["seps"].is<picojson::array>()) << "Invalid seps" << err_templ;
picojson::array seps_arr = config["seps"].get<picojson::array>();
for (const picojson::value& sep : seps_arr) {
CHECK(sep.is<std::string>()) << "Invalid seps" << err_templ;
seps.push_back(sep.get<std::string>());
}
this->seps = seps;
} else {
CHECK(partial_update) << "Key \"seps\" not found.";
}

if (config.count("role_content_sep")) {
CHECK(config["role_content_sep"].is<std::string>()) << "Invalid role_content_sep" << err_templ;
this->role_msg_sep = config["role_content_sep"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"role_msg_sep\" not found.";
}
if (config.count("role_empty_sep")) {
CHECK(config["role_empty_sep"].is<std::string>()) << "Invalid role_empty_sep" << err_templ;
this->role_empty_sep = config["role_empty_sep"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"role_empty_sep\" not found.";
}

if (config.count("stop_str")) {
CHECK(config["stop_str"].is<picojson::array>()) << "Invalid stop_str" << err_templ;
picojson::array stop_str_arr = config["stop_str"].get<picojson::array>();
if (stop_str_arr.size() >= 1) {
picojson::value stop_str = stop_str_arr.at(0);
CHECK(stop_str.is<std::string>());
this->stop_str = stop_str.get<std::string>();
}
} else {
CHECK(partial_update) << "Key \"stop_str\" not found.";
}

if (config.count("stop_token_ids")) {
CHECK(config["stop_token_ids"].is<picojson::array>()) << "Invalid stop_token_ids" << err_templ;
picojson::array stop_tokens_arr = config["stop_token_ids"].get<picojson::array>();
std::vector<int32_t> stop_tokens;
for (const picojson::value& stop_token : stop_tokens_arr) {
CHECK(stop_token.is<int64_t>()) << "Invalid stop_tokens" << err_templ;
stop_tokens.push_back(stop_token.get<int64_t>());
}
this->stop_tokens = stop_tokens;
} else {
CHECK(partial_update) << "Key \"stop_token_ids\" not found.";
}
}

void Conversation::LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();
if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
Expand Down Expand Up @@ -134,7 +257,13 @@ void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_
LOG(FATAL) << err;
return;
}
LoadJSONOverride(config_json, partial_update);

picojson::object config = config_json.get<picojson::object>();
try {
LoadJSONOverride(config_json, partial_update);
} catch (...) {
LoadJSONOverrideLegacy(config_json, partial_update);
}
}

picojson::value Conversation::SerializeToJSON() const {
Expand Down
12 changes: 12 additions & 0 deletions cpp/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ class Conversation {
*/
void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false);

/*!
* \brief Load legacy JSON config and overrides options.
*
* \param config_json A json config in picojson type that is partially specifies
* some of the options.
* \param partial_update Whether it's a partial update or full update, if set to true,
* we perform a partial update on some of the provided options; if set to false, all
* options must be provided.
* \note DEPRECATED. This function loads the legacy JSON config value.
*/
void LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update = false);

/*!
* \brief Serialize the Conversation to JSON.
* \return Serialized conversion in JSON format.
Expand Down
24 changes: 19 additions & 5 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,16 +558,30 @@ class LLMChat {
CHECK(partial_update) << "Key \"shift_fill_factor\" not found.";
}
if (config.count("conv_template")) {
ICHECK(config["conv_template"].is<std::string>());
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
if (config["conv_template"].is<picojson::object>()) {
this->conversation_.LoadJSONOverride(config["conv_template"], false);
} else {
ICHECK(config["conv_template"].is<std::string>());
LOG(WARNING) << "Legacy conversation template detected. It will be deprecated in the future. "
"Please regenerate mlc-chat-config.json with the latest version";
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
}
if (config.count("conv_config")) {
// conv_config can override conv_template
this->conversation_.LoadJSONOverride(config["conv_config"], true);
try {
this->conversation_.LoadJSONOverride(config["conv_config"], true);
} catch (...) {
this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], true);
}
}
} else if (config.count("conv_config")) {
// without conv template, conv_config needs to be a complete config
this->conversation_.LoadJSONOverride(config["conv_config"], false);
try {
this->conversation_.LoadJSONOverride(config["conv_config"], false);
} catch (...) {
this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], true);
}
} else {
CHECK(partial_update) << "Key \"conv_template\" and \"conv_config\" not found.";
}
Expand Down
88 changes: 54 additions & 34 deletions python/mlc_llm/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tvm
from tvm.runtime import disco # pylint: disable=unused-import

from mlc_llm.protocol.conversation_protocol import Conversation
from mlc_llm.support import logging
from mlc_llm.support.auto_device import detect_device
from mlc_llm.support.config import ConfigBase
Expand Down Expand Up @@ -44,54 +45,61 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes
Since the configuration is partial, everything will be ``Optional``.
The parameters are the same as :class:`mlc_llm.protocol.conversation_protocol.Conversation`
Parameters
----------
name : Optional[str]
Name of the conversation.
system : Optional[str]
The prompt encoded before starting the chat.
roles : Optional[List[str]]
An array that describes the role names of the user and the model. These
names are specific to the model being used.
messages : Optional[List[List[str]]]
The chat history represented as an array of string pairs in the following
format: ``[[role_0, msg_0], [role_1, msg_1], ...]``.
offset : Optional[int]
The offset used to begin the chat from the chat history. When offset
is not ``0``, ``messages[0:offset-1]`` will be encoded.
separator_style : Optional[int]
Specifies whether we are in chat-bot mode (``0``) or pure LM prompt mode (``1``).
system_template : Optional[str]
The system prompt template, it optionally contains the system
message placeholder, and the placeholder will be replaced with
the system message below.
system_message : Optional[str]
The content of the system prompt (without the template format).
system_prefix_token_ids : Optional[List[int]]
The system token ids to be prepended at the beginning of tokenized
generated prompt.
roles : Optional[Dict[str, str]]
The conversation roles
role_templates : Optional[Dict[str, str]]
The roles prompt template, it optionally contains the defaults
message placeholders and will be replaced by actual content
messages : Optional[List[Tuple[str, Optional[str]]]]
The conversation history messages.
Each message is a pair of strings, denoting "(role, content)".
The content can be None.
seps : Optional[List[str]]
An array of strings indicating the separators to be used after a user
message and a model message respectively.
role_msg_sep : Optional[str]
A string indicating the separator between a role and a message.
role_content_sep : Optional[str]
The separator between the role and the content in a message.
role_empty_sep : Optional[str]
A string indicating the separator to append to a role when there is no message yet.
stop_str : Optional[str]
The separator between the role and empty contents.
stop_str : Optional[List[str]]
When the ``stop_str`` is encountered, the model will stop generating output.
stop_tokens : Optional[List[int]]
stop_token_ids : Optional[List[int]]
A list of token IDs that act as stop tokens.
prefix_tokens : Optional[List[int]]
Token list prefixing the conversation.
add_bos : Optional[bool]
Determines whether a beginning-of-string (bos) token should be added
before the input tokens.
function_string : Optional[str]
The function calling string.
use_function_calling : Optional[bool]
Whether using function calling or not, helps check for output message format in API call.
"""

name: Optional[str] = None
system: Optional[str] = None
roles: Optional[List[str]] = None
messages: Optional[List[List[str]]] = None
offset: Optional[int] = None
separator_style: Optional[int] = None
system_template: Optional[str] = None
system_message: Optional[str] = None
system_prefix_token_ids: Optional[List[int]] = None
roles: Optional[Dict[str, str]] = None
role_templates: Optional[Dict[str, str]] = None
messages: Optional[List[Tuple[str, Optional[str]]]] = None
seps: Optional[List[str]] = None
role_msg_sep: Optional[str] = None
role_content_sep: Optional[str] = None
role_empty_sep: Optional[str] = None
stop_str: Optional[str] = None
stop_tokens: Optional[List[int]] = None
prefix_tokens: Optional[List[int]] = None
add_bos: Optional[bool] = None
stop_str: Optional[List[str]] = None
stop_token_ids: Optional[List[int]] = None
function_string: Optional[str] = None
use_function_calling: Optional[bool] = None

def __post_init__(self):
if self.messages is not None and self.offset is None:
Expand Down Expand Up @@ -192,7 +200,7 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes

model_lib: Optional[str] = None
local_id: Optional[str] = None
conv_template: Optional[str] = None
conv_template: Optional[str | Conversation] = None
temperature: Optional[float] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
Expand All @@ -217,6 +225,8 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes

@classmethod
def _from_json(cls, json_obj: dict):
if "conv_template" in json_obj and isinstance(json_obj["conv_template"], dict):
json_obj["conv_template"] = Conversation.from_json_dict(json_obj["conv_template"])
return cls(**{k: v for k, v in json_obj.items() if k in inspect.signature(cls).parameters})


Expand Down Expand Up @@ -440,6 +450,13 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi
"override the full model library path instead."
)
warnings.warn(warn_msg)
elif field_name == "conv_template" and isinstance(field_value, Conversation):
warn_msg = (
'WARNING: Do not override "conv_template" in ChatConfig. '
'Please override "conv_config" instead.'
"This override will be ignored."
)
warnings.warn(warn_msg)
else:
setattr(final_chat_config, field_name, field_value)
return final_chat_config
Expand Down Expand Up @@ -613,6 +630,9 @@ def _convert_chat_config_to_json_str(
conv_dict[conv_k] = conv_v
chat_dict[key] = conv_dict
continue
if key == "conv_template" and isinstance(value, Conversation):
chat_dict[key] = Conversation.to_json_dict(value)
continue
if value is not None:
chat_dict[key] = value

Expand Down
16 changes: 14 additions & 2 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mlc_llm.quantization import Quantization
from mlc_llm.support import convert_tiktoken, logging
from mlc_llm.support.style import bold, green, red
from mlc_llm.conversation_template import ConvTemplateRegistry

from .compiler_flags import ModelConfigOverride

Expand Down Expand Up @@ -45,7 +46,7 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes
repetition_penalty: float = None
top_p: float = None
# Conversation template
conv_template: str = None
conv_template: str | Dict[str, Any] = None
pad_token_id: int = None
bos_token_id: int = None
eos_token_id: int = None
Expand Down Expand Up @@ -89,6 +90,17 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
):
"""Entrypoint of MLC Chat configuration generation."""
# Step 1. Initialize `mlc-chat-config.json` using `config.json`
conversation = ConvTemplateRegistry.get_conv_template(conv_template)
if conversation is None:
logger.warning(
"%s: Conversation template is not registered in ConvTemplateRegistry: %s",
red("Warning"),
conv_template,
)
conversation = conv_template
else:
conversation = conversation.to_json_dict()

model_config = ModelConfigOverride(
context_window_size=context_window_size,
sliding_window_size=sliding_window_size,
Expand All @@ -107,7 +119,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
prefill_chunk_size=model_config.prefill_chunk_size,
attention_sink_size=getattr(model_config, "attention_sink_size", -1),
tensor_parallel_shards=model_config.tensor_parallel_shards,
conv_template=conv_template,
conv_template=conversation,
)
# Step 2. Load `generation_config.json` and `config.json` for text-generation related configs
for generation_config_filename in ["generation_config.json", "config.json"]:
Expand Down
Loading

0 comments on commit d31099b

Please sign in to comment.