Skip to content

Commit 9dac1a1

Browse files
authored
Fix split OOB and zeroing (#238)
* Fix split OOB and zeroing Fix invalid slicing in #237 Zero values in shards taken from capacity of data shards. * Add more tests * Fix swapped params for Go fallback. * Tweak default tests * Add conservative retraction
1 parent 7e59db9 commit 9dac1a1

6 files changed

+149
-73
lines changed

galois_amd64.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ func mulAdd8(x, y []byte, log_m ffe8, o *options) {
533533
y = y[done:]
534534
x = x[done:]
535535
}
536-
refMulAdd8(y, x, log_m)
536+
refMulAdd8(x, y, log_m)
537537
}
538538

539539
// 2-way butterfly

go.mod

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,8 @@ require github.com/klauspost/cpuid/v2 v2.1.1
66

77
require golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect
88

9-
// https://github.com/klauspost/reedsolomon/pull/229
10-
retract v1.11.2
9+
10+
retract (
11+
v1.11.2 // https://github.com/klauspost/reedsolomon/pull/229
12+
[v1.11.3, v1.11.5] // https://github.com/klauspost/reedsolomon/pull/238
13+
)

leopard.go

+19-7
Original file line numberDiff line numberDiff line change
@@ -279,23 +279,35 @@ func (r *leopardFF16) Split(data []byte) ([][]byte, error) {
279279
// Calculate number of bytes per data shard.
280280
perShard := (len(data) + r.dataShards - 1) / r.dataShards
281281
perShard = ((perShard + 63) / 64) * 64
282+
needTotal := r.totalShards * perShard
282283

283284
if cap(data) > len(data) {
284-
data = data[:cap(data)]
285+
if cap(data) > needTotal {
286+
data = data[:needTotal]
287+
} else {
288+
data = data[:cap(data)]
289+
}
290+
clear := data[dataLen:]
291+
for i := range clear {
292+
clear[i] = 0
293+
}
285294
}
286295

287296
// Only allocate memory if necessary
288297
var padding [][]byte
289-
if len(data) < (r.totalShards * perShard) {
298+
if len(data) < needTotal {
290299
// calculate maximum number of full shards in `data` slice
291300
fullShards := len(data) / perShard
292301
padding = AllocAligned(r.totalShards-fullShards, perShard)
293-
copyFrom := data[perShard*fullShards : dataLen]
294-
for i := range padding {
295-
if len(copyFrom) <= 0 {
296-
break
302+
if dataLen > perShard*fullShards {
303+
// Copy partial shards
304+
copyFrom := data[perShard*fullShards : dataLen]
305+
for i := range padding {
306+
if len(copyFrom) <= 0 {
307+
break
308+
}
309+
copyFrom = copyFrom[copy(padding[i], copyFrom):]
297310
}
298-
copyFrom = copyFrom[copy(padding[i], copyFrom):]
299311
}
300312
} else {
301313
zero := data[dataLen : r.totalShards*perShard]

leopard8.go

+20-13
Original file line numberDiff line numberDiff line change
@@ -320,28 +320,35 @@ func (r *leopardFF8) Split(data []byte) ([][]byte, error) {
320320
// Calculate number of bytes per data shard.
321321
perShard := (len(data) + r.dataShards - 1) / r.dataShards
322322
perShard = ((perShard + 63) / 64) * 64
323+
needTotal := r.totalShards * perShard
323324

324325
if cap(data) > len(data) {
325-
data = data[:cap(data)]
326+
if cap(data) > needTotal {
327+
data = data[:needTotal]
328+
} else {
329+
data = data[:cap(data)]
330+
}
331+
clear := data[dataLen:]
332+
for i := range clear {
333+
clear[i] = 0
334+
}
326335
}
327336

328337
// Only allocate memory if necessary
329338
var padding [][]byte
330-
if len(data) < (r.totalShards * perShard) {
339+
if len(data) < needTotal {
331340
// calculate maximum number of full shards in `data` slice
332341
fullShards := len(data) / perShard
333342
padding = AllocAligned(r.totalShards-fullShards, perShard)
334-
copyFrom := data[perShard*fullShards : dataLen]
335-
for i := range padding {
336-
if len(copyFrom) <= 0 {
337-
break
343+
if dataLen > perShard*fullShards {
344+
// Copy partial shards
345+
copyFrom := data[perShard*fullShards : dataLen]
346+
for i := range padding {
347+
if len(copyFrom) <= 0 {
348+
break
349+
}
350+
copyFrom = copyFrom[copy(padding[i], copyFrom):]
338351
}
339-
copyFrom = copyFrom[copy(padding[i], copyFrom):]
340-
}
341-
} else {
342-
zero := data[dataLen : r.totalShards*perShard]
343-
for i := range zero {
344-
zero[i] = 0
345352
}
346353
}
347354

@@ -877,7 +884,7 @@ func refMulAdd8(x, y []byte, log_m ffe8) {
877884
for len(x) >= 64 {
878885
// Assert sizes for no bounds checks in loop
879886
src := y[:64]
880-
dst := x[:64] // Needed, but not checked...
887+
dst := x[:len(src)] // Needed, but not checked...
881888
for i, y1 := range src {
882889
dst[i] ^= byte(lut.Value[y1])
883890
}

reedsolomon.go

+30-15
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,16 @@ type Encoder interface {
103103
Update(shards [][]byte, newDatashards [][]byte) error
104104

105105
// Split a data slice into the number of shards given to the encoder,
106-
// and create empty parity shards.
106+
// and create empty parity shards if necessary.
107107
//
108108
// The data will be split into equally sized shards.
109-
// If the data size isn't dividable by the number of shards,
109+
// If the data size isn't divisible by the number of shards,
110110
// the last shard will contain extra zeros.
111111
//
112+
// If there is extra capacity on the provided data slice
113+
// it will be used instead of allocating parity shards.
114+
// It will be zeroed out.
115+
//
112116
// There must be at least 1 byte otherwise ErrShortData will be
113117
// returned.
114118
//
@@ -1542,6 +1546,10 @@ var ErrShortData = errors.New("not enough data to fill the number of requested s
15421546
// If the data size isn't divisible by the number of shards,
15431547
// the last shard will contain extra zeros.
15441548
//
1549+
// If there is extra capacity on the provided data slice
1550+
// it will be used instead of allocating parity shards.
1551+
// It will be zeroed out.
1552+
//
15451553
// There must be at least 1 byte otherwise ErrShortData will be
15461554
// returned.
15471555
//
@@ -1558,29 +1566,36 @@ func (r *reedSolomon) Split(data []byte) ([][]byte, error) {
15581566
dataLen := len(data)
15591567
// Calculate number of bytes per data shard.
15601568
perShard := (len(data) + r.dataShards - 1) / r.dataShards
1569+
needTotal := r.totalShards * perShard
15611570

15621571
if cap(data) > len(data) {
1563-
data = data[:cap(data)]
1572+
if cap(data) > needTotal {
1573+
data = data[:needTotal]
1574+
} else {
1575+
data = data[:cap(data)]
1576+
}
1577+
clear := data[dataLen:]
1578+
for i := range clear {
1579+
clear[i] = 0
1580+
}
15641581
}
15651582

15661583
// Only allocate memory if necessary
15671584
var padding [][]byte
1568-
if len(data) < (r.totalShards * perShard) {
1585+
if len(data) < needTotal {
15691586
// calculate maximum number of full shards in `data` slice
15701587
fullShards := len(data) / perShard
15711588
padding = AllocAligned(r.totalShards-fullShards, perShard)
1572-
copyFrom := data[perShard*fullShards : dataLen]
1573-
for i := range padding {
1574-
if len(copyFrom) <= 0 {
1575-
break
1589+
1590+
if dataLen > perShard*fullShards {
1591+
// Copy partial shards
1592+
copyFrom := data[perShard*fullShards : dataLen]
1593+
for i := range padding {
1594+
if len(copyFrom) <= 0 {
1595+
break
1596+
}
1597+
copyFrom = copyFrom[copy(padding[i], copyFrom):]
15761598
}
1577-
copyFrom = copyFrom[copy(padding[i], copyFrom):]
1578-
}
1579-
data = data[0 : perShard*fullShards]
1580-
} else {
1581-
zero := data[dataLen : r.totalShards*perShard]
1582-
for i := range zero {
1583-
zero[i] = 0
15841599
}
15851600
}
15861601

reedsolomon_test.go

+74-35
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func TestBuildMatrixPAR1Singular(t *testing.T) {
168168
func testOpts() [][]Option {
169169
if testing.Short() {
170170
return [][]Option{
171-
{WithPAR1Matrix()}, {WithCauchyMatrix()},
171+
{WithCauchyMatrix()}, {WithLeopardGF16(true)}, {WithLeopardGF(true)},
172172
}
173173
}
174174
opts := [][]Option{
@@ -1603,7 +1603,7 @@ func testEncoderReconstruct(t *testing.T, o ...Option) {
16031603
fillRandom(data)
16041604

16051605
// Create 5 data slices of 50000 elements each
1606-
enc, err := New(5, 3, testOptions(o...)...)
1606+
enc, err := New(7, 6, testOptions(o...)...)
16071607
if err != nil {
16081608
t.Fatal(err)
16091609
}
@@ -1675,43 +1675,82 @@ func testEncoderReconstruct(t *testing.T, o ...Option) {
16751675
}
16761676

16771677
func TestSplitJoin(t *testing.T) {
1678-
var data = make([]byte, 250000)
1679-
fillRandom(data)
1680-
1681-
enc, _ := New(5, 3, testOptions()...)
1682-
shards, err := enc.Split(data)
1683-
if err != nil {
1684-
t.Fatal(err)
1685-
}
1686-
1687-
_, err = enc.Split([]byte{})
1688-
if err != ErrShortData {
1689-
t.Errorf("expected %v, got %v", ErrShortData, err)
1690-
}
1678+
opts := [][]Option{
1679+
testOptions(),
1680+
append(testOptions(), WithLeopardGF(true)),
1681+
append(testOptions(), WithLeopardGF16(true)),
1682+
}
1683+
for i, opts := range opts {
1684+
t.Run("opt-"+strconv.Itoa(i), func(t *testing.T) {
1685+
for _, dp := range [][2]int{{1, 0}, {5, 0}, {5, 1}, {12, 4}, {2, 15}, {17, 1}} {
1686+
enc, _ := New(dp[0], dp[1], opts...)
1687+
ext := enc.(Extensions)
1688+
1689+
_, err := enc.Split([]byte{})
1690+
if err != ErrShortData {
1691+
t.Errorf("expected %v, got %v", ErrShortData, err)
1692+
}
16911693

1692-
buf := new(bytes.Buffer)
1693-
err = enc.Join(buf, shards, 50)
1694-
if err != nil {
1695-
t.Fatal(err)
1696-
}
1697-
if !bytes.Equal(buf.Bytes(), data[:50]) {
1698-
t.Fatal("recovered data does match original")
1699-
}
1694+
buf := new(bytes.Buffer)
1695+
err = enc.Join(buf, [][]byte{}, 0)
1696+
if err != ErrTooFewShards {
1697+
t.Errorf("expected %v, got %v", ErrTooFewShards, err)
1698+
}
1699+
for _, size := range []int{ext.DataShards(), 1337, 2699} {
1700+
for _, extra := range []int{0, 1, ext.ShardSizeMultiple(), ext.ShardSizeMultiple() * ext.DataShards(), ext.ShardSizeMultiple()*ext.ParityShards() + 1, 255} {
1701+
buf.Reset()
1702+
t.Run(fmt.Sprintf("d-%d-p-%d-sz-%d-cap%d", ext.DataShards(), ext.ParityShards(), size, extra), func(t *testing.T) {
1703+
var data = make([]byte, size, size+extra)
1704+
var ref = make([]byte, size, size)
1705+
fillRandom(data)
1706+
copy(ref, data)
1707+
1708+
shards, err := enc.Split(data)
1709+
if err != nil {
1710+
t.Fatal(err)
1711+
}
1712+
err = enc.Encode(shards)
1713+
if err != nil {
1714+
t.Fatal(err)
1715+
}
1716+
_, err = enc.Verify(shards)
1717+
if err != nil {
1718+
t.Fatal(err)
1719+
}
1720+
for i := range shards[:ext.ParityShards()] {
1721+
// delete data shards up to parity
1722+
shards[i] = nil
1723+
}
1724+
err = enc.Reconstruct(shards)
1725+
if err != nil {
1726+
t.Fatal(err)
1727+
}
17001728

1701-
err = enc.Join(buf, [][]byte{}, 0)
1702-
if err != ErrTooFewShards {
1703-
t.Errorf("expected %v, got %v", ErrTooFewShards, err)
1704-
}
1729+
// Rejoin....
1730+
err = enc.Join(buf, shards, size)
1731+
if err != nil {
1732+
t.Fatal(err)
1733+
}
1734+
if !bytes.Equal(buf.Bytes(), ref) {
1735+
t.Log("")
1736+
t.Fatal("recovered data does match original")
1737+
}
17051738

1706-
err = enc.Join(buf, shards, len(data)+1)
1707-
if err != ErrShortData {
1708-
t.Errorf("expected %v, got %v", ErrShortData, err)
1709-
}
1739+
err = enc.Join(buf, shards, len(data)+ext.DataShards()*ext.ShardSizeMultiple())
1740+
if err != ErrShortData {
1741+
t.Errorf("expected %v, got %v", ErrShortData, err)
1742+
}
17101743

1711-
shards[0] = nil
1712-
err = enc.Join(buf, shards, len(data))
1713-
if err != ErrReconstructRequired {
1714-
t.Errorf("expected %v, got %v", ErrReconstructRequired, err)
1744+
shards[0] = nil
1745+
err = enc.Join(buf, shards, len(data))
1746+
if err != ErrReconstructRequired {
1747+
t.Errorf("expected %v, got %v", ErrReconstructRequired, err)
1748+
}
1749+
})
1750+
}
1751+
}
1752+
}
1753+
})
17151754
}
17161755
}
17171756

0 commit comments

Comments
 (0)