Skip to content

Commit 9cd0fcf

Browse files
committed
MAINT: Create a function to update units for differentiate.
This should simplify adding the same function to integrate and cumulative_integrate.
1 parent 98404ad commit 9cd0fcf

File tree

1 file changed

+30
-20
lines changed

1 file changed

+30
-20
lines changed

cf_xarray/accessor.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,35 @@ def _get_possible(accessor, criteria):
953953
return _get_possible(obj.cf, y_criteria)
954954

955955

956+
def _update_data_units(
957+
result: DataArray | Dataset,
958+
source: DataArray | Dataset,
959+
coord_name: str,
960+
new_unit_template: str
961+
) -> DataArray | Dataset:
962+
try:
963+
coord_units = source[coord_name].attrs["units"]
964+
except KeyError:
965+
return result
966+
967+
if isinstance(source, DataArray):
968+
try:
969+
result.attrs["units"] = new_unit_template.format(
970+
source.attrs["units"], source[coord_name].attrs["units"]
971+
)
972+
except KeyError:
973+
pass
974+
else:
975+
for name in result.data_vars:
976+
try:
977+
result[name].attrs["units"] = new_unit_template.format(
978+
source[name].attrs["units"], coord_units
979+
)
980+
except KeyError:
981+
pass
982+
return result
983+
984+
956985
class _CFWrappedClass(SupportsArithmetic):
957986
"""
958987
This class is used to wrap any class in _WRAPPED_CLASSES.
@@ -2102,26 +2131,7 @@ def differentiate(
21022131
(_single(_get_coords),), self._obj, coord, error=False, default=[coord]
21032132
)[0]
21042133
result = self._obj.differentiate(coord, *xr_args, **xr_kwargs)
2105-
if isinstance(self._obj, DataArray):
2106-
try:
2107-
result.attrs["units"] = "{:s} / ({:s})".format(
2108-
self._obj.attrs["units"], self._obj[coord].attrs["units"]
2109-
)
2110-
except KeyError:
2111-
pass
2112-
else:
2113-
try:
2114-
coord_units = self._obj[coord].attrs["units"]
2115-
except KeyError:
2116-
pass
2117-
else:
2118-
for name in result.data_vars:
2119-
try:
2120-
result[name].attrs["units"] = "{:s} / ({:s})".format(
2121-
self._obj[name].attrs["units"], coord_units
2122-
)
2123-
except KeyError:
2124-
pass
2134+
result = _update_data_units(result, self._obj, coord, "{:s} / ({:s})")
21252135
if positive_upward:
21262136
coord = self._obj[coord]
21272137
attrs = coord.attrs

0 commit comments

Comments
 (0)