Skip to content

Commit

Permalink
Address lint error
Browse files Browse the repository at this point in the history
  • Loading branch information
rickzx committed Mar 15, 2024
1 parent d31099b commit 9b01bd3
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 54 deletions.
7 changes: 4 additions & 3 deletions cpp/conversation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ void Conversation::LoadJSONOverride(const picojson::value& config_json, bool par
CHECK(config["system_message"].is<std::string>()) << "Invalid system message" << err_templ;
std::string system_template = config["system_template"].get<std::string>();
std::string system_msg = config["system_message"].get<std::string>();
std::string system = system_template.replace(system_template.find(system_placeholder),
std::string system = system_template.replace(system_template.find(system_placeholder),
system_placeholder.length(), system_msg);
this->system = system;
} else {
CHECK(partial_update) << "Key \"system_template\" or \"system_message\" not found.";
}

if (config.count("system_prefix_token_ids")) {
CHECK(config["system_prefix_token_ids"].is<picojson::array>()) << "Invalid system_prefix_token_ids" << err_templ;
CHECK(config["system_prefix_token_ids"].is<picojson::array>())
<< "Invalid system_prefix_token_ids" << err_templ;
picojson::array prefix_tokens_arr = config["system_prefix_token_ids"].get<picojson::array>();
std::vector<int32_t> prefix_tokens;
for (const picojson::value& prefix_token : prefix_tokens_arr) {
Expand Down Expand Up @@ -129,7 +130,7 @@ void Conversation::LoadJSONOverride(const picojson::value& config_json, bool par
} else {
CHECK(partial_update) << "Key \"stop_token_ids\" not found.";
}
}
}

void Conversation::LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
Expand Down
2 changes: 1 addition & 1 deletion cpp/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class Conversation {
*/
void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false);

/*!
/*!
* \brief Load legacy JSON config and overrides options.
*
* \param config_json A json config in picojson type that is partially specifies
Expand Down
5 changes: 3 additions & 2 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,9 @@ class LLMChat {
this->conversation_.LoadJSONOverride(config["conv_template"], false);
} else {
ICHECK(config["conv_template"].is<std::string>());
LOG(WARNING) << "Legacy conversation template detected. It will be deprecated in the future. "
"Please regenerate mlc-chat-config.json with the latest version";
LOG(WARNING)
<< "Legacy conversation template detected. It will be deprecated in the future. "
"Please regenerate mlc-chat-config.json with the latest version";
std::string conv_template = config["conv_template"].get<std::string>();
this->conversation_ = Conversation::FromTemplate(conv_template);
}
Expand Down
6 changes: 1 addition & 5 deletions python/mlc_llm/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,6 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes
function_string: Optional[str] = None
use_function_calling: Optional[bool] = None

def __post_init__(self):
if self.messages is not None and self.offset is None:
self.offset = len(self.messages)


@dataclass
class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -200,7 +196,7 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes

model_lib: Optional[str] = None
local_id: Optional[str] = None
conv_template: Optional[str | Conversation] = None
conv_template: Optional[Union[str, Conversation]] = None
temperature: Optional[float] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
Expand Down
12 changes: 6 additions & 6 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import json
import shutil
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from mlc_llm.conversation_template import ConvTemplateRegistry
from mlc_llm.model import Model
from mlc_llm.quantization import Quantization
from mlc_llm.support import convert_tiktoken, logging
from mlc_llm.support.style import bold, green, red
from mlc_llm.conversation_template import ConvTemplateRegistry

from .compiler_flags import ModelConfigOverride

Expand Down Expand Up @@ -46,7 +46,7 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes
repetition_penalty: float = None
top_p: float = None
# Conversation template
conv_template: str | Dict[str, Any] = None
conv_template: Union[str, Dict[str, Any]] = None
pad_token_id: int = None
bos_token_id: int = None
eos_token_id: int = None
Expand Down Expand Up @@ -90,16 +90,16 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
):
"""Entrypoint of MLC Chat configuration generation."""
# Step 1. Initialize `mlc-chat-config.json` using `config.json`
conversation = ConvTemplateRegistry.get_conv_template(conv_template)
if conversation is None:
conversation_reg = ConvTemplateRegistry.get_conv_template(conv_template)
if conversation_reg is None:
logger.warning(
"%s: Conversation template is not registered in ConvTemplateRegistry: %s",
red("Warning"),
conv_template,
)
conversation = conv_template
else:
conversation = conversation.to_json_dict()
conversation = conversation_reg.to_json_dict()

model_config = ModelConfigOverride(
context_window_size=context_window_size,
Expand Down
11 changes: 8 additions & 3 deletions python/mlc_llm/protocol/conversation_protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The standard conversation protocol in MLC LLM"""

from enum import Enum
from typing import Any, Dict, List, Optional, Self, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar

from pydantic import BaseModel, Field, field_validator

Expand All @@ -17,6 +17,9 @@ class MessagePlaceholders(Enum):
FUNCTION = "{function_string}"


T = TypeVar("T", bound="BaseModel")


class Conversation(BaseModel):
"""Class that specifies the convention template of conversation
and contains the conversation history.
Expand Down Expand Up @@ -96,10 +99,12 @@ def check_message_seps(cls, seps: List[str]) -> List[str]:
return seps

def to_json_dict(self) -> Dict[str, Any]:
"""Convert to a json dictionary"""
return self.model_dump(exclude_none=True)

@staticmethod
def from_json_dict(json_dict) -> Self:
@classmethod
def from_json_dict(cls: Type[T], json_dict: Dict[str, Any]) -> T:
"""Convert from a json dictionary"""
return Conversation.model_validate(json_dict)

def as_prompt(self) -> str:
Expand Down
66 changes: 32 additions & 34 deletions tests/cpp/conv_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,45 @@

void _TestConversationLoadJSON() {
std::string conv_template =
"{\n"
" \"name\": \"test\",\n"
" \"system_template\": \"abc{system_message}\",\n"
" \"system_message\": \"de\",\n"
" \"roles\": {\n"
" \"user\": \"Instruct\",\n"
" \"assistant\": \"Output\",\n"
" \"tool\": \"Instruct\"\n"
" },\n"
" \"role_templates\": {\n"
" \"user\": \"{user_message}\",\n"
" \"assistant\": \"{assistant_message}\",\n"
" \"tool\": \"{tool_message}\"\n"
" },\n"
" \"messages\": [[\"Instruct\", \"Hello\"], [\"Output\", \"Hey\"]],\n"
" \"seps\": [\n"
" \"\\n\"\n"
" ],\n"
" \"role_content_sep\": \": \",\n"
" \"role_empty_sep\": \":\",\n"
" \"stop_str\": [\n"
" \"<|endoftext|>\"\n"
" ],\n"
" \"stop_token_ids\": [\n"
" 50256\n"
" ],\n"
" \"function_string\": \"\",\n"
" \"use_function_calling\": false\n"
" }";
"{\n"
" \"name\": \"test\",\n"
" \"system_template\": \"abc{system_message}\",\n"
" \"system_message\": \"de\",\n"
" \"roles\": {\n"
" \"user\": \"Instruct\",\n"
" \"assistant\": \"Output\",\n"
" \"tool\": \"Instruct\"\n"
" },\n"
" \"role_templates\": {\n"
" \"user\": \"{user_message}\",\n"
" \"assistant\": \"{assistant_message}\",\n"
" \"tool\": \"{tool_message}\"\n"
" },\n"
" \"messages\": [[\"Instruct\", \"Hello\"], [\"Output\", \"Hey\"]],\n"
" \"seps\": [\n"
" \"\\n\"\n"
" ],\n"
" \"role_content_sep\": \": \",\n"
" \"role_empty_sep\": \":\",\n"
" \"stop_str\": [\n"
" \"<|endoftext|>\"\n"
" ],\n"
" \"stop_token_ids\": [\n"
" 50256\n"
" ],\n"
" \"function_string\": \"\",\n"
" \"use_function_calling\": false\n"
"}";
mlc::llm::Conversation conv;
conv.LoadJSONOverride(conv_template, true);
ASSERT_EQ(conv.name, "test");
ASSERT_EQ(conv.system, "abcde");

std::vector<std::string> expected_roles {"Instruct", "Output"};
std::vector<std::string> expected_roles{"Instruct", "Output"};
ASSERT_EQ(conv.roles, expected_roles);

std::vector<std::vector<std::string>> expected_messages = {
{"Instruct", "Hello"},
{"Output", "Hey"}
};
std::vector<std::vector<std::string>> expected_messages = {{"Instruct", "Hello"},
{"Output", "Hey"}};
ASSERT_EQ(conv.messages, expected_messages);
ASSERT_EQ(conv.offset, 2);

Expand Down

0 comments on commit 9b01bd3

Please sign in to comment.