Skip to content

Commit c0d634a

Browse files
authored
SECRET support (#20)
* secret manager * ENV support * cast unique_ptr to the correct type * cast unique_ptr to the correct type * resync * add tests * fix env, secrets handling * Update README.md
1 parent ca96d15 commit c0d634a

8 files changed

+212
-15
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension)
1414
project(${TARGET_NAME})
1515
include_directories(src/include duckdb/third_party/httplib)
1616

17-
set(EXTENSION_SOURCES src/open_prompt_extension.cpp)
17+
set(EXTENSION_SOURCES src/open_prompt_extension.cpp src/open_prompt_secret.cpp)
1818

1919
if(MINGW)
2020
set(OPENSSL_USE_STATIC_LIBS TRUE)

docs/README.md

+20
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,29 @@ Setup the completions API configuration w/ optional auth token and model name
2828
SET VARIABLE openprompt_api_url = 'http://localhost:11434/v1/chat/completions';
2929
SET VARIABLE openprompt_api_token = 'your_api_key_here';
3030
SET VARIABLE openprompt_model_name = 'qwen2.5:0.5b';
31+
```
32+
33+
Alternatively the following ENV variables can be used at runtime
34+
```
35+
OPEN_PROMPT_API_URL='http://localhost:11434/v1/chat/completions'
36+
OPEN_PROMPT_API_TOKEN='your_api_key_here'
37+
OPEN_PROMPT_MODEL_NAME='qwen2.5:0.5b'
38+
OPEN_PROMPT_API_TIMEOUT='30'
39+
```
3140

41+
For persistent usage, configure parameters using DuckDB SECRETS
42+
```sql
43+
CREATE SECRET IF NOT EXISTS open_prompt (
44+
TYPE open_prompt,
45+
PROVIDER config,
46+
api_token 'your-api-token',
47+
api_url 'http://localhost:11434/v1/chat/completions',
48+
model_name 'qwen2.5:0.5b',
49+
api_timeout '30'
50+
);
3251
```
3352

53+
3454
### Usage
3555
```sql
3656
D SELECT open_prompt('Write a one-line poem about ducks') AS response;

duckdb

Submodule duckdb updated 3121 files

src/include/open_prompt_secret.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#pragma once
2+
3+
#include "duckdb/main/secret/secret.hpp"
4+
#include "duckdb/main/extension_util.hpp"
5+
6+
namespace duckdb {
7+
8+
struct CreateOpenPromptSecretFunctions {
9+
public:
10+
static void Register(DatabaseInstance &instance);
11+
};
12+
13+
} // namespace duckdb

src/open_prompt_extension.cpp

+85-12
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@
77
#include "duckdb/common/exception/http_exception.hpp"
88
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
99

10+
#include "duckdb/main/secret/secret_manager.hpp"
11+
#include "duckdb/main/secret/secret.hpp"
12+
#include "duckdb/main/secret/secret_storage.hpp"
13+
14+
#include "open_prompt_secret.hpp"
15+
1016
#ifdef USE_ZLIB
1117
#define CPPHTTPLIB_ZLIB_SUPPORT
1218
#endif
1319

1420
#define CPPHTTPLIB_OPENSSL_SUPPORT
1521
#include "httplib.hpp"
1622

