Skip to content

Commit 9188f3f

Browse files
authored
Merge pull request #363 from rsokl/ufunc-overload
Implements Tensor.__array_ufunc__
2 parents 720a0b0 + 5f2b65a commit 9188f3f

File tree

6 files changed

+283
-31
lines changed

6 files changed

+283
-31
lines changed

src/mygrad/tensor_base.py

Lines changed: 166 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from numbers import Integral, Number
88
from typing import (
9+
TYPE_CHECKING,
910
Any,
1011
Callable,
1112
Dict,
@@ -66,6 +67,10 @@
6667

6768
__all__ = ["Tensor", "asarray", "astensor"]
6869

70+
if TYPE_CHECKING: # pragma: no cover
71+
from mygrad.ufuncs._ufunc_creators import ufunc as mygrad_ufunc
72+
73+
6974
CONSTANT_ONLY_DTYPES = (np.integer, np.bool_)
7075

7176

@@ -349,6 +354,53 @@ def astensor(
349354
return tensor(t, dtype=dtype, constant=constant, copy=False, ndmin=0)
350355

351356

357+
_REGISTERED_UFUNC: Dict[np.ufunc, Type["mygrad_ufunc"]] = {}
358+
_REGISTERED_BOOL_ONLY_UFUNC: Set[np.ufunc] = {
359+
np.isnan,
360+
np.isfinite,
361+
np.isinf,
362+
np.isnat,
363+
np.signbit,
364+
np.logical_not,
365+
np.logical_and,
366+
np.logical_or,
367+
np.logical_xor,
368+
np.greater,
369+
np.greater_equal,
370+
np.less,
371+
np.less_equal,
372+
np.equal,
373+
}
374+
375+
# These are ufuncs that users might mistake for being differentiable functions;
376+
# for this reason we make explicit the fact that only constant tensors are permitted
377+
# in these operations.
378+
_REGISTERED_CONST_ONLY_UFUNC = {
379+
np.floor_divide,
380+
np.remainder,
381+
np.mod,
382+
np.fmod,
383+
np.divmod,
384+
np.rint,
385+
np.sign,
386+
np.floor,
387+
np.ceil,
388+
np.trunc,
389+
}
390+
391+
392+
class _ConstantOnly(ValueError):
393+
pass
394+
395+
396+
def _as_constant_array(t: Union["Tensor", np.ndarray]) -> np.ndarray:
397+
if isinstance(t, Tensor):
398+
if t.constant is False:
399+
raise _ConstantOnly()
400+
return t.data
401+
return t
402+
403+
352404
class Tensor:
353405
"""A numpy-array-like object capable of serving as a node in a computational
354406
graph that supports back-propagation of derivatives via the chain rule.
@@ -505,6 +557,97 @@ class Tensor:
505557

506558
__array_priority__ = 15.0
507559

560+
def __array_ufunc__(
561+
self, ufunc: Type[np.ufunc], method: str, *inputs: ArrayLike, **kwargs
562+
) -> Union["Tensor", np.ndarray]:
563+
"""An interface provided by NumPy to override the behavior of its ufuncs [1]_.
564+
565+
MyGrad implements its own ufuncs for all differentiable NumPy ufuncs.
566+
567+
Non-differentiable numpy ufuncs simply get called on the underlying arrays of tensors and
568+
will return ndarrays.
569+
570+
The differentiability - or lack thereof - of ufuncs may not be obvious to end users.
571+
Thus potentially ambiguous ufuncs (e.g. `numpy.ceil`) will be made to raise on non-constant
572+
tensors so that the lack of differentiability is made obvious to the users. This design decision
573+
is made in the same spirit as requiring integer-dtype tensors be constant.
574+
575+
References
576+
----------
577+
.. [1] https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__
578+
579+
Examples
580+
--------
581+
NumPy ufuncs that represent differentiable operations are overloaded by MyGrad tensors
582+
so that they support backprop
583+
584+
>>> import mygrad as mg
585+
>>> import numpy as np
586+
587+
>>> x = mg.tensor([1., 2.])
588+
589+
This calls ``mygrad.sin`` under the hood.
590+
591+
>>> np.sin(x) # returns a tensor
592+
Tensor([0.84147098, 0.90929743])
593+
594+
>>> np.sin(x).backward()
595+
>>> x.grad # note: derivative of
596+
array([ 0.54030231, -0.41614684])
597+
598+
Specifying a dtype, a ``where`` mask, an in-place target (via ``out``) as an array
599+
or a tensor, are all supported.
600+
601+
>>> x = mg.tensor([1., 2.])
602+
>>> y = mg.tensor([-1., -1.])
603+
>>> np.exp(x, where=[False, True], out=y)
604+
Tensor([-1. , 7.3890561])
605+
>>> y.backward()
606+
>>> x.grad
607+
array([0. , 7.3890561])
608+
609+
Non-differentiable NumPy ufuncs simply operate on the ndarrays that are wrapped
610+
by MyGrad tensors; these return ndarrays, which will appropriately and explicitly
611+
serve as constants elsewhere in a computational graph.
612+
613+
>>> x = mg.tensor([1., 2.])
614+
>>> np.less_equal(x, 1)
615+
array([ True, False])
616+
"""
617+
out = kwargs.pop("out", (None,))
618+
if len(out) > 1: # pragma: no cover
619+
raise ValueError(
620+
"mygrad does not support in-place operations with more that one target"
621+
)
622+
(out,) = out
623+
624+
out: Optional[Union[np.ndarray, "Tensor"]]
625+
626+
try:
627+
# differentiable ufunc implemented by mygrad
628+
return getattr(_REGISTERED_UFUNC[ufunc], method)(*inputs, **kwargs, out=out)
629+
except KeyError:
630+
pass
631+
632+
# non-differentiable ufuncs get called on numpy arrays stored by tensors
633+
if ufunc in _REGISTERED_BOOL_ONLY_UFUNC:
634+
caster = asarray
635+
elif ufunc in _REGISTERED_CONST_ONLY_UFUNC:
636+
# the presence of non-constant tensors will raise
637+
caster = _as_constant_array
638+
else: # pragma: no cover
639+
return NotImplemented
640+
641+
try:
642+
if out is not None:
643+
kwargs["out"] = caster(out)
644+
# returns ndarray
645+
return getattr(ufunc, method)(*(caster(t) for t in inputs), **kwargs)
646+
except _ConstantOnly:
647+
raise ValueError(
648+
f"{repr(ufunc)} cannot involve non-constant mygrad tensors."
649+
)
650+
508651
def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
509652
return np.array(self.data, dtype=dtype, copy=False)
510653

@@ -787,11 +930,25 @@ def _op(
787930
-------
788931
mygrad.Tensor
789932
The tensor-result of the operation's forward-pass."""
790-
if out is not None and isinstance(out, Tensor):
791-
out._in_place_op(
792-
Op, *input_vars, op_args=op_args, op_kwargs=op_kwargs, constant=constant
793-
)
794-
return out
933+
if out is not None:
934+
if isinstance(out, tuple):
935+
if len(out) > 1: # pragma: no cover
936+
raise ValueError(
937+
"mygrad does not support in-place operations with more that one target"
938+
)
939+
(out,) = out
940+
941+
if isinstance(out, Tensor):
942+
out._in_place_op(
943+
Op,
944+
*input_vars,
945+
op_args=op_args,
946+
op_kwargs=op_kwargs,
947+
constant=constant,
948+
)
949+
return out
950+
951+
out: Optional[np.ndarray]
795952

796953
_uniques_bases_then_arrs = ()
797954

@@ -1700,21 +1857,11 @@ def __truediv__(self, other: ArrayLike) -> "Tensor":
17001857
def __rtruediv__(self, other: ArrayLike) -> "Tensor":
17011858
return self._op(Divide, other, self)
17021859

1703-
def __floordiv__(self, other: ArrayLike) -> "Tensor":
1704-
if not self.constant:
1705-
raise ValueError(
1706-
"Floor division cannot involve non-constant mygrad tensors."
1707-
)
1708-
if isinstance(other, Tensor):
1709-
other = other.data
1710-
return type(self)(self.data.__floordiv__(other), constant=True)
1860+
def __floordiv__(self, other: ArrayLike) -> np.ndarray:
1861+
return np.floor_divide(self, other)
17111862

1712-
def __rfloordiv__(self, other: ArrayLike) -> "Tensor":
1713-
if not self.constant:
1714-
raise ValueError(
1715-
"Floor division cannot involve non-constant mygrad tensors."
1716-
)
1717-
return type(self)(self.data.__rfloordiv__(other), constant=True)
1863+
def __rfloordiv__(self, other: ArrayLike) -> np.ndarray:
1864+
return np.floor_divide(other, self)
17181865

17191866
def __itruediv__(self, other: ArrayLike) -> "Tensor":
17201867
self._in_place_op(Divide, self, other)

src/mygrad/ufuncs/_ufunc_creators.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
import numpy as np
1818

19-
from mygrad import Tensor
2019
from mygrad.operation_base import BinaryUfunc, Operation, Ufunc, UnaryUfunc, _NoValue
20+
from mygrad.tensor_base import _REGISTERED_UFUNC, Tensor
2121
from mygrad.typing import ArrayLike, DTypeLikeReals, Index, Mask, Real
2222

2323
__all__ = ["ufunc_creator"]
@@ -316,7 +316,7 @@ def _create_ufunc(
316316
outer_op=None,
317317
reduce_op=None,
318318
reduceat_op=None,
319-
):
319+
) -> Type[ufunc]:
320320
def at(
321321
a: ArrayLike,
322322
indices: Union[ArrayLike, Index, Tuple[ArrayLike, Index]],
@@ -330,7 +330,7 @@ def at(
330330
def accumulate(
331331
array: ArrayLike,
332332
axis: int = 0,
333-
dtype: DTypeLikeReals = None,
333+
dtype: Optional[DTypeLikeReals] = None,
334334
out: Optional[Union[Tensor, np.ndarray]] = None,
335335
*,
336336
constant: Optional[bool] = None,
@@ -342,16 +342,16 @@ def outer(
342342
a: ArrayLike,
343343
b: ArrayLike,
344344
*,
345-
dtype: DTypeLikeReals,
346-
out: Optional[Union[Tensor, np.ndarray]],
345+
dtype: Optional[DTypeLikeReals] = None,
346+
out: Optional[Union[Tensor, np.ndarray]] = None,
347347
) -> Tensor: # pragma: no cover
348348
"""Not Implemented"""
349349
raise NotImplementedError()
350350

351351
def reduce(
352352
a: ArrayLike,
353353
axis: Optional[Union[int, Tuple[int, ...]]] = 0,
354-
dtype: DTypeLikeReals = None,
354+
dtype: Optional[DTypeLikeReals] = None,
355355
out: Optional[Union[Tensor, np.ndarray]] = None,
356356
keepdims: bool = False,
357357
initial: Real = _NoValue,
@@ -364,7 +364,7 @@ def reduceat(
364364
a: ArrayLike,
365365
indices: ArrayLike,
366366
axis: Optional[Union[int, Tuple[int, ...]]] = 0,
367-
dtype: DTypeLikeReals = None,
367+
dtype: Optional[DTypeLikeReals] = None,
368368
out: Optional[Union[Tensor, np.ndarray]] = None,
369369
) -> Tensor: # pragma: no cover
370370
"""Not Implemented"""
@@ -378,7 +378,7 @@ def reduceat(
378378
)
379379
else: # pragma: no cover
380380
raise NotImplementedError(
381-
"MyGrad Internal: `mygrad._utils.op_creator` only supports unary and binary ufuncs currently"
381+
"MyGrad Internal: `mygrad._utils.op_creator` only supports unary and binary ufuncs presently"
382382
)
383383

384384
# filter out non-real dtypes
@@ -487,7 +487,7 @@ def __init__(
487487
self.reduceat_op = reduceat_op
488488

489489
def __call__(self, decorated_func: T) -> T:
490-
return _create_ufunc(
490+
out_ufunc = _create_ufunc(
491491
self.op,
492492
decorated_func=decorated_func,
493493
at_op=self.at_op,
@@ -496,3 +496,5 @@ def __call__(self, decorated_func: T) -> T:
496496
reduce_op=self.reduce_op,
497497
reduceat_op=self.reduceat_op,
498498
)
499+
_REGISTERED_UFUNC[getattr(np, out_ufunc.__name__)] = out_ufunc
500+
return out_ufunc

tests/tensor_base/test_operator_override.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,16 @@ def test_arithmetic_operators_between_array_and_tensor_cast_to_tensor(
104104
"f1, f2",
105105
[
106106
(constant_tensor, lambda x: x),
107-
(lambda x: x, constant_tensor),
107+
(
108+
lambda x: x.tolist(),
109+
constant_tensor,
110+
), # `list/tensor` ensures __rfloordiv__ gets called
108111
(constant_tensor, constant_tensor),
109112
],
110113
)
111114
def test_floor_div(arr1, arr2, f1, f2):
112115
desired = arr1 // arr2
113116
actual = f1(arr1) // f2(arr2)
114-
assert actual.constant is True
115117
assert actual.dtype == desired.dtype
116118
assert_array_equal(desired, actual)
117119

tests/ufuncs/test_fwd_prop_and_backprop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,17 @@ def test_ufunc_fwd(
216216
@pytest.mark.parametrize(
217217
"ufunc", [u for u in ufuncs if u not in DOES_NOT_SUPPORT_COMPLEX_DOMAIN]
218218
)
219-
@given(data=st.data())
219+
@given(data=st.data(), use_numpy_overload=st.booleans())
220220
def test_ufunc_bkwd(
221221
data: st.DataObject,
222222
ufunc: Union[MyGradUnaryUfunc, MyGradBinaryUfunc],
223+
use_numpy_overload: bool,
223224
):
224225
"""
225226
Checks:
226227
- backprop matches numerical gradient
227228
- backprop doesn't mutate grad
229+
- that calling op through numpy overload works identically
228230
"""
229231
args = data.draw(
230232
populates_ufunc(
@@ -237,6 +239,8 @@ def test_ufunc_bkwd(
237239
)
238240
args.make_array_based_args_read_only() # guards against mutation
239241

242+
if use_numpy_overload:
243+
ufunc = getattr(np, ufunc.__name__)
240244
mygrad_out = ufunc(*args.args, **args.kwargs)
241245

242246
# Draw upstream gradient to be backpropped

0 commit comments

Comments
 (0)