|
9 | 9 | "errors"
|
10 | 10 | "fmt"
|
11 | 11 | "io"
|
| 12 | + "math" |
12 | 13 | "math/big"
|
13 | 14 | "strings"
|
14 | 15 | "time"
|
@@ -91,6 +92,11 @@ type (
|
91 | 92 | AddSignature(key *ecdsa.PrivateKey, index uint8)
|
92 | 93 | GetEmitterChain() ChainID
|
93 | 94 | }
|
| 95 | + |
| 96 | + // number is a constraint for generic functions that can safely convert integer types to a ChainID (uint16). |
| 97 | + number interface { |
| 98 | + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
| 99 | + } |
94 | 100 | )
|
95 | 101 |
|
96 | 102 | const (
|
@@ -268,6 +274,51 @@ func (c ChainID) String() string {
|
268 | 274 | }
|
269 | 275 | }
|
270 | 276 |
|
| 277 | +// ChainIDFromNumber converts an unsigned integer into a ChainID. This function only determines whether the input is valid |
| 278 | +// with respect to its type; it does not check whether the ChainID is actually registered or used anywhere. |
| 279 | +// This function can be used to validate ChainID values that are deserialized from protobuf messages. (As protobuf |
| 280 | +// does not support the uint16 type, ChainIDs are usually encoded as uint32.) |
| 281 | +// https://protobuf.dev/reference/protobuf/proto3-spec/#fields |
| 282 | +// Returns an error if the argument would overflow uint16. |
| 283 | +func ChainIDFromNumber[N number](n N) (ChainID, error) { |
| 284 | + if n < 0 { |
| 285 | + return ChainIDUnset, fmt.Errorf("chainID cannot be negative but got %d", n) |
| 286 | + } |
| 287 | + switch any(n).(type) { |
| 288 | + case int8, uint8, int16, uint16: |
| 289 | + // Because these values have been checked to be non-negative, we can return early with a simple conversion. |
| 290 | + return ChainID(n), nil |
| 291 | + |
| 292 | + } |
| 293 | + // Use intermediate uint64 to safely handle conversion and allow comparison with MaxUint16. |
| 294 | + // This is safe to do because the negative case is already handled. |
| 295 | + val := uint64(n) |
| 296 | + if val > uint64(math.MaxUint16) { |
| 297 | + return ChainIDUnset, fmt.Errorf("chainID must be less than or equal to %d but got %d", math.MaxUint16, n) |
| 298 | + } |
| 299 | + return ChainID(n), nil |
| 300 | + |
| 301 | +} |
| 302 | + |
| 303 | +// KnownChainIDFromNumber converts an unsigned integer into a known ChainID. It is a wrapper function for ChainIDFromNumber |
| 304 | +// that also checks whether the ChainID corresponds to a real, configured chain. |
| 305 | +func KnownChainIDFromNumber[N number](n N) (ChainID, error) { |
| 306 | + id, err := ChainIDFromNumber(n) |
| 307 | + if err != nil { |
| 308 | + return ChainIDUnset, err |
| 309 | + } |
| 310 | + |
| 311 | + // NOTE: slice.Contains is not used here because some SDK integrators (e.g. wormchain, maybe others) use old versions of Go. |
| 312 | + for _, known := range GetAllNetworkIDs() { |
| 313 | + if id == known { |
| 314 | + return id, nil |
| 315 | + } |
| 316 | + } |
| 317 | + |
| 318 | + return ChainIDUnset, fmt.Errorf("no known ChainID for input %d", n) |
| 319 | + |
| 320 | +} |
| 321 | + |
271 | 322 | func ChainIDFromString(s string) (ChainID, error) {
|
272 | 323 | s = strings.ToLower(s)
|
273 | 324 |
|
|
0 commit comments