Skip to content

Commit 2a9d367

Browse files
Dynamic Reconnect and Multi-Channel Support (#19)
1 parent 862e101 commit 2a9d367

File tree

1 file changed

+99
-32
lines changed

1 file changed

+99
-32
lines changed

subscriber.go

+99-32
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io/ioutil"
99
"log"
10+
"math/rand"
1011
"os"
1112
"os/signal"
1213
"runtime/pprof"
@@ -47,32 +48,71 @@ type testResult struct {
4748
Addresses []string `json:"Addresses"`
4849
}
4950

50-
func subscriberRoutine(mode, channel string, printMessages bool, ctx context.Context, wg *sync.WaitGroup, client *redis.Client) {
51-
// tell the caller we've stopped
51+
func subscriberRoutine(mode string, channels []string, printMessages bool, connectionReconnectInterval int, ctx context.Context, wg *sync.WaitGroup, client *redis.Client) {
52+
// Tell the caller we've stopped
5253
defer wg.Done()
53-
switch mode {
54-
case "ssubscribe":
55-
spubsub := client.SSubscribe(ctx, channel)
56-
defer spubsub.Close()
57-
for {
58-
msg, err := spubsub.ReceiveMessage(ctx)
59-
if err != nil {
60-
panic(err)
61-
}
62-
if printMessages {
63-
fmt.Println(fmt.Sprintf("received message in channel %s. Message: %s", msg.Channel, msg.Payload))
54+
var reconnectTicker *time.Ticker
55+
if connectionReconnectInterval > 0 {
56+
reconnectTicker = time.NewTicker(time.Duration(connectionReconnectInterval) * time.Second)
57+
defer reconnectTicker.Stop()
58+
} else {
59+
reconnectTicker = time.NewTicker(1 * time.Second)
60+
reconnectTicker.Stop()
61+
}
62+
63+
var pubsub *redis.PubSub
64+
65+
// Helper function to handle subscription based on mode
66+
subscribe := func() {
67+
if pubsub != nil {
68+
// Unsubscribe based on mode before re-subscribing
69+
if mode == "ssubscribe" {
70+
if err := pubsub.SUnsubscribe(ctx, channels...); err != nil {
71+
fmt.Printf("Error during SUnsubscribe: %v\n", err)
72+
}
73+
} else {
74+
if err := pubsub.Unsubscribe(ctx, channels...); err != nil {
75+
fmt.Printf("Error during Unsubscribe: %v\n", err)
76+
}
6477
}
65-
atomic.AddUint64(&totalMessages, 1)
78+
pubsub.Close()
79+
}
80+
switch mode {
81+
case "ssubscribe":
82+
pubsub = client.SSubscribe(ctx, channels...)
83+
default:
84+
pubsub = client.Subscribe(ctx, channels...)
6685
}
67-
break
68-
case "subscribe":
69-
fallthrough
70-
default:
71-
pubsub := client.Subscribe(ctx, channel)
72-
defer pubsub.Close()
73-
for {
86+
}
87+
88+
subscribe()
89+
90+
for {
91+
select {
92+
case <-ctx.Done():
93+
// Context cancelled, exit routine
94+
if pubsub != nil {
95+
if mode == "ssubscribe" {
96+
_ = pubsub.SUnsubscribe(ctx, channels...)
97+
} else {
98+
_ = pubsub.Unsubscribe(ctx, channels...)
99+
}
100+
pubsub.Close()
101+
}
102+
return
103+
case <-reconnectTicker.C:
104+
// Reconnect interval triggered, unsubscribe and resubscribe
105+
if reconnectTicker != nil {
106+
subscribe()
107+
}
108+
default:
109+
// Handle messages
74110
msg, err := pubsub.ReceiveMessage(ctx)
75111
if err != nil {
112+
// Handle Redis connection errors, e.g., reconnect immediately
113+
if err == redis.Nil || err == context.DeadlineExceeded || err == context.Canceled {
114+
continue
115+
}
76116
panic(err)
77117
}
78118
if printMessages {
@@ -81,7 +121,6 @@ func subscriberRoutine(mode, channel string, printMessages bool, ctx context.Con
81121
atomic.AddUint64(&totalMessages, 1)
82122
}
83123
}
84-
85124
}
86125

87126
func main() {
@@ -95,6 +134,10 @@ func main() {
95134
channel_minimum := flag.Int("channel-minimum", 1, "channel ID minimum value ( each channel has a dedicated thread ).")
96135
channel_maximum := flag.Int("channel-maximum", 100, "channel ID maximum value ( each channel has a dedicated thread ).")
97136
subscribers_per_channel := flag.Int("subscribers-per-channel", 1, "number of subscribers per channel.")
137+
min_channels_per_subscriber := flag.Int("min-number-channels-per-subscriber", 1, "min number of channels to subscribe to, per connection.")
138+
max_channels_per_subscriber := flag.Int("max-number-channels-per-subscriber", 1, "max number of channels to subscribe to, per connection.")
139+
min_reconnect_interval := flag.Int("min-reconnect-interval", 0, "min reconnect interval. if 0 disable (s)unsubscribe/(s)ubscribe.")
140+
max_reconnect_interval := flag.Int("max-reconnect-interval", 0, "max reconnect interval. if 0 disable (s)unsubscribe/(s)ubscribe.")
98141
messages_per_channel_subscriber := flag.Int64("messages", 0, "Number of total messages per subscriber per channel.")
99142
json_out_file := flag.String("json-out-file", "", "Name of json output file, if not set, will not print to json.")
100143
client_update_tick := flag.Int("client-update-tick", 1, "client update tick.")
@@ -191,16 +234,19 @@ func main() {
191234
poolSize = subscriptions_per_node
192235
log.Println(fmt.Sprintf("Setting per Node pool size of %d given you haven't specified a value and we have %d Subscriptions per node. You can control this option via --%s=<value>", poolSize, subscriptions_per_node, redisPoolSize))
193236
clusterOptions.PoolSize = poolSize
194-
log.Println("Reloading cluster state given we've changed pool size.")
195-
clusterClient = redis.NewClusterClient(&clusterOptions)
196-
// ReloadState reloads cluster state. It calls ClusterSlots func
197-
// to get cluster slots information.
198-
clusterClient.ReloadState(ctx)
199-
err := clusterClient.Ping(ctx).Err()
200-
if err != nil {
201-
log.Fatal(err)
237+
if *distributeSubscribers {
238+
log.Println("Reloading cluster state given we've changed pool size.")
239+
clusterClient = redis.NewClusterClient(&clusterOptions)
240+
// ReloadState reloads cluster state. It calls ClusterSlots func
241+
// to get cluster slots information.
242+
clusterClient.ReloadState(ctx)
243+
err := clusterClient.Ping(ctx).Err()
244+
if err != nil {
245+
log.Fatal(err)
246+
}
247+
nodeCount, nodeClients, nodesAddresses = updateSecondarySlicesCluster(clusterClient, ctx)
202248
}
203-
nodeCount, nodeClients, nodesAddresses = updateSecondarySlicesCluster(clusterClient, ctx)
249+
204250
}
205251

206252
log.Println(fmt.Sprintf("Detailing final setup used for benchmark."))
@@ -241,6 +287,18 @@ func main() {
241287
for channel_id := *channel_minimum; channel_id <= *channel_maximum; channel_id++ {
242288
channel := fmt.Sprintf("%s%d", *subscribe_prefix, channel_id)
243289
for channel_subscriber_number := 1; channel_subscriber_number <= *subscribers_per_channel; channel_subscriber_number++ {
290+
channels := []string{channel}
291+
n_channels_this_conn := 1
292+
if *max_channels_per_subscriber == *min_channels_per_subscriber {
293+
n_channels_this_conn = *max_channels_per_subscriber
294+
} else {
295+
n_channels_this_conn = rand.Intn(*max_channels_per_subscriber - *min_channels_per_subscriber)
296+
}
297+
for channel_this_conn := 1; channel_this_conn < n_channels_this_conn; channel_this_conn++ {
298+
new_channel_id := rand.Intn(*channel_maximum) + *channel_minimum
299+
new_channel := fmt.Sprintf("%s%d", *subscribe_prefix, new_channel_id)
300+
channels = append(channels, new_channel)
301+
}
244302
totalCreatedClients++
245303
subscriberName := fmt.Sprintf("subscriber#%d-%s%d", channel_subscriber_number, *subscribe_prefix, channel_id)
246304
var client *redis.Client
@@ -268,7 +326,16 @@ func main() {
268326
}
269327
}
270328
wg.Add(1)
271-
go subscriberRoutine(*mode, channel, *printMessages, ctx, &wg, client)
329+
connectionReconnectInterval := 0
330+
if *max_reconnect_interval == *min_reconnect_interval {
331+
connectionReconnectInterval = *max_reconnect_interval
332+
} else {
333+
connectionReconnectInterval = rand.Intn(*max_reconnect_interval-*min_reconnect_interval) + *max_reconnect_interval
334+
}
335+
if connectionReconnectInterval > 0 {
336+
log.Println(fmt.Sprintf("Using reconnection interval of %d for subscriber: %s", connectionReconnectInterval, subscriberName))
337+
}
338+
go subscriberRoutine(*mode, channels, *printMessages, connectionReconnectInterval, ctx, &wg, client)
272339
}
273340
}
274341
}

0 commit comments

Comments
 (0)