Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masking out-of-source-domain data points for nearest neighbour remapping #317

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 155 additions & 63 deletions doc/notebooks/Masking.ipynb

Large diffs are not rendered by default.

34 changes: 21 additions & 13 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,27 +127,35 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None):
lon_pointer[...] = lon
lat_pointer[...] = lat

# Follows SCRIP convention where 1 is unmasked and 0 is masked.
# See https://github.com/NCPP/ocgis/blob/61d88c60e9070215f28c1317221c2e074f8fb145/src/ocgis/regrid/base.py#L391-L404
# Set mask
if mask is not None:
# remove fractional values
mask = np.where(mask == 0, 0, 1)
# convert array type to integer (ESMF compat)
grid_mask = mask.astype(np.int32)
if not (grid_mask.shape == lon.shape):
raise ValueError(
'mask must have the same shape as the latitude/longitude'
'coordinates, got: mask.shape = %s, lon.shape = %s' % (mask.shape, lon.shape)
)
grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False)
grid.mask[0][:] = grid_mask
grid._append_mask(mask)

return grid

def get_shape(self, loc=ESMF.StaggerLoc.CENTER):
"""Return shape of grid for specified StaggerLoc"""
return tuple(self.size[loc])

def _append_mask(self, mask):
"""Append mask to existing ESMF.Grid object."""
# Follows SCRIP convention where 1 is unmasked and 0 is masked.
# See https://github.com/NCPP/ocgis/blob/61d88c60e9070215f28c1317221c2e074f8fb145/src/ocgis/regrid/base.py#L391-L404

# remove fractional values
mask = np.where(mask == 0, 0, 1)
# convert array type to integer (ESMF compat)
grid_mask = mask.astype(np.int32)
if not (grid_mask.shape == self._coords[0][0].shape):
raise ValueError(
'mask must have the same shape as the latitude/longitude'
'coordinates, got: mask.shape = %s, lon.shape = %s'
% (mask.shape, self._coords[0][0].shape)
)
if self.mask[0] is None:
self.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False)
self.mask[0][:] = grid_mask


