Skip to content

Commit 780e38f

Browse files
authored
set_api_timeout (#15)
Optional `set_api_timeout` settings function
1 parent 23a4f5e commit 780e38f

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/open_prompt_extension.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result
173173
SetConfigValue(args, state, result, "openprompt_model_name", "Model name");
174174
}
175175

176+
static void SetApiTimeout(DataChunk &args, ExpressionState &state, Vector &result) {
177+
SetConfigValue(args, state, result, "openprompt_api_timeout", "API timeout");
178+
}
179+
176180
// Main Function
177181
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
178182
D_ASSERT(args.data.size() >= 1); // At least prompt required
@@ -187,6 +191,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
187191
"http://localhost:11434/v1/chat/completions");
188192
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
189193
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
194+
std::string api_timeout = GetConfigValue(context, "openprompt_api_timeout", "");
190195
std::string json_schema;
191196
std::string system_prompt;
192197

@@ -259,6 +264,10 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
259264
headers.emplace("Authorization", "Bearer " + api_token);
260265
}
261266

267+
if (!api_timeout.empty()){
268+
client.set_read_timeout(stoi(api_timeout), 0);
269+
}
270+
262271
auto res = client.Post(path.c_str(), headers, str_request_body, "application/json");
263272

264273
if (!res) {
@@ -349,6 +358,8 @@ static void LoadInternal(DatabaseInstance &instance) {
349358
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
350359
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
351360
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
361+
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
362+
"set_api_timeout", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiTimeout));
352363
}
353364

354365
void OpenPromptExtension::Load(DuckDB &db) {

0 commit comments

Comments
 (0)