Skip to content

Commit d4eec76

Browse files
authored
Fix w/ Settings
1 parent ee27c9f commit d4eec76

File tree

1 file changed

+119
-149
lines changed

1 file changed

+119
-149
lines changed

src/open_prompt_extension.cpp

+119-149
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,10 @@
1414
#include <sstream>
1515
#include <mutex>
1616
#include <iostream>
17-
#include <yyjson.hpp>
18-
17+
#include "yyjson.hpp"
1918

2019
namespace duckdb {
21-
struct OpenPromptData: FunctionData {
22-
unique_ptr<FunctionData> Copy() const {
23-
throw std::runtime_error("OpenPromptData::Copy");
24-
};
25-
bool Equals(const FunctionData &other) const {
26-
throw std::runtime_error("OpenPromptData::Equals");
27-
};
28-
};
29-
30-
// Helper function to parse URL and setup client
20+
3121
static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(const std::string &url) {
3222
std::string scheme, domain, path;
3323
size_t pos = url.find("://");
@@ -46,7 +36,6 @@ static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(co
4636
path = "/";
4737
}
4838

49-
// Create client and set a reasonable timeout (e.g., 10 seconds)
5039
duckdb_httplib_openssl::Client client(domain.c_str());
5140
client.set_read_timeout(10, 0); // 10 seconds
5241
client.set_follow_location(true); // Follow redirects
@@ -98,184 +87,167 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std
9887
throw std::runtime_error(err_message);
9988
}
10089

101-
102-
// Open Prompt
103-
// Global settings
104-
static std::string api_url = "http://localhost:11434/v1/chat/completions";
105-
static std::string api_token; // Store your API token here
106-
static std::string model_name = "qwen2.5:0.5b"; // Default model
107-
static std::mutex settings_mutex;
108-
109-
// Function to set API token
110-
void SetApiToken(DataChunk &args, ExpressionState &state, Vector &result) {
111-
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
112-
[&](string_t token) {
113-
try {
114-
auto _token = token.GetData();
115-
if (token.Empty()) {
116-
throw std::invalid_argument("API token cannot be empty.");
117-
}
118-
ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
119-
"openprompt_api_token",
120-
Value::CreateValue(token.GetString()));
121-
return StringVector::AddString(result, string("token : ") + string(_token, token.GetSize()));
122-
} catch (std::exception &e) {
123-
string_t res(e.what());
124-
res.Finalize();
125-
return res;
126-
}
127-
});
90+
// Settings management
91+
static std::string GetConfigValue(ClientContext &context, const string &var_name, const string &default_value) {
92+
Value value;
93+
auto &config = ClientConfig::GetConfig(context);
94+
if (!config.GetUserVariable(var_name, value) || value.IsNull()) {
95+
return default_value;
12896
}
97+
return value.ToString();
98+
}
12999

130-
// Function to set API URL
131-
void SetApiUrl(DataChunk &args, ExpressionState &state, Vector &result) {
132-
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
133-
[&](string_t token) {
100+
static void SetConfigValue(DataChunk &args, ExpressionState &state, Vector &result,
101+
const string &var_name, const string &value_type) {
102+
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
103+
[&](string_t value) {
134104
try {
135-
auto _token = token.GetData();
136-
if (token.Empty()) {
137-
throw std::invalid_argument("API token cannot be empty.");
105+
if (value == "" || value.GetSize() == 0) {
106+
throw std::invalid_argument(value_type + " cannot be empty.");
138107
}
108+
139109
ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
140-
"openprompt_api_url",
141-
Value::CreateValue(token.GetString()));
142-
return StringVector::AddString(result, string("url : ") + string(_token, token.GetSize()));
110+
var_name,
111+
Value::CreateValue(value.GetString())
112+
);
113+
return StringVector::AddString(result, value_type + " set to: " + value.GetString());
143114
} catch (std::exception &e) {
144-
string_t res(e.what());
145-
res.Finalize();
146-
return res;
115+
return StringVector::AddString(result, "Failed to set " + value_type + ": " + e.what());
147116
}
148117
});
149-
}
150-
151-
// Function to set model name
152-
void SetModelName(DataChunk &args, ExpressionState &state, Vector &result) {
153-
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
154-
[&](string_t token) {
155-
try {
156-
auto _token = token.GetData();
157-
if (token.Empty()) {
158-
throw std::invalid_argument("API token cannot be empty.");
159-
}
160-
ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
161-
"openprompt_model_name",
162-
Value::CreateValue(token.GetString()));
163-
return StringVector::AddString(result, string("name : ") + string(_token, token.GetSize()));
164-
} catch (std::exception &e) {
165-
string_t res(e.what());
166-
res.Finalize();
167-
return res;
168-
}
169-
});
170-
}
118+
}
171119