class LocStream(ESMF.LocStream):
@classmethod
Expand Down
51 changes: 51 additions & 0 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
add_nans_to_weights,
apply_weights,
check_shapes,
gen_mask_from_weights,
read_weights,
)
from .util import LAT_CF_ATTRS, LON_CF_ATTRS, split_polygons_and_holes
Expand Down Expand Up @@ -757,6 +758,7 @@ def __init__(
locstream_out=False,
periodic=False,
parallel=False,
nearest_s2d_domain_mask=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -850,6 +852,13 @@ def __init__(
If an output mask is defined, or regridding method is `nearest_s2d` or `nearest_d2s`,
this option has no effect.

nearest_s2d_domain_mask: boolean, optional
When remapping to a larger domain with the method `nearest_s2d`, data points outside the
original domain will - by the nature of this method - be extrapolated, i. e. have values
from the edge of the original domain. This option will apply a "domain mask" that is generated by creating
remapping weights with the `bilinear` method while having `unmapped_to_nan` activated.
This option is not supported for locstream formatted input or output.

Returns
-------
regridder : xESMF regridder object
Expand Down Expand Up @@ -902,6 +911,48 @@ def __init__(
else:
grid_out, shape_out, output_dims = ds_to_ESMFgrid(ds_out, need_bounds=need_bounds)

# nearest_s2d domain mask
if nearest_s2d_domain_mask:
if method == 'nearest_s2d':
if locstream_out or locstream_in:
raise ValueError(
'The option \'nearest_s2d_domain_mask\' is not supported for locstream input or output.'
)
if reuse_weights:
warnings.warn(
'The option \'nearest_s2d_domain_mask\' will have no effect when reusing weights. '
'Please make sure instead, that the weights to be read from disk have been created '
'with this option enabled.'
)
if parallel:
warnings.warn(
'The \'parallel\' setting will not affect the extra generation of bilinear '
'remapping weights that is required to generate the domain mask invoked by your '
'setting for \'nearest_s2d_domain_mask\'.'
)
# Create the BaseRegridder for the bilinear weights
DomainMaskRegridder = BaseRegridder(
grid_in=grid_in,
grid_out=grid_out,
method='bilinear',
input_dims=input_dims,
output_dims=output_dims,
parallel=False,
unmapped_to_nan=True,
ignore_degenerate=kwargs.get('ignore_degenerate', None),
)
# Generate the output mask out of these weights
mask_out = gen_mask_from_weights(
DomainMaskRegridder.weights, nlat=int(shape_out[0]), nlon=int(shape_out[1])
).T
# Append the mask to the output grid
grid_out._append_mask(mask_out)
else:
warnings.warn(
'The option \'nearest_s2d_domain_mask\' will only have an effect when using the '
'remapping method \'nearest_s2d\'.'
)

# Create the BaseRegridder
super().__init__(
grid_in,
Expand Down
29 changes: 29 additions & 0 deletions xesmf/smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,35 @@ def add_nans_to_weights(weights):
return weights


def gen_mask_from_weights(weights, nlat, nlon):
"""Generate a 2D mask from the regridding weights sparse matrix.

This function will generate a 2D binary mask out of a regridding weights sparse matrix.

Parameters
----------
weights : DataArray backed by a sparse.COO array
Sparse weights matrix.

Returns
-------
numpy.ndarray of type numpy.int32 and of shape (nlat, nlon)
Binary mask.
"""
# Taken from @trondkr and adapted by @raphaeldussin to use `lil`.
# lil matrix is better than CSR when changing sparsity
m = weights.data.to_scipy_sparse().tolil()

# Create mask ndarray of ones and fill with 0-elements
mask = np.ones((nlat, nlon), dtype=np.int32).ravel()
for krow in range(len(m.rows)):
if any([np.isnan(x) for x in m.data[krow]]):
mask[krow] = 0

# Reshape and return
return mask.reshape((nlat, nlon))


def _combine_weight_multipoly(weights, areas, indexes):
"""Reduce a weight sparse matrix (csc format) by combining (adding) columns.

Expand Down
26 changes: 26 additions & 0 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,32 @@ def test_build_regridder_with_masks():
assert method in str(regridder)


def test_nearest_s2d_domain_mask():
# Create input and output grid
ds_in = xe.util.grid_2d(50, 60, 1, 40, 50, 1)
ds_out = xe.util.grid_2d(40, 70, 1, 30, 50, 1)
# Create input data
ds_in['data'] = xr.DataArray(data=np.ones((10, 10), dtype=np.float32), dims=['lat', 'lon'])

# Create remapping weights
regridder = xe.Regridder(ds_in, ds_out, method='nearest_s2d', nearest_s2d_domain_mask=True)
ds_out['data'] = regridder(ds_in['data'])

# Create expected output mask
mask = np.ones((20, 30), dtype=np.int32)
mask[:, 0:10] = 0
mask[:, 20:30] = 0
mask[0:10, :] = 0
# Generate mask from weights
maskwgts = xe.smm.gen_mask_from_weights(regridder.weights, 20, 30)
# Generate mask from remapped data
maskdata = np.where(np.isnan(ds_out['data'].data), 0, 1)

# Assert equality of the masks
assert np.array_equal(maskdata, maskwgts, equal_nan=False)
assert np.array_equal(maskdata, mask, equal_nan=False)


def test_regrid_dataset_from_locstream():
# xarray.Dataset containing in-memory numpy array

Expand Down
20 changes: 20 additions & 0 deletions xesmf/tests/test_smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,23 @@ def test_add_nans_to_weights():

Matout = xe.smm.add_nans_to_weights(xr.DataArray(Matin, dims=('in', 'out')))
assert np.allclose(Matin.todense(), Matout.data.todense())


def test_gen_mask_from_weights():
"""testing creating mask out of weight matrix Nans"""
# Create input and output Dataset
ds_in = xe.util.grid_2d(20, 40, 1, 20, 30, 1)
ds_out = xe.util.grid_2d(20, 40, 2, 20, 30, 2)

# Create random mask for ds_out
mask = np.random.randint(low=0, high=2, size=(5, 10), dtype=np.int32)
ds_out['mask'] = xr.DataArray(data=mask, dims=['lat', 'lon'])

# Create remapping weights
Weights = xe.Regridder(ds_in, ds_out, method='bilinear').weights

# Generate mask from weights
maskwgts = xe.smm.gen_mask_from_weights(Weights, 5, 10)

# Assert equality between both masks
assert np.array_equal(mask, maskwgts, equal_nan=False)
Loading