Skip to content

Commit f21cecd

Browse files
committed
Add geometry encoding and decoding functions.
These differ from `shapely_to_cf` and `cf_to_shapely` by returning all variables. Those function, simply encode and decode geometry-related variables.
1 parent 52160a5 commit f21cecd

File tree

2 files changed

+220
-22
lines changed

2 files changed

+220
-22
lines changed

cf_xarray/geometry.py

+188-11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,181 @@
66
import pandas as pd
77
import xarray as xr
88

9+
GEOMETRY_CONTAINER_NAME = "geometry_container"
10+
11+
12+
def decode_geometries(encoded: xr.Dataset) -> xr.Dataset:
13+
"""
14+
Decodes CF encoded geometries to a numpy object array
15+
containing shapely geometries.
16+
17+
Parameters
18+
----------
19+
encoded: Dataset
20+
A Xarray Dataset containing encoded geometries.
21+
22+
Returns
23+
-------
24+
Dataset
25+
A Xarray Dataset containing decoded geometries.
26+
27+
See Also
28+
--------
29+
shapely_to_cf
30+
cf_to_shapely
31+
encode_geometries
32+
"""
33+
if GEOMETRY_CONTAINER_NAME not in encoded._variables:
34+
raise NotImplementedError(
35+
f"Currently only a single geometry variable named {GEOMETRY_CONTAINER_NAME!r} is supported."
36+
"A variable by this name is not present in the provided dataset."
37+
)
38+
39+
enc_geom_var = encoded[GEOMETRY_CONTAINER_NAME]
40+
geom_attrs = enc_geom_var.attrs
41+
# Grab the coordinates attribute
42+
geom_attrs.update(enc_geom_var.encoding)
43+
44+
geom_var = cf_to_shapely(encoded).variable
45+
46+
todrop = (GEOMETRY_CONTAINER_NAME,) + tuple(
47+
s
48+
for s in " ".join(
49+
geom_attrs.get(attr, "")
50+
for attr in [
51+
"interior_ring",
52+
"node_coordinates",
53+
"node_count",
54+
"part_node_count",
55+
"coordinates",
56+
]
57+
).split(" ")
58+
if s
59+
)
60+
decoded = encoded.drop_vars(todrop)
61+
62+
name = geom_attrs.get("variable_name", None)
63+
if name in decoded.dims:
64+
decoded = decoded.assign_coords(
65+
xr.Coordinates(coords={name: geom_var}, indexes={})
66+
)
67+
else:
68+
decoded[name] = geom_var
69+
70+
# Is this a good idea? We are deleting information.
71+
for var in decoded._variables.values():
72+
if var.attrs.get("geometry") == GEOMETRY_CONTAINER_NAME:
73+
var.attrs.pop("geometry")
74+
return decoded
75+
76+
77+
import copy
78+
79+
GEOMETRY_CONTAINER_NAME = "geometry_container"
80+
81+
82+
def encode_geometries(ds: xr.Dataset):
83+
"""
84+
Encodes any discovered geometry variables using the CF conventions.
85+
86+
Practically speaking, geometry variables are numpy object arrays where the first
87+
element is a shapely geometry.
88+
89+
.. warning::
90+
91+
Only a single geometry variable is supported at present. Contributions to fix this
92+
are welcome.
93+
94+
Parameters
95+
----------
96+
ds: Dataset
97+
Dataset containing at least one geometry variable.
98+
99+
Returns
100+
-------
101+
Dataset
102+
Where all geometry variables are encoded.
103+
104+
See Also
105+
--------
106+
shapely_to_cf
107+
cf_to_shapely
108+
"""
109+
from shapely import (
110+
LineString,
111+
MultiLineString,
112+
MultiPoint,
113+
MultiPolygon,
114+
Point,
115+
Polygon,
116+
)
117+
118+
SHAPELY_TYPES = (
119+
Point,
120+
LineString,
121+
Polygon,
122+
MultiPoint,
123+
MultiLineString,
124+
MultiPolygon,
125+
)
126+
127+
geom_var_names = [
128+
name
129+
for name, var in ds._variables.items()
130+
if var.dtype == "O" and isinstance(var.data.flat[0], SHAPELY_TYPES)
131+
]
132+
if not geom_var_names:
133+
return ds
134+
135+
if to_drop := set(geom_var_names) & set(ds._indexes):
136+
# e.g. xvec GeometryIndex
137+
ds = ds.drop_indexes(to_drop)
138+
139+
if len(geom_var_names) > 1:
140+
raise NotImplementedError(
141+
"Multiple geometry variables are not supported at this time. "
142+
"Contributions to fix this are welcome. "
143+
f"Detected geometry variables are {geom_var_names!r}"
144+
)
145+
146+
(name,) = geom_var_names
147+
variables = {}
148+
# If `name` is a dimension name, then we need to drop it. Otherwise we don't
149+
# So set errors="ignore"
150+
variables.update(
151+
shapely_to_cf(ds[name]).drop_vars(name, errors="ignore")._variables
152+
)
153+
154+
geom_var = ds[name]
155+
156+
more_updates = {}
157+
for varname, var in ds._variables.items():
158+
if varname == name:
159+
continue
160+
if name in var.dims:
161+
var = var.copy()
162+
var._attrs = copy.deepcopy(var._attrs)
163+
var.attrs["geometry"] = GEOMETRY_CONTAINER_NAME
164+
# The grid_mapping and coordinates attributes can be carried by the geometry container
165+
# variable provided they are also carried by the data variables associated with the container.
166+
if to_add := geom_var.attrs.get("coordinates", ""):
167+
var.attrs["coordinates"] = var.attrs.get("coordinates", "") + to_add
168+
more_updates[varname] = var
169+
variables.update(more_updates)
170+
171+
# WARNING: cf-xarray specific convention.
172+
# For vector data cubes, `name` is a dimension name.
173+
# By encoding to CF, we have
174+
# encoded the information in that variable across many different
175+
# variables (e.g. node_count) with `name` as a dimension.
176+
# We have to record `name` somewhere so that we reconstruct
177+
# a geometry variable of the right name at decode-time.
178+
variables[GEOMETRY_CONTAINER_NAME].attrs["variable_name"] = name
179+
180+
encoded = xr.Dataset(variables)
181+
182+
return encoded
183+
9184

