@@ -221,25 +221,6 @@ contract TrimmingTest is Test {
221
221
assertEq (expectedIsLt, isLt);
222
222
}
223
223
224
- // invariant: forall (x: TrimmedAmount, aDecimals: uint8, bDecimals: uint8),
225
- // (x.amount <= type(uint64).max)
226
- // => (trim(untrim(x)) == x)
227
- function testFuzz_trimIsLeftInverse (uint256 amount , uint8 aDecimals , uint8 bDecimals ) public {
228
- // restrict inputs up to u64MAX. Inputs above u64 are tested elsewhere
229
- amount = bound (amount, 0 , type (uint64 ).max);
230
- vm.assume (aDecimals <= 50 );
231
- vm.assume (bDecimals <= 50 );
232
-
233
- // initialize TrimmedAmount
234
- TrimmedAmount trimmedAmount = amount.trim (aDecimals, bDecimals);
235
-
236
- // trimming is the left inverse of trimming
237
- // e.g. trim(untrim(x)) == x
238
- TrimmedAmount amountRoundTrip = (trimmedAmount.untrim (bDecimals)).trim (bDecimals, aDecimals);
239
-
240
- assertEq (trimmedAmount.getAmount (), amountRoundTrip.getAmount ());
241
- }
242
-
243
224
// invariant: forall (TrimmedAmount a, TrimmedAmount b)
244
225
// a.saturatingAdd(b).amount <= type(uint64).max
245
226
function testFuzz_saturatingAddDoesNotOverflow (TrimmedAmount a , TrimmedAmount b ) public {
@@ -316,4 +297,76 @@ contract TrimmingTest is Test {
316
297
assertEq (expectedTrimmedSum.getAmount (), trimmedSum.getAmount ());
317
298
assertEq (expectedTrimmedSum.getDecimals (), trimmedSum.getDecimals ());
318
299
}
300
+
301
+ function testFuzz_trimmingInvariants (
302
+ uint256 amount ,
303
+ uint256 amount2 ,
304
+ uint8 fromDecimals ,
305
+ uint8 midDecimals ,
306
+ uint8 toDecimals
307
+ ) public {
308
+ // restrict inputs up to u64MAX. Inputs above u64 are tested elsewhere
309
+ amount = bound (amount, 0 , type (uint64 ).max);
310
+ amount2 = bound (amount, 0 , type (uint64 ).max);
311
+ vm.assume (fromDecimals <= 50 );
312
+ vm.assume (toDecimals <= 50 );
313
+
314
+ TrimmedAmount trimmedAmt = amount.trim (fromDecimals, toDecimals);
315
+ TrimmedAmount trimmedAmt2 = amount2.trim (fromDecimals, toDecimals);
316
+ uint256 untrimmedAmt = trimmedAmt.untrim (fromDecimals);
317
+ uint256 untrimmedAmt2 = trimmedAmt2.untrim (fromDecimals);
318
+
319
+ // trimming is the left inverse of trimming
320
+ // invariant: forall (x: TrimmedAmount, fromDecimals: uint8, toDecimals: uint8),
321
+ // (x.amount <= type(uint64).max)
322
+ // => (trim(untrim(x)) == x)
323
+ TrimmedAmount amountRoundTrip = untrimmedAmt.trim (fromDecimals, toDecimals);
324
+ assertEq (trimmedAmt.getAmount (), amountRoundTrip.getAmount ());
325
+
326
+ // trimming is a NOOP
327
+ // invariant:
328
+ // forall (x: uint256, y: uint8, z: uint8),
329
+ // (y < z && (y < 8 || z < 8)), trim(x) == x
330
+ if (fromDecimals <= toDecimals && (fromDecimals < 8 || toDecimals < 8 )) {
331
+ assertEq (trimmedAmt.getAmount (), uint64 (amount));
332
+ }
333
+
334
+ // invariant: source amount is always greater than or equal to the trimmed amount
335
+ // this is also captured by the invariant above
336
+ assertGe (amount, trimmedAmt.getAmount ());
337
+
338
+ // invariant: trimmed amount must not exceed the untrimmed amount
339
+ assertLe (trimmedAmt.getAmount (), untrimmedAmt);
340
+
341
+ // invariant: untrimmed amount must not exceed the source amount
342
+ assertLe (untrimmedAmt, amount);
343
+
344
+ // invariant:
345
+ // the number of decimals after trimming must not exceed
346
+ // the number of decimals before trimming
347
+ assertLe (trimmedAmt.getDecimals (), fromDecimals);
348
+
349
+ // invariant:
350
+ // trimming and untrimming preserve ordering relations
351
+ if (amount > amount2) {
352
+ assertGt (untrimmedAmt, untrimmedAmt2);
353
+ } else if (amount < amount2) {
354
+ assertLt (untrimmedAmt, untrimmedAmt2);
355
+ } else {
356
+ assertEq (untrimmedAmt, untrimmedAmt2);
357
+ }
358
+
359
+ // invariant: trimming and untrimming are commutative when
360
+ // the number of decimals are the same and less than or equal to 8
361
+ if (fromDecimals <= 8 && fromDecimals == toDecimals) {
362
+ assertEq (amount, untrimmedAmt);
363
+ }
364
+
365
+ // invariant: trimming and untrimming are associative
366
+ // when there is no intermediate loss of precision
367
+ vm.assume (midDecimals >= fromDecimals);
368
+ TrimmedAmount trimmedAmtA = amount.trim (fromDecimals, midDecimals);
369
+ TrimmedAmount trimmedAmtB = amount.trim (fromDecimals, toDecimals);
370
+ assertEq (trimmedAmtA.untrim (toDecimals), trimmedAmtB.untrim (toDecimals));
371
+ }
319
372
}
0 commit comments