From 6a270950ed8ad9ffbf979255f858c12de579cd8c Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Fri, 17 Jan 2025 11:44:13 -0800 Subject: [PATCH 01/13] Drop Python 3.9 support - Update `python >=3.10` constraint in conda-env yml files --- conda-env/ci.yml | 2 +- conda-env/dev.yml | 2 +- docs/examples/climatology-and-departures.ipynb | 7 +++++++ pyproject.toml | 3 +-- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/conda-env/ci.yml b/conda-env/ci.yml index cd5e8438..c1d3e0af 100644 --- a/conda-env/ci.yml +++ b/conda-env/ci.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base - required for building the package. # ========================================= - - python >=3.9 + - python >=3.10 - cf_xarray >=0.9.1 - cftime - dask diff --git a/conda-env/dev.yml b/conda-env/dev.yml index a2542e5c..22bbdbf1 100644 --- a/conda-env/dev.yml +++ b/conda-env/dev.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base - required for building the package. # ========================================= - - python >=3.9 + - python >=3.10 - cf_xarray >=0.9.1 - cftime - dask diff --git a/docs/examples/climatology-and-departures.ipynb b/docs/examples/climatology-and-departures.ipynb index 3768f461..04aef265 100644 --- a/docs/examples/climatology-and-departures.ipynb +++ b/docs/examples/climatology-and-departures.ipynb @@ -17,6 +17,13 @@ "- [xarray.Dataset.temporal.departures()](../generated/xarray.Dataset.temporal.departures.rst)\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/pyproject.toml b/pyproject.toml index 7d21b98e..b9fe508e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "xcdat" dynamic = ["version"] description = "Xarray Climate Data Analysis Tools" readme = "README.rst" -requires-python = ">=3.9" +requires-python = ">=3.10" license = { text = "Apache-2.0" } authors = [{ name = "xCDAT developers" }] classifiers = [ @@ -16,7 +16,6 @@ classifiers = [ "License :: OSI Approved :: Apache-2.0 License", "Natural Language :: English", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", From 2d916a62c959d4d8e0b8af9d6e988e96d4df31d6 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Fri, 17 Jan 2025 11:45:38 -0800 Subject: [PATCH 02/13] Remove Python 3.9 from build workflow --- .github/workflows/build_workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index e3c2992b..13fed4dc 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -55,7 +55,7 @@ jobs: shell: bash -l {0} strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 From ddcee207e68eb0a3babb01b74f44c2178af8c13e Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Fri, 17 Jan 2025 11:46:11 -0800 Subject: [PATCH 03/13] Revert notebook change --- docs/examples/climatology-and-departures.ipynb | 7 ------- 1 file changed, 7 deletions(-) diff --git a/docs/examples/climatology-and-departures.ipynb b/docs/examples/climatology-and-departures.ipynb index 04aef265..3768f461 100644 --- a/docs/examples/climatology-and-departures.ipynb +++ b/docs/examples/climatology-and-departures.ipynb @@ -17,13 +17,6 @@ "- [xarray.Dataset.temporal.departures()](../generated/xarray.Dataset.temporal.departures.rst)\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, From 5f082d7ce72061ac2640cbb66068c694c4173700 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Fri, 17 Jan 2025 11:51:47 -0800 Subject: [PATCH 04/13] Fix `ruff` issue with `strict=False` in `zip()` --- tests/test_regrid.py | 4 ++-- xcdat/dataset.py | 2 +- xcdat/temporal.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 1161f960..df883a07 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -664,10 +664,10 @@ def test_map_latitude_coarse_to_fine(self): [[0.29289322]], ] - for x, y in zip(mapping, expected_mapping): + for x, y in zip(mapping, expected_mapping, strict=False): np.testing.assert_allclose(x, y) - for x2, y2 in zip(weights, expected_weigths): + for x2, y2 in zip(weights, expected_weigths, strict=False): np.testing.assert_allclose(x2, y2) def test_map_latitude_fine_to_coarse(self): diff --git a/xcdat/dataset.py b/xcdat/dataset.py index f5139cd2..21932216 100644 --- a/xcdat/dataset.py +++ b/xcdat/dataset.py @@ -652,7 +652,7 @@ def _get_cftime_coords(offsets: np.ndarray, units: str, calendar: str) -> np.nda # Convert offsets to `np.float64` to avoid "TypeError: unsupported type # for timedelta days component: numpy.int64". - flat_offsets = flat_offsets.astype("float") + flat_offsets = flat_offsets.astype("float") # type: ignore # We don't need to do calendar arithmetic here because the units and # offsets are in "months" or "years", which means leap days should not diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 29bb6ede..4dba9f6b 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1946,7 +1946,7 @@ def _convert_df_to_dt(self, df: pd.DataFrame) -> np.ndarray: dates = [ self.date_type(year, month, day, hour) for year, month, day, hour in zip( - df_new.year, df_new.month, df_new.day, df_new.hour + df_new.year, df_new.month, df_new.day, df_new.hour, strict=False ) ] From 09fc920ae7273a2bd3bfd7a159fc096cbd74883c Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Fri, 17 Jan 2025 12:01:24 -0800 Subject: [PATCH 05/13] Update python version used in build workflow --- .github/workflows/build_workflow.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index 13fed4dc..dd5b4cd6 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -36,10 +36,10 @@ jobs: - name: Checkout Code Repository uses: actions/checkout@v3 - - name: Set up Python 3.10 + - name: Set up Python 3.11 uses: actions/setup-python@v3 with: - python-version: "3.10" + python-version: "3.11" - name: Install and Run Pre-commit uses: pre-commit/action@v3.0.1 From 1e206857a9827945f06fd3dcb278388d92a7512b Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 12 May 2025 10:58:26 -0700 Subject: [PATCH 06/13] Add support for Python 3.13 --- .github/workflows/build_workflow.yml | 6 +++--- pyproject.toml | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index dd5b4cd6..ef786a2b 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -36,10 +36,10 @@ jobs: - name: Checkout Code Repository uses: actions/checkout@v3 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v3 with: - python-version: "3.11" + python-version: "3.12 - name: Install and Run Pre-commit uses: pre-commit/action@v3.0.1 @@ -55,7 +55,7 @@ jobs: shell: bash -l {0} strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 diff --git a/pyproject.toml b/pyproject.toml index b9fe508e..c7afb5f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] keywords = ["xcdat"] dependencies = [ From 65ea8e90ca46566bdeb044bca2c1e07f78479f23 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 12 May 2025 11:02:07 -0700 Subject: [PATCH 07/13] Remove `from __future__ import annotations` - No longer needed with python >=3.10 --- xcdat/axis.py | 2 -- xcdat/bounds.py | 2 -- xcdat/dataset.py | 2 -- xcdat/regridder/accessor.py | 2 -- xcdat/regridder/grid.py | 2 -- xcdat/spatial.py | 2 -- xcdat/temporal.py | 2 -- xcdat/tutorial.py | 2 -- xcdat/utils.py | 2 -- 9 files changed, 18 deletions(-) diff --git a/xcdat/axis.py b/xcdat/axis.py index 09c1755b..4128dd91 100644 --- a/xcdat/axis.py +++ b/xcdat/axis.py @@ -3,8 +3,6 @@ coordinates. """ -from __future__ import annotations - from typing import Dict, List, Literal, Optional, Tuple import numpy as np diff --git a/xcdat/bounds.py b/xcdat/bounds.py index a7a85e86..eee2e4c1 100644 --- a/xcdat/bounds.py +++ b/xcdat/bounds.py @@ -1,7 +1,5 @@ """Bounds module for functions related to coordinate bounds.""" -from __future__ import annotations - import collections import datetime import warnings diff --git a/xcdat/dataset.py b/xcdat/dataset.py index 21932216..85d3a38b 100644 --- a/xcdat/dataset.py +++ b/xcdat/dataset.py @@ -1,7 +1,5 @@ """Dataset module for functions related to an xarray.Dataset.""" -from __future__ import annotations - import os import pathlib from datetime import datetime diff --git a/xcdat/regridder/accessor.py b/xcdat/regridder/accessor.py index e8ac12f5..4c719c06 100644 --- a/xcdat/regridder/accessor.py +++ b/xcdat/regridder/accessor.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Any, Dict, List, Literal, Tuple import xarray as xr diff --git a/xcdat/regridder/grid.py b/xcdat/regridder/grid.py index b9e1ae7f..6a3ebbcd 100644 --- a/xcdat/regridder/grid.py +++ b/xcdat/regridder/grid.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Dict, List, Optional, Tuple, Union import numpy as np diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 527fcf3e..8fdaee84 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -1,7 +1,5 @@ """Module containing geospatial averaging functions.""" -from __future__ import annotations - from functools import reduce from typing import ( Callable, diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 4dba9f6b..d6f2a362 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1,7 +1,5 @@ """Module containing temporal functions.""" -from __future__ import annotations - import warnings from datetime import datetime from itertools import chain diff --git a/xcdat/tutorial.py b/xcdat/tutorial.py index 1d2e1380..b8dce6a7 100644 --- a/xcdat/tutorial.py +++ b/xcdat/tutorial.py @@ -5,8 +5,6 @@ repository. """ -from __future__ import annotations - import os import pathlib import sys diff --git a/xcdat/utils.py b/xcdat/utils.py index 8828272a..46227221 100644 --- a/xcdat/utils.py +++ b/xcdat/utils.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import importlib import json from typing import Dict, List, Optional, Union From 2600d17ead84bd0e786e4f1aa946b625d2045fbe Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 12 May 2025 11:43:41 -0700 Subject: [PATCH 08/13] Update all use of `typing` library to python 3.10+ syntax --- docs/conf.py | 3 +- xcdat/axis.py | 32 ++++++------ xcdat/bounds.py | 100 ++++++++++++++++++------------------ xcdat/dataset.py | 77 +++++++++++++-------------- xcdat/regridder/accessor.py | 16 +++--- xcdat/regridder/base.py | 16 +++--- xcdat/regridder/grid.py | 54 ++++++++++--------- xcdat/regridder/regrid2.py | 30 +++++------ xcdat/regridder/xesmf.py | 14 ++--- xcdat/regridder/xgcm.py | 23 +++++---- xcdat/spatial.py | 49 ++++++++---------- xcdat/temporal.py | 48 ++++++++--------- xcdat/tutorial.py | 7 ++- xcdat/utils.py | 13 +++-- 14 files changed, 234 insertions(+), 248 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index eb38be54..bafb2050 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,6 @@ import pathlib import sys from pathlib import Path -from typing import Dict from textwrap import dedent, indent import sphinx_autosummary_accessors @@ -178,7 +177,7 @@ # -- Options for LaTeX output ------------------------------------------ -latex_elements: Dict[str, str] = { +latex_elements: dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', diff --git a/xcdat/axis.py b/xcdat/axis.py index 4128dd91..02597bc4 100644 --- a/xcdat/axis.py +++ b/xcdat/axis.py @@ -3,7 +3,7 @@ coordinates. """ -from typing import Dict, List, Literal, Optional, Tuple +from typing import Literal import numpy as np import xarray as xr @@ -22,14 +22,14 @@ # we can fetch specific `cf_xarray` mapping tables such as `ds.cf.axes["X"]` # or `ds.cf.coordinates["longitude"]`. # More information: https://cf-xarray.readthedocs.io/en/latest/coord_axes.html -CF_ATTR_MAP: Dict[CFAxisKey, Dict[str, CFAxisKey | CFStandardNameKey]] = { +CF_ATTR_MAP: dict[CFAxisKey, dict[str, CFAxisKey | CFStandardNameKey]] = { "X": {"axis": "X", "coordinate": "longitude"}, "Y": {"axis": "Y", "coordinate": "latitude"}, "T": {"axis": "T", "coordinate": "time"}, "Z": {"axis": "Z", "coordinate": "vertical"}, } -COORD_DEFAULT_ATTRS: Dict[CFAxisKey, Dict[str, str | CFAxisKey | CFStandardNameKey]] = { +COORD_DEFAULT_ATTRS: dict[CFAxisKey, dict[str, str | CFAxisKey | CFStandardNameKey]] = { "X": dict(units="degrees_east", **CF_ATTR_MAP["X"]), "Y": dict(units="degrees_north", **CF_ATTR_MAP["Y"]), "T": dict(calendar="standard", **CF_ATTR_MAP["T"]), @@ -39,7 +39,7 @@ # A dictionary that maps common variable names to coordinate variables. This # map is used as fall-back when coordinate variables don't have CF attributes # set for ``cf_xarray`` to interpret using `CF_ATTR_MAP`. -VAR_NAME_MAP: Dict[CFAxisKey, List[str]] = { +VAR_NAME_MAP: dict[CFAxisKey, list[str]] = { "X": ["longitude", "lon"], "Y": ["latitude", "lat"], "T": ["time"], @@ -47,7 +47,7 @@ } -def get_dim_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> str | List[str]: +def get_dim_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> str | list[str]: """Gets the dimension key(s) for an axis. Each dimension should have a corresponding dimension coordinate variable, @@ -64,7 +64,7 @@ def get_dim_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> str | List[ Returns ------- - str | List[str] + str | list[str] The dimension string or a list of dimensions strings for an axis. """ dims = sorted([str(dim) for dim in get_dim_coords(obj, axis).dims]) @@ -254,7 +254,7 @@ def center_times(dataset: xr.Dataset) -> xr.Dataset: def swap_lon_axis( - dataset: xr.Dataset, to: Tuple[float, float], sort_ascending: bool = True + dataset: xr.Dataset, to: tuple[float, float], sort_ascending: bool = True ) -> xr.Dataset: """Swaps the orientation of a dataset's longitude axis. @@ -272,7 +272,7 @@ def swap_lon_axis( ---------- dataset : xr.Dataset The Dataset containing a longitude axis. - to : Tuple[float, float] + to : tuple[float, float] The orientation to swap the Dataset's longitude axis to. Supported orientations include: @@ -317,7 +317,7 @@ def swap_lon_axis( return ds -def _get_all_coord_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> List[str]: +def _get_all_coord_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> list[str]: """Gets all dimension and non-dimension coordinate keys for an axis. This function uses ``cf_xarray`` to interpret CF axis and coordinate name @@ -336,7 +336,7 @@ def _get_all_coord_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> List Returns ------- - List[str] + list[str] The axis coordinate variable keys. References @@ -346,7 +346,7 @@ def _get_all_coord_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> List cf_attrs = CF_ATTR_MAP[axis] var_names = VAR_NAME_MAP[axis] - keys: List[str] = [] + keys: list[str] = [] try: keys = keys + obj.cf.axes[cf_attrs["axis"]] @@ -365,7 +365,7 @@ def _get_all_coord_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> List return list(set(keys)) -def _swap_lon_bounds(ds: xr.Dataset, key: str, to: Tuple[float, float]): +def _swap_lon_bounds(ds: xr.Dataset, key: str, to: tuple[float, float]): bounds = ds[key].copy() new_bounds = _swap_lon_axis(bounds, to) @@ -386,14 +386,14 @@ def _swap_lon_bounds(ds: xr.Dataset, key: str, to: Tuple[float, float]): return ds -def _swap_lon_axis(coords: xr.DataArray, to: Tuple[float, float]) -> xr.DataArray: +def _swap_lon_axis(coords: xr.DataArray, to: tuple[float, float]) -> xr.DataArray: """Swaps the axis orientation for longitude coordinates. Parameters ---------- coords : xr.DataArray Coordinates on a longitude axis. - to : Tuple[float, float] + to : tuple[float, float] The new longitude axis orientation. Returns @@ -438,7 +438,7 @@ def _swap_lon_axis(coords: xr.DataArray, to: Tuple[float, float]) -> xr.DataArra return new_coords -def _get_prime_meridian_index(lon_bounds: xr.DataArray) -> Optional[np.ndarray]: +def _get_prime_meridian_index(lon_bounds: xr.DataArray) -> np.ndarray | None: """Gets the index of the prime meridian cell in the longitude bounds. A prime meridian cell can exist when converting the axis orientation @@ -451,7 +451,7 @@ def _get_prime_meridian_index(lon_bounds: xr.DataArray) -> Optional[np.ndarray]: Returns ------- - Optional[np.ndarray] + np.ndarray | None An array with a single element representing the index of the prime meridian index if it exists. Otherwise, None if the cell does not exist. diff --git a/xcdat/bounds.py b/xcdat/bounds.py index eee2e4c1..39599962 100644 --- a/xcdat/bounds.py +++ b/xcdat/bounds.py @@ -3,7 +3,7 @@ import collections import datetime import warnings -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Literal import cf_xarray as cfxr # noqa: F401 import cftime @@ -88,7 +88,7 @@ def __init__(self, dataset: xr.Dataset): self._dataset: xr.Dataset = dataset @property - def map(self) -> Dict[str, Optional[xr.DataArray]]: + def map(self) -> dict[str, xr.DataArray | None]: """Returns a map of axis and coordinates keys to their bounds. The dictionary provides all valid CF compliant keys for axis and @@ -97,12 +97,12 @@ def map(self) -> Dict[str, Optional[xr.DataArray]]: Returns ------- - Dict[str, Optional[xr.DataArray]] + dict[str, xr.DataArray | None] Dictionary mapping axis and coordinate keys to their bounds. """ ds = self._dataset - bounds: Dict[str, Optional[xr.DataArray]] = {} + bounds: dict[str, xr.DataArray | None] = {} for axis, bounds_keys in ds.cf.bounds.items(): bound = ds.get(bounds_keys[0], None) bounds[axis] = bound @@ -110,12 +110,12 @@ def map(self) -> Dict[str, Optional[xr.DataArray]]: return collections.OrderedDict(sorted(bounds.items())) @property - def keys(self) -> List[str]: + def keys(self) -> list[str]: """Returns a list of keys for the bounds data variables in the Dataset. Returns ------- - List[str] + list[str] A list of sorted bounds data variable keys. """ return sorted( @@ -129,7 +129,7 @@ def keys(self) -> List[str]: ) def add_missing_bounds( # noqa: C901 - self, axes: List[CFAxisKey] | Tuple[CFAxisKey, ...] = ("X", "Y", "T") + self, axes: list[CFAxisKey] | tuple[CFAxisKey, ...] = ("X", "Y", "T") ) -> xr.Dataset: """Adds missing coordinate bounds for supported axes in the Dataset. @@ -158,7 +158,7 @@ def add_missing_bounds( # noqa: C901 Parameters ---------- - axes : List[CFAxesKey] | Tuple[CFAxisKey, ...] + axes : list[CFAxesKey] | tuple[CFAxisKey, ...] List of CF axes that function should operate on, by default ("X", "Y", "T"). Options include "X", "Y", "T", or "Z". @@ -202,15 +202,15 @@ def add_missing_bounds( # noqa: C901 return ds def get_bounds( - self, axis: CFAxisKey, var_key: Optional[str] = None - ) -> Union[xr.Dataset, xr.DataArray]: + self, axis: CFAxisKey, var_key: str | None = None + ) -> xr.Dataset | xr.DataArray: """Gets coordinate bounds. Parameters ---------- axis : CFAxisKey The CF axis key ("X", "Y", "T", "Z"). - var_key: Optional[str] + var_key: str | None The key of the coordinate or data variable to get axis bounds for. This parameter is useful if you only want the single bounds DataArray related to the axis on the variable (e.g., "tas" has @@ -218,7 +218,7 @@ def get_bounds( Returns ------- - Union[xr.Dataset, xr.DataArray] + xr.Dataset | xr.DataArray A Dataset of N bounds variables, or a single bounds variable DataArray. @@ -249,7 +249,7 @@ def get_bounds( "or `ds.bounds.add_bounds()`." ) - bounds: Union[xr.Dataset, xr.DataArray] = self._dataset[ + bounds: xr.Dataset | xr.DataArray = self._dataset[ bounds_keys if len(bounds_keys) > 1 else bounds_keys[0] ].copy() @@ -289,9 +289,7 @@ def add_bounds(self, axis: CFAxisKey) -> xr.Dataset: ds = self._dataset.copy() self._validate_axis_arg(axis) - coord_vars: Union[xr.DataArray, xr.Dataset] = get_dim_coords( - self._dataset, axis - ) + coord_vars: xr.DataArray | xr.Dataset = get_dim_coords(self._dataset, axis) # In xarray, ancillary singleton coordinates that aren't related to the # axis can still be attached to dimension coordinates. For example, # if the "height" singleton exists, it will be attached to "time". @@ -317,8 +315,8 @@ def add_bounds(self, axis: CFAxisKey) -> xr.Dataset: def add_time_bounds( self, method: Literal["freq", "midpoint"], - freq: Optional[Literal["year", "month", "day", "hour"]] = None, - daily_subfreq: Optional[Literal[1, 2, 3, 4, 6, 8, 12, 24]] = None, + freq: Literal["year", "month", "day", "hour"] | None = None, + daily_subfreq: Literal[1, 2, 3, 4, 6, 8, 12, 24] | None = None, end_of_month: bool = False, ) -> xr.Dataset: """Add bounds for an axis using its coordinate points. @@ -389,7 +387,7 @@ def add_time_bounds( The dataset with time bounds added. """ ds = self._dataset.copy() - coord_vars: Union[xr.DataArray, xr.Dataset] = get_dim_coords(self._dataset, "T") + coord_vars: xr.DataArray | xr.Dataset = get_dim_coords(self._dataset, "T") # In xarray, ancillary singleton coordinates that aren't related to axis # can still be attached to dimension coordinates (e.g., "height" is # attached to "time"). We ignore these singleton coordinates to avoid @@ -418,8 +416,8 @@ def add_time_bounds( return ds def _drop_ancillary_singleton_coords( - self, coord_vars: Union[xr.Dataset, xr.DataArray] - ) -> Union[xr.Dataset, xr.DataArray]: + self, coord_vars: xr.Dataset | xr.DataArray + ) -> xr.Dataset | xr.DataArray: """Drop ancillary singleton coordinates from dimension coordinates. Xarray coordinate variables retain all coordinates from the parent @@ -441,13 +439,13 @@ def _drop_ancillary_singleton_coords( Parameters ---------- - coord_vars : Union[xr.Dataset, xr.DataArray] + coord_vars : xr.Dataset | xr.DataArray The dimension coordinate variables with ancillary coordinates (if they exist). Returns ------- - Union[xr.Dataset, xr.DataArray] + xr.Dataset | xr.DataArray The dimension coordinate variables with ancillary coordinates dropped (if they exist). @@ -465,7 +463,7 @@ def _drop_ancillary_singleton_coords( return coord_vars - def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]: + def _get_bounds_keys(self, axis: CFAxisKey) -> list[str]: """Get bounds keys for an axis's coordinate variables in the dataset. This function attempts to map bounds to an axis using ``cf_xarray`` @@ -478,13 +476,13 @@ def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]: Returns ------- - List[str] + list[str] The axis bounds key(s). """ cf_method = self._dataset.cf.bounds cf_attrs = CF_ATTR_MAP[axis] - keys: List[str] = [] + keys: list[str] = [] try: keys = keys + cf_method[cf_attrs["axis"]] @@ -503,7 +501,7 @@ def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]: def _get_bounds_from_attr( self, obj: xr.DataArray | xr.Dataset, axis: CFAxisKey - ) -> List[str]: + ) -> list[str]: """Retrieve bounds attribute keys from the given xarray object. This method extracts the "bounds" attribute keys from the coordinates @@ -518,12 +516,12 @@ def _get_bounds_from_attr( Returns: -------- - List[str] + list[str] A list of bounds attribute keys found in the coordinates of the specified axis. Otherwise, an empty list is returned. """ coords_obj = get_dim_coords(obj, axis) - bounds_keys: List[str] = [] + bounds_keys: list[str] = [] if isinstance(coords_obj, xr.DataArray): bounds_keys = self._extract_bounds_key(coords_obj, bounds_keys) @@ -534,8 +532,8 @@ def _get_bounds_from_attr( return bounds_keys def _extract_bounds_key( - self, coords_obj: xr.DataArray, bounds_keys: List[str] - ) -> List[str]: + self, coords_obj: xr.DataArray, bounds_keys: list[str] + ) -> list[str]: bnds_key = coords_obj.attrs.get("bounds") if bnds_key is not None: @@ -546,8 +544,8 @@ def _extract_bounds_key( def _create_time_bounds( # noqa: C901 self, time: xr.DataArray, - freq: Optional[Literal["year", "month", "day", "hour"]] = None, - daily_subfreq: Optional[Literal[1, 2, 3, 4, 6, 8, 12, 24]] = None, + freq: Literal["year", "month", "day", "hour"] | None = None, + daily_subfreq: Literal[1, 2, 3, 4, 6, 8, 12, 24] | None = None, end_of_month: bool = False, ) -> xr.DataArray: """Creates time bounds for each timestep of the time coordinate axis. @@ -673,8 +671,8 @@ def _create_time_bounds( # noqa: C901 def _create_yearly_time_bounds( self, timesteps: np.ndarray, - obj_type: Union[cftime.datetime, pd.Timestamp], - ) -> List[Union[cftime.datetime, pd.Timestamp]]: + obj_type: cftime.datetime | pd.Timestamp, + ) -> list[cftime.datetime | pd.Timestamp]: """Creates time bounds for each timestep with the start and end of the year. Bounds for each timestep correspond to Jan. 1 00:00:00 of the year of the @@ -686,16 +684,16 @@ def _create_yearly_time_bounds( An array of timesteps, represented as either `cftime.datetime` or `pd.Timestamp` (casted from `np.datetime64[ns]` to support pandas time/date components). - obj_type : Union[cftime.datetime, pd.Timestamp] + obj_type : cftime.datetime | pd.Timestamp The object type for time bounds based on the dtype of ``time_values``. Returns ------- - List[Union[cftime.datetime, pd.Timestamp]] + list[cftime.datetime | pd.Timestamp] A list of time bound values. """ - time_bnds: List[cftime.datetime] = [] + time_bnds: list[cftime.datetime] = [] for step in timesteps: year = step.year @@ -710,9 +708,9 @@ def _create_yearly_time_bounds( def _create_monthly_time_bounds( self, timesteps: np.ndarray, - obj_type: Union[cftime.datetime, pd.Timestamp], + obj_type: cftime.datetime | pd.Timestamp, end_of_month: bool = False, - ) -> List[Union[cftime.datetime, pd.Timestamp]]: + ) -> list[cftime.datetime | pd.Timestamp]: """Creates time bounds for each timestep with the start and end of the month. Bounds for each timestep correspond to 00:00:00 on the first of the month @@ -724,7 +722,7 @@ def _create_monthly_time_bounds( An array of timesteps, represented as either `cftime.datetime` or `pd.Timestamp` (casted from `np.datetime64[ns]` to support pandas time/date components). - obj_type : Union[cftime.datetime, pd.Timestamp] + obj_type : cftime.datetime | pd.Timestamp The object type for time bounds based on the dtype of ``time_values``. end_of_month : bool, optional @@ -733,7 +731,7 @@ def _create_monthly_time_bounds( Returns ------- - List[Union[cftime.datetime, pd.Timestamp]] + list[cftime.datetime | pd.Timestamp] A list of time bound values. Note @@ -763,10 +761,10 @@ def _create_monthly_time_bounds( def _add_months_to_timestep( self, - timestep: Union[cftime.datetime, pd.Timestamp], - obj_type: Union[cftime.datetime, pd.Timestamp], + timestep: cftime.datetime | pd.Timestamp, + obj_type: cftime.datetime | pd.Timestamp, delta: int, - ) -> Union[cftime.datetime, pd.Timestamp]: + ) -> cftime.datetime | pd.Timestamp: """Adds delta month(s) to a timestep. The delta value can be positive or negative (for subtraction). Refer to @@ -776,7 +774,7 @@ def _add_months_to_timestep( ---------- timestep : Union[cftime.datime, pd.Timestamp] A timestep represented as ``cftime.datetime`` or ``pd.Timestamp``. - obj_type : Union[cftime.datetime, pd.Timestamp] + obj_type : cftime.datetime | pd.Timestamp The object type for time bounds based on the dtype of ``timestep``. delta : int @@ -784,7 +782,7 @@ def _add_months_to_timestep( Returns ------- - Union[cftime.datetime, pd.Timestamp] + cftime.datetime | pd.Timestamp References ---------- @@ -814,9 +812,9 @@ def _add_months_to_timestep( def _create_daily_time_bounds( self, timesteps: np.ndarray, - obj_type: Union[cftime.datetime, pd.Timestamp], + obj_type: cftime.datetime | pd.Timestamp, freq: Literal[1, 2, 3, 4, 6, 8, 12, 24] = 1, - ) -> List[Union[cftime.datetime, pd.Timestamp]]: + ) -> list[cftime.datetime | pd.Timestamp]: """Creates time bounds for each timestep with the start and end of the day. Bounds for each timestep corresponds to 00:00:00 timepoint on the @@ -839,7 +837,7 @@ def _create_daily_time_bounds( An array of timesteps, represented as either `cftime.datetime` or `pd.Timestamp` (casted from `np.datetime64[ns]` to support pandas time/date components). - obj_type : Union[cftime.datetime, pd.Timestamp] + obj_type : cftime.datetime | pd.Timestamp The object type for time bounds based on the dtype of ``time_values``. freq : {1, 2, 3, 4, 6, 8, 12, 24}, optional @@ -855,7 +853,7 @@ def _create_daily_time_bounds( Returns ------- - List[Union[cftime.datetime, pd.Timestamp]] + list[cftime.datetime | pd.Timestamp] A list of time bound values. Raises diff --git a/xcdat/dataset.py b/xcdat/dataset.py index 85d3a38b..bd51539e 100644 --- a/xcdat/dataset.py +++ b/xcdat/dataset.py @@ -2,10 +2,11 @@ import os import pathlib +from collections.abc import Callable from datetime import datetime from functools import partial from io import BufferedIOBase -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Literal import numpy as np import xarray as xr @@ -26,27 +27,27 @@ logger = _setup_custom_logger(__name__) #: List of non-CF compliant time units. -NON_CF_TIME_UNITS: List[str] = ["month", "months", "year", "years"] +NON_CF_TIME_UNITS: list[str] = ["month", "months", "year", "years"] # Type annotation for the `paths` arg. -Paths = Union[ - str, - pathlib.Path, - List[str], - List[pathlib.Path], - List[List[str]], - List[List[pathlib.Path]], -] +Paths = ( + str + | pathlib.Path + | list[str] + | list[pathlib.Path] + | list[list[str]] + | list[list[pathlib.Path]] +) def open_dataset( path: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, - data_var: Optional[str] = None, - add_bounds: List[CFAxisKey] | Tuple[CFAxisKey, ...] | None = ("X", "Y"), + data_var: str | None = None, + add_bounds: list[CFAxisKey] | tuple[CFAxisKey, ...] | None = ("X", "Y"), decode_times: bool = True, center_times: bool = False, - lon_orient: Optional[Tuple[float, float]] = None, - **kwargs: Dict[str, Any], + lon_orient: tuple[float, float] | None = None, + **kwargs: dict[str, Any], ) -> xr.Dataset: """Wraps ``xarray.open_dataset()`` with post-processing options. @@ -58,10 +59,10 @@ def open_dataset( ends with .gz, in which case the file is gunzipped and opened with scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). - data_var: Optional[str], optional + data_var: str | None, optional The key of the non-bounds data variable to keep in the Dataset, alongside any existing bounds data variables, by default None. - add_bounds: List[CFAxisKey] | Tuple[CFAxisKey, ...] | None + add_bounds: list[CFAxisKey] | tuple[CFAxisKey, ...] | None List of CF axes to try to add bounds for (if missing), by default ("X", "Y"). Set to None to not add any missing bounds. Please note that bounds are required for many xCDAT features. @@ -81,7 +82,7 @@ def open_dataset( If True, attempt to center time coordinates using the midpoint between its upper and lower bounds. Otherwise, use the provided time coordinates, by default False. - lon_orient: Optional[Tuple[float, float]], optional + lon_orient: tuple[float, float] | None, optional The orientation to use for the Dataset's longitude axis (if it exists). Either `(-180, 180)` or `(0, 360)`, by default None. Supported options include: @@ -89,7 +90,7 @@ def open_dataset( * None: use the current orientation (if the longitude axis exists) * (-180, 180): represents [-180, 180) in math notation * (0, 360): represents [0, 360) in math notation - **kwargs : Dict[str, Any] + **kwargs : dict[str, Any] Additional arguments passed on to ``xarray.open_dataset``. Refer to the [1]_ xarray docs for accepted keyword arguments. @@ -124,14 +125,14 @@ def open_dataset( def open_mfdataset( paths: str | NestedSequence[str | os.PathLike], - data_var: Optional[str] = None, - add_bounds: List[CFAxisKey] | Tuple[CFAxisKey, ...] | None = ("X", "Y"), + data_var: str | None = None, + add_bounds: list[CFAxisKey] | tuple[CFAxisKey, ...] | None = ("X", "Y"), decode_times: bool = True, center_times: bool = False, - lon_orient: Optional[Tuple[float, float]] = None, - data_vars: Literal["minimal", "different", "all"] | List[str] = "minimal", - preprocess: Optional[Callable] = None, - **kwargs: Dict[str, Any], + lon_orient: tuple[float, float] | None = None, + data_vars: Literal["minimal", "different", "all"] | list[str] = "minimal", + preprocess: Callable | None = None, + **kwargs: dict[str, Any], ) -> xr.Dataset: """Wraps ``xarray.open_mfdataset()`` with post-processing options. @@ -150,7 +151,7 @@ def open_mfdataset( If concatenation along more than one dimension is desired, then ``paths`` must be a nested list-of-lists (see [2]_ ``xarray.combine_nested`` for details). - add_bounds: List[CFAxisKey] | Tuple[CFAxisKey, ...] | None + add_bounds: list[CFAxisKey] | tuple[CFAxisKey, ...] | None List of CF axes to try to add bounds for (if missing), by default ("X", "Y"). Set to None to not add any missing bounds. Please note that bounds are required for many xCDAT features. @@ -161,7 +162,7 @@ def open_mfdataset( of the coordinates. If desired, refer to :py:func:`xarray.Dataset.bounds.add_time_bounds` if you require more granular configuration for how "T" bounds are generated. - data_var: Optional[str], optional + data_var: str | None, optional The key of the data variable to keep in the Dataset, by default None. decode_times: bool, optional If True, attempt to decode times encoded in the standard NetCDF @@ -172,7 +173,7 @@ def open_mfdataset( If True, attempt to center time coordinates using the midpoint between its upper and lower bounds. Otherwise, use the provided time coordinates, by default False. - lon_orient: Optional[Tuple[float, float]], optional + lon_orient: tuple[float, float] | None, optional The orientation to use for the Dataset's longitude axis (if it exists), by default None. Supported options include: @@ -199,11 +200,11 @@ def open_mfdataset( such as "lat_bnds" or "lon_bnds". ``data_vars="minimal"`` is required for some xCDAT functions, including spatial averaging where a reduction is performed using the lat/lon bounds. - preprocess : Optional[Callable], optional + preprocess : Callable | None, optional If provided, call this function on each dataset prior to concatenation. You can find the file-name from which each dataset was loaded in ``ds.encoding["source"]``. - **kwargs : Dict[str, Any] + **kwargs : dict[str, Any] Additional arguments passed on to ``xarray.open_mfdataset``. Refer to the [3]_ xarray docs for accepted keyword arguments. @@ -411,7 +412,7 @@ def _parse_dir_for_nc_glob(dir_path: str | pathlib.Path) -> str: def _preprocess( - ds: xr.Dataset, decode_times: Optional[bool], callable: Optional[Callable] = None + ds: xr.Dataset, decode_times: bool | None, callable: Callable | None = None ) -> xr.Dataset: """Preprocesses each dataset passed to ``open_mfdataset()``. @@ -436,7 +437,7 @@ def _preprocess( ---------- ds : xr.Dataset The Dataset. - callable : Optional[Callable], optional + callable : Callable | None, optional A user specified optional callable function for preprocessing. Returns @@ -460,10 +461,10 @@ def _preprocess( def _postprocess_dataset( dataset: xr.Dataset, - data_var: Optional[str] = None, + data_var: str | None = None, center_times: bool = False, - add_bounds: List[CFAxisKey] | Tuple[CFAxisKey, ...] | None = ("X", "Y"), - lon_orient: Optional[Tuple[float, float]] = None, + add_bounds: list[CFAxisKey] | tuple[CFAxisKey, ...] | None = ("X", "Y"), + lon_orient: tuple[float, float] | None = None, ) -> xr.Dataset: """Post-processes a Dataset object. @@ -471,13 +472,13 @@ def _postprocess_dataset( ---------- dataset : xr.Dataset The dataset. - data_var: Optional[str], optional + data_var: str | None, optional The key of the data variable to keep in the Dataset, by default None. center_times: bool, optional If True, center time coordinates using the midpoint between its upper and lower bounds. Otherwise, use the provided time coordinates, by default False. - add_bounds: List[CFAxisKey] | Tuple[CFAxisKey, ...] | None + add_bounds: list[CFAxisKey] | tuple[CFAxisKey, ...] | None List of CF axes to try to add bounds for (if missing), default ("X", "Y"). Set to None to not add any missing bounds. @@ -487,7 +488,7 @@ def _postprocess_dataset( * If desired, use :py:func:`xarray.Dataset.bounds.add_time_bounds` if you require more granular configuration for how "T" bounds are generated - lon_orient: Optional[Tuple[float, float]], optional + lon_orient: tuple[float, float] | None, optional The orientation to use for the Dataset's longitude axis (if it exists), by default None. @@ -650,7 +651,7 @@ def _get_cftime_coords(offsets: np.ndarray, units: str, calendar: str) -> np.nda # Convert offsets to `np.float64` to avoid "TypeError: unsupported type # for timedelta days component: numpy.int64". - flat_offsets = flat_offsets.astype("float") # type: ignore + flat_offsets = flat_offsets.astype("float") # We don't need to do calendar arithmetic here because the units and # offsets are in "months" or "years", which means leap days should not diff --git a/xcdat/regridder/accessor.py b/xcdat/regridder/accessor.py index 4c719c06..d7306b1a 100644 --- a/xcdat/regridder/accessor.py +++ b/xcdat/regridder/accessor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Tuple +from typing import Any, Literal import xarray as xr @@ -79,11 +79,11 @@ def grid(self) -> xr.Dataset: >>> grid = ds.regridder.grid """ - axis_names: List[CFAxisKey] = ["X", "Y", "Z"] + axis_names: list[CFAxisKey] = ["X", "Y", "Z"] - axis_coords: Dict[str, xr.DataArray] = {} - axis_bounds: Dict[str, xr.DataArray] = {} - axis_has_bounds: Dict[CFAxisKey, bool] = {} + axis_coords: dict[str, xr.DataArray] = {} + axis_bounds: dict[str, xr.DataArray] = {} + axis_has_bounds: dict[CFAxisKey, bool] = {} with xr.set_options(keep_attrs=True): for axis in axis_names: @@ -119,7 +119,7 @@ def grid(self) -> xr.Dataset: def _get_axis_coord_and_bounds( self, axis: CFAxisKey - ) -> Tuple[xr.DataArray | None, xr.DataArray | None]: + ) -> tuple[xr.DataArray | None, xr.DataArray | None]: try: coord_var = get_coords_by_name(self._ds, axis) if coord_var.size == 1: @@ -310,7 +310,7 @@ def vertical( return output_ds -def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: List[CFAxisKey]): +def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKey]): """ Extract the grid from ``ds``. @@ -324,7 +324,7 @@ def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: List[CFAxisKe Dataset to extract grid from. data_var : str Name of target data variable. - dup_check_dims : List[CFAxisKey] + dup_check_dims : list[CFAxisKey] List of dimensions to check for duplicates. Returns diff --git a/xcdat/regridder/base.py b/xcdat/regridder/base.py index 27860d65..252e71e7 100644 --- a/xcdat/regridder/base.py +++ b/xcdat/regridder/base.py @@ -1,5 +1,5 @@ import abc -from typing import Any, List, Tuple, Union +from typing import Any import numpy as np import xarray as xr @@ -10,16 +10,16 @@ logger = _setup_custom_logger(__name__) -Coord = Union[np.ndarray, xr.DataArray] +Coord = np.ndarray | xr.DataArray -CoordOptionalBnds = Union[Coord, Tuple[Coord, Coord]] +CoordOptionalBnds = Coord | tuple[Coord, Coord] def _preserve_bounds( input_ds: xr.Dataset, output_grid: xr.Dataset, output_ds: xr.Dataset, - drop_axis: List[CFAxisKey], + drop_axis: list[CFAxisKey], ) -> xr.Dataset: """Preserves existing bounds from datasets. @@ -33,7 +33,7 @@ def _preserve_bounds( Output grid Dataset used for regridding. output_ds : xr.Dataset Dataset bounds will be copied to. - drop_axis : List[CFAxisKey] + drop_axis : list[CFAxisKey] Axis or axes to drop from `input_ds`, which drops the related coords and bounds. For example, dropping the "Y" axis in `input_ds` ensures that the "Y" axis in `output_grid` is referenced for bounds. @@ -58,14 +58,14 @@ def _preserve_bounds( return output_ds -def _drop_axis(ds: xr.Dataset, axis: List[CFAxisKey]) -> xr.Dataset: +def _drop_axis(ds: xr.Dataset, axis: list[CFAxisKey]) -> xr.Dataset: """Drops an axis or axes in a dataset. Parameters ---------- ds : xr.Dataset The dataset. - axis : List[CFAxisKey] + axis : list[CFAxisKey] The axis or axes to drop. Returns @@ -73,7 +73,7 @@ def _drop_axis(ds: xr.Dataset, axis: List[CFAxisKey]) -> xr.Dataset: xr.Daatset The dataset with axis or axes dropped. """ - dims: List[str] = [] + dims: list[str] = [] for ax in axis: try: diff --git a/xcdat/regridder/grid.py b/xcdat/regridder/grid.py index 6a3ebbcd..6297dd0f 100644 --- a/xcdat/regridder/grid.py +++ b/xcdat/regridder/grid.py @@ -1,5 +1,3 @@ -from typing import Dict, List, Optional, Tuple, Union - import numpy as np import xarray as xr @@ -96,7 +94,7 @@ def create_gaussian_grid(nlats: int) -> xr.Dataset: return create_grid(x=create_axis("lon", lon), y=(lat, lat_bnds)) -def _create_gaussian_axis(nlats: int) -> Tuple[xr.DataArray, xr.DataArray]: +def _create_gaussian_axis(nlats: int) -> tuple[xr.DataArray, xr.DataArray]: """Create Gaussian axis. Creates a Gaussian axis of `nlats`. @@ -156,7 +154,7 @@ def _create_gaussian_axis(nlats: int) -> Tuple[xr.DataArray, xr.DataArray]: return bounds_da, points_da -def _gaussian_axis(mid: int, nlats: int) -> Tuple[np.ndarray, np.ndarray]: +def _gaussian_axis(mid: int, nlats: int) -> tuple[np.ndarray, np.ndarray]: """Calculates the bounds and weights for a Guassian axis. @@ -174,7 +172,7 @@ def _gaussian_axis(mid: int, nlats: int) -> Tuple[np.ndarray, np.ndarray]: Returns ------- - Tuple[np.ndarray, np.ndarray] + tuple[np.ndarray, np.ndarray] First `np.ndarray` contains the angles of the bounds and the second contains the weights. """ points = _bessel_function_zeros(mid + 1) @@ -247,7 +245,7 @@ def _bessel_function_zeros(n: int) -> np.ndarray: return values -def _legendre_polinomial(bessel_zero: int, nlats: int) -> Tuple[float, float, float]: +def _legendre_polinomial(bessel_zero: int, nlats: int) -> tuple[float, float, float]: """Legendre_polynomials. Calculates the third legendre polynomial. @@ -267,7 +265,7 @@ def _legendre_polinomial(bessel_zero: int, nlats: int) -> Tuple[float, float, fl Returns ------- - Tuple[float, float, float] + tuple[float, float, float] First, second and third legendre polynomial. """ zero_poly = np.cos(bessel_zero / np.sqrt(np.power(nlats + 0.5, 2) + ESTIMATE_CONST)) @@ -427,7 +425,7 @@ def create_zonal_grid(grid: xr.Dataset) -> xr.Dataset: lat_bnds = grid.bounds.get_bounds("Y", var_key=lat.name) # Ignore `Argument 1 to "create_grid" has incompatible type - # "Union[Dataset, DataArray]"; expected "Union[ndarray[Any, Any], DataArray]" + # "Dataset | DataArray"; expected "ndarray[Any, Any] | DataArray" # mypy(error)` because this arg is validated to be a DataArray beforehand. return create_grid( x=create_axis("lon", out_lon_data, bounds=lon_bnds), @@ -436,25 +434,25 @@ def create_zonal_grid(grid: xr.Dataset) -> xr.Dataset: def create_grid( - x: xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None = None, - y: xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None = None, - z: xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None = None, - attrs: Optional[Dict[str, str]] = None, + x: xr.DataArray | tuple[xr.DataArray, xr.DataArray | None] | None = None, + y: xr.DataArray | tuple[xr.DataArray, xr.DataArray | None] | None = None, + z: xr.DataArray | tuple[xr.DataArray, xr.DataArray | None] | None = None, + attrs: dict[str, str] | None = None, ) -> xr.Dataset: """Creates a grid dataset using the specified axes. Parameters ---------- - x : xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None + x : xr.DataArray | tuple[xr.DataArray, xr.DataArray | None] | None An optional dataarray or tuple of a datarray with optional bounds to use for the "X" axis, by default None. - y : xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None = None, + y : xr.DataArray | tuple[xr.DataArray, xr.DataArray | None] | None = None, An optional dataarray or tuple of a datarray with optional bounds to use for the "Y" axis, by default None. - z : xr.DataArray | Tuple[xr.DataArray, xr.DataArray | None] | None + z : xr.DataArray | tuple[xr.DataArray, xr.DataArray | None] | None An optional dataarray or tuple of a datarray with optional bounds to use for the "Z" axis, by default None. - attrs : Optional[Dict[str, str]] + attrs : dict[str, str] | None Custom attributes to be added to the generated `xr.Dataset`. Returns @@ -535,11 +533,11 @@ def create_grid( def create_axis( name: str, - data: Union[List[Union[int, float]], np.ndarray], - bounds: Optional[Union[List[List[Union[int, float]]], np.ndarray]] = None, - generate_bounds: Optional[bool] = True, - attrs: Optional[Dict[str, str]] = None, -) -> Tuple[xr.DataArray, Optional[xr.DataArray]]: + data: list[int | float] | np.ndarray, + bounds: list[list[int | float]] | np.ndarray | None = None, + generate_bounds: bool = True, + attrs: dict[str, str] | None = None, +) -> tuple[xr.DataArray, xr.DataArray | None]: """Creates an axis and optional bounds. @@ -549,14 +547,14 @@ def create_axis( The CF standard name for the axis (e.g., "longitude", "latitude", "height"). xCDAT also accepts additional names such as "lon", "lat", and "lev". Refer to ``xcdat.axis.VAR_NAME_MAP`` for accepted names. - data : Union[List[Union[int, float]], np.ndarray] + data : list[int | float] | np.ndarray 1-D axis data consisting of integers or floats. - bounds : Optional[Union[List[List[Union[int, float]]], np.ndarray]] + bounds : list[list[int | float]] | np.ndarray | None 2-D axis bounds data consisting of integers or floats, defaults to None. Must have a shape of n x 2, where n is the length of ``data``. - generate_bounds : Optiona[bool] + generate_bounds : bool Generate bounds for the axis if ``bounds`` is None, by default True. - attrs : Optional[Dict[str, str]] + attrs : dict[str, str] | None Custom attributes to be added to the generated `xr.DataArray` axis, by default None. @@ -566,7 +564,7 @@ def create_axis( Returns ------- - Tuple[xr.DataArray, Optional[xr.DataArray]] + tuple[xr.DataArray, xr.DataArray | None] A DataArray containing the axis data and optional bounds. Raises @@ -636,7 +634,7 @@ def create_axis( def _validate_grid_has_single_axis_dim( - axis: CFAxisKey, coord_var: Union[xr.DataArray, xr.Dataset] + axis: CFAxisKey, coord_var: xr.DataArray | xr.Dataset ): """Validates that the grid's axis has a single dimension. @@ -648,7 +646,7 @@ def _validate_grid_has_single_axis_dim( ---------- axis : CFAxisKey The CF axis key ("X", "Y", "T", or "Z"). - coord_var : Union[xr.DataArray, xr.Dataset] + coord_var : xr.DataArray | xr.Dataset The dimension coordinate variable(s) for the axis. Raises diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index 710313d1..67fda714 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import numpy as np import xarray as xr @@ -119,7 +119,7 @@ def _regrid( src_lon_bnds: np.ndarray, dst_lat_bnds: np.ndarray, dst_lon_bnds: np.ndarray, - src_mask: Optional[np.ndarray], + src_mask: np.ndarray | None, omitted=None, unmapped_to_nan=True, ) -> np.ndarray: @@ -272,7 +272,7 @@ def _build_dataset( def _get_output_coords( dv_input: xr.DataArray, output_grid: xr.Dataset -) -> Dict[str, xr.DataArray]: +) -> dict[str, xr.DataArray]: """ Generate the output coordinates for regridding based on the input data variable and output grid. @@ -286,12 +286,12 @@ def _get_output_coords( Returns ------- - Dict[str, xr.DataArray] + dict[str, xr.DataArray] A dictionary where keys are coordinate names and values are the corresponding coordinates from the output grid or input data variable, aligned with the dimensions of the input data variable. """ - output_coords: Dict[str, xr.DataArray] = {} + output_coords: dict[str, xr.DataArray] = {} # First get the X and Y axes from the output grid. for key in ["X", "Y"]: @@ -313,7 +313,7 @@ def _get_output_coords( def _map_latitude( src: np.ndarray, dst: np.ndarray -) -> Tuple[List[np.ndarray], List[np.ndarray]]: +) -> tuple[list[np.ndarray], list[np.ndarray]]: """ Map source to destination latitude. @@ -335,7 +335,7 @@ def _map_latitude( Returns ------- - Tuple[List[np.ndarray], List[np.ndarray]] + tuple[list[np.ndarray], list[np.ndarray]] A tuple of cell mappings and cell weights. """ src_south, src_north = _extract_bounds(src) @@ -366,8 +366,8 @@ def _map_latitude( def _get_latitude_weights( - bounds: List[Tuple[np.ndarray, np.ndarray]], -) -> List[np.ndarray]: + bounds: list[tuple[np.ndarray, np.ndarray]], +) -> list[np.ndarray]: weights = [] for x, y in bounds: @@ -379,7 +379,7 @@ def _get_latitude_weights( return weights -def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: +def _map_longitude(src: np.ndarray, dst: np.ndarray) -> tuple[list, list]: """ Map source to destination longitude. @@ -404,7 +404,7 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: Returns ------- - Tuple[List, List] + tuple[list, list] A tuple of cell mappings and cell weights. """ src_west, src_east = _extract_bounds(src) @@ -456,7 +456,7 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: return mapping, weights -def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: +def _extract_bounds(bounds: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Extract lower and upper bounds from an axis. @@ -467,7 +467,7 @@ def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: Returns ------- - Tuple[np.ndarray, np.ndarray] + tuple[np.ndarray, np.ndarray] A tuple containing the lower and upper bounds for the axis. """ if bounds[0, 0] < bounds[0, 1]: @@ -484,7 +484,7 @@ def _align_axis( src_west: np.ndarray, src_east: np.ndarray, dst_west: np.ndarray, -) -> Tuple[np.ndarray, np.ndarray, int]: +) -> tuple[np.ndarray, np.ndarray, int]: """ Aligns a source and destination longitude axis. @@ -499,7 +499,7 @@ def _align_axis( Returns ------- - Tuple[np.ndarray, np.ndarray, int] + tuple[np.ndarray, np.ndarray, int] A tuple containing the shifted western source bounds, the shifted eastern source bounds, and the number of places shifted to align axis. """ diff --git a/xcdat/regridder/xesmf.py b/xcdat/regridder/xesmf.py index 7ad9dba0..e29396e5 100644 --- a/xcdat/regridder/xesmf.py +++ b/xcdat/regridder/xesmf.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import xarray as xr import xesmf as xe @@ -24,9 +24,9 @@ def __init__( output_grid: xr.Dataset, method: str, periodic: bool = False, - extrap_method: Optional[str] = None, - extrap_dist_exponent: Optional[float] = None, - extrap_num_src_pnts: Optional[int] = None, + extrap_method: str | None = None, + extrap_dist_exponent: float | None = None, + extrap_num_src_pnts: int | None = None, ignore_degenerate: bool = True, unmapped_to_nan: bool = True, **options: Any, @@ -61,12 +61,12 @@ def __init__( The regridding method to apply, defaults to "bilinear". periodic : bool Treat longitude as periodic, used for global grids. - extrap_method : Optional[str] + extrap_method : str | None Extrapolation method, useful when moving from a fine to coarse grid. - extrap_dist_exponent : Optional[float] + extrap_dist_exponent : float | None The exponent to raise the distance to when calculating weights for the extrapolation method. - extrap_num_src_pnts : Optional[int] + extrap_num_src_pnts : int | None The number of source points to use for the extrapolation methods that use more than one source point. ignore_degenerate : bool diff --git a/xcdat/regridder/xgcm.py b/xcdat/regridder/xgcm.py index 71de999c..94b4ffb7 100644 --- a/xcdat/regridder/xgcm.py +++ b/xcdat/regridder/xgcm.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Hashable, Literal, Optional, Union, get_args +from collections.abc import Hashable +from typing import Any, Literal, get_args import xarray as xr from xgcm import Grid @@ -18,10 +19,10 @@ def __init__( input_grid: xr.Dataset, output_grid: xr.Dataset, method: XGCMVerticalMethods = "linear", - target_data: Optional[Union[str, xr.DataArray]] = None, - grid_positions: Optional[Dict[str, str]] = None, + target_data: str | xr.DataArray | None = None, + grid_positions: dict[str, str] | None = None, periodic: bool = False, - extra_init_options: Optional[Dict[str, Any]] = None, + extra_init_options: dict[str, Any] | None = None, **options: Any, ): """ @@ -57,18 +58,18 @@ def __init__( - linear (default) - log - conservative - target_data : Optional[Union[str, xr.DataArray]] + target_data : str | xr.DataArray | None Data to transform target data onto, either the key of a variable in the input dataset or an ``xr.DataArray``, by default None. - grid_positions : Optional[Dict[str, str]] + grid_positions : dict[str, str] | None Mapping of dimension positions, by default None. If ``None`` then an attempt is made to derive this argument. - periodic : Optional[bool] + periodic : bool Whether the grid is periodic, by default False. - extra_init_options : Optional[Dict[str, Any]] + extra_init_options : dict[str, Any] | None Extra options passed to the ``xgcm.Grid`` constructor, by default None. - options : Optional[Dict[str, Any]] + options : dict[str, Any] | None Extra options passed to the ``xgcm.Grid.transform`` method. Raises @@ -161,7 +162,7 @@ def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: grid = Grid(ds, coords=grid_coords, **self._extra_init_options) - target_data: Union[str, xr.DataArray, None] = None + target_data: str | xr.DataArray | None = None try: target_data = ds[self._target_data] @@ -209,7 +210,7 @@ def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: return output_ds - def _get_grid_positions(self) -> Dict[str, Union[Any, Hashable]]: + def _get_grid_positions(self) -> dict[str, Any | Hashable]: if self._method == "conservative": raise RuntimeError( "Conservative regridding requires a second point position, pass these " diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 8fdaee84..50dd4dc2 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -1,17 +1,8 @@ """Module containing geospatial averaging functions.""" +from collections.abc import Callable, Hashable from functools import reduce -from typing import ( - Callable, - Dict, - Hashable, - List, - Literal, - Optional, - Tuple, - TypedDict, - get_args, -) +from typing import Literal, TypedDict, get_args import cf_xarray # noqa: F401 import numpy as np @@ -31,12 +22,12 @@ ) #: Type alias for a dictionary of axis keys mapped to their bounds. -AxisWeights = Dict[Hashable, xr.DataArray] +AxisWeights = dict[Hashable, xr.DataArray] #: Type alias for supported spatial axis keys. SpatialAxis = Literal["X", "Y"] -SPATIAL_AXES: Tuple[SpatialAxis, ...] = get_args(SpatialAxis) +SPATIAL_AXES: tuple[SpatialAxis, ...] = get_args(SpatialAxis) #: Type alias for a tuple of floats/ints for the regional selection bounds. -RegionAxisBounds = Tuple[float, float] +RegionAxisBounds = tuple[float, float] @xr.register_dataset_accessor("spatial") @@ -72,7 +63,7 @@ def __init__(self, dataset: xr.Dataset): def average( self, data_var: str, - axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"), + axis: list[SpatialAxis] | tuple[SpatialAxis, ...] = ("X", "Y"), weights: Literal["generate"] | xr.DataArray = "generate", keep_weights: bool = False, lat_bounds: RegionAxisBounds | None = None, @@ -105,7 +96,7 @@ def average( data_var: str The name of the data variable inside the dataset to spatially average. - axis : List[SpatialAxis] + axis : list[SpatialAxis] List of axis dimensions to average over, by default ("X", "Y"). Valid axis keys include "X" and "Y". weights : {"generate", xr.DataArray}, optional @@ -226,7 +217,7 @@ def average( def get_weights( self, - axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], + axis: list[SpatialAxis] | tuple[SpatialAxis, ...], lat_bounds: RegionAxisBounds | None = None, lon_bounds: RegionAxisBounds | None = None, data_var: str | None = None, @@ -246,7 +237,7 @@ def get_weights( Parameters ---------- - axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] + axis : list[SpatialAxis] | tuple[SpatialAxis, ...] List of axis dimensions to average over. lat_bounds : RegionAxisBounds | None Tuple of latitude boundaries for regional selection, by default @@ -275,10 +266,10 @@ def get_weights( and pressure). """ Bounds = TypedDict( - "Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]} + "Bounds", {"weights_method": Callable, "region": np.ndarray | None} ) - axis_bounds: Dict[SpatialAxis, Bounds] = { + axis_bounds: dict[SpatialAxis, Bounds] = { "X": { "weights_method": self._get_longitude_weights, "region": np.array(lon_bounds, dtype="float") @@ -315,13 +306,13 @@ def get_weights( return weights - def _validate_axis_arg(self, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...]): + def _validate_axis_arg(self, axis: list[SpatialAxis] | tuple[SpatialAxis, ...]): """ Validates that the ``axis`` dimension(s) exists in the dataset. Parameters ---------- - axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] + axis : list[SpatialAxis] | tuple[SpatialAxis, ...] List of axis dimensions to average over. Raises @@ -673,7 +664,7 @@ def _combine_weights(self, axis_weights: AxisWeights) -> xr.DataArray: return region_weights def _validate_weights( - self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] + self, data_var: xr.DataArray, axis: list[SpatialAxis] | tuple[SpatialAxis, ...] ): """Validates the ``weights`` arg based on a set of criteria. @@ -686,7 +677,7 @@ def _validate_weights( ---------- data_var : xr.DataArray The data variable used for validation with user supplied weights. - axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] + axis : list[SpatialAxis] | tuple[SpatialAxis, ...] List of axes dimension(s) average over. weights : xr.DataArray A DataArray containing the region area weights for averaging. @@ -725,7 +716,7 @@ def _validate_weights( def _averager( self, data_var: xr.DataArray, - axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], + axis: list[SpatialAxis] | tuple[SpatialAxis, ...], skipna: bool | None = None, min_weight: float = 0.0, ) -> xr.DataArray: @@ -744,7 +735,7 @@ def _averager( ---------- data_var : xr.DataArray Data variable inside a Dataset. - axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] + axis : list[SpatialAxis] | tuple[SpatialAxis, ...] List of axis dimensions to average over. skipna : bool | None, optional If True, skip missing values (as marked by NaN). By default, only @@ -776,7 +767,7 @@ def _averager( dv = data_var.copy() weights = self._weights.fillna(0) - dim: List[str] = [] + dim: list[str] = [] for key in axis: dim.append(get_dim_keys(dv, key)) # type: ignore @@ -794,7 +785,7 @@ def _mask_var_with_weight_threshold( self, dv: xr.DataArray, dv_mean: xr.DataArray, - dim: List[str], + dim: list[str], weights: xr.DataArray, min_weight: float, ) -> xr.DataArray: @@ -813,7 +804,7 @@ def _mask_var_with_weight_threshold( The weighted variable used for getting masked weights. dv_mean : xr.DataArray The average of the weighted variable. - dim: List[str]: + dim: list[str]: List of axis dimensions to average over. weights : xr.DataArray A DataArray containing either the regional weights used for weighted diff --git a/xcdat/temporal.py b/xcdat/temporal.py index d6f2a362..223ec474 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -3,7 +3,7 @@ import warnings from datetime import datetime from itertools import chain -from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union, get_args +from typing import Literal, TypedDict, get_args import cf_xarray # noqa: F401 import cftime @@ -37,7 +37,7 @@ DateTimeComponent = Literal["year", "season", "month", "day", "hour"] #: A dictionary mapping temporal averaging mode and frequency to the time groups. -TIME_GROUPS: Dict[Mode, Dict[Frequency, Tuple[DateTimeComponent, ...]]] = { +TIME_GROUPS: dict[Mode, dict[Frequency, tuple[DateTimeComponent, ...]]] = { "average": { "year": ("year",), "month": ("month",), @@ -71,7 +71,7 @@ "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], - "custom_seasons": Optional[List[List[str]]], + "custom_seasons": list[list[str]] | None, }, total=False, ) @@ -83,7 +83,7 @@ "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], - "custom_seasons": Optional[Dict[str, List[str]]], + "custom_seasons": dict[str, list[str]] | None, }, total=False, ) @@ -97,7 +97,7 @@ } #: A dictionary mapping month integers to their equivalent 3-letter string. -MONTH_INT_TO_STR: Dict[int, str] = { +MONTH_INT_TO_STR: dict[int, str] = { 1: "Jan", 2: "Feb", 3: "Mar", @@ -116,7 +116,7 @@ # A dictionary mapping pre-defined seasons to their middle month. This # dictionary is used during the creation of datetime objects, which don't # support season values. -SEASON_TO_MONTH: Dict[str, int] = {"DJF": 1, "MAM": 4, "JJA": 7, "SON": 10} +SEASON_TO_MONTH: dict[str, int] = {"DJF": 1, "MAM": 4, "JJA": 7, "SON": 10} @xr.register_dataset_accessor("temporal") @@ -347,7 +347,7 @@ def group_average( Xarray labels the season with December as "DJF", but it is actually "JFD". - * "custom_seasons" ([List[List[str]]], by default None) + * "custom_seasons" ([list[list[str]]], by default None) List of sublists containing month strings, with each sublist representing a custom season. @@ -454,7 +454,7 @@ def climatology( freq: Frequency, weighted: bool = True, keep_weights: bool = False, - reference_period: Optional[Tuple[str, str]] = None, + reference_period: tuple[str, str] | None = None, season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG, skipna: bool | None = None, ): @@ -504,7 +504,7 @@ def climatology( keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. - reference_period : Optional[Tuple[str, str]], optional + reference_period : tuple[str, str] | None, optional The climatological reference period, which is a subset of the entire time series. This parameter accepts a tuple of strings in the format 'yyyy-mm-dd'. For example, ``('1850-01-01', '1899-12-31')``. If no @@ -547,7 +547,7 @@ def climatology( Xarray labels the season with December as "DJF", but it is actually "JFD". - * "custom_seasons" ([List[List[str]]], by default None) + * "custom_seasons" ([list[list[str]]], by default None) List of sublists containing month strings, with each sublist representing a custom season. @@ -659,7 +659,7 @@ def departures( freq: Frequency, weighted: bool = True, keep_weights: bool = False, - reference_period: Optional[Tuple[str, str]] = None, + reference_period: tuple[str, str] | None = None, season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG, skipna: bool | None = None, ) -> xr.Dataset: @@ -715,7 +715,7 @@ def departures( keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. - reference_period : Optional[Tuple[str, str]], optional + reference_period : tuple[str, str] | None, optional The climatological reference period, which is a subset of the entire time series and used for calculating departures. This parameter accepts a tuple of strings in the format 'yyyy-mm-dd'. For example, @@ -763,7 +763,7 @@ def departures( Configs for custom seasons: - * "custom_seasons" ([List[List[str]]], by default None) + * "custom_seasons" ([list[list[str]]], by default None) List of sublists containing month strings, with each sublist representing a custom season. @@ -896,7 +896,7 @@ def _averager( freq: Frequency, weighted: bool = True, keep_weights: bool = False, - reference_period: Optional[Tuple[str, str]] = None, + reference_period: tuple[str, str] | None = None, season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG, skipna: bool | None = None, ) -> xr.Dataset: @@ -981,7 +981,7 @@ def _set_arg_attrs( mode: Mode, freq: Frequency, weighted: bool, - reference_period: Optional[Tuple[str, str]] = None, + reference_period: tuple[str, str] | None = None, season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG, ): """Validates method arguments and sets them as object attributes. @@ -1072,7 +1072,7 @@ def _set_season_config_attr(self, season_config: SeasonConfigInput): self._season_config["drop_incomplete_djf"] = drop_incomplete_djf - def _is_valid_reference_period(self, reference_period: Tuple[str, str]): + def _is_valid_reference_period(self, reference_period: tuple[str, str]): try: datetime.strptime(reference_period[0], "%Y-%m-%d") datetime.strptime(reference_period[1], "%Y-%m-%d") @@ -1083,7 +1083,7 @@ def _is_valid_reference_period(self, reference_period: Tuple[str, str]): "'1899-12-31')." ) from e - def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]]: + def _form_seasons(self, custom_seasons: list[list[str]]) -> dict[str, list[str]]: """Forms custom seasons from a nested list of months. This method concatenates the strings in each sublist to form a @@ -1091,13 +1091,13 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]] Parameters ---------- - custom_seasons : List[List[str]] + custom_seasons : list[list[str]] List of sublists containing month strings, with each sublist representing a custom season. Returns ------- - Dict[str, List[str]] + dict[str, list[str]] A dictionary with the keys being the custom season and the values being the corresponding list of months. @@ -1200,7 +1200,7 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: return ds def _subset_coords_for_custom_seasons( - self, ds: xr.Dataset, months: List[str] + self, ds: xr.Dataset, months: list[str] ) -> xr.Dataset: """Subsets time coordinates to the months included in custom seasons. @@ -1208,7 +1208,7 @@ def _subset_coords_for_custom_seasons( ---------- ds : xr.Dataset The dataset. - months : List[str] + months : list[str] A list of months included in custom seasons. Example: ["Nov", "Dec", "Jan"] @@ -1265,7 +1265,7 @@ def _shift_custom_season_years(self, ds: xr.Dataset) -> xr.Dataset: # Identify months that span across years in custom seasons by getting # the months before "Jan" if "Jan" is not the first month of the season. # Note: Only one custom season can span the calendar year. - span_months: List[int] = [] + span_months: list[int] = [] for months in custom_seasons.values(): # type: ignore month_ints = [MONTH_STR_TO_INT[month] for month in months] @@ -2139,7 +2139,7 @@ def _contains_datetime_like_objects(var: xr.DataArray) -> bool: def _get_datetime_like_type( var: xr.DataArray, -) -> Union[np.datetime64, np.timedelta64, cftime.datetime]: +) -> np.datetime64 | np.timedelta64 | cftime.datetime: """Get the DataArray's object type if they are datetime-like. A variable contains datetime-like objects if they are either @@ -2157,7 +2157,7 @@ def _get_datetime_like_type( Returns ------- - Union[np.datetime64, np.timedelta64, cftime.datetime]: + np.datetime64 | np.timedelta64 | cftime.datetime: """ var_obj = xr.as_variable(var) dtype = var.dtype diff --git a/xcdat/tutorial.py b/xcdat/tutorial.py index b8dce6a7..463122e7 100644 --- a/xcdat/tutorial.py +++ b/xcdat/tutorial.py @@ -8,7 +8,6 @@ import os import pathlib import sys -from typing import Dict, List, Tuple import xarray as xr from xarray.tutorial import _construct_cache_dir, file_formats @@ -21,7 +20,7 @@ version = "main" XARRAY_DATASETS = list(file_formats.keys()) + ["era5-2mt-2019-03-uk.grib"] -XCDAT_DATASETS: Dict[str, str] = { +XCDAT_DATASETS: dict[str, str] = { # Monthly precipitation data from the ACCESS-ESM1-5 model. "pr_amon_access": "pr_Amon_ACCESS-ESM1-5_historical_r10i1p1f1_gn_185001-201412_subset.nc", # Monthly ocean salinity data from the CESM2 model. @@ -45,7 +44,7 @@ def open_dataset( name: str, cache: bool = True, cache_dir: None | str | os.PathLike = DEFAULT_CACHE_DIR_NAME, - add_bounds: List[CFAxisKey] | Tuple[CFAxisKey, ...] | None = ("X", "Y"), + add_bounds: list[CFAxisKey] | tuple[CFAxisKey, ...] | None = ("X", "Y"), **kargs, ) -> xr.Dataset: """Open a dataset from the online repository (requires internet). @@ -74,7 +73,7 @@ def open_dataset( The directory in which to search for and write cached data. cache : bool, optional If True, then cache data locally for use on subsequent calls - add_bounds : List[CFAxisKey] | Tuple[CFAxisKey] | None, optional + add_bounds : list[CFAxisKey] | tuple[CFAxisKey] | None, optional List or tuple of axis keys for which to add bounds, by default ("X", "Y"). **kargs : dict, optional diff --git a/xcdat/utils.py b/xcdat/utils.py index 46227221..04edf36c 100644 --- a/xcdat/utils.py +++ b/xcdat/utils.py @@ -1,12 +1,11 @@ import importlib import json -from typing import Dict, List, Optional, Union import xarray as xr from dask.array.core import Array -def compare_datasets(ds1: xr.Dataset, ds2: xr.Dataset) -> Dict[str, List[str]]: +def compare_datasets(ds1: xr.Dataset, ds2: xr.Dataset) -> dict[str, list[str]]: """Compares the keys and values of two datasets. This utility function is especially useful for debugging tests that @@ -30,7 +29,7 @@ def compare_datasets(ds1: xr.Dataset, ds2: xr.Dataset) -> Dict[str, List[str]]: Returns ------- - Dict[str, Union[List[str]]] + dict[str, list[str]] A dictionary mapping unique, non-identical, and non-equal keys in both Datasets. """ @@ -111,8 +110,8 @@ def _has_module(modname: str) -> bool: # pragma: no cover def _if_multidim_dask_array_then_load( - obj: Union[xr.DataArray, xr.Dataset], -) -> Optional[Union[xr.DataArray, xr.Dataset]]: + obj: xr.DataArray | xr.Dataset, +) -> xr.DataArray | xr.Dataset | None: """ If the underlying array for an xr.DataArray or xr.Dataset is a multidimensional, lazy Dask Array, load it into an in-memory NumPy array. @@ -124,9 +123,9 @@ def _if_multidim_dask_array_then_load( Parameters ---------- - obj : Union[xr.DataArray, xr.Dataset] + obj : xr.DataArray | xr.Dataset | None The xr.DataArray or xr.Dataset. If the xarray object is chunked, - the underlying array will be a Dask Array. + the underlying array will be a Dask Array. Otherwise, return None. """ if isinstance(obj.data, Array) and obj.ndim > 1: return obj.load() From 694f451d9da8a941f8d4429b3a27c170b77f9f80 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 12 May 2025 11:46:40 -0700 Subject: [PATCH 09/13] Update .github/workflows/build_workflow.yml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .github/workflows/build_workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index ef786a2b..cdad7f65 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -39,7 +39,7 @@ jobs: - name: Set up Python 3.12 uses: actions/setup-python@v3 with: - python-version: "3.12 + python-version: "3.12" - name: Install and Run Pre-commit uses: pre-commit/action@v3.0.1 From f031653adcae3d861a2a804fe1edb83f07cca901 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 12 May 2025 11:47:37 -0700 Subject: [PATCH 10/13] Fix type annotation spelling mistake in `_drop_axis()` --- xcdat/regridder/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xcdat/regridder/base.py b/xcdat/regridder/base.py index 252e71e7..0458abad 100644 --- a/xcdat/regridder/base.py +++ b/xcdat/regridder/base.py @@ -70,7 +70,7 @@ def _drop_axis(ds: xr.Dataset, axis: list[CFAxisKey]) -> xr.Dataset: Returns ------- - xr.Daatset + xr.Dataset The dataset with axis or axes dropped. """ dims: list[str] = [] From fd14f6b06c53cdd5da12415e53d9f193abf5fed5 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 12 May 2025 12:02:56 -0700 Subject: [PATCH 11/13] Remove numpy constraint in build workflow --- .github/workflows/build_workflow.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index cdad7f65..61d66ebf 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -93,8 +93,7 @@ jobs: run: | conda env update -n xcdat_ci -f conda-env/ci.yml # Make sure the Python version in the env matches the current matrix version. - # Make sure numpy is not > 2.0. - conda install -c conda-forge python=${{ matrix.python-version }} "numpy>=1.23.0,<2.0" + conda install -c conda-forge python=${{ matrix.python-version }} - name: Install xcdat # Source: https://github.com/conda/conda-build/issues/4251#issuecomment-1053460542 From 645ff36aab54586559292b83022da47de62dad18 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 12 May 2025 12:03:31 -0700 Subject: [PATCH 12/13] Update logic to check if index_with_360 is > 0 - Based on GitHub Copilot suggestion --- xcdat/axis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xcdat/axis.py b/xcdat/axis.py index 02597bc4..65782fce 100644 --- a/xcdat/axis.py +++ b/xcdat/axis.py @@ -423,7 +423,7 @@ def _swap_lon_axis(coords: xr.DataArray, to: tuple[float, float]) -> xr.DataArra # Example with 360 coords: [60, 150, 0] -> [60, 150, 360] index_with_360 = np.where(coords == 360) - if len(index_with_360) > 0: + if index_with_360[0].size > 0: _if_multidim_dask_array_then_load(new_coords) new_coords[index_with_360] = 360 From 6f2d00cd32ce37a4d5561a349f4840cc07cfcc81 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 12 May 2025 12:09:59 -0700 Subject: [PATCH 13/13] Add `setuptools` as dependencies in conda envs --- conda-env/ci.yml | 1 + conda-env/dev.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/conda-env/ci.yml b/conda-env/ci.yml index c1d3e0af..8906542d 100644 --- a/conda-env/ci.yml +++ b/conda-env/ci.yml @@ -7,6 +7,7 @@ dependencies: # Base - required for building the package. # ========================================= - python >=3.10 + - setuptools - cf_xarray >=0.9.1 - cftime - dask diff --git a/conda-env/dev.yml b/conda-env/dev.yml index 22bbdbf1..79d5cdd5 100644 --- a/conda-env/dev.yml +++ b/conda-env/dev.yml @@ -7,6 +7,7 @@ dependencies: # Base - required for building the package. # ========================================= - python >=3.10 + - setuptools - cf_xarray >=0.9.1 - cftime - dask