Skip to content

Commit 0391f61

Browse files
authored
Update pymc to latest as-at Aug 2024. Likely many breaking changes throughout (#116)
* + update version number + update pymc, god help us all * + updated local env * + removed fastprogress.progress_bar from model_pymc.calc.compute_log_likelihood_for_potential because this package no longer part of pymc/arviz. Lots of other equivalanent changes in the pymc/arviz source function too, will need to return to compute_log_likelihood_for_potential to massively update it + minor improvement to eda.eda_io.display_image_file * + added explicit sizing for sns.set_theme 'figure.dpi':72 * + Added a useful get_cr94 to eda.describe * + improved docstrings in plot * + improved facetplot_single with common signature + improved facetplot_single and plot_posterior with transform and and * + refactored display_image_file into new function figio.read for ease and consistency * + included log_prior in sample * + added logx kwarg to plot_ppc * + now initting PandasExcelIO * + typo * + updated create_dfcmb to allow F() in fcat * + added version to PYMCIO fn * + improved model fn version * + probably time for a PR!
1 parent 6805bfd commit 0391f61

19 files changed

+397
-295
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
2-
ignore = E203, E266, W291, W293, F401, F403, E501, W503, W605, C901
2+
ignore = E203, E266, W291, W293, F401, F403, E501, W503, W605, C901, E712
33
max-line-length = 88
44
max-doc-length = 144
55
max-complexity = 18

.pre-commit-config.yaml

+6-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ default_language_version:
44
default_stages: [commit, push]
55
repos:
66
- repo: https://github.com/pre-commit/pre-commit-hooks # general checks
7-
rev: v4.5.0
7+
rev: v4.6.0
88
hooks:
99
- id: check-added-large-files
1010
args: ['--maxkb=1024']
@@ -31,7 +31,7 @@ repos:
3131
- id: no-print-statements
3232
files: ^oreum_core/
3333
- repo: https://github.com/psf/black # black formatter
34-
rev: 23.12.1
34+
rev: 24.8.0
3535
hooks:
3636
- id: black
3737
files: ^oreum_core/
@@ -41,26 +41,26 @@ repos:
4141
- id: isort
4242
files: ^oreum_core/
4343
- repo: https://github.com/pycqa/flake8 # flake8 linter
44-
rev: 7.0.0
44+
rev: 7.1.0
4545
hooks:
4646
- id: flake8
4747
files: ^oreum_core/
4848
- repo: https://github.com/pycqa/bandit # basic security checks for python code
49-
rev: 1.7.6
49+
rev: 1.7.9
5050
hooks:
5151
- id: bandit
5252
files: ^oreum_core/
5353
args: ["--config", "pyproject.toml"]
5454
additional_dependencies: ["bandit[toml]"]
5555
- repo: https://github.com/econchick/interrogate # check for docstrings
56-
rev: 1.5.0
56+
rev: 1.7.0
5757
hooks:
5858
- id: interrogate
5959
files: ^oreum_core/
6060
args: [--config, pyproject.toml]
6161
pass_filenames: false # see https://github.com/econchick/interrogate/issues/60#issuecomment-1180262851
6262
- repo: https://gitlab.com/iam-cms/pre-commit-hooks # apply Apache2 header
63-
rev: v0.4.0
63+
rev: v0.6.0
6464
hooks:
6565
- id: apache-license
6666
files: ^oreum_core/

LICENSES_THIRD_PARTY.md

+176-171
Large diffs are not rendered by default.

assets/img/interrogate_badge.svg

+4-4
Loading

oreum_core/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Core tools for use on projects by Oreum Industries"""
1616
import logging
1717

18-
__version__ = "0.8.1"
18+
__version__ = "0.9.0"
1919

2020
# logger goes to null handler by default
2121
# packages that import oreum_core can override this and direct elsewhere

oreum_core/curate/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515
# curate/
1616
"""Various classes & functions for data curation"""
17-
from .data_io import PandasCSVIO, PandasParquetIO, SimpleStringIO, copy_csv2md
17+
from .data_io import (
18+
PandasCSVIO,
19+
PandasExcelIO,
20+
PandasParquetIO,
21+
SimpleStringIO,
22+
copy_csv2md,
23+
)
1824
from .data_transform import (
1925
DatasetReshaper,
2026
DatatypeConverter,

oreum_core/curate/data_io.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(self, *args, **kwargs):
102102

103103
def read(self, fn: str, *args, **kwargs) -> pd.DataFrame:
104104
"""Read excel fn from rootdir, pass args kwargs to pd.read_excel"""
105-
fn = Path(fn).with_suffix('.xslx')
105+
fn = Path(fn).with_suffix('.xlsx')
106106
fqn = self.get_path_read(fn)
107107
_log.info(f'Read from {str(fqn.resolve())}')
108108
return pd.read_excel(str(fqn), *args, **kwargs)

oreum_core/curate/data_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def create_dfcmb(self, df: pd.DataFrame, ftsd: dict) -> pd.DataFrame:
228228
dfcmb = pd.DataFrame(index=[0])
229229
fts_factor = ftsd.get('fcat', []) + ftsd.get('fbool', [])
230230
for ft in fts_factor:
231+
ft = ft[2:-1] if ft[:2] == 'F(' else ft
231232
colnames_pre = list(dfcmb.columns.values)
232233
s = pd.Series(np.unique(df[ft]), name=ft)
233234
dfcmb = pd.concat([dfcmb, s], axis=1, join='outer', ignore_index=True)

oreum_core/eda/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
tril_nan,
2828
)
2929
from .describe import describe, display_fw, display_ht, get_fts_by_dtype
30-
from .eda_io import FigureIO, display_image_file, output_data_dict
30+
from .eda_io import FigureIO, output_data_dict
3131
from .plot import ( # plot_umap,; plot_r2_range,; plot_r2_range_pair,
3232
plot_accuracy,
3333
plot_binary_performance,

oreum_core/eda/describe.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def describe(
3636
limit: int = 50, # MB
3737
get_mode: bool = False,
3838
get_counts: bool = True,
39+
get_cr94: bool = False,
3940
reset_index: bool = True,
4041
return_df: bool = False,
4142
**kwargs,
@@ -68,7 +69,12 @@ def describe(
6869
df = df.reset_index()
6970

7071
# start with pandas describe, add on dtypes
71-
dfdesc = df.describe(include='all').T
72+
quantiles = [0.25, 0.5, 0.75] # the default
73+
percentile_names = ['25%', '50%', '75%']
74+
if get_cr94:
75+
quantiles = [0.03] + quantiles + [0.97]
76+
percentile_names = ['3%'] + percentile_names + ['97%']
77+
dfdesc = df.describe(include='all', percentiles=quantiles).T
7278

7379
dfout = pd.concat((dfdesc, df.dtypes), axis=1, join='outer', sort=False)
7480
dfout = dfout.loc[df.columns.values]
@@ -100,23 +106,23 @@ def describe(
100106
dfout.loc[ft, 'min'] = df[ft].value_counts().index.min()
101107
dfout.loc[ft, 'max'] = df[ft].value_counts().index.max()
102108

103-
fts_out_all = [
104-
'dtype',
105-
'count_null',
106-
'count_inf',
107-
'count_zero',
108-
'count_unique',
109-
'top',
110-
'freq',
111-
'sum',
112-
'mean',
113-
'std',
114-
'min',
115-
'25%',
116-
'50%',
117-
'75%',
118-
'max',
119-
]
109+
fts_out_all = (
110+
[
111+
'dtype',
112+
'count_null',
113+
'count_inf',
114+
'count_zero',
115+
'count_unique',
116+
'top',
117+
'freq',
118+
'sum',
119+
'mean',
120+
'std',
121+
'min',
122+
]
123+
+ percentile_names
124+
+ ['max']
125+
)
120126
fts_out = [f for f in fts_out_all if f in dfout.columns.values]
121127

122128
# add mode and mode count WARNING takes forever for large arrays (>10k row)

oreum_core/eda/eda_io.py

+51-36
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,24 @@
2121
import matplotlib.pyplot as plt
2222
import numpy as np
2323
import pandas as pd
24+
import seaborn as sns
2425
from matplotlib import figure
2526

2627
from ..curate.data_io import PandasExcelIO
2728
from ..utils.file_io import BaseFileIO
2829
from .describe import describe, get_fts_by_dtype
2930

30-
__all__ = ['FigureIO', 'display_image_file', 'output_data_dict']
31+
__all__ = ['FigureIO', 'output_data_dict']
3132

3233
_log = logging.getLogger(__name__)
3334

35+
sns.set_theme(
36+
style='darkgrid',
37+
palette='muted',
38+
context='notebook',
39+
rc={'figure.dpi': 72, 'savefig.dpi': 144, 'figure.figsize': (12, 4)},
40+
)
41+
3442

3543
class FigureIO(BaseFileIO):
3644
"""Helper class to save matplotlib.figure.Figure objects to image file"""
@@ -47,41 +55,48 @@ def write(self, f: figure.Figure, fn: str, *args, **kwargs) -> Path:
4755
_log.info(f'Written to {str(fqn.resolve())}')
4856
return fqn
4957

50-
51-
def display_image_file(
52-
fqn: str, title: str = None, figsize: tuple = (12, 6)
53-
) -> figure.Figure:
54-
"""Hacky way to display pre-created image file in a Notebook
55-
such that nbconvert can see it and render to PDF
56-
Force to max width 16 inches, for fullwidth render in live Notebook and PDF
57-
58-
NOTE:
59-
Alternatives are bad
60-
1. This one is entirely missed by nbconvert at render to PDF
61-
# <img src="img.jpg" style="float:center; width:900px" />
62-
63-
2. This one causes following markdown to render monospace in PDF
64-
# from IPython.display import Image
65-
# Image("./assets/img/oreum_eloss_blueprint3.jpg", retina=True)
66-
"""
67-
img = mpimg.imread(fqn)
68-
f, axs = plt.subplots(1, 1, figsize=figsize)
69-
_ = axs.imshow(img)
70-
ax = plt.gca()
71-
_ = ax.grid(False)
72-
_ = ax.set_frame_on(False)
73-
_ = plt.tick_params(
74-
top=False,
75-
bottom=False,
76-
left=False,
77-
right=False,
78-
labelleft=False,
79-
labelbottom=False,
80-
)
81-
if title is not None:
82-
_ = f.suptitle(f'{title}', y=1.0)
83-
_ = f.tight_layout()
84-
return f
58+
def read(
59+
self,
60+
fqn: Path = None,
61+
fn: str = None,
62+
extension: str = '.png',
63+
title: str = None,
64+
figsize: tuple = (12, 4),
65+
) -> figure.Figure:
66+
"""Hacky way to display pre-created image file in a Notebook such that
67+
nbconvert can see it and render to PDF
68+
If don't supply fqn, then this will build fn according to get_path_read
69+
Render according to usual rcParams (set at module-level)
70+
NOTE:
71+
All the alternatives are bad
72+
1. This one is entirely missed by nbconvert at render to PDF
73+
# <img src="img.jpg" style="float:center; width:900px" />
74+
75+
2. This one causes following markdown to render monospace in PDF
76+
# from IPython.display import Image
77+
# Image("./assets/img/oreum_eloss_blueprint3.jpg", retina=True)
78+
"""
79+
if fn is not None:
80+
fqn = self.get_path_read(Path(self.snl.clean(fn)).with_suffix(extension))
81+
img = mpimg.imread(fqn)
82+
f, axs = plt.subplots(1, 1, figsize=figsize)
83+
_ = axs.imshow(img)
84+
ax = plt.gca()
85+
_ = ax.grid(False)
86+
_ = ax.set_frame_on(False)
87+
_ = plt.tick_params(
88+
top=False,
89+
bottom=False,
90+
left=False,
91+
right=False,
92+
labelleft=False,
93+
labelbottom=False,
94+
)
95+
if title is not None:
96+
_ = f.suptitle(f'{title}', fontsize=12, y=1.0)
97+
_ = f.tight_layout()
98+
_log.info(f'Read image from {str(fqn.resolve())}')
99+
return f
85100

86101

87102
def output_data_dict(

oreum_core/eda/plot.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# eda.plot.py
1616
"""EDA Plotting"""
1717
import logging
18-
from textwrap import wrap
1918
from typing import Literal
2019

2120
import matplotlib.pyplot as plt
@@ -66,6 +65,13 @@
6665
RSD = 42
6766
rng = np.random.default_rng(seed=RSD)
6867

68+
sns.set_theme(
69+
style='darkgrid',
70+
palette='muted',
71+
context='notebook',
72+
rc={'figure.dpi': 72, 'savefig.dpi': 144, 'figure.figsize': (12, 4)},
73+
)
74+
6975

7076
def _get_kws_styling() -> dict:
7177
"""Common styling kws for plots"""
@@ -867,8 +873,9 @@ def plot_estimate(
867873
arr_overplot: np.array = None,
868874
**kwargs,
869875
) -> figure.Figure:
870-
"""Plot distribution for estimates, either PPC or bootstrapped, no grouping
871-
Optional overplot bootstrapped dfboot"""
876+
"""Plot distribution for univariate estimates, either PPC or bootstrapped
877+
no grouping. Optionally overplot bootstrapped dfboot"""
878+
# TODO: Extend this to multivariate grouping
872879
txtadd = kwargs.pop('txtadd', None)
873880
sty = _get_kws_styling()
874881
clr = color if color is not None else sns.color_palette()[0]

oreum_core/model_pymc/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .describe import (
3333
describe_dist,
3434
extract_yobs_yhat,
35+
get_mdlvt_specific_nm,
3536
get_summary,
3637
model_desc,
3738
print_rvs,

oreum_core/model_pymc/base.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def __init__(self, **kwargs):
6363
chains=4,
6464
cores=4,
6565
target_accept=0.8,
66-
idata_kwargs={
67-
"log_likelihood": True, # usually useful
68-
## TODO only in 5.16 "log_prior": True, # possibly useful?
69-
},
66+
idata_kwargs=dict(
67+
log_likelihood=True, # usually useful
68+
log_prior=True, # possibly useful?
69+
),
7070
progressbar=True,
7171
)
7272
self.rvs_for_posterior_plots = []
@@ -263,7 +263,8 @@ def update_idata(self, idata: az.InferenceData, replace: bool = False) -> None:
263263

264264
def debug(self):
265265
"""Convenience to run debug on logp and random, and
266-
assert no MeasurableVariable nodes in the graph"""
266+
assert no MeasurableVariable nodes in the graph
267+
TODO catch these outputs in the log"""
267268
if self.model is not None:
268269
assert_no_rvs(self.model.logp())
269270
_ = self.model.debug(fn='logp', verbose=True)

oreum_core/model_pymc/calc.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
import pytensor.gradient as tg
2525
import pytensor.tensor as pt
2626
from arviz import InferenceData, dict_to_dataset
27-
from fastprogress import progress_bar
27+
28+
# from fastprogress import progress_bar
2829
from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata
2930
from pymc.model import Model, modelcontext
3031
from pymc.pytensorf import PointFunc
@@ -459,6 +460,10 @@ def compute_log_likelihood_for_potential(
459460
orig: https://github.com/pymc-devs/pymc/blob/92278278d4a8b78f17ed0f101eb29d0d9982eb45/pymc/stats/log_likelihood.py#L29C1-L128C31
460461
discussion: https://discourse.pymc.io/t/using-a-random-variable-as-observed/7184/10
461462
463+
IMPORTANT NOTE 2024-08-04 in the intervening time, the source function that
464+
this copies / modifies has changed hugely - it's going to cause substantial
465+
pain to update :S
466+
462467
---
463468
464469
Compute elemwise log_likelihood of model given InferenceData with posterior group
@@ -529,8 +534,8 @@ def compute_log_likelihood_for_potential(
529534
n_pts = len(posterior_pts)
530535
loglike_dict = _DefaultTrace(n_pts)
531536
indices = range(n_pts)
532-
if progressbar:
533-
indices = progress_bar(indices, total=n_pts, display=progressbar)
537+
# if progressbar:
538+
# indices = progress_bar(indices, total=n_pts, display=progressbar)
534539

535540
for idx in indices:
536541
loglikes_pts = elemwise_loglike_fn(posterior_pts[idx])

0 commit comments

Comments
 (0)