172-
// Retrieve the API URL from the stored settings
173-
static std::string GetApiUrl() {
174-
std::lock_guard<std::mutex> guard(settings_mutex);
175-
return api_url.empty() ? "http://localhost:11434/v1/chat/completions" : api_url;
176-
}
120+
static void SetApiToken(DataChunk &args, ExpressionState &state, Vector &result) {
121+
SetConfigValue(args, state, result, "openprompt_api_token", "API token");
122+
}
177123

178-
// Retrieve the API token from the stored settings
179-
static std::string GetApiToken() {
180-
std::lock_guard<std::mutex> guard(settings_mutex);
181-
return api_token;
182-
}
124+
static void SetApiUrl(DataChunk &args, ExpressionState &state, Vector &result) {
125+
SetConfigValue(args, state, result, "openprompt_api_url", "API URL");
126+
}
183127

184-
// Retrieve the model name from the stored settings
185-
static std::string GetModelName() {
186-
std::lock_guard<std::mutex> guard(settings_mutex);
187-
return model_name.empty() ? "qwen2.5:0.5b" : model_name;
188-
}
128+
static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result) {
129+
SetConfigValue(args, state, result, "openprompt_model_name", "Model name");
130+
}
189131

190-
template<typename a> a assert_null(a val) {
191-
if (val == nullptr) {
192-
throw std::runtime_error("Failed to parse the first message content in the API response.");
193-
}
194-
return val;
195-
}
196-
// Open Prompt Function
132+
// Main Function
197133
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
134+
D_ASSERT(args.data.size() >= 1); // At least prompt required
135+
198136
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
199137
[&](string_t user_prompt) {
200-
auto &conf = ClientConfig::GetConfig(state.GetContext());
201-
Value api_url;
202-
Value api_token;
203-
Value model_name;
204-
conf.GetUserVariable("openprompt_api_url", api_url);
205-
conf.GetUserVariable("openprompt_api_token", api_token);
206-
conf.GetUserVariable("openprompt_model_name", model_name);
207-
208-
// Manually construct the JSON body as a string. TODO use json parser from extension.
138+
auto &context = state.GetContext();
139+
140+
// Get configuration with defaults
141+
std::string api_url = GetConfigValue(context, "openprompt_api_url",
142+
"http://localhost:11434/v1/chat/completions");
143+
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
144+
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
145+
146+
// Override model if provided as second argument
147+
if (args.data.size() > 1 && !args.data[1].GetValue(0).IsNull()) {
148+
model_name = args.data[1].GetValue(0).ToString();
149+
}
150+
209151
std::string request_body = "{";
210-
request_body += "\"model\":\"" + model_name.ToString() + "\",";
152+
request_body += "\"model\":\"" + model_name + "\",";
211153
request_body += "\"messages\":[";
212154
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
213155
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
214156
request_body += "]}";
215157

216158
try {
217-
// Make the POST request
218-
auto client_and_path = SetupHttpClient(api_url.ToString());
159+
auto client_and_path = SetupHttpClient(api_url);
219160
auto &client = client_and_path.first;
220161
auto &path = client_and_path.second;
221162

222-
// Setup headers
223-
duckdb_httplib_openssl::Headers header_map;
224-
header_map.emplace("Content-Type", "application/json");
225-
if (!api_token.ToString().empty()) {
226-
header_map.emplace("Authorization", "Bearer " + api_token.ToString());
163+
duckdb_httplib_openssl::Headers headers;
164+
headers.emplace("Content-Type", "application/json");
165+
if (!api_token.empty()) {
166+
headers.emplace("Authorization", "Bearer " + api_token);
167+
}
168+
169+
auto res = client.Post(path.c_str(), headers, request_body, "application/json");
170+
171+
if (!res) {
172+
HandleHttpError(res, "POST");
173+
}
174+
175+
if (res->status != 200) {
176+
throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason);
227177
}
228178

229-
// Send the request
230-
auto res = client.Post(path.c_str(), header_map, request_body, "application/json");
231-
if (res && res->status == 200) {
232-
// Extract the first choice's message content from the response
233-
std::string response_body = res->body;
234-
unique_ptr<duckdb_yyjson::yyjson_doc, void(*)(struct duckdb_yyjson::yyjson_doc *)> doc(
235-
nullptr, &duckdb_yyjson::yyjson_doc_free
236-
);
237-
doc.reset(assert_null(
238-
duckdb_yyjson::yyjson_read(response_body.c_str(), response_body.length(), 0)
239-
));
240-
auto root = assert_null(duckdb_yyjson::yyjson_doc_get_root(doc.get()));
241-
auto choices = assert_null(duckdb_yyjson::yyjson_obj_get(root, "choices"));
242-
auto choices_0 = assert_null(duckdb_yyjson::yyjson_arr_get_first(choices));
243-
auto message = assert_null(duckdb_yyjson::yyjson_obj_get(choices_0, "message"));
244-
auto content = assert_null(duckdb_yyjson::yyjson_obj_get(message, "content"));
245-
auto c_content = assert_null(duckdb_yyjson::yyjson_get_str(content));
246-
return StringVector::AddString(result, c_content);
179+
try {
180+
unique_ptr<duckdb_yyjson::yyjson_doc, void(*)(duckdb_yyjson::yyjson_doc *)> doc(
181+
duckdb_yyjson::yyjson_read(res->body.c_str(), res->body.length(), 0),
182+
&duckdb_yyjson::yyjson_doc_free
183+
);
184+
185+
if (!doc) {
186+
throw std::runtime_error("Failed to parse JSON response");
187+
}
188+
189+
auto root = duckdb_yyjson::yyjson_doc_get_root(doc.get());
190+
if (!root) {
191+
throw std::runtime_error("Invalid JSON response: no root object");
192+
}
193+
194+
auto choices = duckdb_yyjson::yyjson_obj_get(root, "choices");
195+
if (!choices || !duckdb_yyjson::yyjson_is_arr(choices)) {
196+
throw std::runtime_error("Invalid response format: missing choices array");
197+
}
198+
199+
auto first_choice = duckdb_yyjson::yyjson_arr_get_first(choices);
200+
if (!first_choice) {
201+
throw std::runtime_error("Empty choices array in response");
202+
}
203+
204+
auto message = duckdb_yyjson::yyjson_obj_get(first_choice, "message");
205+
if (!message) {
206+
throw std::runtime_error("Missing message in response");
207+
}
208+
209+
auto content = duckdb_yyjson::yyjson_obj_get(message, "content");
210+
if (!content) {
211+
throw std::runtime_error("Missing content in response");
212+
}
213+
214+
auto content_str = duckdb_yyjson::yyjson_get_str(content);
215+
if (!content_str) {
216+
throw std::runtime_error("Invalid content in response");
217+
}
218+
219+
return StringVector::AddString(result, content_str);
220+
} catch (std::exception &e) {
221+
throw std::runtime_error("Failed to parse response: " + std::string(e.what()));
247222
}
248-
throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason);
249223
} catch (std::exception &e) {
250-
// In case of any error, return the original input text to avoid disruption
251-
return StringVector::AddString(result, e.what());
224+
// Log error and return error message
225+
return StringVector::AddString(result, "Error: " + std::string(e.what()));
252226
}
253227
});
254228
}
255229

