Skip to content

Commit

Permalink
fix(math): fix edge cases of ln and pow
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewWestberg committed Oct 26, 2024
1 parent 0dd9bcd commit 4d879da
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 7 deletions.
154 changes: 154 additions & 0 deletions pallas-math/src/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,4 +479,158 @@ mod tests {
assert_eq!(res.iterations.to_string(), expected_iterations);
}
}

#[test]
#[should_panic(expected = "ln of a value in (-inf,0] is undefined")]
fn ln_of_0_should_be_undefined() {
let zero: FixedDecimal = FixedDecimal::from(0u64);
zero.ln();
}

#[test]
#[should_panic(expected = "ln of a value in (-inf,0] is undefined")]
fn ln_of_negative_should_be_undefined() {
let minus_one: FixedDecimal = FixedDecimal::from(-1i64);
minus_one.ln();
}

#[test]
fn pow_of_zero_to_any_positive_power_should_be_zero() {
let zero: FixedDecimal = FixedDecimal::from(0u64);
let two: FixedDecimal = FixedDecimal::from(2u64);
let three: FixedDecimal = FixedDecimal::from(3u64);
assert_eq!(zero.pow(&two), zero);
assert_eq!(zero.pow(&three), zero);
}

#[test]
#[should_panic(expected = "zero to a negative power is undefined")]
fn pow_of_zero_to_neg_power_should_be_undefined() {
let zero: FixedDecimal = FixedDecimal::from(0u64);
let minus_one: FixedDecimal = FixedDecimal::from(-1i64);
zero.pow(&minus_one);
}

#[test]
fn pow_of_any_to_power_0_should_be_1() {
let zero: FixedDecimal = FixedDecimal::from(0u64);
let one: FixedDecimal = FixedDecimal::from(1u64);
let neg_one: FixedDecimal = FixedDecimal::from(-1i64);
assert_eq!(one.pow(&zero), one);
assert_eq!(neg_one.pow(&zero), one);
}

#[test]
fn pow_of_any_to_power_1_should_be_same() {
let zero: FixedDecimal = FixedDecimal::from(0u64);
let one: FixedDecimal = FixedDecimal::from(1u64);
let neg_one: FixedDecimal = FixedDecimal::from(-1i64);
let two: FixedDecimal = FixedDecimal::from(2u64);
let three: FixedDecimal = FixedDecimal::from(3u64);
assert_eq!(zero.pow(&one), zero);
assert_eq!(one.pow(&one), one);
assert_eq!(neg_one.pow(&one), neg_one);
assert_eq!(two.pow(&one), two);
assert_eq!(three.pow(&one), three);
}

#[test]
fn pow_of_negative_to_even_pos_power_should_be_positive() {
let minus_five: FixedDecimal = FixedDecimal::from(-5i64);
let two: FixedDecimal = FixedDecimal::from(2u64);
let four: FixedDecimal = FixedDecimal::from(4u64);
let six: FixedDecimal = FixedDecimal::from(6u64);
assert_eq!(
minus_five.pow(&two),
FixedDecimal::from_str("249999999999999999999999909295186150", 34).unwrap()
);
assert_eq!(
minus_five.pow(&four),
FixedDecimal::from_str("6249999999999999999999997600722222790", 34).unwrap()
);
assert_eq!(
minus_five.pow(&six),
FixedDecimal::from_str("156249999999999999999999625575135657250", 34).unwrap()
);
}

#[test]
fn pow_of_negative_to_odd_pos_power_should_be_negative() {
let minus_five: FixedDecimal = FixedDecimal::from(-5i64);
let three: FixedDecimal = FixedDecimal::from(3u64);
let five: FixedDecimal = FixedDecimal::from(5u64);
let seven: FixedDecimal = FixedDecimal::from(7u64);
assert_eq!(
minus_five.pow(&three),
FixedDecimal::from_str("-1249999999999999999999998502300542629", 34).unwrap()
);
assert_eq!(
minus_five.pow(&five),
FixedDecimal::from_str("-31249999999999999999999996346529171901", 34).unwrap()
);
assert_eq!(
minus_five.pow(&seven),
FixedDecimal::from_str("-781249999999999999999998980932253570390", 34).unwrap()
);
}

#[test]
fn pow_of_negative_to_even_neg_power_should_be_inverted() {
let minus_five: FixedDecimal = FixedDecimal::from(-5i64);
let minus_two: FixedDecimal = FixedDecimal::from(-2i64);
let minus_four: FixedDecimal = FixedDecimal::from(-4i64);
let minus_six: FixedDecimal = FixedDecimal::from(-6i64);
assert_eq!(
minus_five.pow(&minus_two),
FixedDecimal::from_str("400000000000000000000000145127702", 34).unwrap()
);
assert_eq!(
minus_five.pow(&minus_four),
FixedDecimal::from_str("16000000000000000000000006142151", 34).unwrap()
);
assert_eq!(
minus_five.pow(&minus_six),
FixedDecimal::from_str("640000000000000000000001533644", 34).unwrap()
);
}

