Skip to content

Commit 89bbe19

Browse files
Update pre-commit (#368)
* Update pre-commit to use ruff * Run ruff with new configuration * Additional fixes requested by ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7ce4ac8 commit 89bbe19

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+448
-370
lines changed

.pre-commit-config.yaml

+9-28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.5.0
3+
rev: v4.6.0
44
hooks:
55
- id: check-merge-conflict
66
- id: check-toml
@@ -10,37 +10,18 @@ repos:
1010
- id: no-commit-to-branch
1111
args: [--branch, main]
1212
- id: trailing-whitespace
13-
- repo: https://github.com/PyCQA/isort
14-
rev: 5.13.2
13+
14+
- repo: https://github.com/astral-sh/ruff-pre-commit
15+
rev: v0.5.5
1516
hooks:
16-
- id: isort
17-
name: isort
18-
- repo: https://github.com/asottile/pyupgrade
19-
rev: v3.15.0
20-
hooks:
21-
- id: pyupgrade
22-
args: [--py37-plus]
23-
- repo: https://github.com/psf/black
24-
rev: 24.1.1
25-
hooks:
26-
- id: black
27-
- id: black-jupyter
28-
- repo: https://github.com/PyCQA/pylint
29-
rev: v3.0.3
30-
hooks:
31-
- id: pylint
32-
args: [--rcfile=.pylintrc]
33-
files: ^pymc_experimental/
17+
- id: ruff
18+
args: [ --fix, --unsafe-fixes, --exit-non-zero-on-fix ]
19+
- id: ruff-format
20+
types_or: [ python, pyi, jupyter ]
21+
3422
- repo: https://github.com/MarcoGorelli/madforhooks
3523
rev: 0.4.1
3624
hooks:
3725
- id: no-print-statements
3826
exclude: _version.py
3927
files: ^pymc_experimental/
40-
- repo: local
41-
hooks:
42-
- id: no-relative-imports
43-
name: No relative imports
44-
entry: from \.[\.\w]* import
45-
types: [python]
46-
language: pygrep

notebooks/SARMA Example.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,7 @@
15541554
" hdi_forecast.coords[\"time\"].values,\n",
15551555
" *hdi_forecast.isel(observed_state=0).values.T,\n",
15561556
" alpha=0.25,\n",
1557-
" color=\"tab:blue\"\n",
1557+
" color=\"tab:blue\",\n",
15581558
" )\n",
15591559
"ax.set_title(\"Porcupine Graph of 10-Period Forecasts (parameters estimated on all data)\")\n",
15601560
"plt.show()"
@@ -2692,7 +2692,7 @@
26922692
" *forecast_hdi.values.T,\n",
26932693
" label=\"Forecast 94% HDI\",\n",
26942694
" color=\"tab:orange\",\n",
2695-
" alpha=0.25\n",
2695+
" alpha=0.25,\n",
26962696
")\n",
26972697
"ax.legend()\n",
26982698
"plt.show()"

notebooks/Structural Timeseries Modeling.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -1657,7 +1657,7 @@
16571657
" nile.index,\n",
16581658
" *component_hdi.smoothed_posterior.sel(state=state).values.T,\n",
16591659
" color=\"tab:blue\",\n",
1660-
" alpha=0.15\n",
1660+
" alpha=0.15,\n",
16611661
" )\n",
16621662
" axis.set_title(state.title())"
16631663
]
@@ -1706,7 +1706,7 @@
17061706
" *hdi.smoothed_posterior.sum(dim=\"state\").values.T,\n",
17071707
" color=\"tab:blue\",\n",
17081708
" alpha=0.15,\n",
1709-
" label=\"HDI 94%\"\n",
1709+
" label=\"HDI 94%\",\n",
17101710
")\n",
17111711
"ax.legend()\n",
17121712
"plt.show()"
@@ -2750,7 +2750,7 @@
27502750
"ax.fill_between(\n",
27512751
" blossom_data.index,\n",
27522752
" *hdi_post.predicted_posterior_observed.isel(observed_state=0).values.T,\n",
2753-
" alpha=0.25\n",
2753+
" alpha=0.25,\n",
27542754
")\n",
27552755
"blossom_data.plot(ax=ax)"
27562756
]

