@@ -25,6 +25,7 @@ import (
25
25
"log"
26
26
"net"
27
27
"os"
28
+ "sync"
28
29
"syscall"
29
30
"time"
30
31
@@ -42,6 +43,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
42
43
conn net.Conn
43
44
debugf = func (format string , v ... any ) {}
44
45
)
46
+
45
47
switch {
46
48
case opt .DialContext != nil :
47
49
conn , err = opt .DialContext (ctx , addr )
@@ -53,9 +55,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
53
55
conn , err = net .DialTimeout ("tcp" , addr , opt .DialTimeout )
54
56
}
55
57
}
58
+
56
59
if err != nil {
57
60
return nil , err
58
61
}
62
+
59
63
if opt .Debug {
60
64
if opt .Debugf != nil {
61
65
debugf = func (format string , v ... any ) {
@@ -68,6 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
68
72
debugf = log .New (os .Stdout , fmt .Sprintf ("[clickhouse][conn=%d][%s]" , num , conn .RemoteAddr ()), 0 ).Printf
69
73
}
70
74
}
75
+
71
76
compression := CompressionNone
72
77
if opt .Compression != nil {
73
78
switch opt .Compression .Method {
@@ -96,9 +101,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
96
101
maxCompressionBuffer : opt .MaxCompressionBuffer ,
97
102
}
98
103
)
104
+
99
105
if err := connect .handshake (opt .Auth .Database , opt .Auth .Username , opt .Auth .Password ); err != nil {
100
106
return nil , err
101
107
}
108
+
102
109
if connect .revision >= proto .DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
103
110
if err := connect .sendAddendum (); err != nil {
104
111
return nil , err
@@ -109,6 +116,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
109
116
if num == 1 && ! resources .ClientMeta .IsSupportedClickHouseVersion (connect .server .Version ) {
110
117
debugf ("[handshake] WARNING: version %v of ClickHouse is not supported by this client - client supports %v" , connect .server .Version , resources .ClientMeta .SupportedVersions ())
111
118
}
119
+
112
120
return connect , nil
113
121
}
114
122
@@ -131,6 +139,8 @@ type connect struct {
131
139
readTimeout time.Duration
132
140
blockBufferSize uint8
133
141
maxCompressionBuffer int
142
+ readerMutex sync.Mutex
143
+ closeMutex sync.Mutex
134
144
}
135
145
136
146
func (c * connect ) settings (querySettings Settings ) []proto.Setting {
@@ -153,15 +163,16 @@ func (c *connect) settings(querySettings Settings) []proto.Setting {
153
163
for k , v := range c .opt .Settings {
154
164
settings = append (settings , settingToProtoSetting (k , v ))
155
165
}
166
+
156
167
for k , v := range querySettings {
157
168
settings = append (settings , settingToProtoSetting (k , v ))
158
169
}
170
+
159
171
return settings
160
172
}
161
173
162
174
func (c * connect ) isBad () bool {
163
- switch {
164
- case c .closed :
175
+ if c .isClosed () {
165
176
return true
166
177
}
167
178
@@ -172,19 +183,43 @@ func (c *connect) isBad() bool {
172
183
if err := c .connCheck (); err != nil {
173
184
return true
174
185
}
186
+
175
187
return false
176
188
}
177
189
190
+ func (c * connect ) isClosed () bool {
191
+ c .closeMutex .Lock ()
192
+ defer c .closeMutex .Unlock ()
193
+
194
+ return c .closed
195
+ }
196
+
197
+ func (c * connect ) setClosed () {
198
+ c .closeMutex .Lock ()
199
+ defer c .closeMutex .Unlock ()
200
+
201
+ c .closed = true
202
+ }
203
+
178
204
func (c * connect ) close () error {
205
+ c .closeMutex .Lock ()
179
206
if c .closed {
207
+ c .closeMutex .Unlock ()
180
208
return nil
181
209
}
182
210
c .closed = true
183
- c .buffer = nil
184
- c . reader = nil
211
+ c .closeMutex . Unlock ()
212
+
185
213
if err := c .conn .Close (); err != nil {
186
214
return err
187
215
}
216
+
217
+ c .buffer = nil
218
+
219
+ c .readerMutex .Lock ()
220
+ c .reader = nil
221
+ c .readerMutex .Unlock ()
222
+
188
223
return nil
189
224
}
190
225
@@ -193,6 +228,7 @@ func (c *connect) progress() (*Progress, error) {
193
228
if err := progress .Decode (c .reader , c .revision ); err != nil {
194
229
return nil , err
195
230
}
231
+
196
232
c .debugf ("[progress] %s" , & progress )
197
233
return & progress , nil
198
234
}
@@ -202,6 +238,7 @@ func (c *connect) exception() error {
202
238
if err := e .Decode (c .reader ); err != nil {
203
239
return err
204
240
}
241
+
205
242
c .debugf ("[exception] %s" , e .Error ())
206
243
return & e
207
244
}
@@ -218,6 +255,12 @@ func (c *connect) compressBuffer(start int) error {
218
255
}
219
256
220
257
func (c * connect ) sendData (block * proto.Block , name string ) error {
258
+ if c .isClosed () {
259
+ err := errors .New ("attempted sending on closed connection" )
260
+ c .debugf ("[send data] err: %v" , err )
261
+ return err
262
+ }
263
+
221
264
c .debugf ("[send data] compression=%q" , c .compression )
222
265
c .buffer .PutByte (proto .ClientData )
223
266
c .buffer .PutString (name )
@@ -227,6 +270,7 @@ func (c *connect) sendData(block *proto.Block, name string) error {
227
270
if err := block .EncodeHeader (c .buffer , c .revision ); err != nil {
228
271
return err
229
272
}
273
+
230
274
for i := range block .Columns {
231
275
if err := block .EncodeColumn (c .buffer , c .revision , i ); err != nil {
232
276
return err
@@ -242,33 +286,50 @@ func (c *connect) sendData(block *proto.Block, name string) error {
242
286
compressionOffset = 0
243
287
}
244
288
}
289
+
245
290
if err := c .compressBuffer (compressionOffset ); err != nil {
246
291
return err
247
292
}
293
+
248
294
if err := c .flush (); err != nil {
249
295
switch {
250
296
case errors .Is (err , syscall .EPIPE ):
251
297
c .debugf ("[send data] pipe is broken, closing connection" )
252
- c .closed = true
298
+ c .setClosed ()
253
299
case errors .Is (err , io .EOF ):
254
300
c .debugf ("[send data] unexpected EOF, closing connection" )
255
- c .closed = true
301
+ c .setClosed ()
256
302
default :
257
303
c .debugf ("[send data] unexpected error: %v" , err )
258
304
}
259
305
return err
260
306
}
307
+
261
308
defer func () {
262
309
c .buffer .Reset ()
263
310
}()
311
+
264
312
return nil
265
313
}
266
314
267
315
func (c * connect ) readData (ctx context.Context , packet byte , compressible bool ) (* proto.Block , error ) {
316
+ if c .isClosed () {
317
+ err := errors .New ("attempted reading on closed connection" )
318
+ c .debugf ("[read data] err: %v" , err )
319
+ return nil , err
320
+ }
321
+
322
+ if c .reader == nil {
323
+ err := errors .New ("attempted reading on nil reader" )
324
+ c .debugf ("[read data] err: %v" , err )
325
+ return nil , err
326
+ }
327
+
268
328
if _ , err := c .reader .Str (); err != nil {
269
329
c .debugf ("[read data] str error: %v" , err )
270
330
return nil , err
271
331
}
332
+
272
333
if compressible && c .compression != CompressionNone {
273
334
c .reader .EnableCompression ()
274
335
defer c .reader .DisableCompression ()
@@ -285,6 +346,7 @@ func (c *connect) readData(ctx context.Context, packet byte, compressible bool)
285
346
c .debugf ("[read data] decode error: %v" , err )
286
347
return nil , err
287
348
}
349
+
288
350
block .Packet = packet
289
351
c .debugf ("[read data] compression=%q. block: columns=%d, rows=%d" , c .compression , len (block .Columns ), block .Rows ())
290
352
return & block , nil
@@ -295,10 +357,12 @@ func (c *connect) flush() error {
295
357
// Nothing to flush.
296
358
return nil
297
359
}
360
+
298
361
n , err := c .conn .Write (c .buffer .Buf )
299
362
if err != nil {
300
363
return errors .Wrap (err , "write" )
301
364
}
365
+
302
366
if n != len (c .buffer .Buf ) {
303
367
return errors .New ("wrote less than expected" )
304
368
}
0 commit comments