#[test]
fn pow_of_positive_to_even_neg_power_should_be_inverted() {
let five: FixedDecimal = FixedDecimal::from(5u64);
let minus_two: FixedDecimal = FixedDecimal::from(-2i64);
let minus_four: FixedDecimal = FixedDecimal::from(-4i64);
let minus_six: FixedDecimal = FixedDecimal::from(-6i64);
assert_eq!(
five.pow(&minus_two),
FixedDecimal::from_str("400000000000000000000000145127702", 34).unwrap()
);
assert_eq!(
five.pow(&minus_four),
FixedDecimal::from_str("16000000000000000000000006142151", 34).unwrap()
);
assert_eq!(
five.pow(&minus_six),
FixedDecimal::from_str("640000000000000000000001533644", 34).unwrap()
);
}

#[test]
fn pow_of_positive_to_odd_neg_power_should_be_inverted() {
let five: FixedDecimal = FixedDecimal::from(5u64);
let minus_three: FixedDecimal = FixedDecimal::from(-3i64);
let minus_five: FixedDecimal = FixedDecimal::from(-5i64);
let minus_seven: FixedDecimal = FixedDecimal::from(-7i64);
assert_eq!(
five.pow(&minus_three),
FixedDecimal::from_str("80000000000000000000000095852765", 34).unwrap()
);
assert_eq!(
five.pow(&minus_five),
FixedDecimal::from_str("3200000000000000000000000374115", 34).unwrap()
);
assert_eq!(
five.pow(&minus_seven),
FixedDecimal::from_str("128000000000000000000000166964", 34).unwrap()
);
}
}
54 changes: 47 additions & 7 deletions pallas-math/src/math_malachite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use malachite::num::basic::traits::One;
use malachite::platform_64::Limb;
use malachite::rounding_modes::RoundingMode;
use malachite::{Integer, Natural};
use malachite_base::num::arithmetic::traits::Sign;
use malachite_base::num::arithmetic::traits::{Parity, Sign};
use once_cell::sync::Lazy;
use regex::Regex;
use std::cmp::Ordering;
Expand Down Expand Up @@ -310,8 +310,11 @@ impl FixedPrecision for Decimal {

fn ln(&self) -> Self {
let mut ln_x = Decimal::new(self.precision);
ref_ln(&mut ln_x.data, &self.data);
ln_x
if ref_ln(&mut ln_x.data, &self.data) {
ln_x
} else {
panic!("ln of a value in (-inf,0] is undefined")
}
}

fn pow(&self, rhs: &Self) -> Self {
Expand Down Expand Up @@ -677,10 +680,47 @@ fn ref_ln(rop: &mut Integer, x: &Integer) -> bool {
fn ref_pow(rop: &mut Integer, base: &Integer, exponent: &Integer) {
/* x^y = exp(y * ln x) */
let mut tmp: Integer = Integer::from(0);
ref_ln(&mut tmp, base);
tmp *= exponent;
scale(&mut tmp);
ref_exp(rop, &tmp);

if exponent == &tmp || base == &ONE.value {
// any base to the power of zero is one, or 1 to any power is 1
*rop = ONE.value.clone();
return;
}
if exponent == &ONE.value {
// any base to the power of one is the base
*rop = base.clone();
return;
}
if base == &tmp && exponent > &tmp {
// zero to any positive power is zero
*rop = Integer::from(0) * &PRECISION.value;
return;
}
if base == &tmp && exponent < &tmp {
panic!("zero to a negative power is undefined");
}
if base < &tmp {
// negate the base and calculate the power
let neg_base = base.neg();
let ref_ln = ref_ln(&mut tmp, &neg_base);
debug_assert!(ref_ln);
tmp *= exponent;
scale(&mut tmp);
let mut tmp_rop = Integer::from(0);
ref_exp(&mut tmp_rop, &tmp);
*rop = if (exponent / &PRECISION.value).even() {
tmp_rop
} else {
-tmp_rop
};
} else {
// base is positive, ref_ln result is valid
let ref_ln = ref_ln(&mut tmp, base);
debug_assert!(ref_ln);
tmp *= exponent;
scale(&mut tmp);
ref_exp(rop, &tmp);
}
}

/// `bound_x` is the bound for exp in the interval x is chosen from
Expand Down

0 comments on commit 4d879da

Please sign in to comment.