256-
230+
// LoadInternal function
257231
static void LoadInternal(DatabaseInstance &instance) {
258-
// Register open_prompt function with two arguments: prompt and model
259232
ScalarFunctionSet open_prompt("open_prompt");
233+
234+
// Register with both single and two-argument variants
260235
open_prompt.AddFunction(ScalarFunction(
261236
{LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
237+
open_prompt.AddFunction(ScalarFunction(
238+
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
239+
262240
ExtensionUtil::RegisterFunction(instance, open_prompt);
263241

264-
// Other set_* functions remain the same as before
242+
// Register setting functions
265243
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
266-
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
267-
SetApiToken));
268-
244+
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiToken));
269245
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
270-
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
271-
SetApiUrl));
272-
246+
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
273247
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
274-
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName
275-
));
248+
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
276249
}
277250

278-
279251
void OpenPromptExtension::Load(DuckDB &db) {
280252
LoadInternal(*db.instance);
281253
}
@@ -292,7 +264,6 @@ std::string OpenPromptExtension::Version() const {
292264
#endif
293265
}
294266

295-
296267
} // namespace duckdb
297268

298269
extern "C" {
@@ -309,4 +280,3 @@ DUCKDB_EXTENSION_API const char *open_prompt_version() {
309280
#ifndef DUCKDB_EXTENSION_MAIN
310281
#error DUCKDB_EXTENSION_MAIN not defined
311282
#endif
312-

0 commit comments

Comments
 (0)