Skip to content

Commit 9b830d2

Browse files
committed
hack query planner
1 parent 73d288e commit 9b830d2

File tree

2 files changed

+110
-132
lines changed

2 files changed

+110
-132
lines changed

src/include/wvlet_extension.hpp

+57-6
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
#include "duckdb/function/table_function.hpp"
77
#include "duckdb/main/client_context.hpp"
88

9-
// Declare the external wvlet_compile_query function
109
extern "C" {
11-
int wvlet_compile_main(const char*);
12-
const char* wvlet_compile_query(const char* json_query);
10+
extern int ScalaNativeInit(void);
11+
extern int wvlet_compile_main(const char*);
12+
extern const char* wvlet_compile_query(const char* json_query); // Changed from wvlet_compile_compile
1313
}
1414

1515
namespace duckdb {
@@ -34,11 +34,62 @@ struct WvletScriptFunction {
3434
vector<unique_ptr<Expression>> &arguments);
3535
};
3636

37+
3738
class WvletExtension : public Extension {
3839
public:
39-
void Load(DuckDB &db) override;
40-
std::string Name() override;
41-
std::string Version() const override;
40+
void Load(DuckDB &db) override;
41+
std::string Name() override { return "wvlet"; }
42+
};
43+
44+
BoundStatement wvlet_bind(ClientContext &context, Binder &binder,
45+
OperatorExtensionInfo *info, SQLStatement &statement);
46+
47+
struct WvletOperatorExtension : public OperatorExtension {
48+
WvletOperatorExtension() : OperatorExtension() { Bind = wvlet_bind; }
49+
50+
std::string GetName() override { return "wvlet"; }
51+
52+
unique_ptr<LogicalExtensionOperator>
53+
Deserialize(Deserializer &deserializer) override {
54+
throw InternalException("wvlet operator should not be serialized");
55+
}
56+
};
57+
58+
ParserExtensionParseResult wvlet_parse(ParserExtensionInfo *,
59+
const std::string &query);
60+
61+
ParserExtensionPlanResult wvlet_plan(ParserExtensionInfo *, ClientContext &,
62+
unique_ptr<ParserExtensionParseData>);
63+
64+
struct WvletParserExtension : public ParserExtension {
65+
WvletParserExtension() : ParserExtension() {
66+
parse_function = wvlet_parse;
67+
plan_function = wvlet_plan;
68+
}
69+
};
70+
71+
struct WvletParseData : ParserExtensionParseData {
72+
unique_ptr<SQLStatement> statement;
73+
74+
unique_ptr<ParserExtensionParseData> Copy() const override {
75+
return make_uniq_base<ParserExtensionParseData, WvletParseData>(
76+
statement->Copy());
77+
}
78+
79+
virtual string ToString() const override { return "WvletParseData"; }
80+
81+
WvletParseData(unique_ptr<SQLStatement> statement)
82+
: statement(std::move(statement)) {}
83+
};
84+
85+
class WvletState : public ClientContextState {
86+
public:
87+
explicit WvletState(unique_ptr<ParserExtensionParseData> parse_data)
88+
: parse_data(std::move(parse_data)) {}
89+
90+
void QueryEnd() override { parse_data.reset(); }
91+
92+
unique_ptr<ParserExtensionParseData> parse_data;
4293
};
4394

4495
} // namespace duckdb

src/wvlet_extension.cpp

+53-126
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
11
#define DUCKDB_EXTENSION_MAIN
2+
23
#include "wvlet_extension.hpp"
34
#include "duckdb.hpp"
45
#include "duckdb/common/exception.hpp"
6+
#include "duckdb/parser/parser.hpp"
7+
#include "duckdb/parser/statement/extension_statement.hpp"
58
#include "duckdb/common/string_util.hpp"
69
#include "duckdb/function/table_function.hpp"
710
#include "duckdb/main/extension_util.hpp"
811
#include <duckdb/parser/parsed_data/create_table_function_info.hpp>
912
#include <fstream>
1013
#include <sstream>
1114
#include <stdexcept>
15+
#include <codecvt>
16+
#include <string>
1217

