From ef0a2be8ef1ffe1697b59b629bdff025da7b7587 Mon Sep 17 00:00:00 2001 From: Dirk Brink Date: Thu, 29 Feb 2024 11:41:03 -0800 Subject: [PATCH 1/3] evm: Always fetch token decimals rather than caching --- evm/src/NttManager/NttManager.sol | 22 ++++++++++++---------- evm/src/NttManager/NttManagerState.sol | 19 +++++++------------ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/evm/src/NttManager/NttManager.sol b/evm/src/NttManager/NttManager.sol index ee02a0b6e..c47d0a340 100644 --- a/evm/src/NttManager/NttManager.sol +++ b/evm/src/NttManager/NttManager.sol @@ -155,8 +155,9 @@ contract NttManager is INttManager, NttManagerState { if (nativeTokenTransfer.toChain != chainId) { revert InvalidTargetChain(nativeTokenTransfer.toChain, chainId); } - TrimmedAmount nativeTransferAmount = - (nativeTokenTransfer.amount.untrim(tokenDecimals_)).trim(tokenDecimals_, tokenDecimals_); + uint8 tokenDecimals = tokenDecimals(); + TrimmedAmount memory nativeTransferAmount = + (nativeTokenTransfer.amount.untrim(tokenDecimals)).trim(tokenDecimals, tokenDecimals); address transferRecipient = fromWormholeFormat(nativeTokenTransfer.to); @@ -313,8 +314,8 @@ contract NttManager is INttManager, NttManagerState { } // trim amount after burning to ensure transfer amount matches (amount - fee) - TrimmedAmount trimmedAmount = _trimTransferAmount(amount, recipientChain); - TrimmedAmount internalAmount = trimmedAmount.shift(tokenDecimals_); + TrimmedAmount memory trimmedAmount = _trimTransferAmount(amount, recipientChain); + TrimmedAmount memory internalAmount = trimmedAmount.shift(tokenDecimals()); // get the sequence for this transfer uint64 sequence = _useMessageSequence(); @@ -415,7 +416,7 @@ contract NttManager is INttManager, NttManagerState { uint16 destinationChain = recipientChain; emit TransferSent( - recipient, amt.untrim(tokenDecimals_), totalPriceQuote, destinationChain, seq + recipient, amt.untrim(tokenDecimals()), totalPriceQuote, destinationChain, seq ); // return the sequence number @@ -429,7 +430,7 @@ contract NttManager is INttManager, NttManagerState { ) internal { // calculate proper amount of tokens to unlock/mint to recipient // untrim the amount - uint256 untrimmedAmount = amount.untrim(tokenDecimals_); + uint256 untrimmedAmount = amount.untrim(tokenDecimals()); emit TransferRedeemed(digest); @@ -444,9 +445,9 @@ contract NttManager is INttManager, NttManagerState { } } - /// @inheritdoc INttManager function tokenDecimals() public view override(INttManager, RateLimiter) returns (uint8) { - return tokenDecimals_; + (, bytes memory queriedDecimals) = token.staticcall(abi.encodeWithSignature("decimals()")); + return abi.decode(queriedDecimals, (uint8)); } // ==================== Internal Helpers =============================================== @@ -473,9 +474,10 @@ contract NttManager is INttManager, NttManagerState { TrimmedAmount trimmedAmount; { - trimmedAmount = amount.trim(tokenDecimals_, toDecimals); + uint8 fromDecimals = tokenDecimals(); + trimmedAmount = amount.trim(fromDecimals, toDecimals); // don't deposit dust that can not be bridged due to the decimal shift - uint256 newAmount = trimmedAmount.untrim(tokenDecimals_); + uint256 newAmount = trimmedAmount.untrim(fromDecimals); if (amount != newAmount) { revert TransferAmountHasDust(amount, amount - newAmount); } diff --git a/evm/src/NttManager/NttManagerState.sol b/evm/src/NttManager/NttManagerState.sol index 58c70cc16..57191fa9e 100644 --- a/evm/src/NttManager/NttManagerState.sol +++ b/evm/src/NttManager/NttManagerState.sol @@ -43,7 +43,6 @@ abstract contract NttManagerState is INttManager.Mode public immutable mode; uint16 public immutable chainId; uint256 immutable evmChainId; - uint8 public immutable tokenDecimals_; // =============== Setup ================================================================= @@ -55,7 +54,6 @@ abstract contract NttManagerState is bool _skipRateLimiting ) RateLimiter(_rateLimitDuration, _skipRateLimiting) { token = _token; - tokenDecimals_ = _initializeTokenDecimals(); mode = _mode; chainId = _chainId; evmChainId = block.chainid; @@ -70,7 +68,7 @@ abstract contract NttManagerState is } __PausedOwnable_init(msg.sender, msg.sender); __ReentrancyGuard_init(); - _setOutboundLimit(TrimmedAmountLib.max(tokenDecimals_)); + _setOutboundLimit(TrimmedAmountLib.max(tokenDecimals())); } function _initialize() internal virtual override { @@ -278,7 +276,8 @@ abstract contract NttManagerState is _getPeersStorage()[peerChainId].peerAddress = peerContract; _getPeersStorage()[peerChainId].tokenDecimals = decimals; - _setInboundLimit(inboundLimit.trim(tokenDecimals_, tokenDecimals_), peerChainId); + uint8 tokenDecimals = tokenDecimals(); + _setInboundLimit(inboundLimit.trim(tokenDecimals, tokenDecimals), peerChainId); emit PeerUpdated( peerChainId, oldPeer.peerAddress, oldPeer.tokenDecimals, peerContract, decimals @@ -287,12 +286,14 @@ abstract contract NttManagerState is /// @inheritdoc INttManagerState function setOutboundLimit(uint256 limit) external onlyOwner { - _setOutboundLimit(limit.trim(tokenDecimals_, tokenDecimals_)); + uint8 tokenDecimals = tokenDecimals(); + _setOutboundLimit(limit.trim(tokenDecimals, tokenDecimals)); } /// @inheritdoc INttManagerState function setInboundLimit(uint256 limit, uint16 chainId_) external onlyOwner { - _setInboundLimit(limit.trim(tokenDecimals_, tokenDecimals_), chainId_); + uint8 tokenDecimals = tokenDecimals(); + _setInboundLimit(limit.trim(tokenDecimals, tokenDecimals), chainId_); } // =============== Internal ============================================================== @@ -349,17 +350,11 @@ abstract contract NttManagerState is _getMessageSequenceStorage().num++; } - function _initializeTokenDecimals() internal view returns (uint8) { - (, bytes memory queriedDecimals) = token.staticcall(abi.encodeWithSignature("decimals()")); - return abi.decode(queriedDecimals, (uint8)); - } - /// ============== Invariants ============================================= /// @dev When we add new immutables, this function should be updated function _checkImmutables() internal view override { assert(this.token() == token); - assert(this.tokenDecimals_() == tokenDecimals_); assert(this.mode() == mode); assert(this.chainId() == chainId); assert(this.rateLimitDuration() == rateLimitDuration); From 3bf952de2dba842b33abfdc88e2ac6c9b68a6a45 Mon Sep 17 00:00:00 2001 From: Dirk Brink Date: Mon, 4 Mar 2024 15:48:47 -0800 Subject: [PATCH 2/3] Add tests --- evm/src/mocks/DummyToken.sol | 19 +++++- evm/test/NttManager.t.sol | 115 +++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 1 deletion(-) diff --git a/evm/src/mocks/DummyToken.sol b/evm/src/mocks/DummyToken.sol index 02d5c4cac..403d44c51 100644 --- a/evm/src/mocks/DummyToken.sol +++ b/evm/src/mocks/DummyToken.sol @@ -3,8 +3,9 @@ pragma solidity >=0.8.8 <0.9.0; import "openzeppelin-contracts/contracts/token/ERC20/ERC20.sol"; +import "openzeppelin-contracts/contracts/proxy/ERC1967/ERC1967Upgrade.sol"; -contract DummyToken is ERC20 { +contract DummyToken is ERC20, ERC1967Upgrade { constructor() ERC20("DummyToken", "DTKN") {} // NOTE: this is purposefully not called mint() to so we can test that in @@ -24,6 +25,10 @@ contract DummyToken is ERC20 { function burn(address, uint256) public virtual { revert("Locking nttManager should not call 'burn()'"); } + + function upgrade(address newImplementation) public { + _upgradeTo(newImplementation); + } } contract DummyTokenMintAndBurn is DummyToken { @@ -37,3 +42,15 @@ contract DummyTokenMintAndBurn is DummyToken { _burn(msg.sender, amount); } } + +contract DummyTokenDifferentDecimals is DummyTokenMintAndBurn { + uint8 private immutable _decimals; + + constructor(uint8 newDecimals) { + _decimals = newDecimals; + } + + function decimals() public view override returns (uint8) { + return _decimals; + } +} diff --git a/evm/test/NttManager.t.sol b/evm/test/NttManager.t.sol index 0f74f6107..0602b42aa 100644 --- a/evm/test/NttManager.t.sol +++ b/evm/test/NttManager.t.sol @@ -693,4 +693,119 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { assertEq(token.balanceOf(address(user_B)), transferAmount.untrim(token.decimals()) * 2); } + + function test_tokenUpgradedAndDecimalsChanged() public { + DummyToken dummy1 = new DummyTokenMintAndBurn(); + + // Make the token an upgradeable token + DummyTokenMintAndBurn t = + DummyTokenMintAndBurn(address(new ERC1967Proxy(address(dummy1), ""))); + + NttManager implementation = + new MockNttManagerContract(address(t), INttManager.Mode.LOCKING, chainId, 1 days, false); + + MockNttManagerContract newNttManager = + MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); + newNttManager.initialize(); + // register nttManager peer + bytes32 peer = toWormholeFormat(address(nttManager)); + newNttManager.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9); + + address user_A = address(0x123); + address user_B = address(0x456); + t.mintDummy(address(user_A), 5 * 10 ** t.decimals()); + + // Check that we can initiate a transfer + vm.startPrank(user_A); + t.approve(address(newNttManager), 3 * 10 ** t.decimals()); + newNttManager.transfer( + 1 * 10 ** t.decimals(), + TransceiverHelpersLib.SENDING_CHAIN_ID, + toWormholeFormat(user_B), + false, + new bytes(1) + ); + vm.stopPrank(); + + // Check that we can receive a transfer + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(newNttManager); + newNttManager.setThreshold(1); + + bytes memory transceiverMessage; + bytes memory tokenTransferMessage; + + TrimmedAmount memory transferAmount = TrimmedAmount(100, 8); + + tokenTransferMessage = TransceiverStructs.encodeNativeTokenTransfer( + TransceiverStructs.NativeTokenTransfer({ + amount: transferAmount, + sourceToken: toWormholeFormat(address(t)), + to: toWormholeFormat(user_B), + toChain: chainId + }) + ); + + (, transceiverMessage) = TransceiverHelpersLib.buildTransceiverMessageWithNttManagerPayload( + 0, bytes32(0), peer, toWormholeFormat(address(newNttManager)), tokenTransferMessage + ); + + e1.receiveMessage(transceiverMessage); + uint256 userBBalanceBefore = t.balanceOf(address(user_B)); + assertEq(userBBalanceBefore, transferAmount.untrim(t.decimals())); + + // If the token decimals change to the same trimmed amount, we should safely receive the correct number of tokens + DummyTokenDifferentDecimals dummy2 = new DummyTokenDifferentDecimals(10); // 10 gets trimmed to 8 + t.upgrade(address(dummy2)); + + vm.startPrank(user_A); + newNttManager.transfer( + 1 * 10 ** 10, + TransceiverHelpersLib.SENDING_CHAIN_ID, + toWormholeFormat(user_B), + false, + new bytes(1) + ); + vm.stopPrank(); + + (, transceiverMessage) = TransceiverHelpersLib.buildTransceiverMessageWithNttManagerPayload( + bytes32("1"), + bytes32(0), + peer, + toWormholeFormat(address(newNttManager)), + tokenTransferMessage + ); + e1.receiveMessage(transceiverMessage); + assertEq( + t.balanceOf(address(user_B)), userBBalanceBefore + transferAmount.untrim(t.decimals()) + ); + + // Now if the token decimals change to a different trimmed amount, we shouldn't be able to send or receive + DummyTokenDifferentDecimals dummy3 = new DummyTokenDifferentDecimals(7); // 7 is 7 trimmed + t.upgrade(address(dummy3)); + + vm.startPrank(user_A); + vm.expectRevert( + abi.encodeWithSelector(TrimmedAmountLib.NumberOfDecimalsNotEqual.selector, 8, 7) + ); + newNttManager.transfer( + 1 * 10 ** 7, + TransceiverHelpersLib.SENDING_CHAIN_ID, + toWormholeFormat(user_B), + false, + new bytes(1) + ); + vm.stopPrank(); + + (, transceiverMessage) = TransceiverHelpersLib.buildTransceiverMessageWithNttManagerPayload( + bytes32("2"), + bytes32(0), + peer, + toWormholeFormat(address(newNttManager)), + tokenTransferMessage + ); + vm.expectRevert( + abi.encodeWithSelector(TrimmedAmountLib.NumberOfDecimalsNotEqual.selector, 8, 7) + ); + e1.receiveMessage(transceiverMessage); + } } From 201ecc4d64528653d9ac0717188f6da9999cd083 Mon Sep 17 00:00:00 2001 From: Dirk Brink Date: Wed, 6 Mar 2024 13:35:20 -0800 Subject: [PATCH 3/3] Rebasing fixes --- evm/src/NttManager/NttManager.sol | 10 +++++----- evm/test/NttManager.t.sol | 12 ++++-------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/evm/src/NttManager/NttManager.sol b/evm/src/NttManager/NttManager.sol index c47d0a340..325577e3f 100644 --- a/evm/src/NttManager/NttManager.sol +++ b/evm/src/NttManager/NttManager.sol @@ -155,9 +155,9 @@ contract NttManager is INttManager, NttManagerState { if (nativeTokenTransfer.toChain != chainId) { revert InvalidTargetChain(nativeTokenTransfer.toChain, chainId); } - uint8 tokenDecimals = tokenDecimals(); - TrimmedAmount memory nativeTransferAmount = - (nativeTokenTransfer.amount.untrim(tokenDecimals)).trim(tokenDecimals, tokenDecimals); + uint8 toDecimals = tokenDecimals(); + TrimmedAmount nativeTransferAmount = + (nativeTokenTransfer.amount.untrim(toDecimals)).trim(toDecimals, toDecimals); address transferRecipient = fromWormholeFormat(nativeTokenTransfer.to); @@ -314,8 +314,8 @@ contract NttManager is INttManager, NttManagerState { } // trim amount after burning to ensure transfer amount matches (amount - fee) - TrimmedAmount memory trimmedAmount = _trimTransferAmount(amount, recipientChain); - TrimmedAmount memory internalAmount = trimmedAmount.shift(tokenDecimals()); + TrimmedAmount trimmedAmount = _trimTransferAmount(amount, recipientChain); + TrimmedAmount internalAmount = trimmedAmount.shift(tokenDecimals()); // get the sequence for this transfer uint64 sequence = _useMessageSequence(); diff --git a/evm/test/NttManager.t.sol b/evm/test/NttManager.t.sol index 0602b42aa..d23e37474 100644 --- a/evm/test/NttManager.t.sol +++ b/evm/test/NttManager.t.sol @@ -709,7 +709,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { newNttManager.initialize(); // register nttManager peer bytes32 peer = toWormholeFormat(address(nttManager)); - newNttManager.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9); + newNttManager.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9, type(uint64).max); address user_A = address(0x123); address user_B = address(0x456); @@ -734,7 +734,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { bytes memory transceiverMessage; bytes memory tokenTransferMessage; - TrimmedAmount memory transferAmount = TrimmedAmount(100, 8); + TrimmedAmount transferAmount = packTrimmedAmount(100, 8); tokenTransferMessage = TransceiverStructs.encodeNativeTokenTransfer( TransceiverStructs.NativeTokenTransfer({ @@ -784,9 +784,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { t.upgrade(address(dummy3)); vm.startPrank(user_A); - vm.expectRevert( - abi.encodeWithSelector(TrimmedAmountLib.NumberOfDecimalsNotEqual.selector, 8, 7) - ); + vm.expectRevert(abi.encodeWithSelector(NumberOfDecimalsNotEqual.selector, 8, 7)); newNttManager.transfer( 1 * 10 ** 7, TransceiverHelpersLib.SENDING_CHAIN_ID, @@ -803,9 +801,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { toWormholeFormat(address(newNttManager)), tokenTransferMessage ); - vm.expectRevert( - abi.encodeWithSelector(TrimmedAmountLib.NumberOfDecimalsNotEqual.selector, 8, 7) - ); + vm.expectRevert(abi.encodeWithSelector(NumberOfDecimalsNotEqual.selector, 8, 7)); e1.receiveMessage(transceiverMessage); } }