Skip to content

Commit 014acb5

Browse files
authored
Add ability to handle context cancellations for TCP protocol (#1389)
* Add ability to handle context cancellations. * Fixed ability to handle context cancellations. * Removed obsolete comments. * Added missing change. * Fixed data race on connection close(). * Synchronisation fix. * Fixed data race on connection close(). * Sync conn.close() calls eparately. Add test. * Add one more test. * Final clean-up. * Close net connection first, to release blocked reader, then all pending ops on that connection will be properly released. * Make new tests pretty. * Clean-up.
1 parent 360020f commit 014acb5

File tree

4 files changed

+391
-17
lines changed

4 files changed

+391
-17
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ _testmain.go
2727

2828
coverage.txt
2929
.idea/**
30+
.vscode/**
3031
dev/*
3132
.run/**
3233

conn.go

+70-6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"log"
2626
"net"
2727
"os"
28+
"sync"
2829
"syscall"
2930
"time"
3031

@@ -42,6 +43,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
4243
conn net.Conn
4344
debugf = func(format string, v ...any) {}
4445
)
46+
4547
switch {
4648
case opt.DialContext != nil:
4749
conn, err = opt.DialContext(ctx, addr)
@@ -53,9 +55,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
5355
conn, err = net.DialTimeout("tcp", addr, opt.DialTimeout)
5456
}
5557
}
58+
5659
if err != nil {
5760
return nil, err
5861
}
62+
5963
if opt.Debug {
6064
if opt.Debugf != nil {
6165
debugf = func(format string, v ...any) {
@@ -68,6 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
6872
debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse][conn=%d][%s]", num, conn.RemoteAddr()), 0).Printf
6973
}
7074
}
75+
7176
compression := CompressionNone
7277
if opt.Compression != nil {
7378
switch opt.Compression.Method {
@@ -96,9 +101,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
96101
maxCompressionBuffer: opt.MaxCompressionBuffer,
97102
}
98103
)
104+
99105
if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil {
100106
return nil, err
101107
}
108+
102109
if connect.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
103110
if err := connect.sendAddendum(); err != nil {
104111
return nil, err
@@ -109,6 +116,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
109116
if num == 1 && !resources.ClientMeta.IsSupportedClickHouseVersion(connect.server.Version) {
110117
debugf("[handshake] WARNING: version %v of ClickHouse is not supported by this client - client supports %v", connect.server.Version, resources.ClientMeta.SupportedVersions())
111118
}
119+
112120
return connect, nil
113121
}
114122

@@ -131,6 +139,8 @@ type connect struct {
131139
readTimeout time.Duration
132140
blockBufferSize uint8
133141
maxCompressionBuffer int
142+
readerMutex sync.Mutex
143+
closeMutex sync.Mutex
134144
}
135145

136146
func (c *connect) settings(querySettings Settings) []proto.Setting {
@@ -153,15 +163,16 @@ func (c *connect) settings(querySettings Settings) []proto.Setting {
153163
for k, v := range c.opt.Settings {
154164
settings = append(settings, settingToProtoSetting(k, v))
155165
}
166+
156167
for k, v := range querySettings {
157168
settings = append(settings, settingToProtoSetting(k, v))
158169
}
170+
159171
return settings
160172
}
161173

162174
func (c *connect) isBad() bool {
163-
switch {
164-
case c.closed:
175+
if c.isClosed() {
165176
return true
166177
}
167178

@@ -172,19 +183,43 @@ func (c *connect) isBad() bool {
172183
if err := c.connCheck(); err != nil {
173184
return true
174185
}
186+
175187
return false
176188
}
177189

190+
func (c *connect) isClosed() bool {
191+
c.closeMutex.Lock()
192+
defer c.closeMutex.Unlock()
193+
194+
return c.closed
195+
}
196+
197+
func (c *connect) setClosed() {
198+
c.closeMutex.Lock()
199+
defer c.closeMutex.Unlock()
200+
201+
c.closed = true
202+
}
203+
178204
func (c *connect) close() error {
205+
c.closeMutex.Lock()
179206
if c.closed {
207+
c.closeMutex.Unlock()
180208
return nil
181209
}
182210
c.closed = true
183-
c.buffer = nil
184-
c.reader = nil
211+
c.closeMutex.Unlock()
212+
185213
if err := c.conn.Close(); err != nil {
186214
return err
187215
}
216+
217+
c.buffer = nil
218+
219+
c.readerMutex.Lock()
220+
c.reader = nil
221+
c.readerMutex.Unlock()
222+
188223
return nil
189224
}
190225

@@ -193,6 +228,7 @@ func (c *connect) progress() (*Progress, error) {
193228
if err := progress.Decode(c.reader, c.revision); err != nil {
194229
return nil, err
195230
}
231+
196232
c.debugf("[progress] %s", &progress)
197233
return &progress, nil
198234
}
@@ -202,6 +238,7 @@ func (c *connect) exception() error {
202238
if err := e.Decode(c.reader); err != nil {
203239
return err
204240
}
241+
205242
c.debugf("[exception] %s", e.Error())
206243
return &e
207244
}
@@ -218,6 +255,12 @@ func (c *connect) compressBuffer(start int) error {
218255
}
219256

220257
func (c *connect) sendData(block *proto.Block, name string) error {
258+
if c.isClosed() {
259+
err := errors.New("attempted sending on closed connection")
260+
c.debugf("[send data] err: %v", err)
261+
return err
262+
}
263+
221264
c.debugf("[send data] compression=%q", c.compression)
222265
c.buffer.PutByte(proto.ClientData)
223266
c.buffer.PutString(name)
@@ -227,6 +270,7 @@ func (c *connect) sendData(block *proto.Block, name string) error {
227270
if err := block.EncodeHeader(c.buffer, c.revision); err != nil {
228271
return err
229272
}
273+
230274
for i := range block.Columns {
231275
if err := block.EncodeColumn(c.buffer, c.revision, i); err != nil {
232276
return err
@@ -242,33 +286,50 @@ func (c *connect) sendData(block *proto.Block, name string) error {
242286
compressionOffset = 0
243287
}
244288
}
289+
245290
if err := c.compressBuffer(compressionOffset); err != nil {
246291
return err
247292
}
293+
248294
if err := c.flush(); err != nil {
249295
switch {
250296
case errors.Is(err, syscall.EPIPE):
251297
c.debugf("[send data] pipe is broken, closing connection")
252-
c.closed = true
298+
c.setClosed()
253299
case errors.Is(err, io.EOF):
254300
c.debugf("[send data] unexpected EOF, closing connection")
255-
c.closed = true
301+
c.setClosed()
256302
default:
257303
c.debugf("[send data] unexpected error: %v", err)
258304
}
259305
return err
260306
}
307+
261308
defer func() {
262309
c.buffer.Reset()
263310
}()
311+
264312
return nil
265313
}
266314

267315
func (c *connect) readData(ctx context.Context, packet byte, compressible bool) (*proto.Block, error) {
316+
if c.isClosed() {
317+
err := errors.New("attempted reading on closed connection")
318+
c.debugf("[read data] err: %v", err)
319+
return nil, err
320+
}
321+
322+
if c.reader == nil {
323+
err := errors.New("attempted reading on nil reader")
324+
c.debugf("[read data] err: %v", err)
325+
return nil, err
326+
}
327+
268328
if _, err := c.reader.Str(); err != nil {
269329
c.debugf("[read data] str error: %v", err)
270330
return nil, err
271331
}
332+
272333
if compressible && c.compression != CompressionNone {
273334
c.reader.EnableCompression()
274335
defer c.reader.DisableCompression()
@@ -285,6 +346,7 @@ func (c *connect) readData(ctx context.Context, packet byte, compressible bool)
285346
c.debugf("[read data] decode error: %v", err)
286347
return nil, err
287348
}
349+
288350
block.Packet = packet
289351
c.debugf("[read data] compression=%q. block: columns=%d, rows=%d", c.compression, len(block.Columns), block.Rows())
290352
return &block, nil
@@ -295,10 +357,12 @@ func (c *connect) flush() error {
295357
// Nothing to flush.
296358
return nil
297359
}
360+
298361
n, err := c.conn.Write(c.buffer.Buf)
299362
if err != nil {
300363
return errors.Wrap(err, "write")
301364
}
365+
302366
if n != len(c.buffer.Buf) {
303367
return errors.New("wrote less than expected")
304368
}

0 commit comments

Comments
 (0)