Skip to content

Commit f948af5

Browse files
authored
🐛 Fix: use context for multi statement queries (#91)
1 parent b9f6381 commit f948af5

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

pkg/plugin/helper.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package plugin
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/json"
67
"fmt"
@@ -25,7 +26,7 @@ type defaultsResponseBody struct {
2526
DefaultSchema string `json:"defaultSchema"`
2627
}
2728

28-
func autocompletionQueries(req *backend.CallResourceRequest, sender backend.CallResourceResponseSender, d *Datasource) error {
29+
func autocompletionQueries(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender, d *Datasource) error {
2930
path := req.Path
3031
log.DefaultLogger.Info("CallResource called", "path", path)
3132
var body schemaRequestBody
@@ -36,7 +37,7 @@ func autocompletionQueries(req *backend.CallResourceRequest, sender backend.Call
3637
}
3738
switch path {
3839
case "catalogs":
39-
rows, err := d.ExecuteQuery("SHOW CATALOGS")
40+
rows, err := d.QueryContext(ctx, "SHOW CATALOGS")
4041
if err != nil {
4142
log.DefaultLogger.Error("CallResource Error", "err", err)
4243
return err
@@ -74,7 +75,7 @@ func autocompletionQueries(req *backend.CallResourceRequest, sender backend.Call
7475
queryString = fmt.Sprintf("SHOW SCHEMAS IN %s", body.Catalog)
7576
}
7677
log.DefaultLogger.Info("CallResource called", "queryString", queryString)
77-
rows, err := d.ExecuteQuery(queryString)
78+
rows, err := d.QueryContext(ctx, queryString)
7879
if err != nil {
7980
log.DefaultLogger.Error("CallResource Error", "err", err)
8081
return err
@@ -114,7 +115,7 @@ func autocompletionQueries(req *backend.CallResourceRequest, sender backend.Call
114115
}
115116
}
116117
log.DefaultLogger.Info("CallResource called", "queryString", queryString)
117-
rows, err := d.ExecuteQuery(queryString)
118+
rows, err := d.QueryContext(ctx, queryString)
118119
if err != nil {
119120
log.DefaultLogger.Error("CallResource Error", "err", err)
120121
return err
@@ -150,7 +151,7 @@ func autocompletionQueries(req *backend.CallResourceRequest, sender backend.Call
150151
case "columns":
151152
queryString := fmt.Sprintf("DESCRIBE TABLE %s", body.Table)
152153
log.DefaultLogger.Info("CallResource called", "queryString", queryString)
153-
rows, err := d.ExecuteQuery(queryString)
154+
rows, err := d.QueryContext(ctx, queryString)
154155
if err != nil {
155156
log.DefaultLogger.Error("CallResource Error", "err", err)
156157
return err
@@ -190,7 +191,7 @@ func autocompletionQueries(req *backend.CallResourceRequest, sender backend.Call
190191
case "defaults":
191192
queryString := "SELECT current_catalog(), current_schema();"
192193
log.DefaultLogger.Info("CallResource called", "queryString", queryString)
193-
rows, err := d.ExecuteQuery(queryString)
194+
rows, err := d.QueryContext(ctx, queryString)
194195
if err != nil {
195196
log.DefaultLogger.Error("CallResource Error", "err", err)
196197
return err

pkg/plugin/plugin.go

+25-13
Original file line numberDiff line numberDiff line change
@@ -159,23 +159,35 @@ func (d *Datasource) RefreshDBConnection() error {
159159
return nil
160160
}
161161

162-
func (d *Datasource) ExecuteQuery(queryString string) (*sql.Rows, error) {
163-
rows, err := d.databricksDB.Query(queryString)
162+
// ExecContext is a helper function to execute a query on the Databricks SQL DB without returning any rows and handling session expiration
163+
func (d *Datasource) ExecContext(ctx context.Context, queryString string) error {
164+
_, err := d.databricksDB.ExecContext(ctx, queryString)
164165
if err != nil {
165166
if strings.Contains(err.Error(), "Invalid SessionHandle") {
166167
err = d.RefreshDBConnection()
167168
if err != nil {
168-
return nil, err
169+
return err
169170
}
170-
rows, err = d.ExecuteQuery(queryString)
171+
return d.ExecContext(ctx, queryString)
172+
}
173+
return err
174+
}
175+
return nil
176+
}
177+
178+
// QueryContext is a helper function to query the Databricks SQL DB returning the rows and handling session expiration
179+
func (d *Datasource) QueryContext(ctx context.Context, queryString string) (*sql.Rows, error) {
180+
rows, err := d.databricksDB.QueryContext(ctx, queryString)
181+
if err != nil {
182+
if strings.Contains(err.Error(), "Invalid SessionHandle") {
183+
err = d.RefreshDBConnection()
171184
if err != nil {
172185
return nil, err
173186
}
174-
} else {
175-
return nil, err
187+
return d.QueryContext(ctx, queryString)
176188
}
189+
return nil, err
177190
}
178-
179191
return rows, nil
180192
}
181193

@@ -187,7 +199,7 @@ type Datasource struct {
187199
}
188200

189201
func (d *Datasource) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
190-
return autocompletionQueries(req, sender, d)
202+
return autocompletionQueries(ctx, req, sender, d)
191203
}
192204

193205
// Dispose here tells plugin SDK that plugin wants to clean up resources when a new instance
@@ -230,7 +242,7 @@ type queryModel struct {
230242
QuerySettings querySettings `json:"querySettings"`
231243
}
232244

233-
func (d *Datasource) query(_ context.Context, pCtx backend.PluginContext, query backend.DataQuery) backend.DataResponse {
245+
func (d *Datasource) query(ctx context.Context, pCtx backend.PluginContext, query backend.DataQuery) backend.DataResponse {
234246
response := backend.DataResponse{}
235247

236248
// Unmarshal the JSON into our queryModel.
@@ -260,7 +272,7 @@ func (d *Datasource) query(_ context.Context, pCtx backend.PluginContext, query
260272
if len(queries) > 1 {
261273
// Execute all but the last statement without returning any data
262274
for _, query := range queries[:len(queries)-1] {
263-
_, err := d.ExecuteQuery(query)
275+
err := d.ExecContext(ctx, query)
264276
if err != nil {
265277
response.Error = err
266278
log.DefaultLogger.Info("Error", "err", err)
@@ -276,7 +288,7 @@ func (d *Datasource) query(_ context.Context, pCtx backend.PluginContext, query
276288

277289
frame := data.NewFrame("response")
278290

279-
rows, err := d.ExecuteQuery(queryString)
291+
rows, err := d.QueryContext(ctx, queryString)
280292
if err != nil {
281293
response.Error = err
282294
log.DefaultLogger.Info("Error", "err", err)
@@ -332,10 +344,10 @@ func (d *Datasource) query(_ context.Context, pCtx backend.PluginContext, query
332344
// The main use case for these health checks is the test button on the
333345
// datasource configuration page which allows users to verify that
334346
// a datasource is working as expected.
335-
func (d *Datasource) CheckHealth(_ context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
347+
func (d *Datasource) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
336348
log.DefaultLogger.Info("CheckHealth called", "request", req)
337349

338-
rows, err := d.ExecuteQuery("SELECT 1")
350+
rows, err := d.QueryContext(ctx, "SELECT 1")
339351

340352
if err != nil {
341353
return &backend.CheckHealthResult{

0 commit comments

Comments
 (0)