Skip to content

Commit a2417f8

Browse files
committed
evm: optimize Manager
1 parent ce88d05 commit a2417f8

File tree

5 files changed

+78
-78
lines changed

5 files changed

+78
-78
lines changed

evm/foundry.toml

+1-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
solc_version = "0.8.19"
33
optimizer = true
44
optimizer_runs = 200
5-
via_ir = false
5+
via_ir = true
66
evm_version = "london"
77
src = "src"
88
out = "out"
@@ -13,8 +13,4 @@ line_length = 100
1313
multiline_func_header = "params_first"
1414
# wrap_comments = true
1515

16-
17-
[profile.production]
18-
via_ir = true
19-
2016
# See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options

evm/src/EndpointRegistry.sol

+24-24
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ abstract contract EndpointRegistry {
2626
/// invariant: numRegisteredEndpoints <= MAX_ENDPOINTS
2727
/// invariant: forall (i: uint8),
2828
/// i < numRegisteredEndpoints <=> exists (a: address), endpointInfos[a].index == i
29-
struct _NumRegisteredEndpoints {
30-
uint8 num;
29+
struct _NumEndpoints {
30+
uint8 registered;
31+
uint8 enabled;
3132
}
3233

3334
uint8 constant MAX_ENDPOINTS = 64;
@@ -98,11 +99,7 @@ abstract contract EndpointRegistry {
9899
}
99100
}
100101

101-
function _getNumRegisteredEndpointsStorage()
102-
internal
103-
pure
104-
returns (_NumRegisteredEndpoints storage $)
105-
{
102+
function _getNumEndpointsStorage() internal pure returns (_NumEndpoints storage $) {
106103
uint256 slot = uint256(NUM_REGISTERED_ENDPOINTS_SLOT);
107104
assembly ("memory-safe") {
108105
$.slot := slot
@@ -116,27 +113,27 @@ abstract contract EndpointRegistry {
116113
_EnabledEndpointBitmap storage _enabledEndpointBitmap = _getEndpointBitmapStorage();
117114
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
118115

119-
_NumRegisteredEndpoints storage _numRegisteredEndpoints =
120-
_getNumRegisteredEndpointsStorage();
116+
_NumEndpoints storage _numEndpoints = _getNumEndpointsStorage();
121117

122118
if (endpoint == address(0)) {
123119
revert InvalidEndpointZeroAddress();
124120
}
125121

126-
if (_numRegisteredEndpoints.num >= MAX_ENDPOINTS) {
122+
if (_numEndpoints.registered >= MAX_ENDPOINTS) {
127123
revert TooManyEndpoints();
128124
}
129125

130126
if (endpointInfos[endpoint].registered) {
131127
endpointInfos[endpoint].enabled = true;
132128
} else {
133129
endpointInfos[endpoint] =
134-
EndpointInfo({registered: true, enabled: true, index: _numRegisteredEndpoints.num});
135-
_numRegisteredEndpoints.num++;
130+
EndpointInfo({registered: true, enabled: true, index: _numEndpoints.registered});
131+
_numEndpoints.registered++;
136132
_getRegisteredEndpointsStorage().push(endpoint);
137133
}
138134

139135
_enabledEndpoints.push(endpoint);
136+
_numEndpoints.enabled++;
140137

141138
uint64 updatedEnabledEndpointBitmap =
142139
_enabledEndpointBitmap.bitmap | uint64(1 << endpointInfos[endpoint].index);
@@ -171,6 +168,7 @@ abstract contract EndpointRegistry {
171168
}
172169

173170
endpointInfos[endpoint].enabled = false;
171+
_getNumEndpointsStorage().enabled--;
174172

175173
uint64 updatedEnabledEndpointBitmap =
176174
_enabledEndpointBitmap.bitmap & uint64(~(1 << endpointInfos[endpoint].index));
@@ -180,9 +178,10 @@ abstract contract EndpointRegistry {
180178

181179
bool removed = false;
182180

183-
for (uint256 i = 0; i < _enabledEndpoints.length; i++) {
181+
uint256 numEnabledEndpoints = _enabledEndpoints.length;
182+
for (uint256 i = 0; i < numEnabledEndpoints; i++) {
184183
if (_enabledEndpoints[i] == endpoint) {
185-
_enabledEndpoints[i] = _enabledEndpoints[_enabledEndpoints.length - 1];
184+
_enabledEndpoints[i] = _enabledEndpoints[numEnabledEndpoints - 1];
186185
_enabledEndpoints.pop();
187186
removed = true;
188187
break;
@@ -213,31 +212,32 @@ abstract contract EndpointRegistry {
213212
/// Checking these invariants is somewhat costly, but we only need to do it
214213
/// when modifying the endpoints, which happens infrequently.
215214
function _checkEndpointsInvariants() internal view {
216-
_NumRegisteredEndpoints storage _numRegisteredEndpoints =
217-
_getNumRegisteredEndpointsStorage();
215+
_NumEndpoints storage _numEndpoints = _getNumEndpointsStorage();
218216
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
219217

220-
for (uint256 i = 0; i < _enabledEndpoints.length; i++) {
218+
uint256 numEndpointsEnabled = _numEndpoints.enabled;
219+
assert(numEndpointsEnabled == _enabledEndpoints.length);
220+
221+
for (uint256 i = 0; i < numEndpointsEnabled; i++) {
221222
_checkEndpointInvariants(_enabledEndpoints[i]);
222223
}
223224

224225
// invariant: each endpoint is only enabled once
225-
for (uint256 i = 0; i < _enabledEndpoints.length; i++) {
226-
for (uint256 j = i + 1; j < _enabledEndpoints.length; j++) {
226+
for (uint256 i = 0; i < numEndpointsEnabled; i++) {
227+
for (uint256 j = i + 1; j < numEndpointsEnabled; j++) {
227228
assert(_enabledEndpoints[i] != _enabledEndpoints[j]);
228229
}
229230
}
230231

231232
// invariant: numRegisteredEndpoints <= MAX_ENDPOINTS
232-
assert(_numRegisteredEndpoints.num <= MAX_ENDPOINTS);
233+
assert(_numEndpoints.registered <= MAX_ENDPOINTS);
233234
}
234235

235236
// @dev Check that the endpoint is in a valid state.
236237
function _checkEndpointInvariants(address endpoint) private view {
237238
mapping(address => EndpointInfo) storage endpointInfos = _getEndpointInfosStorage();
238239
_EnabledEndpointBitmap storage _enabledEndpointBitmap = _getEndpointBitmapStorage();
239-
_NumRegisteredEndpoints storage _numRegisteredEndpoints =
240-
_getNumRegisteredEndpointsStorage();
240+
_NumEndpoints storage _numEndpoints = _getNumEndpointsStorage();
241241
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
242242

243243
EndpointInfo memory endpointInfo = endpointInfos[endpoint];
@@ -251,7 +251,7 @@ abstract contract EndpointRegistry {
251251

252252
bool endpointInEnabledEndpoints = false;
253253

254-
for (uint256 i = 0; i < _enabledEndpoints.length; i++) {
254+
for (uint256 i = 0; i < _numEndpoints.enabled; i++) {
255255
if (_enabledEndpoints[i] == endpoint) {
256256
endpointInEnabledEndpoints = true;
257257
break;
@@ -265,6 +265,6 @@ abstract contract EndpointRegistry {
265265
// invariant: endpointInfos[endpoint].enabled <=> endpoint in _enabledEndpoints
266266
assert(endpointInEnabledEndpoints == endpointEnabled);
267267

268-
assert(endpointInfo.index < _numRegisteredEndpoints.num);
268+
assert(endpointInfo.index < _numEndpoints.registered);
269269
}
270270
}

evm/src/Manager.sol

+46-38
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ contract Manager is
130130
/// @notice Returns the number of Endpoints that must attest to a msgId for
131131
/// it to be considered valid and acted upon.
132132
function getThreshold() public view returns (uint8) {
133-
_Threshold storage _threshold = _getThresholdStorage();
134-
return _threshold.num;
133+
return _getThresholdStorage().num;
135134
}
136135

137136
function setEndpoint(address endpoint) external onlyOwner {
@@ -155,19 +154,17 @@ contract Manager is
155154
_threshold.num = 1;
156155
}
157156

158-
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
159-
160-
emit EndpointAdded(endpoint, _enabledEndpoints.length, _threshold.num);
157+
emit EndpointAdded(endpoint, _getNumEndpointsStorage().enabled, _threshold.num);
161158
}
162159

163160
function removeEndpoint(address endpoint) external onlyOwner {
164161
_removeEndpoint(endpoint);
165162

166163
_Threshold storage _threshold = _getThresholdStorage();
167-
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
164+
uint8 numEnabledEndpoints = _getNumEndpointsStorage().enabled;
168165

169-
if (_enabledEndpoints.length < _threshold.num) {
170-
_threshold.num = uint8(_enabledEndpoints.length);
166+
if (numEnabledEndpoints < _threshold.num) {
167+
_threshold.num = numEnabledEndpoints;
171168
}
172169

173170
emit EndpointRemoved(endpoint, _threshold.num);
@@ -234,21 +231,24 @@ contract Manager is
234231
/// This method should return an array of delivery prices corresponding to each endpoint.
235232
function quoteDeliveryPrice(
236233
uint16 recipientChain,
237-
EndpointStructs.EndpointInstruction[] memory endpointInstructions
238-
) public view returns (uint256[] memory) {
239-
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
234+
EndpointStructs.EndpointInstruction[] memory endpointInstructions,
235+
address[] memory enabledEndpoints
236+
) public view returns (uint256[] memory, uint256) {
237+
uint256 numEnabledEndpoints = enabledEndpoints.length;
240238
mapping(address => EndpointInfo) storage endpointInfos = _getEndpointInfosStorage();
241239

242-
uint256[] memory priceQuotes = new uint256[](_enabledEndpoints.length);
243-
for (uint256 i = 0; i < _enabledEndpoints.length; i++) {
244-
address endpointAddr = _enabledEndpoints[i];
240+
uint256[] memory priceQuotes = new uint256[](numEnabledEndpoints);
241+
uint256 totalPriceQuote = 0;
242+
for (uint256 i = 0; i < numEnabledEndpoints; i++) {
243+
address endpointAddr = enabledEndpoints[i];
245244
uint8 registeredEndpointIndex = endpointInfos[endpointAddr].index;
246-
uint256 endpointPriceQuote = IEndpoint(_enabledEndpoints[i]).quoteDeliveryPrice(
245+
uint256 endpointPriceQuote = IEndpoint(endpointAddr).quoteDeliveryPrice(
247246
recipientChain, endpointInstructions[registeredEndpointIndex]
248247
);
249248
priceQuotes[i] = endpointPriceQuote;
249+
totalPriceQuote += endpointPriceQuote;
250250
}
251-
return priceQuotes;
251+
return (priceQuotes, totalPriceQuote);
252252
}
253253

254254
/// @dev This will either cross-call or internal call, depending on
@@ -257,13 +257,14 @@ contract Manager is
257257
uint16 recipientChain,
258258
uint256[] memory priceQuotes,
259259
EndpointStructs.EndpointInstruction[] memory endpointInstructions,
260+
address[] memory enabledEndpoints,
260261
bytes memory managerMessage
261262
) internal {
262-
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
263+
uint256 numEnabledEndpoints = enabledEndpoints.length;
263264
mapping(address => EndpointInfo) storage endpointInfos = _getEndpointInfosStorage();
264265
// call into endpoint contracts to send the message
265-
for (uint256 i = 0; i < _enabledEndpoints.length; i++) {
266-
address endpointAddr = _enabledEndpoints[i];
266+
for (uint256 i = 0; i < numEnabledEndpoints; i++) {
267+
address endpointAddr = enabledEndpoints[i];
267268
uint8 registeredEndpointIndex = endpointInfos[endpointAddr].index;
268269
// send it to the recipient manager based on the chain
269270
IEndpoint(endpointAddr).sendMessage{value: priceQuotes[i]}(
@@ -415,9 +416,14 @@ contract Manager is
415416
revert ZeroAmount();
416417
}
417418

419+
if (recipient == bytes32(0)) {
420+
revert InvalidRecipient();
421+
}
422+
418423
// parse the instructions up front to ensure they:
419424
// - are encoded correctly
420425
// - follow payload length restrictions
426+
421427
EndpointStructs.parseEndpointInstructions(endpointInstructions);
422428

423429
{
@@ -528,10 +534,13 @@ contract Manager is
528534
EndpointStructs.EndpointInstruction[] memory sortedInstructions = EndpointStructs
529535
.sortEndpointInstructions(EndpointStructs.parseEndpointInstructions(endpointInstructions));
530536

531-
uint256[] memory priceQuotes = quoteDeliveryPrice(recipientChain, sortedInstructions);
537+
// cache enabled endpoints to avoid multiple storage reads
538+
address[] memory enabledEndpoints = _getEnabledEndpointsStorage();
539+
540+
(uint256[] memory priceQuotes, uint256 totalPriceQuote) =
541+
quoteDeliveryPrice(recipientChain, sortedInstructions, enabledEndpoints);
532542
{
533543
// check up front that msg.value will cover the delivery price
534-
uint256 totalPriceQuote = arraySum(priceQuotes);
535544
if (msg.value < totalPriceQuote) {
536545
revert DeliveryPaymentTooLow(totalPriceQuote, msg.value);
537546
}
@@ -543,22 +552,22 @@ contract Manager is
543552
}
544553
}
545554

546-
bytes memory encodedTransferPayload = EndpointStructs.encodeNativeTokenTransfer(
547-
EndpointStructs.NativeTokenTransfer(
548-
amount, toWormholeFormat(token), recipient, recipientChain
549-
)
550-
);
551-
552555
// construct the ManagerMessage payload
553556
bytes memory encodedManagerPayload = EndpointStructs.encodeManagerMessage(
554557
EndpointStructs.ManagerMessage(
555-
sequence, toWormholeFormat(sender), encodedTransferPayload
558+
sequence,
559+
toWormholeFormat(sender),
560+
EndpointStructs.encodeNativeTokenTransfer(
561+
EndpointStructs.NativeTokenTransfer(
562+
amount, toWormholeFormat(token), recipient, recipientChain
563+
)
564+
)
556565
)
557566
);
558567

559568
// send the message
560569
_sendMessageToEndpoints(
561-
recipientChain, priceQuotes, sortedInstructions, encodedManagerPayload
570+
recipientChain, priceQuotes, sortedInstructions, enabledEndpoints, encodedManagerPayload
562571
);
563572

564573
emit TransferSent(recipient, nttDenormalize(amount), recipientChain, sequence);
@@ -792,25 +801,24 @@ contract Manager is
792801
}
793802

794803
function _checkRegisteredEndpointsInvariants() internal view {
795-
if (_getRegisteredEndpointsStorage().length != _getNumRegisteredEndpointsStorage().num) {
804+
if (_getRegisteredEndpointsStorage().length != _getNumEndpointsStorage().registered) {
796805
revert RetrievedIncorrectRegisteredEndpoints(
797-
_getRegisteredEndpointsStorage().length, _getNumRegisteredEndpointsStorage().num
806+
_getRegisteredEndpointsStorage().length, _getNumEndpointsStorage().registered
798807
);
799808
}
800809
}
801810

802811
function _checkThresholdInvariants() internal view {
803-
_Threshold storage _threshold = _getThresholdStorage();
804-
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
805-
address[] storage _registeredEndpoints = _getRegisteredEndpointsStorage();
812+
uint8 threshold = _getThresholdStorage().num;
813+
_NumEndpoints memory numEndpoints = _getNumEndpointsStorage();
806814

807815
// invariant: threshold <= enabledEndpoints.length
808-
if (_threshold.num > _enabledEndpoints.length) {
809-
revert ThresholdTooHigh(_threshold.num, _enabledEndpoints.length);
816+
if (threshold > numEndpoints.enabled) {
817+
revert ThresholdTooHigh(threshold, numEndpoints.enabled);
810818
}
811819

812-
if (_registeredEndpoints.length > 0) {
813-
if (_threshold.num == 0) {
820+
if (numEndpoints.registered > 0) {
821+
if (threshold == 0) {
814822
revert ZeroThreshold();
815823
}
816824
}

evm/src/interfaces/IManager.sol

+7-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ interface IManager {
1919
error MessageNotApproved(bytes32 msgHash);
2020
error InvalidTargetChain(uint16 targetChain, uint16 thisChain);
2121
error ZeroAmount();
22+
error InvalidRecipient();
2223
error BurnAmountDifferentThanBalanceDiff(uint256 burnAmount, uint256 balanceDiff);
2324

2425
/// @notice The mode is invalid. It is neither in LOCKING or BURNING mode.
@@ -101,11 +102,14 @@ interface IManager {
101102
// @param recipientChain The chain to transfer to.
102103
// @param endpointInstructions An additional instruction the endpoint can forward to
103104
// the recipient chain.
104-
// @return The delivery prices associated with each endpoint.
105+
// @param enabledEndpoints The endpoints that are enabled for the transfer.
106+
// @return The delivery prices associated with each endpoint, and the sum
107+
// of these prices.
105108
function quoteDeliveryPrice(
106109
uint16 recipientChain,
107-
EndpointStructs.EndpointInstruction[] memory endpointInstructions
108-
) external view returns (uint256[] memory);
110+
EndpointStructs.EndpointInstruction[] memory endpointInstructions,
111+
address[] memory enabledEndpoints
112+
) external view returns (uint256[] memory, uint256);
109113

110114
function nextMessageSequence() external view returns (uint64);
111115

evm/src/libraries/EndpointHelpers.sol

-8
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,3 @@ function isFork(uint256 evmChainId) view returns (bool) {
3535
function min(uint256 a, uint256 b) pure returns (uint256) {
3636
return a < b ? a : b;
3737
}
38-
39-
function arraySum(uint256[] memory arr) pure returns (uint256) {
40-
uint256 sum = 0;
41-
for (uint256 i = 0; i < arr.length; i++) {
42-
sum += arr[i];
43-
}
44-
return sum;
45-
}

0 commit comments

Comments
 (0)