Skip to content

Commit 8c95b29

Browse files
author
Rahul Maganti
committed
evm: test additional invariants
1 parent bd05006 commit 8c95b29

File tree

1 file changed

+72
-19
lines changed

1 file changed

+72
-19
lines changed

evm/test/TrimmedAmount.t.sol

+72-19
Original file line numberDiff line numberDiff line change
@@ -221,25 +221,6 @@ contract TrimmingTest is Test {
221221
assertEq(expectedIsLt, isLt);
222222
}
223223

224-
// invariant: forall (x: TrimmedAmount, aDecimals: uint8, bDecimals: uint8),
225-
// (x.amount <= type(uint64).max)
226-
// => (trim(untrim(x)) == x)
227-
function testFuzz_trimIsLeftInverse(uint256 amount, uint8 aDecimals, uint8 bDecimals) public {
228-
// restrict inputs up to u64MAX. Inputs above u64 are tested elsewhere
229-
amount = bound(amount, 0, type(uint64).max);
230-
vm.assume(aDecimals <= 50);
231-
vm.assume(bDecimals <= 50);
232-
233-
// initialize TrimmedAmount
234-
TrimmedAmount trimmedAmount = amount.trim(aDecimals, bDecimals);
235-
236-
// trimming is the left inverse of trimming
237-
// e.g. trim(untrim(x)) == x
238-
TrimmedAmount amountRoundTrip = (trimmedAmount.untrim(bDecimals)).trim(bDecimals, aDecimals);
239-
240-
assertEq(trimmedAmount.getAmount(), amountRoundTrip.getAmount());
241-
}
242-
243224
// invariant: forall (TrimmedAmount a, TrimmedAmount b)
244225
// a.saturatingAdd(b).amount <= type(uint64).max
245226
function testFuzz_saturatingAddDoesNotOverflow(TrimmedAmount a, TrimmedAmount b) public {
@@ -316,4 +297,76 @@ contract TrimmingTest is Test {
316297
assertEq(expectedTrimmedSum.getAmount(), trimmedSum.getAmount());
317298
assertEq(expectedTrimmedSum.getDecimals(), trimmedSum.getDecimals());
318299
}
300+
301+
function testFuzz_trimmingInvariants(
302+
uint256 amount,
303+
uint256 amount2,
304+
uint8 fromDecimals,
305+
uint8 midDecimals,
306+
uint8 toDecimals
307+
) public {
308+
// restrict inputs up to u64MAX. Inputs above u64 are tested elsewhere
309+
amount = bound(amount, 0, type(uint64).max);
310+
amount2 = bound(amount, 0, type(uint64).max);
311+
vm.assume(fromDecimals <= 50);
312+
vm.assume(toDecimals <= 50);
313+
314+
TrimmedAmount trimmedAmt = amount.trim(fromDecimals, toDecimals);
315+
TrimmedAmount trimmedAmt2 = amount2.trim(fromDecimals, toDecimals);
316+
uint256 untrimmedAmt = trimmedAmt.untrim(fromDecimals);
317+
uint256 untrimmedAmt2 = trimmedAmt2.untrim(fromDecimals);
318+
319+
// trimming is the left inverse of trimming
320+
// invariant: forall (x: TrimmedAmount, fromDecimals: uint8, toDecimals: uint8),
321+
// (x.amount <= type(uint64).max)
322+
// => (trim(untrim(x)) == x)
323+
TrimmedAmount amountRoundTrip = untrimmedAmt.trim(fromDecimals, toDecimals);
324+
assertEq(trimmedAmt.getAmount(), amountRoundTrip.getAmount());
325+
326+
// trimming is a NOOP
327+
// invariant:
328+
// forall (x: uint256, y: uint8, z: uint8),
329+
// (y < z && (y < 8 || z < 8)), trim(x) == x
330+
if (fromDecimals <= toDecimals && (fromDecimals < 8 || toDecimals < 8)) {
331+
assertEq(trimmedAmt.getAmount(), uint64(amount));
332+
}
333+
334+
// invariant: source amount is always greater than or equal to the trimmed amount
335+
// this is also captured by the invariant above
336+
assertGe(amount, trimmedAmt.getAmount());
337+
338+
// invariant: trimmed amount must not exceed the untrimmed amount
339+
assertLe(trimmedAmt.getAmount(), untrimmedAmt);
340+
341+
// invariant: untrimmed amount must not exceed the source amount
342+
assertLe(untrimmedAmt, amount);
343+
344+
// invariant:
345+
// the number of decimals after trimming must not exceed
346+
// the number of decimals before trimming
347+
assertLe(trimmedAmt.getDecimals(), fromDecimals);
348+
349+
// invariant:
350+
// trimming and untrimming preserve ordering relations
351+
if (amount > amount2) {
352+
assertGt(untrimmedAmt, untrimmedAmt2);
353+
} else if (amount < amount2) {
354+
assertLt(untrimmedAmt, untrimmedAmt2);
355+
} else {
356+
assertEq(untrimmedAmt, untrimmedAmt2);
357+
}
358+
359+
// invariant: trimming and untrimming are commutative when
360+
// the number of decimals are the same and less than or equal to 8
361+
if (fromDecimals <= 8 && fromDecimals == toDecimals) {
362+
assertEq(amount, untrimmedAmt);
363+
}
364+
365+
// invariant: trimming and untrimming are associative
366+
// when there is no intermediate loss of precision
367+
vm.assume(midDecimals >= fromDecimals);
368+
TrimmedAmount trimmedAmtA = amount.trim(fromDecimals, midDecimals);
369+
TrimmedAmount trimmedAmtB = amount.trim(fromDecimals, toDecimals);
370+
assertEq(trimmedAmtA.untrim(toDecimals), trimmedAmtB.untrim(toDecimals));
371+
}
319372
}

0 commit comments

Comments
 (0)