14
14
#include < sstream>
15
15
#include < mutex>
16
16
#include < iostream>
17
- #include < yyjson.hpp>
18
-
17
+ #include " yyjson.hpp"
19
18
20
19
namespace duckdb {
21
- struct OpenPromptData : FunctionData {
22
- unique_ptr<FunctionData> Copy () const {
23
- throw std::runtime_error (" OpenPromptData::Copy" );
24
- };
25
- bool Equals (const FunctionData &other) const {
26
- throw std::runtime_error (" OpenPromptData::Equals" );
27
- };
28
- };
29
-
30
- // Helper function to parse URL and setup client
20
+
31
21
static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient (const std::string &url) {
32
22
std::string scheme, domain, path;
33
23
size_t pos = url.find (" ://" );
@@ -46,7 +36,6 @@ static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(co
46
36
path = " /" ;
47
37
}
48
38
49
- // Create client and set a reasonable timeout (e.g., 10 seconds)
50
39
duckdb_httplib_openssl::Client client (domain.c_str ());
51
40
client.set_read_timeout (10 , 0 ); // 10 seconds
52
41
client.set_follow_location (true ); // Follow redirects
@@ -98,184 +87,167 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std
98
87
throw std::runtime_error (err_message);
99
88
}
100
89
101
-
102
- // Open Prompt
103
- // Global settings
104
- static std::string api_url = " http://localhost:11434/v1/chat/completions" ;
105
- static std::string api_token; // Store your API token here
106
- static std::string model_name = " qwen2.5:0.5b" ; // Default model
107
- static std::mutex settings_mutex;
108
-
109
- // Function to set API token
110
- void SetApiToken (DataChunk &args, ExpressionState &state, Vector &result) {
111
- UnaryExecutor::Execute<string_t , string_t >(args.data [0 ], result, args.size (),
112
- [&](string_t token) {
113
- try {
114
- auto _token = token.GetData ();
115
- if (token.Empty ()) {
116
- throw std::invalid_argument (" API token cannot be empty." );
117
- }
118
- ClientConfig::GetConfig (state.GetContext ()).SetUserVariable (
119
- " openprompt_api_token" ,
120
- Value::CreateValue (token.GetString ()));
121
- return StringVector::AddString (result, string (" token : " ) + string (_token, token.GetSize ()));
122
- } catch (std::exception &e) {
123
- string_t res (e.what ());
124
- res.Finalize ();
125
- return res;
126
- }
127
- });
90
+ // Settings management
91
+ static std::string GetConfigValue (ClientContext &context, const string &var_name, const string &default_value) {
92
+ Value value;
93
+ auto &config = ClientConfig::GetConfig (context);
94
+ if (!config.GetUserVariable (var_name, value) || value.IsNull ()) {
95
+ return default_value;
128
96
}
97
+ return value.ToString ();
98
+ }
129
99
130
- // Function to set API URL
131
- void SetApiUrl (DataChunk &args, ExpressionState &state, Vector &result ) {
132
- UnaryExecutor::Execute<string_t , string_t >(args.data [0 ], result, args.size (),
133
- [&](string_t token ) {
100
+ static void SetConfigValue (DataChunk &args, ExpressionState &state, Vector &result,
101
+ const string &var_name, const string &value_type ) {
102
+ UnaryExecutor::Execute<string_t , string_t >(args.data [0 ], result, args.size (),
103
+ [&](string_t value ) {
134
104
try {
135
- auto _token = token.GetData ();
136
- if (token.Empty ()) {
137
- throw std::invalid_argument (" API token cannot be empty." );
105
+ if (value == " " || value.GetSize () == 0 ) {
106
+ throw std::invalid_argument (value_type + " cannot be empty." );
138
107
}
108
+
139
109
ClientConfig::GetConfig (state.GetContext ()).SetUserVariable (
140
- " openprompt_api_url" ,
141
- Value::CreateValue (token.GetString ()));
142
- return StringVector::AddString (result, string (" url : " ) + string (_token, token.GetSize ()));
110
+ var_name,
111
+ Value::CreateValue (value.GetString ())
112
+ );
113
+ return StringVector::AddString (result, value_type + " set to: " + value.GetString ());
143
114
} catch (std::exception &e) {
144
- string_t res (e.what ());
145
- res.Finalize ();
146
- return res;
115
+ return StringVector::AddString (result, " Failed to set " + value_type + " : " + e.what ());
147
116
}
148
117
});
149
- }
150
-
151
- // Function to set model name
152
- void SetModelName (DataChunk &args, ExpressionState &state, Vector &result) {
153
- UnaryExecutor::Execute<string_t , string_t >(args.data [0 ], result, args.size (),
154
- [&](string_t token) {
155
- try {
156
- auto _token = token.GetData ();
157
- if (token.Empty ()) {
158
- throw std::invalid_argument (" API token cannot be empty." );
159
- }
160
- ClientConfig::GetConfig (state.GetContext ()).SetUserVariable (
161
- " openprompt_model_name" ,
162
- Value::CreateValue (token.GetString ()));
163
- return StringVector::AddString (result, string (" name : " ) + string (_token, token.GetSize ()));
164
- } catch (std::exception &e) {
165
- string_t res (e.what ());
166
- res.Finalize ();
167
- return res;
168
- }
169
- });
170
- }
118
+ }
171
119
172
- // Retrieve the API URL from the stored settings
173
- static std::string GetApiUrl () {
174
- std::lock_guard<std::mutex> guard (settings_mutex);
175
- return api_url.empty () ? " http://localhost:11434/v1/chat/completions" : api_url;
176
- }
120
+ static void SetApiToken (DataChunk &args, ExpressionState &state, Vector &result) {
121
+ SetConfigValue (args, state, result, " openprompt_api_token" , " API token" );
122
+ }
177
123
178
- // Retrieve the API token from the stored settings
179
- static std::string GetApiToken () {
180
- std::lock_guard<std::mutex> guard (settings_mutex);
181
- return api_token;
182
- }
124
+ static void SetApiUrl (DataChunk &args, ExpressionState &state, Vector &result) {
125
+ SetConfigValue (args, state, result, " openprompt_api_url" , " API URL" );
126
+ }
183
127
184
- // Retrieve the model name from the stored settings
185
- static std::string GetModelName () {
186
- std::lock_guard<std::mutex> guard (settings_mutex);
187
- return model_name.empty () ? " qwen2.5:0.5b" : model_name;
188
- }
128
+ static void SetModelName (DataChunk &args, ExpressionState &state, Vector &result) {
129
+ SetConfigValue (args, state, result, " openprompt_model_name" , " Model name" );
130
+ }
189
131
190
- template <typename a> a assert_null (a val) {
191
- if (val == nullptr ) {
192
- throw std::runtime_error (" Failed to parse the first message content in the API response." );
193
- }
194
- return val;
195
- }
196
- // Open Prompt Function
132
+ // Main Function
197
133
static void OpenPromptRequestFunction (DataChunk &args, ExpressionState &state, Vector &result) {
134
+ D_ASSERT (args.data .size () >= 1 ); // At least prompt required
135
+
198
136
UnaryExecutor::Execute<string_t , string_t >(args.data [0 ], result, args.size (),
199
137
[&](string_t user_prompt) {
200
- auto &conf = ClientConfig::GetConfig (state.GetContext ());
201
- Value api_url;
202
- Value api_token;
203
- Value model_name;
204
- conf.GetUserVariable (" openprompt_api_url" , api_url);
205
- conf.GetUserVariable (" openprompt_api_token" , api_token);
206
- conf.GetUserVariable (" openprompt_model_name" , model_name);
207
-
208
- // Manually construct the JSON body as a string. TODO use json parser from extension.
138
+ auto &context = state.GetContext ();
139
+
140
+ // Get configuration with defaults
141
+ std::string api_url = GetConfigValue (context, " openprompt_api_url" ,
142
+ " http://localhost:11434/v1/chat/completions" );
143
+ std::string api_token = GetConfigValue (context, " openprompt_api_token" , " " );
144
+ std::string model_name = GetConfigValue (context, " openprompt_model_name" , " qwen2.5:0.5b" );
145
+
146
+ // Override model if provided as second argument
147
+ if (args.data .size () > 1 && !args.data [1 ].GetValue (0 ).IsNull ()) {
148
+ model_name = args.data [1 ].GetValue (0 ).ToString ();
149
+ }
150
+
209
151
std::string request_body = " {" ;
210
- request_body += " \" model\" :\" " + model_name. ToString () + " \" ," ;
152
+ request_body += " \" model\" :\" " + model_name + " \" ," ;
211
153
request_body += " \" messages\" :[" ;
212
154
request_body += " {\" role\" :\" system\" ,\" content\" :\" You are a helpful assistant.\" }," ;
213
155
request_body += " {\" role\" :\" user\" ,\" content\" :\" " + user_prompt.GetString () + " \" }" ;
214
156
request_body += " ]}" ;
215
157
216
158
try {
217
- // Make the POST request
218
- auto client_and_path = SetupHttpClient (api_url.ToString ());
159
+ auto client_and_path = SetupHttpClient (api_url);
219
160
auto &client = client_and_path.first ;
220
161
auto &path = client_and_path.second ;
221
162
222
- // Setup headers
223
- duckdb_httplib_openssl::Headers header_map;
224
- header_map.emplace (" Content-Type" , " application/json" );
225
- if (!api_token.ToString ().empty ()) {
226
- header_map.emplace (" Authorization" , " Bearer " + api_token.ToString ());
163
+ duckdb_httplib_openssl::Headers headers;
164
+ headers.emplace (" Content-Type" , " application/json" );
165
+ if (!api_token.empty ()) {
166
+ headers.emplace (" Authorization" , " Bearer " + api_token);
167
+ }
168
+
169
+ auto res = client.Post (path.c_str (), headers, request_body, " application/json" );
170
+
171
+ if (!res) {
172
+ HandleHttpError (res, " POST" );
173
+ }
174
+
175
+ if (res->status != 200 ) {
176
+ throw std::runtime_error (" HTTP error " + std::to_string (res->status ) + " : " + res->reason );
227
177
}
228
178
229
- // Send the request
230
- auto res = client.Post (path.c_str (), header_map, request_body, " application/json" );
231
- if (res && res->status == 200 ) {
232
- // Extract the first choice's message content from the response
233
- std::string response_body = res->body ;
234
- unique_ptr<duckdb_yyjson::yyjson_doc, void (*)(struct duckdb_yyjson ::yyjson_doc *)> doc (
235
- nullptr , &duckdb_yyjson::yyjson_doc_free
236
- );
237
- doc.reset (assert_null (
238
- duckdb_yyjson::yyjson_read (response_body.c_str (), response_body.length (), 0 )
239
- ));
240
- auto root = assert_null (duckdb_yyjson::yyjson_doc_get_root (doc.get ()));
241
- auto choices = assert_null (duckdb_yyjson::yyjson_obj_get (root, " choices" ));
242
- auto choices_0 = assert_null (duckdb_yyjson::yyjson_arr_get_first (choices));
243
- auto message = assert_null (duckdb_yyjson::yyjson_obj_get (choices_0, " message" ));
244
- auto content = assert_null (duckdb_yyjson::yyjson_obj_get (message, " content" ));
245
- auto c_content = assert_null (duckdb_yyjson::yyjson_get_str (content));
246
- return StringVector::AddString (result, c_content);
179
+ try {
180
+ unique_ptr<duckdb_yyjson::yyjson_doc, void (*)(duckdb_yyjson::yyjson_doc *)> doc (
181
+ duckdb_yyjson::yyjson_read (res->body .c_str (), res->body .length (), 0 ),
182
+ &duckdb_yyjson::yyjson_doc_free
183
+ );
184
+
185
+ if (!doc) {
186
+ throw std::runtime_error (" Failed to parse JSON response" );
187
+ }
188
+
189
+ auto root = duckdb_yyjson::yyjson_doc_get_root (doc.get ());
190
+ if (!root) {
191
+ throw std::runtime_error (" Invalid JSON response: no root object" );
192
+ }
193
+
194
+ auto choices = duckdb_yyjson::yyjson_obj_get (root, " choices" );
195
+ if (!choices || !duckdb_yyjson::yyjson_is_arr (choices)) {
196
+ throw std::runtime_error (" Invalid response format: missing choices array" );
197
+ }
198
+
199
+ auto first_choice = duckdb_yyjson::yyjson_arr_get_first (choices);
200
+ if (!first_choice) {
201
+ throw std::runtime_error (" Empty choices array in response" );
202
+ }
203
+
204
+ auto message = duckdb_yyjson::yyjson_obj_get (first_choice, " message" );
205
+ if (!message) {
206
+ throw std::runtime_error (" Missing message in response" );
207
+ }
208
+
209
+ auto content = duckdb_yyjson::yyjson_obj_get (message, " content" );
210
+ if (!content) {
211
+ throw std::runtime_error (" Missing content in response" );
212
+ }
213
+
214
+ auto content_str = duckdb_yyjson::yyjson_get_str (content);
215
+ if (!content_str) {
216
+ throw std::runtime_error (" Invalid content in response" );
217
+ }
218
+
219
+ return StringVector::AddString (result, content_str);
220
+ } catch (std::exception &e) {
221
+ throw std::runtime_error (" Failed to parse response: " + std::string (e.what ()));
247
222
}
248
- throw std::runtime_error (" HTTP POST error: " + std::to_string (res->status ) + " - " + res->reason );
249
223
} catch (std::exception &e) {
250
- // In case of any error, return the original input text to avoid disruption
251
- return StringVector::AddString (result, e.what ());
224
+ // Log error and return error message
225
+ return StringVector::AddString (result, " Error: " + std::string ( e.what () ));
252
226
}
253
227
});
254
228
}
255
229
256
-
230
+ // LoadInternal function
257
231
static void LoadInternal (DatabaseInstance &instance) {
258
- // Register open_prompt function with two arguments: prompt and model
259
232
ScalarFunctionSet open_prompt (" open_prompt" );
233
+
234
+ // Register with both single and two-argument variants
260
235
open_prompt.AddFunction (ScalarFunction (
261
236
{LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
237
+ open_prompt.AddFunction (ScalarFunction (
238
+ {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
239
+
262
240
ExtensionUtil::RegisterFunction (instance, open_prompt);
263
241
264
- // Other set_* functions remain the same as before
242
+ // Register setting functions
265
243
ExtensionUtil::RegisterFunction (instance, ScalarFunction (
266
- " set_api_token" , {LogicalType::VARCHAR}, LogicalType::VARCHAR,
267
- SetApiToken));
268
-
244
+ " set_api_token" , {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiToken));
269
245
ExtensionUtil::RegisterFunction (instance, ScalarFunction (
270
- " set_api_url" , {LogicalType::VARCHAR}, LogicalType::VARCHAR,
271
- SetApiUrl));
272
-
246
+ " set_api_url" , {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
273
247
ExtensionUtil::RegisterFunction (instance, ScalarFunction (
274
- " set_model_name" , {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName
275
- ));
248
+ " set_model_name" , {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
276
249
}
277
250
278
-
279
251
void OpenPromptExtension::Load (DuckDB &db) {
280
252
LoadInternal (*db.instance );
281
253
}
@@ -292,7 +264,6 @@ std::string OpenPromptExtension::Version() const {
292
264
#endif
293
265
}
294
266
295
-
296
267
} // namespace duckdb
297
268
298
269
extern " C" {
@@ -309,4 +280,3 @@ DUCKDB_EXTENSION_API const char *open_prompt_version() {
309
280
#ifndef DUCKDB_EXTENSION_MAIN
310
281
#error DUCKDB_EXTENSION_MAIN not defined
311
282
#endif
312
-
0 commit comments