Skip to content

Commit 49e2818

Browse files
authored
Fix: importance_sampling=None produces error (#427)
* fix: importance sampling handling causing error when chosen method is "none" or None - Moved importance sampling logic from `multipath_pathfinder` to `fit_pathfinder` to fix error method is "none" or None - Update docstrings to clarify importance sampling method behavior - Use match statement for method selection in importance_sampling
1 parent 4fbbfeb commit 49e2818

File tree

3 files changed

+96
-44
lines changed

3 files changed

+96
-44
lines changed

pymc_extras/inference/pathfinder/importance_sampling.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ class ImportanceSamplingResult:
2020
samples: NDArray
2121
pareto_k: float | None = None
2222
warnings: list[str] = field(default_factory=list)
23-
method: str = "none"
23+
method: str = "psis"
2424

2525

2626
def importance_sampling(
2727
samples: NDArray,
2828
logP: NDArray,
2929
logQ: NDArray,
3030
num_draws: int,
31-
method: Literal["psis", "psir", "identity", "none"] | None,
31+
method: Literal["psis", "psir", "identity"] | None,
3232
random_seed: int | None = None,
3333
) -> ImportanceSamplingResult:
3434
"""Pareto Smoothed Importance Resampling (PSIR)
@@ -44,8 +44,15 @@ def importance_sampling(
4444
log probability values of proposal distribution, shape (L, M)
4545
num_draws : int
4646
number of draws to return where num_draws <= samples.shape[0]
47-
method : str, optional
48-
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
47+
method : str, None, optional
48+
Method to apply sampling based on log importance weights (logP - logQ).
49+
Options are:
50+
"psis" : Pareto Smoothed Importance Sampling (default)
51+
Recommended for more stable results.
52+
"psir" : Pareto Smoothed Importance Resampling
53+
Less stable than PSIS.
54+
"identity" : Applies log importance weights directly without resampling.
55+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
4956
random_seed : int | None
5057
5158
Returns
@@ -71,11 +78,11 @@ def importance_sampling(
7178
warnings = []
7279
num_paths, _, N = samples.shape
7380

74-
if method == "none":
81+
if method is None:
7582
warnings.append(
7683
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
7784
)
78-
return ImportanceSamplingResult(samples=samples, warnings=warnings)
85+
return ImportanceSamplingResult(samples=samples, warnings=warnings, method=method)
7986
else:
8087
samples = samples.reshape(-1, N)
8188
logP = logP.ravel()
@@ -91,17 +98,16 @@ def importance_sampling(
9198
_warnings.filterwarnings(
9299
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
93100
)
94-
if method == "psis":
95-
replace = False
96-
logiw, pareto_k = az.psislw(logiw)
97-
elif method == "psir":
98-
replace = True
99-
logiw, pareto_k = az.psislw(logiw)
100-
elif method == "identity":
101-
replace = False
102-
pareto_k = None
103-
else:
104-
raise ValueError(f"Invalid importance sampling method: {method}")
101+
match method:
102+
case "psis":
103+
replace = False
104+
logiw, pareto_k = az.psislw(logiw)
105+
case "psir":
106+
replace = True
107+
logiw, pareto_k = az.psislw(logiw)
108+
case "identity":
109+
replace = False
110+
pareto_k = None
105111

106112
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
107113
# Pareto k may not be a good diagnostic for Pathfinder.

pymc_extras/inference/pathfinder/pathfinder.py

+33-17
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def convert_flat_trace_to_idata(
156156
postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
157157
inference_backend: Literal["pymc", "blackjax"] = "pymc",
158158
model: Model | None = None,
159-
importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
159+
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
160160
) -> az.InferenceData:
161161
"""convert flattened samples to arviz InferenceData format.
162162
@@ -181,7 +181,7 @@ def convert_flat_trace_to_idata(
181181
arviz inference data object
182182
"""
183183

184-
if importance_sampling == "none":
184+
if importance_sampling is None:
185185
# samples.ndim == 3 in this case, otherwise ndim == 2
186186
num_paths, num_pdraws, N = samples.shape
187187
samples = samples.reshape(-1, N)
@@ -220,7 +220,7 @@ def convert_flat_trace_to_idata(
220220
fn.trust_input = True
221221
result = fn(*list(trace.values()))
222222

223-
if importance_sampling == "none":
223+
if importance_sampling is None:
224224
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
225225

226226
elif inference_backend == "blackjax":
@@ -1189,7 +1189,7 @@ class MultiPathfinderResult:
11891189
elbo_argmax: NDArray | None = None
11901190
lbfgs_status: Counter = field(default_factory=Counter)
11911191
path_status: Counter = field(default_factory=Counter)
1192-
importance_sampling: str = "none"
1192+
importance_sampling: str | None = "psis"
11931193
warnings: list[str] = field(default_factory=list)
11941194
pareto_k: float | None = None
11951195

@@ -1258,7 +1258,7 @@ def with_warnings(self, warnings: list[str]) -> Self:
12581258
def with_importance_sampling(
12591259
self,
12601260
num_draws: int,
1261-
method: Literal["psis", "psir", "identity", "none"] | None,
1261+
method: Literal["psis", "psir", "identity"] | None,
12621262
random_seed: int | None = None,
12631263
) -> Self:
12641264
"""perform importance sampling"""
@@ -1424,7 +1424,7 @@ def multipath_pathfinder(
14241424
num_elbo_draws: int,
14251425
jitter: float,
14261426
epsilon: float,
1427-
importance_sampling: Literal["psis", "psir", "identity", "none"] | None,
1427+
importance_sampling: Literal["psis", "psir", "identity"] | None,
14281428
progressbar: bool,
14291429
concurrent: Literal["thread", "process"] | None,
14301430
random_seed: RandomSeed,
@@ -1460,8 +1460,14 @@ def multipath_pathfinder(
14601460
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
14611461
epsilon: float
14621462
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1463-
importance_sampling : str, optional
1464-
importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1463+
importance_sampling : str, None, optional
1464+
Method to apply sampling based on log importance weights (logP - logQ).
1465+
"psis" : Pareto Smoothed Importance Sampling (default)
1466+
Recommended for more stable results.
1467+
"psir" : Pareto Smoothed Importance Resampling
1468+
Less stable than PSIS.
1469+
"identity" : Applies log importance weights directly without resampling.
1470+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
14651471
progressbar : bool, optional
14661472
Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
14671473
random_seed : RandomSeed, optional
@@ -1483,12 +1489,6 @@ def multipath_pathfinder(
14831489
The result containing samples and other information from the Multi-Path Pathfinder algorithm.
14841490
"""
14851491

1486-
valid_importance_sampling = ["psis", "psir", "identity", "none", None]
1487-
if importance_sampling is None:
1488-
importance_sampling = "none"
1489-
if importance_sampling.lower() not in valid_importance_sampling:
1490-
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1491-
14921492
*path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
14931493

14941494
pathfinder_config = PathfinderConfig(
@@ -1622,7 +1622,7 @@ def fit_pathfinder(
16221622
num_elbo_draws: int = 10, # K
16231623
jitter: float = 2.0,
16241624
epsilon: float = 1e-8,
1625-
importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
1625+
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
16261626
progressbar: bool = True,
16271627
concurrent: Literal["thread", "process"] | None = None,
16281628
random_seed: RandomSeed | None = None,
@@ -1662,8 +1662,15 @@ def fit_pathfinder(
16621662
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
16631663
epsilon: float
16641664
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1665-
importance_sampling : str, optional
1666-
importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1665+
importance_sampling : str, None, optional
1666+
Method to apply sampling based on log importance weights (logP - logQ).
1667+
Options are:
1668+
"psis" : Pareto Smoothed Importance Sampling (default)
1669+
Recommended for more stable results.
1670+
"psir" : Pareto Smoothed Importance Resampling
1671+
Less stable than PSIS.
1672+
"identity" : Applies log importance weights directly without resampling.
1673+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
16671674
progressbar : bool, optional
16681675
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
16691676
random_seed : RandomSeed, optional
@@ -1690,6 +1697,15 @@ def fit_pathfinder(
16901697
"""
16911698

16921699
model = modelcontext(model)
1700+
1701+
valid_importance_sampling = {"psis", "psir", "identity", None}
1702+
1703+
if importance_sampling is not None:
1704+
importance_sampling = importance_sampling.lower()
1705+
1706+
if importance_sampling not in valid_importance_sampling:
1707+
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1708+
16931709
N = DictToArrayBijection.map(model.initial_point()).data.shape[0]
16941710

16951711
if maxcor is None:

tests/test_pathfinder.py

+40-10
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def reference_idata():
4444
with model:
4545
idata = pmx.fit(
4646
method="pathfinder",
47-
num_paths=50,
48-
jitter=10.0,
47+
num_paths=10,
48+
jitter=12.0,
4949
random_seed=41,
5050
inference_backend="pymc",
5151
)
@@ -62,15 +62,15 @@ def test_pathfinder(inference_backend, reference_idata):
6262
with model:
6363
idata = pmx.fit(
6464
method="pathfinder",
65-
num_paths=50,
66-
jitter=10.0,
65+
num_paths=10,
66+
jitter=12.0,
6767
random_seed=41,
6868
inference_backend=inference_backend,
6969
)
7070
else:
7171
idata = reference_idata
72-
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6)
73-
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5)
72+
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95)
73+
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35)
7474

7575
assert idata.posterior["mu"].shape == (1, 1000)
7676
assert idata.posterior["tau"].shape == (1, 1000)
@@ -83,8 +83,8 @@ def test_concurrent_results(reference_idata, concurrent):
8383
with model:
8484
idata_conc = pmx.fit(
8585
method="pathfinder",
86-
num_paths=50,
87-
jitter=10.0,
86+
num_paths=10,
87+
jitter=12.0,
8888
random_seed=41,
8989
inference_backend="pymc",
9090
concurrent=concurrent,
@@ -108,15 +108,15 @@ def test_seed(reference_idata):
108108
with model:
109109
idata_41 = pmx.fit(
110110
method="pathfinder",
111-
num_paths=50,
111+
num_paths=4,
112112
jitter=10.0,
113113
random_seed=41,
114114
inference_backend="pymc",
115115
)
116116

117117
idata_123 = pmx.fit(
118118
method="pathfinder",
119-
num_paths=50,
119+
num_paths=4,
120120
jitter=10.0,
121121
random_seed=123,
122122
inference_backend="pymc",
@@ -171,3 +171,33 @@ def test_bfgs_sample():
171171
assert gamma.eval().shape == (L, 2 * J, 2 * J)
172172
assert phi.eval().shape == (L, num_samples, N)
173173
assert logq.eval().shape == (L, num_samples)
174+
175+
176+
@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
177+
def test_pathfinder_importance_sampling(importance_sampling):
178+
model = eight_schools_model()
179+
180+
num_paths = 4
181+
num_draws_per_path = 300
182+
num_draws = 750
183+
184+
with model:
185+
idata = pmx.fit(
186+
method="pathfinder",
187+
num_paths=num_paths,
188+
num_draws_per_path=num_draws_per_path,
189+
num_draws=num_draws,
190+
maxiter=5,
191+
random_seed=41,
192+
inference_backend="pymc",
193+
importance_sampling=importance_sampling,
194+
)
195+
196+
if importance_sampling is None:
197+
assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path)
198+
assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path)
199+
assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8)
200+
else:
201+
assert idata.posterior["mu"].shape == (1, num_draws)
202+
assert idata.posterior["tau"].shape == (1, num_draws)
203+
assert idata.posterior["theta"].shape == (1, num_draws, 8)

0 commit comments

Comments
 (0)