Skip to content

Commit a3a2c66

Browse files
committed
More geometry functionality
1. Associate geometry vars as coordinates when we can 2. Add `cf.geometries` 3. Geometries in repr 4. Allow indexing by `"geometry"` or any geometry type.
1 parent aa41dce commit a3a2c66

9 files changed

+282
-95
lines changed

cf_xarray/accessor.py

+106-28
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
from . import sgrid
2929
from .criteria import (
3030
_DSG_ROLES,
31+
_GEOMETRY_TYPES,
3132
cf_role_criteria,
3233
coordinate_criteria,
34+
geometry_var_criteria,
3335
grid_mapping_var_criteria,
3436
regex,
3537
)
@@ -39,6 +41,7 @@
3941
_format_data_vars,
4042
_format_dsg_roles,
4143
_format_flags,
44+
_format_geometries,
4245
_format_sgrid,
4346
_maybe_panel,
4447
)
@@ -227,18 +230,16 @@ def _get_custom_criteria(
227230
except ImportError:
228231
from re import match as regex_match # type: ignore[no-redef]
229232

230-
if isinstance(obj, DataArray):
231-
obj = obj._to_temp_dataset()
232-
variables = obj._variables
233-
234233
if criteria is None:
235234
if not OPTIONS["custom_criteria"]:
236235
return []
237236
criteria = OPTIONS["custom_criteria"]
238237

239-
if criteria is not None:
240-
criteria_iter = always_iterable(criteria, allowed=(tuple, list, set))
238+
if isinstance(obj, DataArray):
239+
obj = obj._to_temp_dataset()
240+
variables = obj._variables
241241

242+
criteria_iter = always_iterable(criteria, allowed=(tuple, list, set))
242243
criteria_map = ChainMap(*criteria_iter)
243244
results: set = set()
244245
if key in criteria_map:
@@ -367,6 +368,21 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
367368
return list(results)
368369

369370

371+
def _parse_related_geometry_vars(attrs: Mapping) -> tuple[Hashable]:
372+
names = itertools.chain(
373+
*[
374+
attrs.get(attr, "").split(" ")
375+
for attr in [
376+
"interior_ring",
377+
"node_coordinates",
378+
"node_count",
379+
"part_node_count",
380+
]
381+
]
382+
)
383+
return tuple(n for n in names if n)
384+
385+
370386
def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
371387
"""
372388
Translate from key (either CF key or variable name) to its bounds' variable names.
@@ -470,8 +486,12 @@ def _get_all(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
470486
"""
471487
all_mappers: tuple[Mapper] = (
472488
_get_custom_criteria,
473-
functools.partial(_get_custom_criteria, criteria=cf_role_criteria), # type: ignore[assignment]
474-
functools.partial(_get_custom_criteria, criteria=grid_mapping_var_criteria),
489+
functools.partial(
490+
_get_custom_criteria,
491+
criteria=ChainMap(
492+
cf_role_criteria, grid_mapping_var_criteria, geometry_var_criteria
493+
),
494+
),
475495
_get_axis_coord,
476496
_get_measure,
477497
_get_grid_mapping_name,
@@ -821,6 +841,23 @@ def check_results(names, key):
821841
successful[k] = bool(grid_mapping)
822842
if grid_mapping:
823843
varnames.extend(grid_mapping)
844+
elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES):
845+
geometries = _get_all(obj, k)
846+
if geometries and k in _GEOMETRY_TYPES:
847+
new = itertools.chain(
848+
_parse_related_geometry_vars(
849+
ChainMap(obj[g].attrs, obj[g].encoding)
850+
)
851+
for g in geometries
852+
)
853+
geometries.extend(*new)
854+
if len(geometries) > 1 and scalar_key:
855+
raise ValueError(
856+
f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead."
857+
)
858+
successful[k] = bool(geometries)
859+
if geometries:
860+
varnames.extend(geometries)
824861
elif k in custom_criteria or k in cf_role_criteria:
825862
names = _get_all(obj, k)
826863
check_results(names, k)
@@ -1559,8 +1596,7 @@ def _generate_repr(self, rich=False):
15591596
_format_flags(self, rich), title="Flag Variable", rich=rich
15601597
)
15611598

