Skip to content

Commit e333d4e

Browse files
authored
Openprompt polishing (#8)
* yyjson rewrite * yyjson fix * yyjson fix
1 parent 060bb45 commit e333d4e

File tree

1 file changed

+62
-13
lines changed

1 file changed

+62
-13
lines changed

src/open_prompt_extension.cpp

+62-13
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,23 @@ namespace duckdb {
2424
struct OpenPromptData: FunctionData {
2525
idx_t model_idx;
2626
idx_t json_schema_idx;
27+
idx_t json_system_prompt_idx;
2728
unique_ptr<FunctionData> Copy() const {
2829
auto res = make_uniq<OpenPromptData>();
2930
res->model_idx = model_idx;
3031
res->json_schema_idx = json_schema_idx;
32+
res->json_system_prompt_idx = json_system_prompt_idx;
3133
return res;
3234
};
3335
bool Equals(const FunctionData &other) const {
3436
return model_idx == other.Cast<OpenPromptData>().model_idx &&
35-
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx;
37+
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx &&
38+
json_system_prompt_idx==other.Cast<OpenPromptData>().json_system_prompt_idx;
3639
};
3740
OpenPromptData() {
3841
model_idx = 0;
3942
json_schema_idx = 0;
43+
json_system_prompt_idx = 0;
4044
}
4145
};
4246

@@ -49,6 +53,8 @@ namespace duckdb {
4953
res->model_idx = i;
5054
} else if (argument->alias == "json_schema") {
5155
res->json_schema_idx = i;
56+
} else if (argument->alias == "system_prompt") {
57+
res->json_system_prompt_idx = i;
5258
}
5359
}
5460
return std::move(res);
@@ -182,26 +188,65 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
182188
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
183189
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
184190
std::string json_schema;
191+
std::string system_prompt;
185192

186193
if (info.model_idx != 0) {
187194
model_name = args.data[info.model_idx].GetValue(0).ToString();
188195
}
189196
if (info.json_schema_idx != 0) {
190197
json_schema = args.data[info.json_schema_idx].GetValue(0).ToString();
191198
}
199+
if (info.json_system_prompt_idx != 0) {
200+
system_prompt = args.data[info.json_system_prompt_idx].GetValue(0).ToString();
201+
}
192202

193-
std::string request_body = "{";
194-
request_body += "\"model\":\"" + model_name + "\",";
203+
unique_ptr<duckdb_yyjson::yyjson_mut_doc, void (*)(duckdb_yyjson::yyjson_mut_doc*)> doc(
204+
duckdb_yyjson::yyjson_mut_doc_new(nullptr), &duckdb_yyjson::yyjson_mut_doc_free);
205+
auto obj = duckdb_yyjson::yyjson_mut_obj(doc.get());
206+
duckdb_yyjson::yyjson_mut_doc_set_root(doc.get(), obj);
207+
duckdb_yyjson::yyjson_mut_obj_add(obj,
208+
duckdb_yyjson::yyjson_mut_str(doc.get(), "model"),
209+
duckdb_yyjson::yyjson_mut_str(doc.get(), model_name.c_str())
210+
);
195211
if (!json_schema.empty()) {
196-
request_body += "\"response_format\":{\"type\":\"json_object\", \"schema\":";
197-
request_body += json_schema;
198-
request_body += "},";
212+
auto response_format = duckdb_yyjson::yyjson_mut_obj(doc.get());
213+
duckdb_yyjson::yyjson_mut_obj_add(response_format,
214+
duckdb_yyjson::yyjson_mut_str(doc.get(), "type"),
215+
duckdb_yyjson::yyjson_mut_str(doc.get(), "json_object"));
216+
auto yyschema = duckdb_yyjson::yyjson_mut_raw(doc.get(), json_schema.c_str());
217+
duckdb_yyjson::yyjson_mut_obj_add(response_format,
218+
duckdb_yyjson::yyjson_mut_str(doc.get(), "schema"),
219+
yyschema);
220+
duckdb_yyjson::yyjson_mut_obj_add(obj,
221+
duckdb_yyjson::yyjson_mut_str(doc.get(),"response_format"),
222+
response_format);
199223
}
200-
request_body += "\"messages\":[";
201-
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
202-
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
203-
request_body += "]}";
204-
224+
auto messages = duckdb_yyjson::yyjson_mut_arr(doc.get());
225+
string str_messages[2][2] = {
226+
{"system", system_prompt},
227+
{"user", user_prompt.GetString()}
228+
};
229+
for (auto message : str_messages) {
230+
if (message[1].empty()) {
231+
continue;
232+
}
233+
auto yymessage = duckdb_yyjson::yyjson_mut_arr_add_obj(doc.get(),messages);
234+
duckdb_yyjson::yyjson_mut_obj_add(yymessage,
235+
duckdb_yyjson::yyjson_mut_str(doc.get(), "role"),
236+
duckdb_yyjson::yyjson_mut_str(doc.get(), message[0].c_str()));
237+
duckdb_yyjson::yyjson_mut_obj_add(yymessage,
238+
duckdb_yyjson::yyjson_mut_str(doc.get(), "content"),
239+
duckdb_yyjson::yyjson_mut_str(doc.get(), message[1].c_str()));
240+
}
241+
duckdb_yyjson::yyjson_mut_obj_add(obj, duckdb_yyjson::yyjson_mut_str(doc.get(), "messages"),
242+
messages);
243+
duckdb_yyjson::yyjson_write_err err;
244+
auto request_body = duckdb_yyjson::yyjson_mut_write_opts(doc.get(), 0, nullptr, nullptr, &err);
245+
if (request_body == nullptr) {
246+
throw std::runtime_error(err.msg);
247+
}
248+
string str_request_body(request_body);
249+
free(request_body);
205250

206251
try {
207252
auto client_and_path = SetupHttpClient(api_url);
@@ -214,7 +259,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
214259
headers.emplace("Authorization", "Bearer " + api_token);
215260
}
216261

217-
auto res = client.Post(path.c_str(), headers, request_body, "application/json");
262+
auto res = client.Post(path.c_str(), headers, str_request_body, "application/json");
218263

219264
if (!res) {
220265
HandleHttpError(res, "POST");
@@ -286,10 +331,14 @@ static void LoadInternal(DatabaseInstance &instance) {
286331
open_prompt.AddFunction(ScalarFunction(
287332
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
288333
OpenPromptBind));
289-
open_prompt.AddFunction(ScalarFunction(
334+
open_prompt.AddFunction(ScalarFunction(
290335
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
291336
LogicalType::VARCHAR, OpenPromptRequestFunction,
292337
OpenPromptBind));
338+
open_prompt.AddFunction(ScalarFunction(
339+
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
340+
LogicalType::VARCHAR, OpenPromptRequestFunction,
341+
OpenPromptBind));
293342

294343
ExtensionUtil::RegisterFunction(instance, open_prompt);
295344

0 commit comments

Comments
 (0)