pymc_experimental/__init__.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
# limitations under the License.
1414
import logging
1515

16+
from pymc_experimental import distributions, gp, statespace, utils
17+
from pymc_experimental.inference.fit import fit
18+
from pymc_experimental.model.marginal_model import MarginalModel
19+
from pymc_experimental.model.model_api import as_model
1620
from pymc_experimental.version import __version__
1721

1822
_log = logging.getLogger("pmx")
@@ -23,7 +27,14 @@
2327
handler = logging.StreamHandler()
2428
_log.addHandler(handler)
2529

26-
from pymc_experimental import distributions, gp, statespace, utils
27-
from pymc_experimental.inference.fit import fit
28-
from pymc_experimental.model.marginal_model import MarginalModel
29-
from pymc_experimental.model.model_api import as_model
30+
31+
__all__ = [
32+
"distributions",
33+
"gp",
34+
"statespace",
35+
"utils",
36+
"fit",
37+
"MarginalModel",
38+
"as_model",
39+
"__version__",
40+
]

pymc_experimental/distributions/continuous.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
The imports from pymc are not fully replicated here: add imports as necessary.
2020
"""
2121

22-
from typing import Tuple, Union
23-
2422
import numpy as np
2523
import pytensor.tensor as pt
24+
2625
from pymc import ChiSquared, CustomDist
2726
from pymc.distributions import transforms
2827
from pymc.distributions.dist_math import check_parameters
@@ -39,19 +38,19 @@ class GenExtremeRV(RandomVariable):
3938
name: str = "Generalized Extreme Value"
4039
signature = "(),(),()->()"
4140
dtype: str = "floatX"
42-
_print_name: Tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
41+
_print_name: tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
4342

4443
def __call__(self, mu=0.0, sigma=1.0, xi=0.0, size=None, **kwargs) -> TensorVariable:
4544
return super().__call__(mu, sigma, xi, size=size, **kwargs)
4645

4746
@classmethod
4847
def rng_fn(
4948
cls,
50-
rng: Union[np.random.RandomState, np.random.Generator],
49+
rng: np.random.RandomState | np.random.Generator,
5150
mu: np.ndarray,
5251
sigma: np.ndarray,
5352
xi: np.ndarray,
54-
size: Tuple[int, ...],
53+
size: tuple[int, ...],
5554
) -> np.ndarray:
5655
# Notice negative here, since remainder of GenExtreme is based on Coles parametrization
5756
return stats.genextreme.rvs(c=-xi, loc=mu, scale=sigma, random_state=rng, size=size)

pymc_experimental/distributions/discrete.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pymc as pm
17+
1718
from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
1819
from pymc.distributions.shape_utils import rv_size_is_none
1920
from pytensor import tensor as pt

pymc_experimental/distributions/histogram_utils.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,17 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Dict
17-
1816
import numpy as np
1917
import pymc as pm
18+
2019
from numpy.typing import ArrayLike
2120

2221
__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
2322

2423

2524
def quantile_histogram(
2625
data: ArrayLike, n_quantiles=1000, zero_inflation=False
27-
) -> Dict[str, ArrayLike]:
26+
) -> dict[str, ArrayLike]:
2827
try:
2928
import xhistogram.core
3029
except ImportError as e:
@@ -34,7 +33,7 @@ def quantile_histogram(
3433
import dask.dataframe
3534
except ImportError:
3635
dask = None
37-
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
36+
if dask and isinstance(data, dask.dataframe.Series | dask.dataframe.DataFrame):
3837
data = data.to_dask_array(lengths=True)
3938
if zero_inflation:
4039
zeros = (data == 0).sum(0)
@@ -67,7 +66,7 @@ def quantile_histogram(
6766
return result
6867

6968

70-
def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
69+
def discrete_histogram(data: ArrayLike, min_count=None) -> dict[str, ArrayLike]:
7170
try:
7271
import xhistogram.core
7372
except ImportError as e:
@@ -78,7 +77,7 @@ def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
7877
except ImportError:
7978
dask = None
8079

81-
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
80+
if dask and isinstance(data, dask.dataframe.Series | dask.dataframe.DataFrame):
8281
data = data.to_dask_array(lengths=True)
8382
mid, count_uniq = np.unique(data, return_counts=True)
8483
if min_count is not None:
@@ -153,7 +152,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
153152
import dask.dataframe
154153
except ImportError:
155154
dask = None
156-
if dask and isinstance(observed, (dask.dataframe.Series, dask.dataframe.DataFrame)):
155+
if dask and isinstance(observed, dask.dataframe.Series | dask.dataframe.DataFrame):
157156
observed = observed.to_dask_array(lengths=True)
158157
if np.issubdtype(observed.dtype, np.integer):
159158
histogram = discrete_histogram(observed, **h_kwargs)
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP
2+
3+
__all__ = ["R2D2M2CP"]

pymc_experimental/distributions/multivariate/r2d2m2cp.py

+21-21
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
from collections import namedtuple
17-
from typing import Sequence, Tuple, Union
17+
from collections.abc import Sequence
1818

1919
import numpy as np
2020
import pymc as pm
@@ -26,8 +26,8 @@
2626
def _psivar2musigma(
2727
psi: pt.TensorVariable,
2828
explained_var: pt.TensorVariable,
29-
psi_mask: Union[pt.TensorLike, None],
30-
) -> Tuple[pt.TensorVariable, pt.TensorVariable]:
29+
psi_mask: pt.TensorLike | None,
30+
) -> tuple[pt.TensorVariable, pt.TensorVariable]:
3131
sign = pt.sign(psi - 0.5)
3232
if psi_mask is not None:
3333
# any computation might be ignored for ~psi_mask
@@ -55,7 +55,7 @@ def _R2D2M2CP_beta(
5555
psi: pt.TensorVariable,
5656
*,
5757
psi_mask,
58-
dims: Union[str, Sequence[str]],
58+
dims: str | Sequence[str],
5959
centered=False,
6060
) -> pt.TensorVariable:
6161
"""R2D2M2CP beta prior.
@@ -120,7 +120,7 @@ def _R2D2M2CP_beta(
120120
def _broadcast_as_dims(
121121
*values: np.ndarray,
122122
dims: Sequence[str],
123-
) -> Union[Tuple[np.ndarray, ...], np.ndarray]:
123+
) -> tuple[np.ndarray, ...] | np.ndarray:
124124
model = pm.modelcontext(None)
125125
shape = [len(model.coords[d]) for d in dims]
126126
ret = tuple(np.broadcast_to(v, shape) for v in values)
@@ -135,7 +135,7 @@ def _psi_masked(
135135
positive_probs_std: pt.TensorLike,
136136
*,
137137
dims: Sequence[str],
138-
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
138+
) -> tuple[pt.TensorLike | None, pt.TensorVariable]:
139139
if not (
140140
isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant)
141141
):
@@ -172,10 +172,10 @@ def _psi_masked(
172172

173173
def _psi(
174174
positive_probs: pt.TensorLike,
175-
positive_probs_std: Union[pt.TensorLike, None],
175+
positive_probs_std: pt.TensorLike | None,
176176
*,
177177
dims: Sequence[str],
178-
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
178+
) -> tuple[pt.TensorLike | None, pt.TensorVariable]:
179179
if positive_probs_std is not None:
180180
mask, psi = _psi_masked(
181181
positive_probs=pt.as_tensor(positive_probs),
@@ -194,9 +194,9 @@ def _psi(
194194

195195

196196
def _phi(
197-
variables_importance: Union[pt.TensorLike, None],
198-
variance_explained: Union[pt.TensorLike, None],
199-
importance_concentration: Union[pt.TensorLike, None],
197+
variables_importance: pt.TensorLike | None,
198+
variance_explained: pt.TensorLike | None,
199+
importance_concentration: pt.TensorLike | None,
200200
*,
201201
dims: Sequence[str],
202202
) -> pt.TensorVariable:
@@ -210,15 +210,15 @@ def _phi(
210210
variables_importance = pt.as_tensor(variables_importance)
211211
if importance_concentration is not None:
212212
variables_importance *= importance_concentration
213-
return pm.Dirichlet("phi", variables_importance, dims=broadcast_dims + [dim])
213+
return pm.Dirichlet("phi", variables_importance, dims=[*broadcast_dims, dim])
214214
elif variance_explained is not None:
215215
if len(model.coords[dim]) <= 1:
216216
raise TypeError("Can't use variance explained with less than two variables")
217217
phi = pt.as_tensor(variance_explained)
218218
else:
219219
phi = _broadcast_as_dims(1.0, dims=dims)
220220
if importance_concentration is not None:
221-
return pm.Dirichlet("phi", importance_concentration * phi, dims=broadcast_dims + [dim])
221+
return pm.Dirichlet("phi", importance_concentration * phi, dims=[*broadcast_dims, dim])
222222
else:
223223
return phi
224224

@@ -233,12 +233,12 @@ def R2D2M2CP(
233233
*,
234234
dims: Sequence[str],
235235
r2: pt.TensorLike,
236-
variables_importance: Union[pt.TensorLike, None] = None,
237-
variance_explained: Union[pt.TensorLike, None] = None,
238-
importance_concentration: Union[pt.TensorLike, None] = None,
239-
r2_std: Union[pt.TensorLike, None] = None,
240-
positive_probs: Union[pt.TensorLike, None] = 0.5,
241-
positive_probs_std: Union[pt.TensorLike, None] = None,
236+
variables_importance: pt.TensorLike | None = None,
237+
variance_explained: pt.TensorLike | None = None,
238+
importance_concentration: pt.TensorLike | None = None,
239+
r2_std: pt.TensorLike | None = None,
240+
positive_probs: pt.TensorLike | None = 0.5,
241+
positive_probs_std: pt.TensorLike | None = None,
242242
centered: bool = False,
243243
) -> R2D2M2CPOut:
244244
"""R2D2M2CP Prior.
@@ -413,7 +413,7 @@ def R2D2M2CP(
413413
year = {2023}
414414
}
415415
"""
416-
if not isinstance(dims, (list, tuple)):
416+
if not isinstance(dims, list | tuple):
417417
dims = (dims,)
418418
*broadcast_dims, dim = dims
419419
input_sigma = pt.as_tensor(input_sigma)
@@ -438,7 +438,7 @@ def R2D2M2CP(
438438
r2,
439439
phi,
440440
psi,
441-
dims=broadcast_dims + [dim],
441+
dims=[*broadcast_dims, dim],
442442
centered=centered,
443443
psi_mask=mask,
444444
)

pymc_experimental/distributions/timeseries.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
2-
from typing import List, Union
32

43
import numpy as np
54
import pymc as pm
65
import pytensor
76
import pytensor.tensor as pt
7+
88
from pymc.distributions.dist_math import check_parameters
99
from pymc.distributions.distribution import (
1010
Distribution,
@@ -26,7 +26,7 @@
2626
from pytensor.tensor.random.op import RandomVariable
2727

2828

29-
def _make_outputs_info(n_lags: int, init_dist: Distribution) -> List[Union[Distribution, dict]]:
29+
def _make_outputs_info(n_lags: int, init_dist: Distribution) -> list[Distribution | dict]:
3030
"""
3131
Two cases are needed for outputs_info in the scans used by DiscreteMarkovRv. If n_lags = 1, we need to throw away
3232
the first dimension of init_dist_ or else markov_chain will have shape (steps, 1, *batch_size) instead of
@@ -142,7 +142,7 @@ def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwar
142142

143143
if init_dist is not None:
144144
if not isinstance(init_dist, TensorVariable) or not isinstance(
145-
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
145+
init_dist.owner.op, RandomVariable | SymbolicRandomVariable
146146
):
147147
raise ValueError(
148148
f"Init dist must be a distribution created via the `.dist()` API, "

pymc_experimental/gp/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414

1515

1616
from pymc_experimental.gp.latent_approx import KarhunenLoeveExpansion, ProjectedProcess
17+
18+
__all__ = ["KarhunenLoeveExpansion", "ProjectedProcess"]

0 commit comments

Comments
 (0)