1318
#ifdef __cplusplus
1419
extern "C" {
1520
#endif
1621
extern int ScalaNativeInit(void);
17-
1822
extern int wvlet_compile_main(const char*);
19-
extern const char* wvlet_compile_compile(const char*);
20-
23+
extern const char* wvlet_compile_query(const char* json_query);
2124
#ifdef __cplusplus
2225
}
2326
#endif
2427

28+
2529
namespace duckdb {
2630

2731
// EXPERIMENT INIT
@@ -46,146 +50,69 @@ bool InitializeWvletRuntime() {
4650
}
4751
}
4852

49-
void WvletScriptFunction::ParseWvletScript(DataChunk &args, ExpressionState &state, Vector &result) {
50-
auto &input_vector = args.data[0];
51-
auto input = FlatVector::GetData<string_t>(input_vector);
52-
53-
for (idx_t i = 0; i < args.size(); i++) {
54-
string query = input[i].GetString();
55-
std::string json = "[\"-q\", \"" + query + "\"]";
56-
57-
const char* sql_result = wvlet_compile_query(json.c_str());
58-
59-
if (!sql_result || strlen(sql_result) == 0) {
60-
throw std::runtime_error("Failed to compile wvlet script");
61-
}
62-
63-
FlatVector::GetData<string_t>(result)[i] = StringVector::AddString(result, sql_result);
64-
}
65-
66-
result.Verify(args.size());
67-
}
68-
69-
unique_ptr<FunctionData> WvletScriptFunction::Bind(ClientContext &context, ScalarFunction &bound_function,
70-
vector<unique_ptr<Expression>> &arguments) {
71-
return nullptr;
72-
}
73-
74-
static unique_ptr<FunctionData> WvletBind(ClientContext &context, TableFunctionBindInput &input,
75-
vector<LogicalType> &return_types, vector<string> &names) {
76-
auto result = make_uniq<WvletBindData>();
77-
result->query = input.inputs[0].GetValue<string>();
78-
79-
std::string json = "[\"-q\", \"" + result->query + "\"]";
80-
81-
wvlet_compile_main(json.c_str());
82-
const char* sql_result = wvlet_compile_query(json.c_str());
83-
84-
if (!sql_result || strlen(sql_result) == 0) {
85-
throw std::runtime_error("Failed to compile wvlet script");
86-
}
87-
88-
result->query = std::string(sql_result);
89-
90-
// Create a temporary connection to execute the query and get the schema
91-
Connection conn(*context.db);
92-
auto result_set = conn.Query(result->query);
93-
94-
if (result_set->HasError()) {
95-
throw std::runtime_error(result_set->GetError());
96-
}
97-
98-
// Get the types and names of the columns from the result set
99-
for (auto &column : result_set->types) {
100-
return_types.push_back(column);
101-
}
102-
for (auto &name : result_set->names) {
103-
names.push_back(name);
104-
}
105-
106-
return std::move(result);
107-
}
108-
109-
static void WvletFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
110-
auto &bind_data = data_p.bind_data->Cast<WvletBindData>();
111-
112-
if (!bind_data.query_result) {
113-
throw std::runtime_error("query_result is nullptr");
114-
}
115-
116-
if (!bind_data.query_result->initialized) {
117-
118-
try {
119-
Connection conn(*context.db);
120-
121-
auto result = conn.Query(bind_data.query);
122-
123-
if (result->HasError()) {
124-
throw std::runtime_error(result->GetError());
125-
}
126-
127-
bind_data.query_result->result = std::move(result);
128-
bind_data.query_result->initialized = true;
129-
130-
auto &types = bind_data.query_result->result->types;
131-
132-
output.Destroy(); // Clean up the existing chunk
133-
output.Initialize(context, types); // Initialize with actual types
134-
} catch (const std::exception &e) {
135-
throw;
136-
}
137-
}
138-
139-
auto chunk = bind_data.query_result->result->Fetch();
140-
141-
if (!chunk || chunk->size() == 0) {
142-
output.SetCardinality(0);
143-
return;
144-
}
145-
146-
output.Reference(*chunk);
147-
output.SetCardinality(chunk->size());
148-
}
149-
15053
static void LoadInternal(DatabaseInstance &instance) {
151-
auto wvlet_fun = ScalarFunction("wvlet", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
152-
WvletScriptFunction::ParseWvletScript,
153-
WvletScriptFunction::Bind);
154-
ExtensionUtil::RegisterFunction(instance, wvlet_fun);
155-
156-
TableFunction wvlet_func("wvlet", {LogicalType::VARCHAR}, WvletFunction, WvletBind);
157-
ExtensionUtil::RegisterFunction(instance, wvlet_func);
54+
auto &config = DBConfig::GetConfig(instance);
55+
// Register the custom Wvlet parser extension
56+
WvletParserExtension wvlet_parser;
57+
config.parser_extensions.push_back(wvlet_parser);
58+
// No operator extensions added for now
15859
}
15960

