Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evm: Always fetch token decimals rather than caching #238

Merged
merged 3 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions evm/src/NttManager/NttManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,9 @@ contract NttManager is INttManager, NttManagerState {
if (nativeTokenTransfer.toChain != chainId) {
revert InvalidTargetChain(nativeTokenTransfer.toChain, chainId);
}
uint8 toDecimals = tokenDecimals();
TrimmedAmount nativeTransferAmount =
(nativeTokenTransfer.amount.untrim(tokenDecimals_)).trim(tokenDecimals_, tokenDecimals_);
(nativeTokenTransfer.amount.untrim(toDecimals)).trim(toDecimals, toDecimals);

address transferRecipient = fromWormholeFormat(nativeTokenTransfer.to);

Expand Down Expand Up @@ -314,7 +315,7 @@ contract NttManager is INttManager, NttManagerState {

// trim amount after burning to ensure transfer amount matches (amount - fee)
TrimmedAmount trimmedAmount = _trimTransferAmount(amount, recipientChain);
TrimmedAmount internalAmount = trimmedAmount.shift(tokenDecimals_);
TrimmedAmount internalAmount = trimmedAmount.shift(tokenDecimals());

// get the sequence for this transfer
uint64 sequence = _useMessageSequence();
Expand Down Expand Up @@ -415,7 +416,7 @@ contract NttManager is INttManager, NttManagerState {
uint16 destinationChain = recipientChain;

emit TransferSent(
recipient, amt.untrim(tokenDecimals_), totalPriceQuote, destinationChain, seq
recipient, amt.untrim(tokenDecimals()), totalPriceQuote, destinationChain, seq
);

// return the sequence number
Expand All @@ -429,7 +430,7 @@ contract NttManager is INttManager, NttManagerState {
) internal {
// calculate proper amount of tokens to unlock/mint to recipient
// untrim the amount
uint256 untrimmedAmount = amount.untrim(tokenDecimals_);
uint256 untrimmedAmount = amount.untrim(tokenDecimals());

emit TransferRedeemed(digest);

Expand All @@ -444,9 +445,9 @@ contract NttManager is INttManager, NttManagerState {
}
}

/// @inheritdoc INttManager
function tokenDecimals() public view override(INttManager, RateLimiter) returns (uint8) {
return tokenDecimals_;
(, bytes memory queriedDecimals) = token.staticcall(abi.encodeWithSignature("decimals()"));
return abi.decode(queriedDecimals, (uint8));
}

// ==================== Internal Helpers ===============================================
Expand All @@ -473,9 +474,10 @@ contract NttManager is INttManager, NttManagerState {

TrimmedAmount trimmedAmount;
{
trimmedAmount = amount.trim(tokenDecimals_, toDecimals);
uint8 fromDecimals = tokenDecimals();
trimmedAmount = amount.trim(fromDecimals, toDecimals);
// don't deposit dust that can not be bridged due to the decimal shift
uint256 newAmount = trimmedAmount.untrim(tokenDecimals_);
uint256 newAmount = trimmedAmount.untrim(fromDecimals);
if (amount != newAmount) {
revert TransferAmountHasDust(amount, amount - newAmount);
}
Expand Down
19 changes: 7 additions & 12 deletions evm/src/NttManager/NttManagerState.sol
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ abstract contract NttManagerState is
INttManager.Mode public immutable mode;
uint16 public immutable chainId;
uint256 immutable evmChainId;
uint8 public immutable tokenDecimals_;

// =============== Setup =================================================================

Expand All @@ -55,7 +54,6 @@ abstract contract NttManagerState is
bool _skipRateLimiting
) RateLimiter(_rateLimitDuration, _skipRateLimiting) {
token = _token;
tokenDecimals_ = _initializeTokenDecimals();
mode = _mode;
chainId = _chainId;
evmChainId = block.chainid;
Expand All @@ -70,7 +68,7 @@ abstract contract NttManagerState is
}
__PausedOwnable_init(msg.sender, msg.sender);
__ReentrancyGuard_init();
_setOutboundLimit(TrimmedAmountLib.max(tokenDecimals_));
_setOutboundLimit(TrimmedAmountLib.max(tokenDecimals()));
}

function _initialize() internal virtual override {
Expand Down Expand Up @@ -278,7 +276,8 @@ abstract contract NttManagerState is
_getPeersStorage()[peerChainId].peerAddress = peerContract;
_getPeersStorage()[peerChainId].tokenDecimals = decimals;

_setInboundLimit(inboundLimit.trim(tokenDecimals_, tokenDecimals_), peerChainId);
uint8 tokenDecimals = tokenDecimals();
_setInboundLimit(inboundLimit.trim(tokenDecimals, tokenDecimals), peerChainId);

emit PeerUpdated(
peerChainId, oldPeer.peerAddress, oldPeer.tokenDecimals, peerContract, decimals
Expand All @@ -287,12 +286,14 @@ abstract contract NttManagerState is

/// @inheritdoc INttManagerState
function setOutboundLimit(uint256 limit) external onlyOwner {
_setOutboundLimit(limit.trim(tokenDecimals_, tokenDecimals_));
uint8 tokenDecimals = tokenDecimals();
_setOutboundLimit(limit.trim(tokenDecimals, tokenDecimals));
}

/// @inheritdoc INttManagerState
function setInboundLimit(uint256 limit, uint16 chainId_) external onlyOwner {
_setInboundLimit(limit.trim(tokenDecimals_, tokenDecimals_), chainId_);
uint8 tokenDecimals = tokenDecimals();
_setInboundLimit(limit.trim(tokenDecimals, tokenDecimals), chainId_);
}

// =============== Internal ==============================================================
Expand Down Expand Up @@ -349,17 +350,11 @@ abstract contract NttManagerState is
_getMessageSequenceStorage().num++;
}

function _initializeTokenDecimals() internal view returns (uint8) {
(, bytes memory queriedDecimals) = token.staticcall(abi.encodeWithSignature("decimals()"));
return abi.decode(queriedDecimals, (uint8));
}

/// ============== Invariants =============================================

/// @dev When we add new immutables, this function should be updated
function _checkImmutables() internal view override {
assert(this.token() == token);
assert(this.tokenDecimals_() == tokenDecimals_);
assert(this.mode() == mode);
assert(this.chainId() == chainId);
assert(this.rateLimitDuration() == rateLimitDuration);
Expand Down
19 changes: 18 additions & 1 deletion evm/src/mocks/DummyToken.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
pragma solidity >=0.8.8 <0.9.0;

import "openzeppelin-contracts/contracts/token/ERC20/ERC20.sol";
import "openzeppelin-contracts/contracts/proxy/ERC1967/ERC1967Upgrade.sol";

contract DummyToken is ERC20 {
contract DummyToken is ERC20, ERC1967Upgrade {
constructor() ERC20("DummyToken", "DTKN") {}

// NOTE: this is purposefully not called mint() to so we can test that in
Expand All @@ -24,6 +25,10 @@ contract DummyToken is ERC20 {
function burn(address, uint256) public virtual {
revert("Locking nttManager should not call 'burn()'");
}

function upgrade(address newImplementation) public {
_upgradeTo(newImplementation);
}
}

contract DummyTokenMintAndBurn is DummyToken {
Expand All @@ -37,3 +42,15 @@ contract DummyTokenMintAndBurn is DummyToken {
_burn(msg.sender, amount);
}
}

contract DummyTokenDifferentDecimals is DummyTokenMintAndBurn {
uint8 private immutable _decimals;

constructor(uint8 newDecimals) {
_decimals = newDecimals;
}

function decimals() public view override returns (uint8) {
return _decimals;
}
}
111 changes: 111 additions & 0 deletions evm/test/NttManager.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -693,4 +693,115 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents {

assertEq(token.balanceOf(address(user_B)), transferAmount.untrim(token.decimals()) * 2);
}

function test_tokenUpgradedAndDecimalsChanged() public {
DummyToken dummy1 = new DummyTokenMintAndBurn();

// Make the token an upgradeable token
DummyTokenMintAndBurn t =
DummyTokenMintAndBurn(address(new ERC1967Proxy(address(dummy1), "")));

NttManager implementation =
new MockNttManagerContract(address(t), INttManager.Mode.LOCKING, chainId, 1 days, false);

MockNttManagerContract newNttManager =
MockNttManagerContract(address(new ERC1967Proxy(address(implementation), "")));
newNttManager.initialize();
// register nttManager peer
bytes32 peer = toWormholeFormat(address(nttManager));
newNttManager.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9, type(uint64).max);

address user_A = address(0x123);
address user_B = address(0x456);
t.mintDummy(address(user_A), 5 * 10 ** t.decimals());

// Check that we can initiate a transfer
vm.startPrank(user_A);
t.approve(address(newNttManager), 3 * 10 ** t.decimals());
newNttManager.transfer(
1 * 10 ** t.decimals(),
TransceiverHelpersLib.SENDING_CHAIN_ID,
toWormholeFormat(user_B),
false,
new bytes(1)
);
vm.stopPrank();

// Check that we can receive a transfer
(DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(newNttManager);
newNttManager.setThreshold(1);

bytes memory transceiverMessage;
bytes memory tokenTransferMessage;

TrimmedAmount transferAmount = packTrimmedAmount(100, 8);

tokenTransferMessage = TransceiverStructs.encodeNativeTokenTransfer(
TransceiverStructs.NativeTokenTransfer({
amount: transferAmount,
sourceToken: toWormholeFormat(address(t)),
to: toWormholeFormat(user_B),
toChain: chainId
})
);

(, transceiverMessage) = TransceiverHelpersLib.buildTransceiverMessageWithNttManagerPayload(
0, bytes32(0), peer, toWormholeFormat(address(newNttManager)), tokenTransferMessage
);

e1.receiveMessage(transceiverMessage);
uint256 userBBalanceBefore = t.balanceOf(address(user_B));
assertEq(userBBalanceBefore, transferAmount.untrim(t.decimals()));

// If the token decimals change to the same trimmed amount, we should safely receive the correct number of tokens
DummyTokenDifferentDecimals dummy2 = new DummyTokenDifferentDecimals(10); // 10 gets trimmed to 8
t.upgrade(address(dummy2));

vm.startPrank(user_A);
newNttManager.transfer(
1 * 10 ** 10,
TransceiverHelpersLib.SENDING_CHAIN_ID,
toWormholeFormat(user_B),
false,
new bytes(1)
);
vm.stopPrank();

(, transceiverMessage) = TransceiverHelpersLib.buildTransceiverMessageWithNttManagerPayload(
bytes32("1"),
bytes32(0),
peer,
toWormholeFormat(address(newNttManager)),
tokenTransferMessage
);
e1.receiveMessage(transceiverMessage);
assertEq(
t.balanceOf(address(user_B)), userBBalanceBefore + transferAmount.untrim(t.decimals())
);

// Now if the token decimals change to a different trimmed amount, we shouldn't be able to send or receive
DummyTokenDifferentDecimals dummy3 = new DummyTokenDifferentDecimals(7); // 7 is 7 trimmed
t.upgrade(address(dummy3));

vm.startPrank(user_A);
vm.expectRevert(abi.encodeWithSelector(NumberOfDecimalsNotEqual.selector, 8, 7));
newNttManager.transfer(
1 * 10 ** 7,
TransceiverHelpersLib.SENDING_CHAIN_ID,
toWormholeFormat(user_B),
false,
new bytes(1)
);
vm.stopPrank();

(, transceiverMessage) = TransceiverHelpersLib.buildTransceiverMessageWithNttManagerPayload(
bytes32("2"),
bytes32(0),
peer,
toWormholeFormat(address(newNttManager)),
tokenTransferMessage
);
vm.expectRevert(abi.encodeWithSelector(NumberOfDecimalsNotEqual.selector, 8, 7));
e1.receiveMessage(transceiverMessage);
}
}
Loading