Skip to content

Commit c511cc5

Browse files
authored
Add geometry encoding and decoding functions. (#517)
* Add geometry encoding and decoding functions. These differ from `shapely_to_cf` and `cf_to_shapely` by returning all variables. Those function, simply encode and decode geometry-related variables. * More geometry functionality 1. Associate geometry vars as coordinates when we can 2. Add `cf.geometries` 3. Geometries in repr 4. Allow indexing by `"geometry"` or any geometry type.
1 parent 0db5233 commit c511cc5

12 files changed

+550
-108
lines changed

cf_xarray/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@
99
from .options import set_options # noqa
1010
from .utils import _get_version
1111

12+
from . import geometry # noqa
13+
1214
__version__ = _get_version()

cf_xarray/accessor.py

+111-29
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
from . import sgrid
2929
from .criteria import (
3030
_DSG_ROLES,
31+
_GEOMETRY_TYPES,
3132
cf_role_criteria,
3233
coordinate_criteria,
34+
geometry_var_criteria,
3335
grid_mapping_var_criteria,
3436
regex,
3537
)
@@ -39,6 +41,7 @@
3941
_format_data_vars,
4042
_format_dsg_roles,
4143
_format_flags,
44+
_format_geometries,
4245
_format_sgrid,
4346
_maybe_panel,
4447
)
@@ -198,7 +201,9 @@ def _get_groupby_time_accessor(
198201

199202

200203
def _get_custom_criteria(
201-
obj: DataArray | Dataset, key: Hashable, criteria: Mapping | None = None
204+
obj: DataArray | Dataset,
205+
key: Hashable,
206+
criteria: Iterable[Mapping] | Mapping | None = None,
202207
) -> list[Hashable]:
203208
"""
204209
Translate from axis, coord, or custom name to variable name.
@@ -227,18 +232,16 @@ def _get_custom_criteria(
227232
except ImportError:
228233
from re import match as regex_match # type: ignore[no-redef]
229234

230-
if isinstance(obj, DataArray):
231-
obj = obj._to_temp_dataset()
232-
variables = obj._variables
233-
234235
if criteria is None:
235236
if not OPTIONS["custom_criteria"]:
236237
return []
237238
criteria = OPTIONS["custom_criteria"]
238239

239-
if criteria is not None:
240-
criteria_iter = always_iterable(criteria, allowed=(tuple, list, set))
240+
if isinstance(obj, DataArray):
241+
obj = obj._to_temp_dataset()
242+
variables = obj._variables
241243

244+
criteria_iter = always_iterable(criteria, allowed=(tuple, list, set))
242245
criteria_map = ChainMap(*criteria_iter)
243246
results: set = set()
244247
if key in criteria_map:
@@ -367,6 +370,21 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
367370
return list(results)
368371

369372

373+
def _parse_related_geometry_vars(attrs: Mapping) -> tuple[Hashable]:
374+
names = itertools.chain(
375+
*[
376+
attrs.get(attr, "").split(" ")
377+
for attr in [
378+
"interior_ring",
379+
"node_coordinates",
380+
"node_count",
381+
"part_node_count",
382+
]
383+
]
384+
)
385+
return tuple(n for n in names if n)
386+
387+
370388
def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
371389
"""
372390
Translate from key (either CF key or variable name) to its bounds' variable names.
@@ -470,8 +488,14 @@ def _get_all(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
470488
"""
471489
all_mappers: tuple[Mapper] = (
472490
_get_custom_criteria,
473-
functools.partial(_get_custom_criteria, criteria=cf_role_criteria), # type: ignore[assignment]
474-
functools.partial(_get_custom_criteria, criteria=grid_mapping_var_criteria),
491+
functools.partial(
492+
_get_custom_criteria,
493+
criteria=(
494+
cf_role_criteria,
495+
grid_mapping_var_criteria,
496+
geometry_var_criteria,
497+
),
498+
), # type: ignore[assignment]
475499
_get_axis_coord,
476500
_get_measure,
477501
_get_grid_mapping_name,
@@ -821,6 +845,23 @@ def check_results(names, key):
821845
successful[k] = bool(grid_mapping)
822846
if grid_mapping:
823847
varnames.extend(grid_mapping)
848+
elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES):
849+
geometries = _get_all(obj, k)
850+
if geometries and k in _GEOMETRY_TYPES:
851+
new = itertools.chain(
852+
_parse_related_geometry_vars(
853+
ChainMap(obj[g].attrs, obj[g].encoding)
854+
)
855+
for g in geometries
856+
)
857+
geometries.extend(*new)
858+
if len(geometries) > 1 and scalar_key:
859+
raise ValueError(
860+
f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead."
861+
)
862+
successful[k] = bool(geometries)
863+
if geometries:
864+
varnames.extend(geometries)
824865
elif k in custom_criteria or k in cf_role_criteria:
825866
names = _get_all(obj, k)
826867
check_results(names, k)
@@ -1559,8 +1600,7 @@ def _generate_repr(self, rich=False):
15591600
_format_flags(self, rich), title="Flag Variable", rich=rich
15601601
)
15611602

1562-
roles = self.cf_roles
1563-
if roles:
1603+
if roles := self.cf_roles:
15641604
if any(role in roles for role in _DSG_ROLES):
15651605
yield _maybe_panel(
15661606
_format_dsg_roles(self, dims, rich),
@@ -1576,6 +1616,13 @@ def _generate_repr(self, rich=False):
15761616
rich=rich,
15771617
)
15781618

1619+
if self.geometries:
1620+
yield _maybe_panel(
1621+
_format_geometries(self, dims, rich),
1622+
title="Geometries",
1623+
rich=rich,
1624+
)
1625+
15791626
yield _maybe_panel(
15801627
_format_coordinates(self, dims, coords, rich),
15811628
title="Coordinates",
@@ -1755,12 +1802,42 @@ def cf_roles(self) -> dict[str, list[Hashable]]:
17551802

17561803
vardict: dict[str, list[Hashable]] = {}
17571804
for k, v in variables.items():
1758-
if "cf_role" in v.attrs:
1759-
role = v.attrs["cf_role"]
1805+
attrs_or_encoding = ChainMap(v.attrs, v.encoding)
1806+
if role := attrs_or_encoding.get("cf_role", None):
17601807
vardict[role] = vardict.setdefault(role, []) + [k]
17611808

17621809
return {role_: sort_maybe_hashable(v) for role_, v in vardict.items()}
17631810

1811+
@property
1812+
def geometries(self) -> dict[str, list[Hashable]]:
1813+
"""
1814+
Mapping geometry type names to variable names.
1815+
1816+
Returns
1817+
-------
1818+
dict
1819+
Dictionary mapping geometry names to variable names.
1820+
1821+
References
1822+
----------
1823+
Please refer to the CF conventions document : http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#coordinates-metadata
1824+
"""
1825+
vardict: dict[str, list[Hashable]] = {}
1826+
1827+
if isinstance(self._obj, Dataset):
1828+
variables = self._obj._variables
1829+
elif isinstance(self._obj, DataArray):
1830+
variables = {"_": self._obj._variable}
1831+
1832+
for v in variables.values():
1833+
attrs_or_encoding = ChainMap(v.attrs, v.encoding)
1834+
if geometry := attrs_or_encoding.get("geometry", None):
1835+
gtype = self._obj[geometry].attrs["geometry_type"]
1836+
vardict.setdefault(gtype, [])
1837+
if geometry not in vardict[gtype]:
1838+
vardict[gtype] += [geometry]
1839+
return {type_: sort_maybe_hashable(v) for type_, v in vardict.items()}
1840+
17641841
def get_associated_variable_names(
17651842
self, name: Hashable, skip_bounds: bool = False, error: bool = True
17661843
) -> dict[str, list[Hashable]]:
@@ -1795,15 +1872,15 @@ def get_associated_variable_names(
17951872
"bounds",
17961873
"grid_mapping",
17971874
"grid",
1875+
"geometry",
17981876
]
17991877

18001878
coords: dict[str, list[Hashable]] = {k: [] for k in keys}
18011879
attrs_or_encoding = ChainMap(self._obj[name].attrs, self._obj[name].encoding)
18021880

1803-
coordinates = attrs_or_encoding.get("coordinates", None)
18041881
# Handles case where the coordinates attribute is None
18051882
# This is used to tell xarray to not write a coordinates attribute
1806-
if coordinates:
1883+
if coordinates := attrs_or_encoding.get("coordinates", None):
18071884
coords["coordinates"] = coordinates.split(" ")
18081885

18091886
if "cell_measures" in attrs_or_encoding:
@@ -1822,27 +1899,32 @@ def get_associated_variable_names(
18221899
)
18231900
coords["cell_measures"] = []
18241901

1825-
if (
1826-
isinstance(self._obj, Dataset)
1827-
and "ancillary_variables" in attrs_or_encoding
1902+
if isinstance(self._obj, Dataset) and (
1903+
anc := attrs_or_encoding.get("ancillary_variables", None)
18281904
):
1829-
coords["ancillary_variables"] = attrs_or_encoding[
1830-
"ancillary_variables"
1831-
].split(" ")
1905+
coords["ancillary_variables"] = anc.split(" ")
18321906

18331907
if not skip_bounds:
1834-
if "bounds" in attrs_or_encoding:
1835-
coords["bounds"] = [attrs_or_encoding["bounds"]]
1908+
if bounds := attrs_or_encoding.get("bounds", None):
1909+
coords["bounds"] = [bounds]
18361910
for dim in self._obj[name].dims:
1837-
dbounds = self._obj[dim].attrs.get("bounds", None)
1838-
if dbounds:
1911+
if dbounds := self._obj[dim].attrs.get("bounds", None):
18391912
coords["bounds"].append(dbounds)
18401913

1841-
if "grid" in attrs_or_encoding:
1842-
coords["grid"] = [attrs_or_encoding["grid"]]
1914+
for attrname in ["grid", "grid_mapping"]:
1915+
if maybe := attrs_or_encoding.get(attrname, None):
1916+
coords[attrname] = [maybe]
18431917

1844-
if "grid_mapping" in attrs_or_encoding:
1845-
coords["grid_mapping"] = [attrs_or_encoding["grid_mapping"]]
1918+
more: Sequence[Hashable] = ()
1919+
if geometry_var := attrs_or_encoding.get("geometry", None):
1920+
coords["geometry"] = [geometry_var]
1921+
_attrs = ChainMap(
1922+
self._obj[geometry_var].attrs, self._obj[geometry_var].encoding
1923+
)
1924+
more = _parse_related_geometry_vars(_attrs)
1925+
elif "geometry_type" in attrs_or_encoding:
1926+
more = _parse_related_geometry_vars(attrs_or_encoding)
1927+
coords["geometry"].extend(more)
18461928

18471929
allvars = itertools.chain(*coords.values())
18481930
missing = set(allvars) - set(self._maybe_to_dataset()._variables)

cf_xarray/criteria.py

+9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from collections.abc import Mapping, MutableMapping
1313
from typing import Any
1414

15+
#: CF Roles understood by cf-xarray
1516
_DSG_ROLES = ["timeseries_id", "profile_id", "trajectory_id"]
17+
#: Geometry types understood by cf-xarray
18+
_GEOMETRY_TYPES = ("line", "point", "polygon")
1619

1720
cf_role_criteria: Mapping[str, Mapping[str, str]] = {
1821
k: {"cf_role": k}
@@ -31,6 +34,12 @@
3134
"grid_mapping": {"grid_mapping_name": re.compile(".")}
3235
}
3336

37+
# A geometry container is anything with a geometry_type attribute
38+
geometry_var_criteria: Mapping[str, Mapping[str, Any]] = {
39+
"geometry": {"geometry_type": re.compile(".")},
40+
**{k: {"geometry_type": k} for k in _GEOMETRY_TYPES},
41+
}
42+
3443
coordinate_criteria: MutableMapping[str, MutableMapping[str, tuple]] = {
3544
"latitude": {
3645
"standard_name": ("latitude",),

cf_xarray/datasets.py

+29
Original file line numberDiff line numberDiff line change
@@ -748,3 +748,32 @@ def _create_inexact_bounds():
748748
node_coordinates="node_lon node_lat node_elevation",
749749
),
750750
)
751+
752+
753+
def point_dataset():
754+
from shapely.geometry import MultiPoint, Point
755+
756+
da = xr.DataArray(
757+
[
758+
MultiPoint([(1.0, 2.0), (2.0, 3.0)]),
759+
Point(3.0, 4.0),
760+
Point(4.0, 5.0),
761+
Point(3.0, 4.0),
762+
],
763+
dims=("index",),
764+
name="geometry",
765+
)
766+
ds = da.to_dataset()
767+
return ds
768+
769+
770+
def encoded_point_dataset():
771+
from .geometry import encode_geometries
772+
773+
ds = encode_geometries(point_dataset())
774+
ds["data"] = (
775+
"index",
776+
np.arange(ds.sizes["index"]),
777+
{"geometry": "geometry_container"},
778+
)
779+
return ds

cf_xarray/formatting.py

+11
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,17 @@ def _format_dsg_roles(accessor, dims, rich):
295295
)
296296

297297

298+
def _format_geometries(accessor, dims, rich):
299+
yield make_text_section(
300+
accessor,
301+
"CF Geometries",
302+
"geometries",
303+
dims=dims,
304+
# valid_keys=_DSG_ROLES,
305+
rich=rich,
306+
)
307+
308+
298309
def _format_coordinates(accessor, dims, coords, rich):
299310
from .accessor import _AXIS_NAMES, _CELL_MEASURES, _COORD_NAMES
300311

0 commit comments

Comments
 (0)