@@ -24,19 +24,23 @@ namespace duckdb {
24
24
struct OpenPromptData : FunctionData {
25
25
idx_t model_idx;
26
26
idx_t json_schema_idx;
27
+ idx_t json_system_prompt_idx;
27
28
unique_ptr<FunctionData> Copy () const {
28
29
auto res = make_uniq<OpenPromptData>();
29
30
res->model_idx = model_idx;
30
31
res->json_schema_idx = json_schema_idx;
32
+ res->json_system_prompt_idx = json_system_prompt_idx;
31
33
return res;
32
34
};
33
35
bool Equals (const FunctionData &other) const {
34
36
return model_idx == other.Cast <OpenPromptData>().model_idx &&
35
- json_schema_idx == other.Cast <OpenPromptData>().json_schema_idx ;
37
+ json_schema_idx == other.Cast <OpenPromptData>().json_schema_idx &&
38
+ json_system_prompt_idx==other.Cast <OpenPromptData>().json_system_prompt_idx ;
36
39
};
37
40
OpenPromptData () {
38
41
model_idx = 0 ;
39
42
json_schema_idx = 0 ;
43
+ json_system_prompt_idx = 0 ;
40
44
}
41
45
};
42
46
@@ -49,6 +53,8 @@ namespace duckdb {
49
53
res->model_idx = i;
50
54
} else if (argument->alias == " json_schema" ) {
51
55
res->json_schema_idx = i;
56
+ } else if (argument->alias == " system_prompt" ) {
57
+ res->json_system_prompt_idx = i;
52
58
}
53
59
}
54
60
return std::move (res);
@@ -182,26 +188,65 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
182
188
std::string api_token = GetConfigValue (context, " openprompt_api_token" , " " );
183
189
std::string model_name = GetConfigValue (context, " openprompt_model_name" , " qwen2.5:0.5b" );
184
190
std::string json_schema;
191
+ std::string system_prompt;
185
192
186
193
if (info.model_idx != 0 ) {
187
194
model_name = args.data [info.model_idx ].GetValue (0 ).ToString ();
188
195
}
189
196
if (info.json_schema_idx != 0 ) {
190
197
json_schema = args.data [info.json_schema_idx ].GetValue (0 ).ToString ();
191
198
}
199
+ if (info.json_system_prompt_idx != 0 ) {
200
+ system_prompt = args.data [info.json_system_prompt_idx ].GetValue (0 ).ToString ();
201
+ }
192
202
193
- std::string request_body = " {" ;
194
- request_body += " \" model\" :\" " + model_name + " \" ," ;
203
+ unique_ptr<duckdb_yyjson::yyjson_mut_doc, void (*)(duckdb_yyjson::yyjson_mut_doc*)> doc (
204
+ duckdb_yyjson::yyjson_mut_doc_new (nullptr ), &duckdb_yyjson::yyjson_mut_doc_free);
205
+ auto obj = duckdb_yyjson::yyjson_mut_obj (doc.get ());
206
+ duckdb_yyjson::yyjson_mut_doc_set_root (doc.get (), obj);
207
+ duckdb_yyjson::yyjson_mut_obj_add (obj,
208
+ duckdb_yyjson::yyjson_mut_str (doc.get (), " model" ),
209
+ duckdb_yyjson::yyjson_mut_str (doc.get (), model_name.c_str ())
210
+ );
195
211
if (!json_schema.empty ()) {
196
- request_body += " \" response_format\" :{\" type\" :\" json_object\" , \" schema\" :" ;
197
- request_body += json_schema;
198
- request_body += " }," ;
212
+ auto response_format = duckdb_yyjson::yyjson_mut_obj (doc.get ());
213
+ duckdb_yyjson::yyjson_mut_obj_add (response_format,
214
+ duckdb_yyjson::yyjson_mut_str (doc.get (), " type" ),
215
+ duckdb_yyjson::yyjson_mut_str (doc.get (), " json_object" ));
216
+ auto yyschema = duckdb_yyjson::yyjson_mut_raw (doc.get (), json_schema.c_str ());
217
+ duckdb_yyjson::yyjson_mut_obj_add (response_format,
218
+ duckdb_yyjson::yyjson_mut_str (doc.get (), " schema" ),
219
+ yyschema);
220
+ duckdb_yyjson::yyjson_mut_obj_add (obj,
221
+ duckdb_yyjson::yyjson_mut_str (doc.get ()," response_format" ),
222
+ response_format);
199
223
}
200
- request_body += " \" messages\" :[" ;
201
- request_body += " {\" role\" :\" system\" ,\" content\" :\" You are a helpful assistant.\" }," ;
202
- request_body += " {\" role\" :\" user\" ,\" content\" :\" " + user_prompt.GetString () + " \" }" ;
203
- request_body += " ]}" ;
204
-
224
+ auto messages = duckdb_yyjson::yyjson_mut_arr (doc.get ());
225
+ string str_messages[2 ][2 ] = {
226
+ {" system" , system_prompt},
227
+ {" user" , user_prompt.GetString ()}
228
+ };
229
+ for (auto message : str_messages) {
230
+ if (message[1 ].empty ()) {
231
+ continue ;
232
+ }
233
+ auto yymessage = duckdb_yyjson::yyjson_mut_arr_add_obj (doc.get (),messages);
234
+ duckdb_yyjson::yyjson_mut_obj_add (yymessage,
235
+ duckdb_yyjson::yyjson_mut_str (doc.get (), " role" ),
236
+ duckdb_yyjson::yyjson_mut_str (doc.get (), message[0 ].c_str ()));
237
+ duckdb_yyjson::yyjson_mut_obj_add (yymessage,
238
+ duckdb_yyjson::yyjson_mut_str (doc.get (), " content" ),
239
+ duckdb_yyjson::yyjson_mut_str (doc.get (), message[1 ].c_str ()));
240
+ }
241
+ duckdb_yyjson::yyjson_mut_obj_add (obj, duckdb_yyjson::yyjson_mut_str (doc.get (), " messages" ),
242
+ messages);
243
+ duckdb_yyjson::yyjson_write_err err;
244
+ auto request_body = duckdb_yyjson::yyjson_mut_write_opts (doc.get (), 0 , nullptr , nullptr , &err);
245
+ if (request_body == nullptr ) {
246
+ throw std::runtime_error (err.msg );
247
+ }
248
+ string str_request_body (request_body);
249
+ free (request_body);
205
250
206
251
try {
207
252
auto client_and_path = SetupHttpClient (api_url);
@@ -214,7 +259,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
214
259
headers.emplace (" Authorization" , " Bearer " + api_token);
215
260
}
216
261
217
- auto res = client.Post (path.c_str (), headers, request_body , " application/json" );
262
+ auto res = client.Post (path.c_str (), headers, str_request_body , " application/json" );
218
263
219
264
if (!res) {
220
265
HandleHttpError (res, " POST" );
@@ -286,10 +331,14 @@ static void LoadInternal(DatabaseInstance &instance) {
286
331
open_prompt.AddFunction (ScalarFunction (
287
332
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
288
333
OpenPromptBind));
289
- open_prompt.AddFunction (ScalarFunction (
334
+ open_prompt.AddFunction (ScalarFunction (
290
335
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
291
336
LogicalType::VARCHAR, OpenPromptRequestFunction,
292
337
OpenPromptBind));
338
+ open_prompt.AddFunction (ScalarFunction (
339
+ {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
340
+ LogicalType::VARCHAR, OpenPromptRequestFunction,
341
+ OpenPromptBind));
293
342
294
343
ExtensionUtil::RegisterFunction (instance, open_prompt);
295
344
0 commit comments