Skip to content

Commit

Permalink
feat: Extend hypergeometric distribution PMF for non-integral arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
fpelliccioni committed Feb 4, 2025
1 parent a5c0625 commit a4fe870
Showing 1 changed file with 62 additions and 5 deletions.
67 changes: 62 additions & 5 deletions include/boost/math/distributions/hypergeometric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
#include <boost/math/distributions/detail/hypergeometric_pdf.hpp>
#include <boost/math/distributions/detail/hypergeometric_cdf.hpp>
#include <boost/math/distributions/detail/hypergeometric_quantile.hpp>
#include <boost/math/interpolators/cubic_hermite.hpp>
#include <boost/math/special_functions/fpclassify.hpp>
#include <boost/math/special_functions/round.hpp>
#include <cstdint>

namespace boost { namespace math {
Expand Down Expand Up @@ -136,14 +138,69 @@ namespace boost { namespace math {
{
BOOST_MATH_STD_USING
static const char* function = "boost::math::pdf(const hypergeometric_distribution<%1%>&, const %1%&)";
RealType r = static_cast<RealType>(x);
auto u = static_cast<std::uint64_t>(lltrunc(r, typename policies::normalise<Policy, policies::rounding_error<policies::ignore_error> >::type()));
if(u != r)
const RealType x_real = static_cast<RealType>(x);
const auto u = static_cast<std::uint64_t>(lltrunc(x_real, typename policies::normalise<Policy, policies::rounding_error<policies::ignore_error> >::type()));

// If x is an integer, call the PDF directly
if(u == x_real)
{
return pdf(dist, u);
}

if (x_real < 0)
{
return pdf(dist, static_cast<std::uint64_t>(-1));
}

const auto r = dist.defective();
const auto n = dist.sample_count();
const auto N = dist.total();
const std::int64_t max_valid_x = std::min(r, n);

if (x_real > max_valid_x) {
return pdf(dist, max_valid_x + 1);
}

// If x is not an integer, perform cubic Hermite interpolation
const std::int64_t x_rounded = static_cast<std::int64_t>(round(x_real));
if (max_valid_x < 2) {
return boost::math::policies::raise_domain_error<RealType>(
function, "Random variable out of range: must be an integer but got %1%", r, Policy());
function, "Not enough points available for interpolation, we got %1% points and we need at least 3", x, Policy());
}

std::int64_t lower_x = x_rounded - 1;
if (lower_x < 0) {
lower_x = 0;
}
std::int64_t upper_x = lower_x + 2;
if (upper_x > max_valid_x) {
upper_x = max_valid_x;
--lower_x;
}
return pdf(dist, u);

std::vector<RealType> x_vals;
std::vector<RealType> y_vals;
for (std::int64_t xi = lower_x; xi <= upper_x; ++xi) {
const auto pdf_val = pdf(dist, xi);
x_vals.push_back(static_cast<RealType>(xi));
y_vals.push_back(pdf_val);
}

std::vector<RealType> dydx_vals;
for (size_t i = 1; i < x_vals.size() - 1; ++i) {
const RealType deriv = (y_vals[i + 1] - y_vals[i - 1]) / (x_vals[i + 1] - x_vals[i - 1]);
dydx_vals.push_back(deriv);
}

dydx_vals.insert(dydx_vals.begin(), (y_vals[1] - y_vals[0]) / (x_vals[1] - x_vals[0]));
dydx_vals.push_back((y_vals[y_vals.size() - 1] - y_vals[y_vals.size() - 2]) /
(x_vals[y_vals.size() - 1] - x_vals[y_vals.size() - 2]));

using boost::math::interpolators::cubic_hermite;
const auto interpolator = cubic_hermite<std::vector<RealType>>(
std::move(x_vals), std::move(y_vals), std::move(dydx_vals)
);
return interpolator(x);
}

template <class RealType, class Policy>
Expand Down

0 comments on commit a4fe870

Please sign in to comment.