Skip to content

Commit 3e777af

Browse files
authored
Validate connection in bad state before query execution in the stdlib database/sql driver (#1396)
The current behavior of a library is to invalidate connection if it encounters any of ClickHouse errors. Connection in bad state shouldn't be reused. Stdlib driver attempts to use connection without checking whether the connection is in a good condition.
1 parent 2bd4c82 commit 3e777af

File tree

2 files changed

+137
-2
lines changed

2 files changed

+137
-2
lines changed

clickhouse_std.go

+37-2
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,32 @@ func (std *stdDriver) ResetSession(ctx context.Context) error {
239239

240240
var _ driver.SessionResetter = (*stdDriver)(nil)
241241

242-
func (std *stdDriver) Ping(ctx context.Context) error { return std.conn.ping(ctx) }
242+
func (std *stdDriver) Ping(ctx context.Context) error {
243+
if std.conn.isBad() {
244+
std.debugf("Ping: connection is bad")
245+
return driver.ErrBadConn
246+
}
247+
248+
return std.conn.ping(ctx)
249+
}
243250

244251
var _ driver.Pinger = (*stdDriver)(nil)
245252

246-
func (std *stdDriver) Begin() (driver.Tx, error) { return std, nil }
253+
func (std *stdDriver) Begin() (driver.Tx, error) {
254+
if std.conn.isBad() {
255+
std.debugf("Begin: connection is bad")
256+
return nil, driver.ErrBadConn
257+
}
258+
259+
return std, nil
260+
}
261+
247262
func (std *stdDriver) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
263+
if std.conn.isBad() {
264+
std.debugf("BeginTx: connection is bad")
265+
return nil, driver.ErrBadConn
266+
}
267+
248268
return std, nil
249269
}
250270

@@ -280,6 +300,11 @@ func (std *stdDriver) CheckNamedValue(nv *driver.NamedValue) error { return nil
280300
var _ driver.NamedValueChecker = (*stdDriver)(nil)
281301

282302
func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
303+
if std.conn.isBad() {
304+
std.debugf("ExecContext: connection is bad")
305+
return nil, driver.ErrBadConn
306+
}
307+
283308
var err error
284309
if options := queryOptions(ctx); options.async.ok {
285310
err = std.conn.asyncInsert(ctx, query, options.async.wait, rebind(args)...)
@@ -299,6 +324,11 @@ func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driv
299324
}
300325

301326
func (std *stdDriver) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
327+
if std.conn.isBad() {
328+
std.debugf("QueryContext: connection is bad")
329+
return nil, driver.ErrBadConn
330+
}
331+
302332
r, err := std.conn.query(ctx, func(*connect, error) {}, query, rebind(args)...)
303333
if isConnBrokenError(err) {
304334
std.debugf("QueryContext got a fatal error, resetting connection: %v\n", err)
@@ -319,6 +349,11 @@ func (std *stdDriver) Prepare(query string) (driver.Stmt, error) {
319349
}
320350

321351
func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
352+
if std.conn.isBad() {
353+
std.debugf("PrepareContext: connection is bad")
354+
return nil, driver.ErrBadConn
355+
}
356+
322357
batch, err := std.conn.prepareBatch(ctx, query, ldriver.PrepareBatchOptions{}, func(*connect, error) {}, func(context.Context) (*connect, error) { return nil, nil })
323358
if err != nil {
324359
if isConnBrokenError(err) {

tests/issues/1395_test.go

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// Licensed to ClickHouse, Inc. under one or more contributor
2+
// license agreements. See the NOTICE file distributed with
3+
// this work for additional information regarding copyright
4+
// ownership. ClickHouse, Inc. licenses this file to you under
5+
// the Apache License, Version 2.0 (the "License"); you may
6+
// not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package issues
19+
20+
import (
21+
"context"
22+
"database/sql"
23+
"database/sql/driver"
24+
"testing"
25+
26+
"github.com/ClickHouse/clickhouse-go/v2"
27+
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
28+
"github.com/pkg/errors"
29+
"github.com/stretchr/testify/require"
30+
)
31+
32+
func Test1395(t *testing.T) {
33+
testEnv, err := clickhouse_tests.GetTestEnvironment("issues")
34+
require.NoError(t, err)
35+
opts := clickhouse_tests.ClientOptionsFromEnv(testEnv, clickhouse.Settings{}, false)
36+
conn, err := sql.Open("clickhouse", clickhouse_tests.OptionsToDSN(&opts))
37+
require.NoError(t, err)
38+
39+
ctx := context.Background()
40+
41+
singleConn, err := conn.Conn(ctx)
42+
if err != nil {
43+
t.Fatalf("Get single conn from pool: %v", err)
44+
}
45+
46+
tx1 := func(c *sql.Conn) error {
47+
tx, err := c.BeginTx(ctx, nil)
48+
if err != nil {
49+
return errors.Wrap(err, "begin tx")
50+
}
51+
defer tx.Rollback()
52+
53+
_, err = tx.ExecContext(ctx, `
54+
CREATE TABLE IF NOT EXISTS test_table
55+
ON CLUSTER my
56+
(id UInt32, name String)
57+
ENGINE = MergeTree()
58+
ORDER BY id`)
59+
if err != nil {
60+
return errors.Wrap(err, "create table")
61+
}
62+
63+
err = tx.Commit()
64+
if err != nil {
65+
return errors.Wrap(err, "commit tx")
66+
}
67+
68+
return nil
69+
}
70+
71+
err = tx1(singleConn)
72+
require.Error(t, err, "expected error due to cluster is not configured")
73+
74+
tx2 := func(c *sql.Conn) error {
75+
tx, err := c.BeginTx(ctx, nil)
76+
if err != nil {
77+
return errors.Wrap(err, "begin tx")
78+
}
79+
defer tx.Rollback()
80+
81+
_, err = tx.ExecContext(ctx, "INSERT INTO test_table (id, name) VALUES (?, ?)", 1, "test_name")
82+
if err != nil {
83+
return errors.Wrap(err, "failed to insert record")
84+
}
85+
err = tx.Commit()
86+
if err != nil {
87+
return errors.Wrap(err, "commit tx")
88+
}
89+
90+
return nil
91+
}
92+
require.NotPanics(
93+
t,
94+
func() {
95+
err := tx2(singleConn)
96+
require.ErrorIs(t, err, driver.ErrBadConn)
97+
},
98+
"must not panics",
99+
)
100+
}

0 commit comments

Comments
 (0)