Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify schema for conversation template and embed into mlc-chat-config.json #1965

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
132 changes: 131 additions & 1 deletion cpp/conversation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,130 @@ namespace llm {
void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();

if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"name\" not found.";
}

if (config.count("system_template") && config.count("system_message")) {
std::string system_placeholder = "{system_message}";
CHECK(config["system_template"].is<std::string>()) << "Invalid system template" << err_templ;
CHECK(config["system_message"].is<std::string>()) << "Invalid system message" << err_templ;
std::string system_template = config["system_template"].get<std::string>();
std::string system_msg = config["system_message"].get<std::string>();
std::string system = system_template.replace(system_template.find(system_placeholder),
system_placeholder.length(), system_msg);
this->system = system;
} else {
CHECK(partial_update) << "Key \"system_template\" or \"system_message\" not found.";
}

if (config.count("system_prefix_token_ids")) {
CHECK(config["system_prefix_token_ids"].is<picojson::array>())
<< "Invalid system_prefix_token_ids" << err_templ;
picojson::array prefix_tokens_arr = config["system_prefix_token_ids"].get<picojson::array>();
std::vector<int32_t> prefix_tokens;
for (const picojson::value& prefix_token : prefix_tokens_arr) {
CHECK(prefix_token.is<int64_t>()) << "Invalid prefix_tokens" << err_templ;
prefix_tokens.push_back(prefix_token.get<int64_t>());
}
this->prefix_tokens = prefix_tokens;
}

if (config.count("roles")) {
CHECK(config["roles"].is<picojson::object>()) << "Invalid roles" << err_templ;
picojson::object roles_json = config["roles"].get<picojson::object>();
std::vector<std::string> roles(2);
for (auto [role, role_name] : roles_json) {
CHECK(role_name.is<std::string>());
if (role == "user") {
roles.at(0) = role_name.get<std::string>();
}
if (role == "assistant") {
roles.at(1) = role_name.get<std::string>();
}
}
this->roles = roles;
}

if (config.count("messages")) {
CHECK(config["messages"].is<picojson::array>()) << "Invalid messages" << err_templ;
std::vector<std::vector<std::string>> messages;
picojson::array msgs_arr = config["messages"].get<picojson::array>();
for (const picojson::value& msgs_i : msgs_arr) {
CHECK(msgs_i.is<picojson::array>()) << "Invalid messages" << err_templ;
picojson::array msgs_i_arr = msgs_i.get<picojson::array>();
std::vector<std::string> messages_i;
for (const picojson::value& msg_v : msgs_i_arr) {
CHECK(msg_v.is<std::string>()) << "Invalid messages" << err_templ;
messages_i.push_back(msg_v.get<std::string>());
}
messages.push_back(messages_i);
}
this->messages = messages;
this->offset = messages.size();
} else {
this->offset = 0;
}

if (config.count("seps")) {
std::vector<std::string> seps;
CHECK(config["seps"].is<picojson::array>()) << "Invalid seps" << err_templ;
picojson::array seps_arr = config["seps"].get<picojson::array>();
for (const picojson::value& sep : seps_arr) {
CHECK(sep.is<std::string>()) << "Invalid seps" << err_templ;
seps.push_back(sep.get<std::string>());
}
this->seps = seps;
} else {
CHECK(partial_update) << "Key \"seps\" not found.";
}