23+
#include <cstdlib>
24+
#include <algorithm>
25+
#include <cctype>
1726
#include <string>
1827
#include <sstream>
1928
#include <mutex>
@@ -29,13 +38,13 @@ namespace duckdb {
2938
idx_t model_idx;
3039
idx_t json_schema_idx;
3140
idx_t json_system_prompt_idx;
32-
unique_ptr<FunctionData> Copy() const {
33-
auto res = make_uniq<OpenPromptData>();
34-
res->model_idx = model_idx;
35-
res->json_schema_idx = json_schema_idx;
36-
res->json_system_prompt_idx = json_system_prompt_idx;
37-
return res;
38-
};
41+
unique_ptr<FunctionData> Copy() const override {
42+
auto res = make_uniq<OpenPromptData>();
43+
res->model_idx = model_idx;
44+
res->json_schema_idx = json_schema_idx;
45+
res->json_system_prompt_idx = json_system_prompt_idx;
46+
return unique_ptr<FunctionData>(std::move(res));
47+
};
3948
bool Equals(const FunctionData &other) const {
4049
return model_idx == other.Cast<OpenPromptData>().model_idx &&
4150
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx &&
@@ -142,14 +151,75 @@ namespace duckdb {
142151

143152
// Settings management
144153
static std::string GetConfigValue(ClientContext &context, const string &var_name, const string &default_value) {
145-
Value value;
146-
auto &config = ClientConfig::GetConfig(context);
147-
if (!config.GetUserVariable(var_name, value) || value.IsNull()) {
148-
return default_value;
154+
// Try environment variables
155+
{
156+
// Create uppercase ENV version: OPEN_PROMPT_SETTING
157+
std::string stripped_name = var_name;
158+
const std::string prefix = "openprompt_";
159+
if (stripped_name.substr(0, prefix.length()) == prefix) {
160+
stripped_name = stripped_name.substr(prefix.length());
161+
}
162+
std::string env_var_name = "OPEN_PROMPT_" + stripped_name;
163+
std::transform(env_var_name.begin(), env_var_name.end(), env_var_name.begin(), ::toupper);
164+
// std::cout << "SEARCH ENV FOR " << env_var_name << "\n";
165+
166+
const char* env_value = std::getenv(env_var_name.c_str());
167+
if (env_value != nullptr && strlen(env_value) > 0) {
168+
// std::cout << "USING ENV FOR " << var_name << "\n";
169+
std::string result(env_value);
170+
return result;
171+
}
172+
}
173+
174+
// Try to get from secrets
175+
{
176+
// Create lowercase secret version: open_prompt_setting
177+
std::string secret_key = var_name;
178+
const std::string prefix = "openprompt_";
179+
if (secret_key.substr(0, prefix.length()) == prefix) {
180+
secret_key = secret_key.substr(prefix.length());
181+
}
182+
// secret_key = "open_prompt_" + secret_key;
183+
std::transform(secret_key.begin(), secret_key.end(), secret_key.begin(), ::tolower);
184+
185+
auto &secret_manager = SecretManager::Get(context);
186+
try {
187+
// std::cout << "SEARCH SECRET FOR " << secret_key << "\n";
188+
auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context);
189+
auto secret_match = secret_manager.LookupSecret(transaction, "open_prompt", "open_prompt");
190+
if (secret_match.HasMatch()) {
191+
auto &secret = secret_match.GetSecret();
192+
if (secret.GetType() != "open_prompt") {
193+
throw InvalidInputException("Invalid secret type. Expected 'open_prompt', got '%s'", secret.GetType());
194+
}
195+
const auto *kv_secret = dynamic_cast<const KeyValueSecret*>(&secret);
196+
if (!kv_secret) {
197+
throw InvalidInputException("Invalid secret format for 'open_prompt' secret");
198+
}
199+
Value secret_value;
200+
if (kv_secret->TryGetValue(secret_key, secret_value)) {
201+
// std::cout << "USING SECRET FOR " << var_name << "\n";
202+
return secret_value.ToString();
203+
}
204+
}
205+
} catch (...) {
206+
// If secret lookup fails, fall back to user variables
149207
}
150-
return value.ToString();
151208
}
152209

210+
// Fall back to user variables if secret not found (using original var_name)
211+
Value value;
212+
auto &config = ClientConfig::GetConfig(context);
213+
if (!config.GetUserVariable(var_name, value) || value.IsNull()) {
214+
// std::cout << "USING SET FOR " << var_name << "\n";
215+
return default_value;
216+
}
217+
218+
// std::cout << "USING DEFAULT FOR " << var_name << "\n";
219+
return value.ToString();
220+
}
221+
222+
153223
static void SetConfigValue(DataChunk &args, ExpressionState &state, Vector &result,
154224
const string &var_name, const string &value_type) {
155225
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
@@ -356,6 +426,9 @@ namespace duckdb {
356426
LogicalType::VARCHAR, OpenPromptRequestFunction,
357427
OpenPromptBind));
358428

429+
// Register Secret functions
430+
CreateOpenPromptSecretFunctions::Register(instance);
431+
359432
ExtensionUtil::RegisterFunction(instance, open_prompt);
360433

361434
ExtensionUtil::RegisterFunction(instance, ScalarFunction(

src/open_prompt_secret.cpp

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "open_prompt_secret.hpp"
2+
#include "duckdb/common/exception.hpp"
3+
#include "duckdb/main/secret/secret.hpp"
4+
#include "duckdb/main/extension_util.hpp"
5+
6+
namespace duckdb {
7+
8+
static void CopySecret(const std::string &key, const CreateSecretInput &input, KeyValueSecret &result) {
9+
auto val = input.options.find(key);
10+
if (val != input.options.end()) {
11+
result.secret_map[key] = val->second;
12+
}
13+
}
14+
15+
static void RegisterCommonSecretParameters(CreateSecretFunction &function) {
16+
// Register open_prompt common parameters
17+
function.named_parameters["api_token"] = LogicalType::VARCHAR;
18+
function.named_parameters["api_url"] = LogicalType::VARCHAR;
19+
function.named_parameters["model_name"] = LogicalType::VARCHAR;
20+
function.named_parameters["api_timeout"] = LogicalType::VARCHAR;
21+
}
22+
23+
static void RedactCommonKeys(KeyValueSecret &result) {
24+
// Redact sensitive information
25+
result.redact_keys.insert("api_token");
26+
}
27+
28+
static unique_ptr<BaseSecret> CreateOpenPromptSecretFromConfig(ClientContext &context, CreateSecretInput &input) {
29+
auto scope = input.scope;
30+
auto result = make_uniq<KeyValueSecret>(scope, input.type, input.provider, input.name);
31+
32+
// Copy all relevant secrets
33+
CopySecret("api_token", input, *result);
34+
CopySecret("api_url", input, *result);
35+
CopySecret("model_name", input, *result);
36+
CopySecret("api_timeout", input, *result);
37+
38+
// Redact sensitive keys
39+
RedactCommonKeys(*result);
40+
41+
return std::move(result);
42+
}
43+
44+
void CreateOpenPromptSecretFunctions::Register(DatabaseInstance &instance) {
45+
string type = "open_prompt";
46+
47+
// Register the new type
48+
SecretType secret_type;
49+
secret_type.name = type;
50+
secret_type.deserializer = KeyValueSecret::Deserialize<KeyValueSecret>;
51+
secret_type.default_provider = "config";
52+
ExtensionUtil::RegisterSecretType(instance, secret_type);
53+
54+
// Register the config secret provider
55+
CreateSecretFunction config_function = {type, "config", CreateOpenPromptSecretFromConfig};
56+
RegisterCommonSecretParameters(config_function);
57+
ExtensionUtil::RegisterFunction(instance, config_function);
58+
}
59+
60+
} // namespace duckdb

test/sql/open_prompt.test

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# name: test/sql/rusty_quack.test
2+
# description: test rusty_quack extension
3+
# group: [quack]
4+
5+
# Before we load the extension, this will fail
6+
statement error
7+
SELECT open_prompt('error');
8+
----
9+
Catalog Error: Scalar Function with name open_prompt does not exist!
10+
11+
# Require statement will ensure the extension is loaded from now on
12+
require open_prompt
13+
14+
# Confirm the extension works by setting a secret
15+
query I
16+
CREATE SECRET IF NOT EXISTS open_prompt (
17+
TYPE open_prompt,
18+
PROVIDER config,
19+
api_token 'xxxxx',
20+
api_url 'https://api.groq.com/openai/v1/chat/completions',
21+
model_name 'llama-3.3-70b-versatile',
22+
api_timeout '30'
23+
);
24+
----
25+
true
26+
27+
# Confirm the secret exists
28+
query I
29+
SELECT name FROM duckdb_secrets() WHERE name = 'open_prompt' ;
30+
----
31+
open_prompt

0 commit comments

Comments
 (0)