Skip to content

Commit c04adbf

Browse files
committed
Add receive chunk tracker for better received chunk handling
1 parent e4788a9 commit c04adbf

4 files changed

+290
-22
lines changed

association.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ type Association struct {
175175
myMaxNumInboundStreams uint16
176176
myMaxNumOutboundStreams uint16
177177
myCookie *paramStateCookie
178-
payloadQueue *payloadQueue
178+
payloadQueue *receivedChunkTracker
179179
inflightQueue *payloadQueue
180180
pendingQueue *pendingQueue
181181
controlQueue *controlQueue
@@ -318,7 +318,7 @@ func createAssociation(config Config) *Association {
318318
myMaxNumOutboundStreams: math.MaxUint16,
319319
myMaxNumInboundStreams: math.MaxUint16,
320320

321-
payloadQueue: newPayloadQueue(),
321+
payloadQueue: newReceivedPacketTracker(),
322322
inflightQueue: newPayloadQueue(),
323323
pendingQueue: newPendingQueue(),
324324
controlQueue: newControlQueue(),
@@ -1378,7 +1378,7 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet {
13781378
a.name, d.tsn, d.immediateSack, len(d.userData))
13791379
a.stats.incDATAs()
13801380

1381-
canPush := a.payloadQueue.canPush(d, a.peerLastTSN)
1381+
canPush := a.payloadQueue.canPush(d.tsn, a.peerLastTSN)
13821382
if canPush {
13831383
s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown)
13841384
if s == nil {
@@ -1390,14 +1390,14 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet {
13901390

13911391
if a.getMyReceiverWindowCredit() > 0 {
13921392
// Pass the new chunk to stream level as soon as it arrives
1393-
a.payloadQueue.push(d, a.peerLastTSN)
1393+
a.payloadQueue.push(d.tsn, a.peerLastTSN)
13941394
s.handleData(d)
13951395
} else {
13961396
// Receive buffer is full
13971397
lastTSN, ok := a.payloadQueue.getLastTSNReceived()
13981398
if ok && sna32LT(d.tsn, lastTSN) {
13991399
a.log.Debugf("[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber)
1400-
a.payloadQueue.push(d, a.peerLastTSN)
1400+
a.payloadQueue.push(d.tsn, a.peerLastTSN)
14011401
s.handleData(d)
14021402
} else {
14031403
a.log.Debugf("[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber)
@@ -1421,7 +1421,7 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool)
14211421
// Meaning, if peerLastTSN+1 points to a chunk that is received,
14221422
// advance peerLastTSN until peerLastTSN+1 points to unreceived chunk.
14231423
for {
1424-
if _, popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk {
1424+
if popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk {
14251425
break
14261426
}
14271427
a.peerLastTSN++

association_test.go

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,14 +1310,7 @@ func TestHandleForwardTSN(t *testing.T) {
13101310
prevTSN := a.peerLastTSN
13111311

13121312
// this chunk is blocked by the missing chunk at tsn=1
1313-
a.payloadQueue.push(&chunkPayloadData{
1314-
beginningFragment: true,
1315-
endingFragment: true,
1316-
tsn: a.peerLastTSN + 2,
1317-
streamIdentifier: 0,
1318-
streamSequenceNumber: 1,
1319-
userData: []byte("ABC"),
1320-
}, a.peerLastTSN)
1313+
a.payloadQueue.push(a.peerLastTSN+2, a.peerLastTSN)
13211314

13221315
fwdtsn := &chunkForwardTSN{
13231316
newCumulativeTSN: a.peerLastTSN + 1,
@@ -1347,14 +1340,7 @@ func TestHandleForwardTSN(t *testing.T) {
13471340
prevTSN := a.peerLastTSN
13481341

13491342
// this chunk is blocked by the missing chunk at tsn=1
1350-
a.payloadQueue.push(&chunkPayloadData{
1351-
beginningFragment: true,
1352-
endingFragment: true,
1353-
tsn: a.peerLastTSN + 3,
1354-
streamIdentifier: 0,
1355-
streamSequenceNumber: 1,
1356-
userData: []byte("ABC"),
1357-
}, a.peerLastTSN)
1343+
a.payloadQueue.push(a.peerLastTSN+3, a.peerLastTSN)
13581344

13591345
fwdtsn := &chunkForwardTSN{
13601346
newCumulativeTSN: a.peerLastTSN + 1,

received_packet_tracker.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
package sctp
5+
6+
import (
7+
"fmt"
8+
"strings"
9+
)
10+
11+
// receivedChunkTracker tracks received chunks for maintaining ACK ranges
12+
type receivedChunkTracker struct {
13+
chunks map[uint32]struct{}
14+
dupTSN []uint32
15+
ranges []ackRange
16+
}
17+
18+
// ackRange is a contiguous range of chunks that we have received
19+
type ackRange struct {
20+
start uint32
21+
end uint32
22+
}
23+
24+
func newReceivedPacketTracker() *receivedChunkTracker {
25+
return &receivedChunkTracker{chunks: make(map[uint32]struct{})}
26+
}
27+
28+
func (q *receivedChunkTracker) canPush(tsn uint32, cumulativeTSN uint32) bool {
29+
_, ok := q.chunks[tsn]
30+
if ok || sna32LTE(tsn, cumulativeTSN) {
31+
return false
32+
}
33+
return true
34+
}
35+
36+
// push pushes a payload data. If the payload data is already in our queue or
37+
// older than our cumulativeTSN marker, it will be recorded as duplications,
38+
// which can later be retrieved using popDuplicates.
39+
func (q *receivedChunkTracker) push(tsn uint32, cumulativeTSN uint32) bool {
40+
_, ok := q.chunks[tsn]
41+
if ok || sna32LTE(tsn, cumulativeTSN) {
42+
// Found the packet, log in dups
43+
q.dupTSN = append(q.dupTSN, tsn)
44+
return false
45+
}
46+
q.chunks[tsn] = struct{}{}
47+
48+
insert := true
49+
var pos int
50+
for pos = len(q.ranges) - 1; pos >= 0; pos-- {
51+
if tsn == q.ranges[pos].end+1 {
52+
q.ranges[pos].end++
53+
insert = false
54+
break
55+
}
56+
if tsn == q.ranges[pos].start-1 {
57+
q.ranges[pos].start--
58+
insert = false
59+
break
60+
}
61+
if tsn > q.ranges[pos].end {
62+
break
63+
}
64+
}
65+
if insert {
66+
// pos is at the element just before the insertion point
67+
pos++
68+
q.ranges = append(q.ranges, ackRange{})
69+
copy(q.ranges[pos+1:], q.ranges[pos:])
70+
q.ranges[pos] = ackRange{start: tsn, end: tsn}
71+
} else {
72+
// extended element at pos, check if we can merge it with adjacent elements
73+
if pos-1 >= 0 {
74+
if q.ranges[pos-1].end+1 == q.ranges[pos].start {
75+
q.ranges[pos-1] = ackRange{
76+
start: q.ranges[pos-1].start,
77+
end: q.ranges[pos].end,
78+
}
79+
copy(q.ranges[pos:], q.ranges[pos+1:])
80+
q.ranges = q.ranges[:len(q.ranges)-1]
81+
// We have merged pos and pos-1 in to pos-1, update pos to reflect that.
82+
// Not updating this won't be an error but it's nice to maintain the invariant
83+
pos--
84+
}
85+
}
86+
if pos+1 < len(q.ranges) {
87+
if q.ranges[pos+1].start-1 == q.ranges[pos].end {
88+
q.ranges[pos+1] = ackRange{
89+
start: q.ranges[pos].start,
90+
end: q.ranges[pos+1].end,
91+
}
92+
copy(q.ranges[pos:], q.ranges[pos+1:])
93+
q.ranges = q.ranges[:len(q.ranges)-1]
94+
}
95+
}
96+
}
97+
return true
98+
}
99+
100+
// pop pops only if the oldest chunk's TSN matches the given TSN.
101+
func (q *receivedChunkTracker) pop(tsn uint32) bool {
102+
if len(q.ranges) == 0 || q.ranges[0].start != tsn {
103+
return false
104+
}
105+
q.ranges[0].start++
106+
if q.ranges[0].start > q.ranges[0].end {
107+
q.ranges = q.ranges[1:]
108+
}
109+
delete(q.chunks, tsn)
110+
return true
111+
}
112+
113+
// popDuplicates returns an array of TSN values that were found duplicate.
114+
func (q *receivedChunkTracker) popDuplicates() []uint32 {
115+
dups := q.dupTSN
116+
q.dupTSN = []uint32{}
117+
return dups
118+
}
119+
120+
// receivedPacketTracker getGapACKBlocks returns gapAckBlocks after the cummulative TSN
121+
func (q *receivedChunkTracker) getGapAckBlocks(cumulativeTSN uint32) []gapAckBlock {
122+
gapAckBlocks := make([]gapAckBlock, 0, len(q.ranges))
123+
for _, ar := range q.ranges {
124+
if ar.end > cumulativeTSN {
125+
st := ar.start
126+
if st < cumulativeTSN {
127+
st = cumulativeTSN + 1
128+
}
129+
gapAckBlocks = append(gapAckBlocks, gapAckBlock{
130+
start: uint16(st - cumulativeTSN),
131+
end: uint16(ar.end - cumulativeTSN),
132+
})
133+
}
134+
}
135+
return gapAckBlocks
136+
}
137+
138+
func (q *receivedChunkTracker) getGapAckBlocksString(cumulativeTSN uint32) string {
139+
gapAckBlocks := q.getGapAckBlocks(cumulativeTSN)
140+
sb := strings.Builder{}
141+
sb.WriteString(fmt.Sprintf("cumTSN=%d", cumulativeTSN))
142+
for _, b := range gapAckBlocks {
143+
sb.WriteString(fmt.Sprintf(",%d-%d", b.start, b.end))
144+
}
145+
return sb.String()
146+
}
147+
148+
func (q *receivedChunkTracker) getLastTSNReceived() (uint32, bool) {
149+
if len(q.ranges) == 0 {
150+
return 0, false
151+
}
152+
return q.ranges[len(q.ranges)-1].end, true
153+
}
154+
155+
func (q *receivedChunkTracker) size() int {
156+
return len(q.chunks)
157+
}

received_packet_tracker_test.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
package sctp
5+
6+
import (
7+
"fmt"
8+
"math/rand"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestReceivedPacketTrackerPushPop(t *testing.T) {
15+
q := newReceivedPacketTracker()
16+
for i := uint32(1); i < 100; i++ {
17+
q.push(i, 0)
18+
}
19+
// leave a gap at position 100
20+
for i := uint32(101); i < 200; i++ {
21+
q.push(i, 0)
22+
}
23+
for i := uint32(2); i < 200; i++ {
24+
require.False(t, q.pop(i)) // all pop will fail till we pop the first tsn
25+
}
26+
for i := uint32(1); i < 100; i++ {
27+
require.True(t, q.pop(i))
28+
}
29+
// 101 is the smallest value now
30+
for i := uint32(102); i < 200; i++ {
31+
require.False(t, q.pop(i))
32+
}
33+
q.push(100, 99)
34+
for i := uint32(100); i < 200; i++ {
35+
require.True(t, q.pop(i))
36+
}
37+
38+
// q is empty now
39+
require.Equal(t, q.size(), 0)
40+
for i := uint32(0); i < 200; i++ {
41+
require.False(t, q.pop(i))
42+
}
43+
}
44+
45+
func TestReceivedPacketTrackerGapACKBlocksStress(t *testing.T) {
46+
testChunks := func(chunks []uint32, st uint32) {
47+
if len(chunks) == 0 {
48+
return
49+
}
50+
expected := make([]gapAckBlock, 0, len(chunks))
51+
cr := ackRange{start: chunks[0], end: chunks[0]}
52+
for i := 1; i < len(chunks); i++ {
53+
if cr.end+1 != chunks[i] {
54+
expected = append(expected, gapAckBlock{
55+
start: uint16(cr.start - st),
56+
end: uint16(cr.end - st),
57+
})
58+
cr = ackRange{start: chunks[i], end: chunks[i]}
59+
} else {
60+
cr.end++
61+
}
62+
}
63+
expected = append(expected, gapAckBlock{
64+
start: uint16(cr.start - st),
65+
end: uint16(cr.end - st),
66+
})
67+
68+
q := newReceivedPacketTracker()
69+
rand.Shuffle(len(chunks), func(i, j int) {
70+
chunks[i], chunks[j] = chunks[j], chunks[i]
71+
})
72+
for _, t := range chunks {
73+
q.push(t, 0)
74+
}
75+
res := q.getGapAckBlocks(0)
76+
require.Equal(t, expected, res, chunks)
77+
}
78+
chunks := make([]uint32, 0, 10)
79+
for i := 1; i < (1 << 10); i++ {
80+
for j := 0; j < 10; j++ {
81+
if i&(1<<j) != 0 {
82+
chunks = append(chunks, uint32(j+1))
83+
}
84+
}
85+
testChunks(chunks, 0)
86+
chunks = chunks[:0]
87+
}
88+
}
89+
90+
func TestReceivedPacketTrackerGapACKBlocksStress2(t *testing.T) {
91+
92+
tests := []struct {
93+
chunks []uint32
94+
cummulativeTSN uint32
95+
result []gapAckBlock
96+
}{
97+
{
98+
chunks: []uint32{3, 4, 1, 2, 7, 8, 10000},
99+
cummulativeTSN: 3,
100+
result: []gapAckBlock{{1, 1}, {4, 5}, {10000 - 3, 10000 - 3}},
101+
},
102+
{
103+
chunks: []uint32{3, 5, 1, 2, 7, 8, 10000},
104+
cummulativeTSN: 3,
105+
result: []gapAckBlock{{2, 2}, {4, 5}, {10000 - 3, 10000 - 3}},
106+
},
107+
{
108+
chunks: []uint32{3, 4, 1, 2, 7, 8, 10000},
109+
cummulativeTSN: 0,
110+
result: []gapAckBlock{{1, 4}, {7, 8}, {10000, 10000}},
111+
},
112+
}
113+
114+
for i, tc := range tests {
115+
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
116+
q := newReceivedPacketTracker()
117+
for _, t := range tc.chunks {
118+
q.push(t, 0)
119+
}
120+
res := q.getGapAckBlocks(tc.cummulativeTSN)
121+
require.Equal(t, tc.result, res)
122+
})
123+
}
124+
125+
}

0 commit comments

Comments
 (0)