16061
void WvletExtension::Load(DuckDB &db) {
161-
LoadInternal(*db.instance);
162-
// EXPERIMENT
163-
if (!InitializeWvletRuntime()) {
62+
LoadInternal(*db.instance);
63+
if (!InitializeWvletRuntime()) {
16464
throw std::runtime_error("Failed to initialize Wvlet runtime");
165-
}
65+
}
16666
}
16767

168-
std::string WvletExtension::Name() {
169-
return "wvlet";
68+
ParserExtensionParseResult wvlet_parse(ParserExtensionInfo *, const std::string &query) {
69+
// Directly pass through the query with no transformation
70+
auto sql_query = query;
71+
72+
std::string json = "[\"-q\", \"" + query + "\"]";
73+
std::cout << "in: " << json << "\n";
74+
wvlet_compile_main(json.c_str());
75+
std::cout << "in2: " << json.c_str() << "\n";
76+
const char* sql_result = wvlet_compile_query(json.c_str());
77+
std::cout << "out: " << sql_result << "\n";
78+
if (!sql_result || strlen(sql_result) == 0) {
79+
throw std::runtime_error("Failed to compile wvlet script");
80+
}
81+
82+
Parser parser; // Parse the SQL query
83+
parser.ParseQuery(sql_query);
84+
auto statements = std::move(parser.statements);
85+
86+
return ParserExtensionParseResult(
87+
make_uniq_base<ParserExtensionParseData, WvletParseData>(
88+
std::move(statements[0])));
17089
}
17190

172-
std::string WvletExtension::Version() const {
173-
#ifdef EXT_VERSION_WVLET
174-
return EXT_VERSION_WVLET;
175-
#else
176-
return "";
177-
#endif
91+
ParserExtensionPlanResult wvlet_plan(ParserExtensionInfo *, ClientContext &context,
92+
unique_ptr<ParserExtensionParseData> parse_data) {
93+
// Placeholder plan result
94+
return ParserExtensionPlanResult();
95+
}
96+
97+
BoundStatement wvlet_bind(ClientContext &context, Binder &binder,
98+
OperatorExtensionInfo *info, SQLStatement &statement) {
99+
// Directly return a no-op bound statement
100+
return {};
178101
}
179102

180103
} // namespace duckdb
181104

182105
extern "C" {
106+
183107
DUCKDB_EXTENSION_API void wvlet_init(duckdb::DatabaseInstance &db) {
184-
duckdb::DuckDB db_wrapper(db);
185-
db_wrapper.LoadExtension<duckdb::WvletExtension>();
108+
LoadInternal(db);
186109
}
187110

188111
DUCKDB_EXTENSION_API const char *wvlet_version() {
189-
return duckdb::DuckDB::LibraryVersion();
112+
return duckdb::DuckDB::LibraryVersion();
190113
}
191114
}
115+
116+
#ifndef DUCKDB_EXTENSION_MAIN
117+
#error DUCKDB_EXTENSION_MAIN not defined
118+
#endif

0 commit comments

Comments
 (0)