10185
def reshape_unique_geometries(
11186
ds: xr.Dataset,
@@ -119,13 +294,15 @@ def shapely_to_cf(geometries: xr.DataArray | Sequence, grid_mapping: str | None
119294
f"Mixed geometry types are not supported in CF-compliant datasets. Got {types}"
120295
)
121296

297+
ds[GEOMETRY_CONTAINER_NAME].attrs.update(coordinates="crd_x crd_y")
298+
122299
# Special treatment of selected grid mappings
123300
if grid_mapping == "longitude_latitude":
124301
# Special case for longitude_latitude grid mapping
125302
ds = ds.rename(crd_x="lon", crd_y="lat")
126303
ds.lon.attrs.update(units="degrees_east", standard_name="longitude")
127304
ds.lat.attrs.update(units="degrees_north", standard_name="latitude")
128-
ds.geometry_container.attrs.update(coordinates="lon lat")
305+
ds[GEOMETRY_CONTAINER_NAME].attrs.update(coordinates="lon lat")
129306
ds.x.attrs.update(units="degrees_east", standard_name="longitude")
130307
ds.y.attrs.update(units="degrees_north", standard_name="latitude")
131308
elif grid_mapping is not None:
@@ -157,7 +334,7 @@ def cf_to_shapely(ds: xr.Dataset):
157334
----------
158335
Please refer to the CF conventions document: http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#geometries
159336
"""
160-
geom_type = ds.geometry_container.attrs["geometry_type"]
337+
geom_type = ds[GEOMETRY_CONTAINER_NAME].attrs["geometry_type"]
161338
if geom_type == "point":
162339
geometries = cf_to_points(ds)
163340
elif geom_type == "line":
@@ -235,7 +412,7 @@ def points_to_cf(pts: xr.DataArray | Sequence):
235412
# Special case when we have no MultiPoints
236413
if (ds.node_count == 1).all():
237414
ds = ds.drop_vars("node_count")
238-
del ds.geometry_container.attrs["node_count"]
415+
del ds[GEOMETRY_CONTAINER_NAME].attrs["node_count"]
239416
return ds
240417

241418

@@ -259,18 +436,18 @@ def cf_to_points(ds: xr.Dataset):
259436
from shapely.geometry import MultiPoint, Point
260437

261438
# Shorthand for convenience
262-
geo = ds.geometry_container.attrs
439+
geo = ds[GEOMETRY_CONTAINER_NAME].attrs
263440

264441
# The features dimension name, defaults to the one of 'node_count' or the dimension of the coordinates, if present.
265442
feat_dim = None
266443
if "coordinates" in geo and feat_dim is None:
267444
xcoord_name, _ = geo["coordinates"].split(" ")
268445
(feat_dim,) = ds[xcoord_name].dims
269446

270-
x_name, y_name = ds.geometry_container.attrs["node_coordinates"].split(" ")
447+
x_name, y_name = ds[GEOMETRY_CONTAINER_NAME].attrs["node_coordinates"].split(" ")
271448
xy = np.stack([ds[x_name].values, ds[y_name].values], axis=-1)
272449

273-
node_count_name = ds.geometry_container.attrs.get("node_count")
450+
node_count_name = ds[GEOMETRY_CONTAINER_NAME].attrs.get("node_count")
274451
if node_count_name is None:
275452
# No node_count means all geometries are single points (node_count = 1)
276453
# And if we had no coordinates, then the dimension defaults to "features"
@@ -363,7 +540,7 @@ def lines_to_cf(lines: xr.DataArray | Sequence):
363540
# Special case when we have no MultiLines
364541
if len(ds.part_node_count) == len(ds.node_count):
365542
ds = ds.drop_vars("part_node_count")
366-
del ds.geometry_container.attrs["part_node_count"]
543+
del ds[GEOMETRY_CONTAINER_NAME].attrs["part_node_count"]
367544
return ds
368545

369546

@@ -387,7 +564,7 @@ def cf_to_lines(ds: xr.Dataset):
387564
from shapely import GeometryType, from_ragged_array
388565

389566
# Shorthand for convenience
390-
geo = ds.geometry_container.attrs
567+
geo = ds[GEOMETRY_CONTAINER_NAME].attrs
391568

392569
# The features dimension name, defaults to the one of 'node_count'
393570
# or the dimension of the coordinates, if present.
@@ -503,12 +680,12 @@ def polygons_to_cf(polygons: xr.DataArray | Sequence):
503680
# Special case when we have no MultiPolygons and no holes
504681
if len(ds.part_node_count) == len(ds.node_count):
505682
ds = ds.drop_vars("part_node_count")
506-
del ds.geometry_container.attrs["part_node_count"]
683+
del ds[GEOMETRY_CONTAINER_NAME].attrs["part_node_count"]
507684

508685
# Special case when we have no holes
509686
if (ds.interior_ring == 0).all():
510687
ds = ds.drop_vars("interior_ring")
511-
del ds.geometry_container.attrs["interior_ring"]
688+
del ds[GEOMETRY_CONTAINER_NAME].attrs["interior_ring"]
512689
return ds
513690

514691

@@ -532,7 +709,7 @@ def cf_to_polygons(ds: xr.Dataset):
532709
from shapely import GeometryType, from_ragged_array
533710

534711
# Shorthand for convenience
535-
geo = ds.geometry_container.attrs
712+
geo = ds[GEOMETRY_CONTAINER_NAME].attrs
536713

537714
# The features dimension name, defaults to the one of 'part_node_count'
538715
# or the dimension of the coordinates, if present.

cf_xarray/tests/test_geometry.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,23 @@
44

55
import cf_xarray as cfxr
66

7+
from ..geometry import decode_geometries, encode_geometries
78
from . import requires_shapely
89

910

11+
@pytest.fixture
12+
def polygon_geometry() -> xr.DataArray:
13+
from shapely.geometry import Polygon
14+
15+
# empty/fill workaround to avoid numpy deprecation(warning) due to the array interface of shapely geometries.
16+
geoms = np.empty(2, dtype=object)
17+
geoms[:] = [
18+
Polygon(([50, 0], [40, 15], [30, 0])),
19+
Polygon(([70, 50], [60, 65], [50, 50])),
20+
]
21+
return xr.DataArray(geoms, dims=("index",), name="geometry")
22+
23+
1024
@pytest.fixture
1125
def geometry_ds():
1226
from shapely.geometry import MultiPoint, Point
@@ -127,18 +141,9 @@ def geometry_line_without_multilines_ds():
127141

128142

129143
@pytest.fixture
130-
def geometry_polygon_without_holes_ds():
131-
from shapely.geometry import Polygon
132-
133-
# empty/fill workaround to avoid numpy deprecation(warning) due to the array interface of shapely geometries.
134-
geoms = np.empty(2, dtype=object)
135-
geoms[:] = [
136-
Polygon(([50, 0], [40, 15], [30, 0])),
137-
Polygon(([70, 50], [60, 65], [50, 50])),
138-
]
139-
144+
def geometry_polygon_without_holes_ds(polygon_geometry):
145+
shp_da = polygon_geometry
140146
ds = xr.Dataset()
141-
shp_da = xr.DataArray(geoms, dims=("index",), name="geometry")
142147

143148
cf_ds = ds.assign(
144149
x=xr.DataArray(
@@ -521,3 +526,19 @@ def test_reshape_unique_geometries(geometry_ds):
521526
in_ds = in_ds.assign(geometry=geoms)
522527
with pytest.raises(ValueError, match="The geometry variable must be 1D"):
523528
cfxr.geometry.reshape_unique_geometries(in_ds)
529+
530+
531+
@requires_shapely
532+
def test_encode_decode(geometry_ds, polygon_geometry):
533+
534+
geom_dim_ds = xr.Dataset()
535+
geom_dim_ds = geom_dim_ds.assign_coords(
536+
xr.Coordinates(
537+
coords={"geoms": xr.Variable("geoms", polygon_geometry.variable)},
538+
indexes={},
539+
)
540+
).assign({"foo": ("geoms", [1, 2])})
541+
542+
for ds in (geometry_ds[1], polygon_geometry.to_dataset(), geom_dim_ds):
543+
roundtripped = decode_geometries(encode_geometries(ds))
544+
xr.testing.assert_identical(ds, roundtripped)

0 commit comments

Comments
 (0)