1562-
roles = self.cf_roles
1563-
if roles:
1599+
if roles := self.cf_roles:
15641600
if any(role in roles for role in _DSG_ROLES):
15651601
yield _maybe_panel(
15661602
_format_dsg_roles(self, dims, rich),
@@ -1576,6 +1612,13 @@ def _generate_repr(self, rich=False):
15761612
rich=rich,
15771613
)
15781614

1615+
if self.geometries:
1616+
yield _maybe_panel(
1617+
_format_geometries(self, dims, rich),
1618+
title="Geometries",
1619+
rich=rich,
1620+
)
1621+
15791622
yield _maybe_panel(
15801623
_format_coordinates(self, dims, coords, rich),
15811624
title="Coordinates",
@@ -1755,12 +1798,42 @@ def cf_roles(self) -> dict[str, list[Hashable]]:
17551798

17561799
vardict: dict[str, list[Hashable]] = {}
17571800
for k, v in variables.items():
1758-
if "cf_role" in v.attrs:
1759-
role = v.attrs["cf_role"]
1801+
attrs_or_encoding = ChainMap(v.attrs, v.encoding)
1802+
if role := attrs_or_encoding.get("cf_role", None):
17601803
vardict[role] = vardict.setdefault(role, []) + [k]
17611804

17621805
return {role_: sort_maybe_hashable(v) for role_, v in vardict.items()}
17631806

1807+
@property
1808+
def geometries(self) -> dict[str, list[Hashable]]:
1809+
"""
1810+
Mapping geometry type names to variable names.
1811+
1812+
Returns
1813+
-------
1814+
dict
1815+
Dictionary mapping geometry names to variable names.
1816+
1817+
References
1818+
----------
1819+
Please refer to the CF conventions document : http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#coordinates-metadata
1820+
"""
1821+
vardict: dict[str, list[Hashable]] = {}
1822+
1823+
if isinstance(self._obj, Dataset):
1824+
variables = self._obj._variables
1825+
elif isinstance(self._obj, DataArray):
1826+
variables = {"_": self._obj._variable}
1827+
1828+
for v in variables.values():
1829+
attrs_or_encoding = ChainMap(v.attrs, v.encoding)
1830+
if geometry := attrs_or_encoding.get("geometry", None):
1831+
gtype = self._obj[geometry].attrs["geometry_type"]
1832+
vardict.setdefault(gtype, [])
1833+
if geometry not in vardict[gtype]:
1834+
vardict[gtype] += [geometry]
1835+
return {type_: sort_maybe_hashable(v) for type_, v in vardict.items()}
1836+
17641837
def get_associated_variable_names(
17651838
self, name: Hashable, skip_bounds: bool = False, error: bool = True
17661839
) -> dict[str, list[Hashable]]:
@@ -1795,15 +1868,15 @@ def get_associated_variable_names(
17951868
"bounds",
17961869
"grid_mapping",
17971870
"grid",
1871+
"geometry",
17981872
]
17991873

18001874
coords: dict[str, list[Hashable]] = {k: [] for k in keys}
18011875
attrs_or_encoding = ChainMap(self._obj[name].attrs, self._obj[name].encoding)
18021876

1803-
coordinates = attrs_or_encoding.get("coordinates", None)
18041877
# Handles case where the coordinates attribute is None
18051878
# This is used to tell xarray to not write a coordinates attribute
1806-
if coordinates:
1879+
if coordinates := attrs_or_encoding.get("coordinates", None):
18071880
coords["coordinates"] = coordinates.split(" ")
18081881

18091882
if "cell_measures" in attrs_or_encoding:
@@ -1822,27 +1895,32 @@ def get_associated_variable_names(
18221895
)
18231896
coords["cell_measures"] = []
18241897

1825-
if (
1826-
isinstance(self._obj, Dataset)
1827-
and "ancillary_variables" in attrs_or_encoding
1898+
if isinstance(self._obj, Dataset) and (
1899+
anc := attrs_or_encoding.get("ancillary_variables", None)
18281900
):
1829-
coords["ancillary_variables"] = attrs_or_encoding[
1830-
"ancillary_variables"
1831-
].split(" ")
1901+
coords["ancillary_variables"] = anc.split(" ")
18321902

18331903
if not skip_bounds:
1834-
if "bounds" in attrs_or_encoding:
1835-
coords["bounds"] = [attrs_or_encoding["bounds"]]
1904+
if bounds := attrs_or_encoding.get("bounds", None):
1905+
coords["bounds"] = [bounds]
18361906
for dim in self._obj[name].dims:
1837-
dbounds = self._obj[dim].attrs.get("bounds", None)
1838-
if dbounds:
1907+
if dbounds := self._obj[dim].attrs.get("bounds", None):
18391908
coords["bounds"].append(dbounds)
18401909

1841-
if "grid" in attrs_or_encoding:
1842-
coords["grid"] = [attrs_or_encoding["grid"]]
1910+
for attrname in ["grid", "grid_mapping"]:
1911+
if maybe := attrs_or_encoding.get(attrname, None):
1912+
coords[attrname] = [maybe]
18431913

1844-
if "grid_mapping" in attrs_or_encoding:
1845-
coords["grid_mapping"] = [attrs_or_encoding["grid_mapping"]]
1914+
more = []
1915+
if geometry_var := attrs_or_encoding.get("geometry", None):
1916+
coords["geometry"] = [geometry_var]
1917+
_attrs = ChainMap(
1918+
self._obj[geometry_var].attrs, self._obj[geometry_var].encoding
1919+
)
1920+
more = _parse_related_geometry_vars(_attrs)
1921+
elif "geometry_type" in attrs_or_encoding:
1922+
more = _parse_related_geometry_vars(attrs_or_encoding)
1923+
coords["geometry"].extend(more)
18461924

18471925
allvars = itertools.chain(*coords.values())
18481926
missing = set(allvars) - set(self._maybe_to_dataset()._variables)

cf_xarray/criteria.py

+10
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from collections.abc import Mapping, MutableMapping
1313
from typing import Any
1414

15+
#: CF Roles understood by cf-xarray
1516
_DSG_ROLES = ["timeseries_id", "profile_id", "trajectory_id"]
17+
#: Geometry types understood by cf-xarray
18+
_GEOMETRY_TYPES = ("line", "point", "polygon")
1619

1720
cf_role_criteria: Mapping[str, Mapping[str, str]] = {
1821
k: {"cf_role": k}
@@ -31,6 +34,13 @@
3134
"grid_mapping": {"grid_mapping_name": re.compile(".")}
3235
}
3336

37+
# A geometry container is anything with a geometry_type attribute
38+
geometry_var_criteria: Mapping[str, Mapping[str, Any]] = {
39+
"geometry": {"geometry_type": re.compile(".")},
40+
}
41+
# And we allow indexing by geometry_type
42+
geometry_var_criteria.update({k: {"geometry_type": k} for k in _GEOMETRY_TYPES})
43+
3444
coordinate_criteria: MutableMapping[str, MutableMapping[str, tuple]] = {
3545
"latitude": {
3646
"standard_name": ("latitude",),

cf_xarray/datasets.py

+29
Original file line numberDiff line numberDiff line change
@@ -748,3 +748,32 @@ def _create_inexact_bounds():
748748
node_coordinates="node_lon node_lat node_elevation",
749749
),
750750
)
751+
752+
753+
def point_dataset():
754+
from shapely.geometry import MultiPoint, Point
755+
756+
da = xr.DataArray(
757+
[
758+
MultiPoint([(1.0, 2.0), (2.0, 3.0)]),
759+
Point(3.0, 4.0),
760+
Point(4.0, 5.0),
761+
Point(3.0, 4.0),
762+
],
763+
dims=("index",),
764+
name="geometry",
765+
)
766+
ds = da.to_dataset()
767+
return ds
768+
769+
770+
def encoded_point_dataset():
771+
from .geometry import encode_geometries
772+
773+
ds = encode_geometries(point_dataset())
774+
ds["data"] = (
775+
"index",
776+
np.arange(ds.sizes["index"]),
777+
{"geometry": "geometry_container"},
778+
)
779+
return ds

cf_xarray/formatting.py

+11
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,17 @@ def _format_dsg_roles(accessor, dims, rich):
295295
)
296296

297297

298+
def _format_geometries(accessor, dims, rich):
299+
yield make_text_section(
300+
accessor,
301+
"CF Geometries",
302+
"geometries",
303+
dims=dims,
304+
# valid_keys=_DSG_ROLES,
305+
rich=rich,
306+
)
307+
308+
298309
def _format_coordinates(accessor, dims, coords, rich):
299310
from .accessor import _AXIS_NAMES, _CELL_MEASURES, _COORD_NAMES
300311

cf_xarray/tests/conftest.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
import pytest
3+
import xarray as xr
4+
5+
6+
@pytest.fixture(scope="session")
7+
def geometry_ds():
8+
pytest.importorskip("shapely")
9+
10+
from shapely.geometry import MultiPoint, Point
11+
12+
# empty/fill workaround to avoid numpy deprecation(warning) due to the array interface of shapely geometries.
13+
geoms = np.empty(4, dtype=object)
14+
geoms[:] = [
15+
MultiPoint([(1.0, 2.0), (2.0, 3.0)]),
16+
Point(3.0, 4.0),
17+
Point(4.0, 5.0),
18+
Point(3.0, 4.0),
19+
]
20+
21+
ds = xr.Dataset(
22+
{
23+
"data": xr.DataArray(
24+
range(len(geoms)),
25+
dims=("index",),
26+
attrs={
27+
"geometry": "geometry_container",
28+
"coordinates": "crd_x crd_y",
29+
},
30+
),
31+
"time": xr.DataArray([0, 0, 0, 1], dims=("index",)),
32+
}
33+
)
34+
shp_ds = ds.assign(geometry=xr.DataArray(geoms, dims=("index",)))
35+
36+
cf_ds = ds.assign(
37+
x=xr.DataArray([1.0, 2.0, 3.0, 4.0, 3.0], dims=("node",), attrs={"axis": "X"}),
38+
y=xr.DataArray([2.0, 3.0, 4.0, 5.0, 4.0], dims=("node",), attrs={"axis": "Y"}),
39+
node_count=xr.DataArray([2, 1, 1, 1], dims=("index",)),
40+
crd_x=xr.DataArray([1.0, 3.0, 4.0, 3.0], dims=("index",), attrs={"nodes": "x"}),
41+
crd_y=xr.DataArray([2.0, 4.0, 5.0, 4.0], dims=("index",), attrs={"nodes": "y"}),
42+
geometry_container=xr.DataArray(
43+
attrs={
44+
"geometry_type": "point",
45+
"node_count": "node_count",
46+
"node_coordinates": "x y",
47+
"coordinates": "crd_x crd_y",
48+
}
49+
),
50+
)
51+
52+
cf_ds = cf_ds.set_coords(["x", "y", "crd_x", "crd_y"])
53+
54+
return cf_ds, shp_ds

cf_xarray/tests/test_accessor.py

+22
Original file line numberDiff line numberDiff line change
@@ -2076,3 +2076,25 @@ def test_ancillary_variables_extra_dim():
20762076
}
20772077
)
20782078
assert_identical(ds.cf["X"], ds["x"])
2079+
2080+
2081+
def test_geometry_association(geometry_ds):
2082+
cf_ds, _ = geometry_ds
2083+
actual = cf_ds.cf[["data"]]
2084+
for name in ["geometry_container", "x", "y", "node_count", "crd_x", "crd_y"]:
2085+
assert name in actual.coords
2086+
2087+
actual = cf_ds.cf["data"]
2088+
for name in ["geometry_container", "node_count", "crd_x", "crd_y"]:
2089+
assert name in actual.coords
2090+
2091+
assert cf_ds.cf.geometries == {"point": ["geometry_container"]}
2092+
assert_identical(cf_ds.cf["geometry"], cf_ds["geometry_container"])
2093+
with pytest.raises(ValueError):
2094+
cf_ds.cf["point"]
2095+
2096+
expected = cf_ds[["geometry_container", "node_count", "x", "y", "crd_x", "crd_y"]]
2097+
assert_identical(
2098+
cf_ds.cf[["point"]],
2099+
expected.set_coords(["node_count", "x", "y", "crd_x", "crd_y"]),
2100+
)

0 commit comments

Comments
 (0)