Skip to content

Commit 5f2b65a

Browse files
committed
add __array_ufunc__ docs and remove redundant floordiv implementations
1 parent b19466d commit 5f2b65a

File tree

2 files changed

+69
-24
lines changed

2 files changed

+69
-24
lines changed

src/mygrad/tensor_base.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,60 @@ class Tensor:
560560
def __array_ufunc__(
561561
self, ufunc: Type[np.ufunc], method: str, *inputs: ArrayLike, **kwargs
562562
) -> Union["Tensor", np.ndarray]:
563+
"""An interface provided by NumPy to override the behavior of its ufuncs [1]_.
563564
565+
MyGrad implements its own ufuncs for all differentiable NumPy ufuncs.
566+
567+
Non-differentiable numpy ufuncs simply get called on the underlying arrays of tensors and
568+
will return ndarrays.
569+
570+
The differentiability - or lack thereof - of ufuncs may not be obvious to end users.
571+
Thus potentially ambiguous ufuncs (e.g. `numpy.ceil`) will be made to raise on non-constant
572+
tensors so that the lack of differentiability is made obvious to the users. This design decision
573+
is made in the same spirit as requiring integer-dtype tensors be constant.
574+
575+
References
576+
----------
577+
.. [1] https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__
578+
579+
Examples
580+
--------
581+
NumPy ufuncs that represent differentiable operations are overloaded by MyGrad tensors
582+
so that they support backprop
583+
584+
>>> import mygrad as mg
585+
>>> import numpy as np
586+
587+
>>> x = mg.tensor([1., 2.])
588+
589+
This calls ``mygrad.sin`` under the hood.
590+
591+
>>> np.sin(x) # returns a tensor
592+
Tensor([0.84147098, 0.90929743])
593+
594+
>>> np.sin(x).backward()
595+
>>> x.grad # note: derivative of
596+
array([ 0.54030231, -0.41614684])
597+
598+
Specifying a dtype, a ``where`` mask, an in-place target (via ``out``) as an array
599+
or a tensor, are all supported.
600+
601+
>>> x = mg.tensor([1., 2.])
602+
>>> y = mg.tensor([-1., -1.])
603+
>>> np.exp(x, where=[False, True], out=y)
604+
Tensor([-1. , 7.3890561])
605+
>>> y.backward()
606+
>>> x.grad
607+
array([0. , 7.3890561])
608+
609+
Non-differentiable NumPy ufuncs simply operate on the ndarrays that are wrapped
610+
by MyGrad tensors; these return ndarrays, which will appropriately and explicitly
611+
serve as constants elsewhere in a computational graph.
612+
613+
>>> x = mg.tensor([1., 2.])
614+
>>> np.less_equal(x, 1)
615+
array([ True, False])
616+
"""
564617
out = kwargs.pop("out", (None,))
565618
if len(out) > 1: # pragma: no cover
566619
raise ValueError(
@@ -576,9 +629,11 @@ def __array_ufunc__(
576629
except KeyError:
577630
pass
578631

632+
# non-differentiable ufuncs get called on numpy arrays stored by tensors
579633
if ufunc in _REGISTERED_BOOL_ONLY_UFUNC:
580634
caster = asarray
581635
elif ufunc in _REGISTERED_CONST_ONLY_UFUNC:
636+
# the presence of non-constant tensors will raise
582637
caster = _as_constant_array
583638
else: # pragma: no cover
584639
return NotImplemented
@@ -1802,21 +1857,11 @@ def __truediv__(self, other: ArrayLike) -> "Tensor":
18021857
def __rtruediv__(self, other: ArrayLike) -> "Tensor":
18031858
return self._op(Divide, other, self)
18041859

1805-
def __floordiv__(self, other: ArrayLike) -> "Tensor":
1806-
if not self.constant:
1807-
raise ValueError(
1808-
"Floor division cannot involve non-constant mygrad tensors."
1809-
)
1810-
if isinstance(other, Tensor):
1811-
other = other.data
1812-
return type(self)(self.data.__floordiv__(other), constant=True)
1860+
def __floordiv__(self, other: ArrayLike) -> np.ndarray:
1861+
return np.floor_divide(self, other)
18131862

1814-
def __rfloordiv__(self, other: ArrayLike) -> "Tensor":
1815-
if not self.constant:
1816-
raise ValueError(
1817-
"Floor division cannot involve non-constant mygrad tensors."
1818-
)
1819-
return type(self)(self.data.__rfloordiv__(other), constant=True)
1863+
def __rfloordiv__(self, other: ArrayLike) -> np.ndarray:
1864+
return np.floor_divide(other, self)
18201865

18211866
def __itruediv__(self, other: ArrayLike) -> "Tensor":
18221867
self._in_place_op(Divide, self, other)

src/mygrad/ufuncs/_ufunc_creators.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -325,50 +325,50 @@ def at(
325325
constant: Optional[bool] = None,
326326
) -> Tensor: # pragma: no cover
327327
"""Not implemented"""
328-
return NotImplementedError
328+
raise NotImplementedError()
329329

330330
def accumulate(
331331
array: ArrayLike,
332332
axis: int = 0,
333-
dtype: DTypeLikeReals = None,
333+
dtype: Optional[DTypeLikeReals] = None,
334334
out: Optional[Union[Tensor, np.ndarray]] = None,
335335
*,
336336
constant: Optional[bool] = None,
337337
) -> Tensor: # pragma: no cover
338338
"""Not implemented"""
339-
return NotImplementedError
339+
raise NotImplementedError()
340340

341341
def outer(
342342
a: ArrayLike,
343343
b: ArrayLike,
344344
*,
345-
dtype: DTypeLikeReals,
346-
out: Optional[Union[Tensor, np.ndarray]],
345+
dtype: Optional[DTypeLikeReals] = None,
346+
out: Optional[Union[Tensor, np.ndarray]] = None,
347347
) -> Tensor: # pragma: no cover
348348
"""Not Implemented"""
349-
return NotImplementedError
349+
raise NotImplementedError()
350350

351351
def reduce(
352352
a: ArrayLike,
353353
axis: Optional[Union[int, Tuple[int, ...]]] = 0,
354-
dtype: DTypeLikeReals = None,
354+
dtype: Optional[DTypeLikeReals] = None,
355355
out: Optional[Union[Tensor, np.ndarray]] = None,
356356
keepdims: bool = False,
357357
initial: Real = _NoValue,
358358
where: Mask = True,
359359
) -> Tensor: # pragma: no cover
360360
"""Not Implemented"""
361-
return NotImplementedError
361+
raise NotImplementedError()
362362

363363
def reduceat(
364364
a: ArrayLike,
365365
indices: ArrayLike,
366366
axis: Optional[Union[int, Tuple[int, ...]]] = 0,
367-
dtype: DTypeLikeReals = None,
367+
dtype: Optional[DTypeLikeReals] = None,
368368
out: Optional[Union[Tensor, np.ndarray]] = None,
369369
) -> Tensor: # pragma: no cover
370370
"""Not Implemented"""
371-
return NotImplementedError
371+
raise NotImplementedError()
372372

373373
if op.numpy_ufunc.nin == 1:
374374
MetaBuilder = MyGradUnaryUfunc

0 commit comments

Comments
 (0)