if (config.count("role_content_sep")) {
CHECK(config["role_content_sep"].is<std::string>()) << "Invalid role_content_sep" << err_templ;
this->role_msg_sep = config["role_content_sep"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"role_msg_sep\" not found.";
}
if (config.count("role_empty_sep")) {
CHECK(config["role_empty_sep"].is<std::string>()) << "Invalid role_empty_sep" << err_templ;
this->role_empty_sep = config["role_empty_sep"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"role_empty_sep\" not found.";
}

if (config.count("stop_str")) {
CHECK(config["stop_str"].is<picojson::array>()) << "Invalid stop_str" << err_templ;
picojson::array stop_str_arr = config["stop_str"].get<picojson::array>();
if (stop_str_arr.size() >= 1) {
picojson::value stop_str = stop_str_arr.at(0);
CHECK(stop_str.is<std::string>());
this->stop_str = stop_str.get<std::string>();
}
} else {
CHECK(partial_update) << "Key \"stop_str\" not found.";
}

if (config.count("stop_token_ids")) {
CHECK(config["stop_token_ids"].is<picojson::array>()) << "Invalid stop_token_ids" << err_templ;
picojson::array stop_tokens_arr = config["stop_token_ids"].get<picojson::array>();
std::vector<int32_t> stop_tokens;
for (const picojson::value& stop_token : stop_tokens_arr) {
CHECK(stop_token.is<int64_t>()) << "Invalid stop_tokens" << err_templ;
stop_tokens.push_back(stop_token.get<int64_t>());
}
this->stop_tokens = stop_tokens;
} else {
CHECK(partial_update) << "Key \"stop_token_ids\" not found.";
}
}

void Conversation::LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();
if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
Expand Down Expand Up @@ -134,7 +258,13 @@ void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_
LOG(FATAL) << err;
return;
}
LoadJSONOverride(config_json, partial_update);

picojson::object config = config_json.get<picojson::object>();
try {
LoadJSONOverride(config_json, partial_update);
} catch (...) {
LoadJSONOverrideLegacy(config_json, partial_update);
}
}

picojson::value Conversation::SerializeToJSON() const {
Expand Down
12 changes: 12 additions & 0 deletions cpp/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ class Conversation {
*/
void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false);

/*!
* \brief Load legacy JSON config and overrides options.
*
* \param config_json A json config in picojson type that is partially specifies
* some of the options.
* \param partial_update Whether it's a partial update or full update, if set to true,
* we perform a partial update on some of the provided options; if set to false, all
* options must be provided.
* \note DEPRECATED. This function loads the legacy JSON config value.
*/
void LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update = false);

/*!
* \brief Serialize the Conversation to JSON.
* \return Serialized conversion in JSON format.
Expand Down
25 changes: 20 additions & 5 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,16 +558,31 @@ class LLMChat {
CHECK(partial_update) << "Key \"shift_fill_factor\" not found.";
}
if (config.count("conv_template")) {
ICHECK(config["conv_template"].is<std::string>());
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
if (config["conv_template"].is<picojson::object>()) {
this->conversation_.LoadJSONOverride(config["conv_template"], false);
} else {
ICHECK(config["conv_template"].is<std::string>());
LOG(WARNING)
<< "Legacy conversation template detected. It will be deprecated in the future. "
"Please regenerate mlc-chat-config.json with the latest version";
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
}
if (config.count("conv_config")) {
// conv_config can override conv_template
this->conversation_.LoadJSONOverride(config["conv_config"], true);
try {
this->conversation_.LoadJSONOverride(config["conv_config"], true);
} catch (...) {
this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], true);
}
}
} else if (config.count("conv_config")) {
// without conv template, conv_config needs to be a complete config
this->conversation_.LoadJSONOverride(config["conv_config"], false);
try {
this->conversation_.LoadJSONOverride(config["conv_config"], false);
} catch (...) {
this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], false);
}
} else {
CHECK(partial_update) << "Key \"conv_template\" and \"conv_config\" not found.";
}
Expand Down
2 changes: 1 addition & 1 deletion docs/deploy/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ We provide an example below.

# Using a `ConvConfig`, we modify `system`, a field in the conversation template
# `system` refers to the prompt encoded before starting the chat
conv_config = ConvConfig(system='Please show as much happiness as you can when talking to me.')
conv_config = ConvConfig(system_message='Please show as much happiness as you can when talking to me.')

# We then include the `ConvConfig` instance in `ChatConfig` while overriding `max_gen_len`
# Note that `conv_config` is an optional subfield of `chat_config`
Expand Down
Loading
Loading