28
28
from . import sgrid
29
29
from .criteria import (
30
30
_DSG_ROLES ,
31
+ _GEOMETRY_TYPES ,
31
32
cf_role_criteria ,
32
33
coordinate_criteria ,
34
+ geometry_var_criteria ,
33
35
grid_mapping_var_criteria ,
34
36
regex ,
35
37
)
39
41
_format_data_vars ,
40
42
_format_dsg_roles ,
41
43
_format_flags ,
44
+ _format_geometries ,
42
45
_format_sgrid ,
43
46
_maybe_panel ,
44
47
)
@@ -198,7 +201,9 @@ def _get_groupby_time_accessor(
198
201
199
202
200
203
def _get_custom_criteria (
201
- obj : DataArray | Dataset , key : Hashable , criteria : Mapping | None = None
204
+ obj : DataArray | Dataset ,
205
+ key : Hashable ,
206
+ criteria : Iterable [Mapping ] | Mapping | None = None ,
202
207
) -> list [Hashable ]:
203
208
"""
204
209
Translate from axis, coord, or custom name to variable name.
@@ -227,18 +232,16 @@ def _get_custom_criteria(
227
232
except ImportError :
228
233
from re import match as regex_match # type: ignore[no-redef]
229
234
230
- if isinstance (obj , DataArray ):
231
- obj = obj ._to_temp_dataset ()
232
- variables = obj ._variables
233
-
234
235
if criteria is None :
235
236
if not OPTIONS ["custom_criteria" ]:
236
237
return []
237
238
criteria = OPTIONS ["custom_criteria" ]
238
239
239
- if criteria is not None :
240
- criteria_iter = always_iterable (criteria , allowed = (tuple , list , set ))
240
+ if isinstance (obj , DataArray ):
241
+ obj = obj ._to_temp_dataset ()
242
+ variables = obj ._variables
241
243
244
+ criteria_iter = always_iterable (criteria , allowed = (tuple , list , set ))
242
245
criteria_map = ChainMap (* criteria_iter )
243
246
results : set = set ()
244
247
if key in criteria_map :
@@ -367,6 +370,21 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
367
370
return list (results )
368
371
369
372
373
+ def _parse_related_geometry_vars (attrs : Mapping ) -> tuple [Hashable ]:
374
+ names = itertools .chain (
375
+ * [
376
+ attrs .get (attr , "" ).split (" " )
377
+ for attr in [
378
+ "interior_ring" ,
379
+ "node_coordinates" ,
380
+ "node_count" ,
381
+ "part_node_count" ,
382
+ ]
383
+ ]
384
+ )
385
+ return tuple (n for n in names if n )
386
+
387
+
370
388
def _get_bounds (obj : DataArray | Dataset , key : Hashable ) -> list [Hashable ]:
371
389
"""
372
390
Translate from key (either CF key or variable name) to its bounds' variable names.
@@ -470,8 +488,14 @@ def _get_all(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
470
488
"""
471
489
all_mappers : tuple [Mapper ] = (
472
490
_get_custom_criteria ,
473
- functools .partial (_get_custom_criteria , criteria = cf_role_criteria ), # type: ignore[assignment]
474
- functools .partial (_get_custom_criteria , criteria = grid_mapping_var_criteria ),
491
+ functools .partial (
492
+ _get_custom_criteria ,
493
+ criteria = (
494
+ cf_role_criteria ,
495
+ grid_mapping_var_criteria ,
496
+ geometry_var_criteria ,
497
+ ),
498
+ ), # type: ignore[assignment]
475
499
_get_axis_coord ,
476
500
_get_measure ,
477
501
_get_grid_mapping_name ,
@@ -821,6 +845,23 @@ def check_results(names, key):
821
845
successful [k ] = bool (grid_mapping )
822
846
if grid_mapping :
823
847
varnames .extend (grid_mapping )
848
+ elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES ):
849
+ geometries = _get_all (obj , k )
850
+ if geometries and k in _GEOMETRY_TYPES :
851
+ new = itertools .chain (
852
+ _parse_related_geometry_vars (
853
+ ChainMap (obj [g ].attrs , obj [g ].encoding )
854
+ )
855
+ for g in geometries
856
+ )
857
+ geometries .extend (* new )
858
+ if len (geometries ) > 1 and scalar_key :
859
+ raise ValueError (
860
+ f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{ k !r} ]` instead."
861
+ )
862
+ successful [k ] = bool (geometries )
863
+ if geometries :
864
+ varnames .extend (geometries )
824
865
elif k in custom_criteria or k in cf_role_criteria :
825
866
names = _get_all (obj , k )
826
867
check_results (names , k )
@@ -1559,8 +1600,7 @@ def _generate_repr(self, rich=False):
1559
1600
_format_flags (self , rich ), title = "Flag Variable" , rich = rich
1560
1601
)
1561
1602
1562
- roles = self .cf_roles
1563
- if roles :
1603
+ if roles := self .cf_roles :
1564
1604
if any (role in roles for role in _DSG_ROLES ):
1565
1605
yield _maybe_panel (
1566
1606
_format_dsg_roles (self , dims , rich ),
@@ -1576,6 +1616,13 @@ def _generate_repr(self, rich=False):
1576
1616
rich = rich ,
1577
1617
)
1578
1618
1619
+ if self .geometries :
1620
+ yield _maybe_panel (
1621
+ _format_geometries (self , dims , rich ),
1622
+ title = "Geometries" ,
1623
+ rich = rich ,
1624
+ )
1625
+
1579
1626
yield _maybe_panel (
1580
1627
_format_coordinates (self , dims , coords , rich ),
1581
1628
title = "Coordinates" ,
@@ -1755,12 +1802,42 @@ def cf_roles(self) -> dict[str, list[Hashable]]:
1755
1802
1756
1803
vardict : dict [str , list [Hashable ]] = {}
1757
1804
for k , v in variables .items ():
1758
- if "cf_role" in v .attrs :
1759
- role = v . attrs [ "cf_role" ]
1805
+ attrs_or_encoding = ChainMap ( v .attrs , v . encoding )
1806
+ if role := attrs_or_encoding . get ( "cf_role" , None ):
1760
1807
vardict [role ] = vardict .setdefault (role , []) + [k ]
1761
1808
1762
1809
return {role_ : sort_maybe_hashable (v ) for role_ , v in vardict .items ()}
1763
1810
1811
+ @property
1812
+ def geometries (self ) -> dict [str , list [Hashable ]]:
1813
+ """
1814
+ Mapping geometry type names to variable names.
1815
+
1816
+ Returns
1817
+ -------
1818
+ dict
1819
+ Dictionary mapping geometry names to variable names.
1820
+
1821
+ References
1822
+ ----------
1823
+ Please refer to the CF conventions document : http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#coordinates-metadata
1824
+ """
1825
+ vardict : dict [str , list [Hashable ]] = {}
1826
+
1827
+ if isinstance (self ._obj , Dataset ):
1828
+ variables = self ._obj ._variables
1829
+ elif isinstance (self ._obj , DataArray ):
1830
+ variables = {"_" : self ._obj ._variable }
1831
+
1832
+ for v in variables .values ():
1833
+ attrs_or_encoding = ChainMap (v .attrs , v .encoding )
1834
+ if geometry := attrs_or_encoding .get ("geometry" , None ):
1835
+ gtype = self ._obj [geometry ].attrs ["geometry_type" ]
1836
+ vardict .setdefault (gtype , [])
1837
+ if geometry not in vardict [gtype ]:
1838
+ vardict [gtype ] += [geometry ]
1839
+ return {type_ : sort_maybe_hashable (v ) for type_ , v in vardict .items ()}
1840
+
1764
1841
def get_associated_variable_names (
1765
1842
self , name : Hashable , skip_bounds : bool = False , error : bool = True
1766
1843
) -> dict [str , list [Hashable ]]:
@@ -1795,15 +1872,15 @@ def get_associated_variable_names(
1795
1872
"bounds" ,
1796
1873
"grid_mapping" ,
1797
1874
"grid" ,
1875
+ "geometry" ,
1798
1876
]
1799
1877
1800
1878
coords : dict [str , list [Hashable ]] = {k : [] for k in keys }
1801
1879
attrs_or_encoding = ChainMap (self ._obj [name ].attrs , self ._obj [name ].encoding )
1802
1880
1803
- coordinates = attrs_or_encoding .get ("coordinates" , None )
1804
1881
# Handles case where the coordinates attribute is None
1805
1882
# This is used to tell xarray to not write a coordinates attribute
1806
- if coordinates :
1883
+ if coordinates := attrs_or_encoding . get ( "coordinates" , None ) :
1807
1884
coords ["coordinates" ] = coordinates .split (" " )
1808
1885
1809
1886
if "cell_measures" in attrs_or_encoding :
@@ -1822,27 +1899,32 @@ def get_associated_variable_names(
1822
1899
)
1823
1900
coords ["cell_measures" ] = []
1824
1901
1825
- if (
1826
- isinstance (self ._obj , Dataset )
1827
- and "ancillary_variables" in attrs_or_encoding
1902
+ if isinstance (self ._obj , Dataset ) and (
1903
+ anc := attrs_or_encoding .get ("ancillary_variables" , None )
1828
1904
):
1829
- coords ["ancillary_variables" ] = attrs_or_encoding [
1830
- "ancillary_variables"
1831
- ].split (" " )
1905
+ coords ["ancillary_variables" ] = anc .split (" " )
1832
1906
1833
1907
if not skip_bounds :
1834
- if " bounds" in attrs_or_encoding :
1835
- coords ["bounds" ] = [attrs_or_encoding [ " bounds" ] ]
1908
+ if bounds := attrs_or_encoding . get ( "bounds" , None ) :
1909
+ coords ["bounds" ] = [bounds ]
1836
1910
for dim in self ._obj [name ].dims :
1837
- dbounds = self ._obj [dim ].attrs .get ("bounds" , None )
1838
- if dbounds :
1911
+ if dbounds := self ._obj [dim ].attrs .get ("bounds" , None ):
1839
1912
coords ["bounds" ].append (dbounds )
1840
1913
1841
- if "grid" in attrs_or_encoding :
1842
- coords ["grid" ] = [attrs_or_encoding ["grid" ]]
1914
+ for attrname in ["grid" , "grid_mapping" ]:
1915
+ if maybe := attrs_or_encoding .get (attrname , None ):
1916
+ coords [attrname ] = [maybe ]
1843
1917
1844
- if "grid_mapping" in attrs_or_encoding :
1845
- coords ["grid_mapping" ] = [attrs_or_encoding ["grid_mapping" ]]
1918
+ more : Sequence [Hashable ] = ()
1919
+ if geometry_var := attrs_or_encoding .get ("geometry" , None ):
1920
+ coords ["geometry" ] = [geometry_var ]
1921
+ _attrs = ChainMap (
1922
+ self ._obj [geometry_var ].attrs , self ._obj [geometry_var ].encoding
1923
+ )
1924
+ more = _parse_related_geometry_vars (_attrs )
1925
+ elif "geometry_type" in attrs_or_encoding :
1926
+ more = _parse_related_geometry_vars (attrs_or_encoding )
1927
+ coords ["geometry" ].extend (more )
1846
1928
1847
1929
allvars = itertools .chain (* coords .values ())
1848
1930
missing = set (allvars ) - set (self ._maybe_to_dataset ()._variables )
0 commit comments