Skip to content

Commit 868f201

Browse files
authored
sdk: Add functions for type-safe conversions from integers to ChainID (#4271)
* sdk: Add functions for type-safe conversions from integers to ChainID
1 parent 7d140ef commit 868f201

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

sdk/vaa/structs.go

+51
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"errors"
1010
"fmt"
1111
"io"
12+
"math"
1213
"math/big"
1314
"strings"
1415
"time"
@@ -91,6 +92,11 @@ type (
9192
AddSignature(key *ecdsa.PrivateKey, index uint8)
9293
GetEmitterChain() ChainID
9394
}
95+
96+
// number is a constraint for generic functions that can safely convert integer types to a ChainID (uint16).
97+
number interface {
98+
~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
99+
}
94100
)
95101

96102
const (
@@ -268,6 +274,51 @@ func (c ChainID) String() string {
268274
}
269275
}
270276

277+
// ChainIDFromNumber converts an unsigned integer into a ChainID. This function only determines whether the input is valid
278+
// with respect to its type; it does not check whether the ChainID is actually registered or used anywhere.
279+
// This function can be used to validate ChainID values that are deserialized from protobuf messages. (As protobuf
280+
// does not support the uint16 type, ChainIDs are usually encoded as uint32.)
281+
// https://protobuf.dev/reference/protobuf/proto3-spec/#fields
282+
// Returns an error if the argument would overflow uint16.
283+
func ChainIDFromNumber[N number](n N) (ChainID, error) {
284+
if n < 0 {
285+
return ChainIDUnset, fmt.Errorf("chainID cannot be negative but got %d", n)
286+
}
287+
switch any(n).(type) {
288+
case int8, uint8, int16, uint16:
289+
// Because these values have been checked to be non-negative, we can return early with a simple conversion.
290+
return ChainID(n), nil
291+
292+
}
293+
// Use intermediate uint64 to safely handle conversion and allow comparison with MaxUint16.
294+
// This is safe to do because the negative case is already handled.
295+
val := uint64(n)
296+
if val > uint64(math.MaxUint16) {
297+
return ChainIDUnset, fmt.Errorf("chainID must be less than or equal to %d but got %d", math.MaxUint16, n)
298+
}
299+
return ChainID(n), nil
300+
301+
}
302+
303+
// KnownChainIDFromNumber converts an unsigned integer into a known ChainID. It is a wrapper function for ChainIDFromNumber
304+
// that also checks whether the ChainID corresponds to a real, configured chain.
305+
func KnownChainIDFromNumber[N number](n N) (ChainID, error) {
306+
id, err := ChainIDFromNumber(n)
307+
if err != nil {
308+
return ChainIDUnset, err
309+
}
310+
311+
// NOTE: slice.Contains is not used here because some SDK integrators (e.g. wormchain, maybe others) use old versions of Go.
312+
for _, known := range GetAllNetworkIDs() {
313+
if id == known {
314+
return id, nil
315+
}
316+
}
317+
318+
return ChainIDUnset, fmt.Errorf("no known ChainID for input %d", n)
319+
320+
}
321+
271322
func ChainIDFromString(s string) (ChainID, error) {
272323
s = strings.ToLower(s)
273324

sdk/vaa/structs_test.go

+58
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"time"
1717

1818
"github.com/ethereum/go-ethereum/common"
19+
"github.com/ethereum/go-ethereum/common/math"
1920
"github.com/ethereum/go-ethereum/crypto"
2021
"github.com/stretchr/testify/assert"
2122
"github.com/stretchr/testify/require"
@@ -1125,3 +1126,60 @@ func TestUnmarshalBody(t *testing.T) {
11251126
})
11261127
}
11271128
}
1129+
1130+
func TestChainIDFromNumber(t *testing.T) {
1131+
// Define test case struct that works with any Number type
1132+
type testCase[N number] struct {
1133+
name string
1134+
input N
1135+
expected ChainID
1136+
wantErr bool
1137+
errMsg string
1138+
wantKnown bool
1139+
}
1140+
// Using the int64 type here because it can be representative of the error conditions (overflow, negative)
1141+
// NOTE: more test cases could be added with different concrete types.
1142+
tests := []testCase[int64]{
1143+
{
1144+
name: "valid",
1145+
input: int64(1),
1146+
expected: ChainIDSolana,
1147+
wantErr: false,
1148+
wantKnown: true,
1149+
},
1150+
{
1151+
name: "valid but unknown",
1152+
input: int64(math.MaxUint16),
1153+
expected: ChainID(math.MaxUint16),
1154+
wantErr: false,
1155+
wantKnown: false,
1156+
},
1157+
{
1158+
name: "overflow",
1159+
input: math.MaxUint16 + 1,
1160+
expected: ChainIDUnset,
1161+
wantErr: true,
1162+
wantKnown: false,
1163+
},
1164+
}
1165+
1166+
for _, testCase := range tests {
1167+
t.Run(testCase.name, func(t *testing.T) {
1168+
got, err := ChainIDFromNumber(testCase.input)
1169+
require.Equal(t, testCase.expected, got)
1170+
if testCase.wantErr {
1171+
require.ErrorContains(t, err, testCase.errMsg)
1172+
require.Equal(t, ChainIDUnset, got)
1173+
}
1174+
1175+
got, err = KnownChainIDFromNumber(testCase.input)
1176+
if testCase.wantKnown {
1177+
require.NoError(t, err)
1178+
require.Equal(t, testCase.expected, got)
1179+
} else {
1180+
require.Error(t, err)
1181+
require.Equal(t, ChainIDUnset, got)
1182+
}
1183+
})
1184+
}
1185+
}

0 commit comments

Comments
 (0)