14
14
15
15
16
16
from collections import namedtuple
17
- from typing import Sequence , Tuple , Union
17
+ from collections . abc import Sequence
18
18
19
19
import numpy as np
20
20
import pymc as pm
26
26
def _psivar2musigma (
27
27
psi : pt .TensorVariable ,
28
28
explained_var : pt .TensorVariable ,
29
- psi_mask : Union [ pt .TensorLike , None ] ,
30
- ) -> Tuple [pt .TensorVariable , pt .TensorVariable ]:
29
+ psi_mask : pt .TensorLike | None ,
30
+ ) -> tuple [pt .TensorVariable , pt .TensorVariable ]:
31
31
sign = pt .sign (psi - 0.5 )
32
32
if psi_mask is not None :
33
33
# any computation might be ignored for ~psi_mask
@@ -55,7 +55,7 @@ def _R2D2M2CP_beta(
55
55
psi : pt .TensorVariable ,
56
56
* ,
57
57
psi_mask ,
58
- dims : Union [ str , Sequence [str ] ],
58
+ dims : str | Sequence [str ],
59
59
centered = False ,
60
60
) -> pt .TensorVariable :
61
61
"""R2D2M2CP beta prior.
@@ -120,7 +120,7 @@ def _R2D2M2CP_beta(
120
120
def _broadcast_as_dims (
121
121
* values : np .ndarray ,
122
122
dims : Sequence [str ],
123
- ) -> Union [ Tuple [ np .ndarray , ...], np .ndarray ] :
123
+ ) -> tuple [ np .ndarray , ...] | np .ndarray :
124
124
model = pm .modelcontext (None )
125
125
shape = [len (model .coords [d ]) for d in dims ]
126
126
ret = tuple (np .broadcast_to (v , shape ) for v in values )
@@ -135,7 +135,7 @@ def _psi_masked(
135
135
positive_probs_std : pt .TensorLike ,
136
136
* ,
137
137
dims : Sequence [str ],
138
- ) -> Tuple [ Union [ pt .TensorLike , None ] , pt .TensorVariable ]:
138
+ ) -> tuple [ pt .TensorLike | None , pt .TensorVariable ]:
139
139
if not (
140
140
isinstance (positive_probs , pt .Constant ) and isinstance (positive_probs_std , pt .Constant )
141
141
):
@@ -172,10 +172,10 @@ def _psi_masked(
172
172
173
173
def _psi (
174
174
positive_probs : pt .TensorLike ,
175
- positive_probs_std : Union [ pt .TensorLike , None ] ,
175
+ positive_probs_std : pt .TensorLike | None ,
176
176
* ,
177
177
dims : Sequence [str ],
178
- ) -> Tuple [ Union [ pt .TensorLike , None ] , pt .TensorVariable ]:
178
+ ) -> tuple [ pt .TensorLike | None , pt .TensorVariable ]:
179
179
if positive_probs_std is not None :
180
180
mask , psi = _psi_masked (
181
181
positive_probs = pt .as_tensor (positive_probs ),
@@ -194,9 +194,9 @@ def _psi(
194
194
195
195
196
196
def _phi (
197
- variables_importance : Union [ pt .TensorLike , None ] ,
198
- variance_explained : Union [ pt .TensorLike , None ] ,
199
- importance_concentration : Union [ pt .TensorLike , None ] ,
197
+ variables_importance : pt .TensorLike | None ,
198
+ variance_explained : pt .TensorLike | None ,
199
+ importance_concentration : pt .TensorLike | None ,
200
200
* ,
201
201
dims : Sequence [str ],
202
202
) -> pt .TensorVariable :
@@ -210,15 +210,15 @@ def _phi(
210
210
variables_importance = pt .as_tensor (variables_importance )
211
211
if importance_concentration is not None :
212
212
variables_importance *= importance_concentration
213
- return pm .Dirichlet ("phi" , variables_importance , dims = broadcast_dims + [ dim ])
213
+ return pm .Dirichlet ("phi" , variables_importance , dims = [ * broadcast_dims , dim ])
214
214
elif variance_explained is not None :
215
215
if len (model .coords [dim ]) <= 1 :
216
216
raise TypeError ("Can't use variance explained with less than two variables" )
217
217
phi = pt .as_tensor (variance_explained )
218
218
else :
219
219
phi = _broadcast_as_dims (1.0 , dims = dims )
220
220
if importance_concentration is not None :
221
- return pm .Dirichlet ("phi" , importance_concentration * phi , dims = broadcast_dims + [ dim ])
221
+ return pm .Dirichlet ("phi" , importance_concentration * phi , dims = [ * broadcast_dims , dim ])
222
222
else :
223
223
return phi
224
224
@@ -233,12 +233,12 @@ def R2D2M2CP(
233
233
* ,
234
234
dims : Sequence [str ],
235
235
r2 : pt .TensorLike ,
236
- variables_importance : Union [ pt .TensorLike , None ] = None ,
237
- variance_explained : Union [ pt .TensorLike , None ] = None ,
238
- importance_concentration : Union [ pt .TensorLike , None ] = None ,
239
- r2_std : Union [ pt .TensorLike , None ] = None ,
240
- positive_probs : Union [ pt .TensorLike , None ] = 0.5 ,
241
- positive_probs_std : Union [ pt .TensorLike , None ] = None ,
236
+ variables_importance : pt .TensorLike | None = None ,
237
+ variance_explained : pt .TensorLike | None = None ,
238
+ importance_concentration : pt .TensorLike | None = None ,
239
+ r2_std : pt .TensorLike | None = None ,
240
+ positive_probs : pt .TensorLike | None = 0.5 ,
241
+ positive_probs_std : pt .TensorLike | None = None ,
242
242
centered : bool = False ,
243
243
) -> R2D2M2CPOut :
244
244
"""R2D2M2CP Prior.
@@ -413,7 +413,7 @@ def R2D2M2CP(
413
413
year = {2023}
414
414
}
415
415
"""
416
- if not isinstance (dims , ( list , tuple ) ):
416
+ if not isinstance (dims , list | tuple ):
417
417
dims = (dims ,)
418
418
* broadcast_dims , dim = dims
419
419
input_sigma = pt .as_tensor (input_sigma )
@@ -438,7 +438,7 @@ def R2D2M2CP(
438
438
r2 ,
439
439
phi ,
440
440
psi ,
441
- dims = broadcast_dims + [ dim ],
441
+ dims = [ * broadcast_dims , dim ],
442
442
centered = centered ,
443
443
psi_mask = mask ,
444
444
)
0 commit comments