forked from wormhole-foundation/native-token-transfers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathManagerStandalone.sol
193 lines (157 loc) · 6.91 KB
/
ManagerStandalone.sol
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
// SPDX-License-Identifier: Apache 2
pragma solidity >=0.8.8 <0.9.0;
import "./interfaces/IManagerStandalone.sol";
import "./interfaces/IEndpointStandalone.sol";
import "./Manager.sol";
import "./EndpointRegistry.sol";
import "./libraries/Implementation.sol";
contract ManagerStandalone is IManagerStandalone, Manager, Implementation {
constructor(
address token,
Mode mode,
uint16 chainId,
uint64 rateLimitDuration
) Manager(token, mode, chainId, rateLimitDuration) {
_checkThresholdInvariants();
}
function _initialize() internal override {
__Manager_init();
_checkThresholdInvariants();
_checkEndpointsInvariants();
}
function _migrate() internal override {
// TODO: document (migration code)
_checkThresholdInvariants();
_checkEndpointsInvariants();
}
/// @dev When we add new immutables, this function should be updated
function _checkImmutables() internal view override {
assert(this.token() == token);
assert(this.mode() == mode);
assert(this.chainId() == chainId);
assert(this.evmChainId() == evmChainId);
assert(this.rateLimitDuration() == rateLimitDuration);
}
function upgrade(address newImplementation) external onlyOwner {
_upgrade(newImplementation);
}
function upgradeEndpoint(address endpoint, address newImplementation) external onlyOwner {
IEndpointStandalone(endpoint).upgrade(newImplementation);
}
struct _Threshold {
uint8 num;
}
/// =============== STORAGE ===============================================
bytes32 public constant THRESHOLD_SLOT = bytes32(uint256(keccak256("ntt.threshold")) - 1);
function _getThresholdStorage() private pure returns (_Threshold storage $) {
uint256 slot = uint256(THRESHOLD_SLOT);
assembly ("memory-safe") {
$.slot := slot
}
}
/// =============== GETTERS/SETTERS ========================================
function setThreshold(uint8 threshold) external onlyOwner {
_Threshold storage _threshold = _getThresholdStorage();
uint8 oldThreshold = _threshold.num;
_threshold.num = threshold;
_checkThresholdInvariants();
emit ThresholdChanged(oldThreshold, threshold);
}
/// @notice Returns the number of Endpoints that must attest to a msgId for
/// it to be considered valid and acted upon.
function getThreshold() public view returns (uint8) {
_Threshold storage _threshold = _getThresholdStorage();
return _threshold.num;
}
function setEndpoint(address endpoint) external onlyOwner {
_setEndpoint(endpoint);
_Threshold storage _threshold = _getThresholdStorage();
// We increase the threshold here. This might not be what the user
// wants, in which case they can call setThreshold() afterwards.
// However, this is the most sensible default behaviour, since
// this makes the system more secure in the event that the user forgets
// to call setThreshold().
_threshold.num += 1;
emit EndpointAdded(endpoint, _threshold.num);
}
function removeEndpoint(address endpoint) external onlyOwner {
_removeEndpoint(endpoint);
_Threshold storage _threshold = _getThresholdStorage();
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
if (_enabledEndpoints.length < _threshold.num) {
_threshold.num = uint8(_enabledEndpoints.length);
}
emit EndpointRemoved(endpoint, _threshold.num);
}
function quoteDeliveryPrice(uint16 recipientChain) public view override returns (uint256) {
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
uint256 totalPriceQuote = 0;
for (uint256 i = 0; i < _enabledEndpoints.length; i++) {
uint256 endpointPriceQuote =
IEndpointStandalone(_enabledEndpoints[i]).quoteDeliveryPrice(recipientChain);
totalPriceQuote += endpointPriceQuote;
}
return totalPriceQuote;
}
function _sendMessageToEndpoint(
uint16 recipientChain,
bytes memory payload
) internal override {
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
// call into endpoint contracts to send the message
for (uint256 i = 0; i < _enabledEndpoints.length; i++) {
uint256 endpointPriceQuote =
IEndpointStandalone(_enabledEndpoints[i]).quoteDeliveryPrice(recipientChain);
IEndpointStandalone(_enabledEndpoints[i]).sendMessage{value: endpointPriceQuote}(
recipientChain, payload
);
}
}
/// @dev Called by an Endpoint contract to deliver a verified attestation.
/// This function enforces attestation threshold and replay logic for messages.
/// Once all validations are complete, this function calls _executeMsg to execute the command specified by the message.
function attestationReceived(EndpointStructs.ManagerMessage memory payload)
external
onlyEndpoint
{
bytes32 managerMessageHash = EndpointStructs.managerMessageDigest(payload);
// set the attested flag for this endpoint.
// TODO: this allows an endpoint to attest to a message multiple times.
// This is fine, because attestation is idempotent (bitwise or 1), but
// maybe we want to revert anyway?
_setEndpointAttestedToMessage(managerMessageHash, msg.sender);
if (isMessageApproved(managerMessageHash)) {
_executeMsg(payload);
}
}
// @dev Count the number of attestations from enabled endpoints for a given message.
function messageAttestations(bytes32 digest) public view returns (uint8 count) {
return countSetBits(_getMessageAttestations(digest));
}
function isMessageApproved(bytes32 digest) public view override returns (bool) {
uint8 threshold = getThreshold();
return messageAttestations(digest) >= threshold && threshold > 0;
}
// @dev Count the number of set bits in a uint64
function countSetBits(uint64 x) public pure returns (uint8 count) {
while (x != 0) {
x &= x - 1;
count++;
}
return count;
}
/// ============== INVARIANTS =============================================
function _checkThresholdInvariants() internal view {
_Threshold storage _threshold = _getThresholdStorage();
address[] storage _enabledEndpoints = _getEnabledEndpointsStorage();
// invariant: threshold <= enabledEndpoints.length
if (_threshold.num > _enabledEndpoints.length) {
revert ThresholdTooHigh(_threshold.num, _enabledEndpoints.length);
}
if (_enabledEndpoints.length > 0) {
if (_threshold.num == 0) {
revert ZeroThreshold();
}
}
}
}