Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pole_kind to grids #236

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
What's new
==========

0.7.1 (unreleased)
0.8.0 (unreleased)
------------------

New features
~~~~~~~~~~~
- Expose ESMF capability to use ``pole_kind`` to specify monopolar or bipolar grid types, useful for regridding tripolar ocean grids. By `Benjamin Cash <https://github.com/benjamin-cash>`_.

Bug fixes
~~~~~~~~~
- Fix ``Mesh.from_polygons`` to support ``shapely`` 2.0. By `Pascal Bourgault <https://github.com/aulemahal>`_.
Expand Down
25 changes: 22 additions & 3 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def warn_lat_range(lat):

class Grid(ESMF.Grid):
@classmethod
def from_xarray(cls, lon, lat, periodic=False, mask=None):
def from_xarray(cls, lon, lat, periodic=False, mask=None, pole_kind=None):
"""
Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid.

Expand All @@ -83,6 +83,17 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None):
Shape should be ``(Nlon, Nlat)`` for rectilinear grid,
or ``(Nx, Ny)`` for general quadrilateral grid.

pole_kind : [int, int] or None
Two item list which specifies the type of connection which occurs at the pole.
The first value specifies the connection that occurs at the minimum end of the
pole dimension. The second value specifies the connection that occurs at the
maximum end of the pole dimension. Options are 0 (no connections at pole),
1 (monopole, this edge is connected to itself. Given that the edge is n elements long,
then element i is connected to element i+n/2), and 2 (bipole, this edge is connected
to itself. Given that the edge is n elements long, element i is connected to element n-i-1.
If None, defaults to [1,1] for monopole connections. See :attr:`ESMF.api.constants.PoleKind`.
Requires ESMF >= 8.0.1

Returns
-------
grid : ESMF.Grid object
Expand Down Expand Up @@ -111,13 +122,21 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None):
# they will be set to default values (CENTER and SPH_DEG).
# However, they actually need to be set explicitly,
# otherwise grid._coord_sys and grid._staggerloc will still be None.
grid = cls(
np.array(lon.shape),
kwds = dict(
staggerloc=staggerloc,
coord_sys=ESMF.CoordSys.SPH_DEG,
num_peri_dims=num_peri_dims,
pole_kind=pole_kind,
)

# `pole_kind` option supported since 8.0.1
if ESMF.__version__ < '8.0.1':
if pole_kind is not None:
raise ValueError('The `pole_kind` option requires esmpy >= 8.0.1')
kwds.pop('pole_kind')

grid = cls(np.array(lon.shape), **kwds)

# The grid object points to the underlying Fortran arrays in ESMF.
# To modify lat/lon coordinates, need to get pointers to them
lon_pointer = grid.get_coords(coord_dim=0, staggerloc=staggerloc)
Expand Down
9 changes: 7 additions & 2 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,16 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
else:
mask = None

if 'pole_kind' in ds:
pole_kind = np.asarray(ds['pole_kind'])
else:
pole_kind = None

# tranpose the arrays so they become Fortran-ordered
if mask is not None:
grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=mask.T)
grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=mask.T, pole_kind=pole_kind)
else:
grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=None)
grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=None, pole_kind=pole_kind)

if need_bounds:
lon_b, lat_b = _get_lon_lat_bounds(ds)
Expand Down
46 changes: 46 additions & 0 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,3 +854,49 @@ def test_spatial_averager_mask():
savg = xe.SpatialAverager(dsm, [poly], geom_dim_name='my_geom')
out = savg(dsm.abc)
assert_allclose(out, 2, rtol=1e-3)


def test_regrid_polekind():

# Open tripole SST file
ds_in = xr.open_dataset('mom6_tripole_SST.nc')

# Open input grid specification
ds_ingrid = xr.open_dataset('grid_spec.nc')
ds_sst_grid = ds_ingrid.rename({'geolat': 'lat', 'geolon': 'lon'})
ds_sst_grid['mask'] = ds_ingrid['wet']

# Get MOM6 mask
ds_ingrid['mask'] = ds_ingrid['wet']

# Open output grid specification
ds_outgrid = xr.open_dataset('C384_gaussian_grid.nc')

# Get C384 land-sea mask
ds_outgrid['mask'] = 1 - ds_outgrid['land'].where(ds_outgrid['land'] < 2.0).squeeze()

# Create regridder without specifying pole kind
base_regrid = xe.Regridder(ds_sst_grid, ds_outgrid, 'bilinear', periodic=True)
base_result = base_regrid(ds_in['SST'])

# Add monopole grid information. 1 denotes monopole, 2 bipole
ds_sst_grid['pole_kind'] = np.array([1, 1])
ds_outgrid['pole_kind'] = np.array([1, 1])

monopole_regrid = xe.Regridder(ds_sst_grid, ds_outgrid, 'bilinear', periodic=True)
monopole_result = monopole_regrid(ds_in['SST'])

# Check behavior unchanged
assert monopole_result.equals(base_result)

# Add bipole grid information
ds_sst_grid['pole_kind'] = np.array([1, 2], np.int32)
bipole_regrid = xe.Regridder(ds_sst_grid, ds_outgrid, 'bilinear', periodic=True)
bipole_result = bipole_regrid(ds_in['SST'])

# Confirm results have changed
assert not bipole_result.equals(monopole_result)

# Confirm results match saved values
verif_in = xr.open_dataset('verify_bipole_regrid_SST.nc')['SST']
assert bipole_result.equals(verif_in)