28
28
from . import sgrid
29
29
from .criteria import (
30
30
_DSG_ROLES ,
31
+ _GEOMETRY_TYPES ,
31
32
cf_role_criteria ,
32
33
coordinate_criteria ,
34
+ geometry_var_criteria ,
33
35
grid_mapping_var_criteria ,
34
36
regex ,
35
37
)
39
41
_format_data_vars ,
40
42
_format_dsg_roles ,
41
43
_format_flags ,
44
+ _format_geometries ,
42
45
_format_sgrid ,
43
46
_maybe_panel ,
44
47
)
@@ -227,18 +230,16 @@ def _get_custom_criteria(
227
230
except ImportError :
228
231
from re import match as regex_match # type: ignore[no-redef]
229
232
230
- if isinstance (obj , DataArray ):
231
- obj = obj ._to_temp_dataset ()
232
- variables = obj ._variables
233
-
234
233
if criteria is None :
235
234
if not OPTIONS ["custom_criteria" ]:
236
235
return []
237
236
criteria = OPTIONS ["custom_criteria" ]
238
237
239
- if criteria is not None :
240
- criteria_iter = always_iterable (criteria , allowed = (tuple , list , set ))
238
+ if isinstance (obj , DataArray ):
239
+ obj = obj ._to_temp_dataset ()
240
+ variables = obj ._variables
241
241
242
+ criteria_iter = always_iterable (criteria , allowed = (tuple , list , set ))
242
243
criteria_map = ChainMap (* criteria_iter )
243
244
results : set = set ()
244
245
if key in criteria_map :
@@ -367,6 +368,21 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
367
368
return list (results )
368
369
369
370
371
+ def _parse_related_geometry_vars (attrs : Mapping ) -> tuple [Hashable ]:
372
+ names = itertools .chain (
373
+ * [
374
+ attrs .get (attr , "" ).split (" " )
375
+ for attr in [
376
+ "interior_ring" ,
377
+ "node_coordinates" ,
378
+ "node_count" ,
379
+ "part_node_count" ,
380
+ ]
381
+ ]
382
+ )
383
+ return tuple (n for n in names if n )
384
+
385
+
370
386
def _get_bounds (obj : DataArray | Dataset , key : Hashable ) -> list [Hashable ]:
371
387
"""
372
388
Translate from key (either CF key or variable name) to its bounds' variable names.
@@ -470,8 +486,12 @@ def _get_all(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
470
486
"""
471
487
all_mappers : tuple [Mapper ] = (
472
488
_get_custom_criteria ,
473
- functools .partial (_get_custom_criteria , criteria = cf_role_criteria ), # type: ignore[assignment]
474
- functools .partial (_get_custom_criteria , criteria = grid_mapping_var_criteria ),
489
+ functools .partial (
490
+ _get_custom_criteria ,
491
+ criteria = ChainMap (
492
+ cf_role_criteria , grid_mapping_var_criteria , geometry_var_criteria
493
+ ),
494
+ ),
475
495
_get_axis_coord ,
476
496
_get_measure ,
477
497
_get_grid_mapping_name ,
@@ -821,6 +841,23 @@ def check_results(names, key):
821
841
successful [k ] = bool (grid_mapping )
822
842
if grid_mapping :
823
843
varnames .extend (grid_mapping )
844
+ elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES ):
845
+ geometries = _get_all (obj , k )
846
+ if geometries and k in _GEOMETRY_TYPES :
847
+ new = itertools .chain (
848
+ _parse_related_geometry_vars (
849
+ ChainMap (obj [g ].attrs , obj [g ].encoding )
850
+ )
851
+ for g in geometries
852
+ )
853
+ geometries .extend (* new )
854
+ if len (geometries ) > 1 and scalar_key :
855
+ raise ValueError (
856
+ f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{ k !r} ]` instead."
857
+ )
858
+ successful [k ] = bool (geometries )
859
+ if geometries :
860
+ varnames .extend (geometries )
824
861
elif k in custom_criteria or k in cf_role_criteria :
825
862
names = _get_all (obj , k )
826
863
check_results (names , k )
@@ -1559,8 +1596,7 @@ def _generate_repr(self, rich=False):
1559
1596
_format_flags (self , rich ), title = "Flag Variable" , rich = rich
1560
1597
)
1561
1598
1562
- roles = self .cf_roles
1563
- if roles :
1599
+ if roles := self .cf_roles :
1564
1600
if any (role in roles for role in _DSG_ROLES ):
1565
1601
yield _maybe_panel (
1566
1602
_format_dsg_roles (self , dims , rich ),
@@ -1576,6 +1612,13 @@ def _generate_repr(self, rich=False):
1576
1612
rich = rich ,
1577
1613
)
1578
1614
1615
+ if self .geometries :
1616
+ yield _maybe_panel (
1617
+ _format_geometries (self , dims , rich ),
1618
+ title = "Geometries" ,
1619
+ rich = rich ,
1620
+ )
1621
+
1579
1622
yield _maybe_panel (
1580
1623
_format_coordinates (self , dims , coords , rich ),
1581
1624
title = "Coordinates" ,
@@ -1755,12 +1798,42 @@ def cf_roles(self) -> dict[str, list[Hashable]]:
1755
1798
1756
1799
vardict : dict [str , list [Hashable ]] = {}
1757
1800
for k , v in variables .items ():
1758
- if "cf_role" in v .attrs :
1759
- role = v . attrs [ "cf_role" ]
1801
+ attrs_or_encoding = ChainMap ( v .attrs , v . encoding )
1802
+ if role := attrs_or_encoding . get ( "cf_role" , None ):
1760
1803
vardict [role ] = vardict .setdefault (role , []) + [k ]
1761
1804
1762
1805
return {role_ : sort_maybe_hashable (v ) for role_ , v in vardict .items ()}
1763
1806
1807
+ @property
1808
+ def geometries (self ) -> dict [str , list [Hashable ]]:
1809
+ """
1810
+ Mapping geometry type names to variable names.
1811
+
1812
+ Returns
1813
+ -------
1814
+ dict
1815
+ Dictionary mapping geometry names to variable names.
1816
+
1817
+ References
1818
+ ----------
1819
+ Please refer to the CF conventions document : http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#coordinates-metadata
1820
+ """
1821
+ vardict : dict [str , list [Hashable ]] = {}
1822
+
1823
+ if isinstance (self ._obj , Dataset ):
1824
+ variables = self ._obj ._variables
1825
+ elif isinstance (self ._obj , DataArray ):
1826
+ variables = {"_" : self ._obj ._variable }
1827
+
1828
+ for v in variables .values ():
1829
+ attrs_or_encoding = ChainMap (v .attrs , v .encoding )
1830
+ if geometry := attrs_or_encoding .get ("geometry" , None ):
1831
+ gtype = self ._obj [geometry ].attrs ["geometry_type" ]
1832
+ vardict .setdefault (gtype , [])
1833
+ if geometry not in vardict [gtype ]:
1834
+ vardict [gtype ] += [geometry ]
1835
+ return {type_ : sort_maybe_hashable (v ) for type_ , v in vardict .items ()}
1836
+
1764
1837
def get_associated_variable_names (
1765
1838
self , name : Hashable , skip_bounds : bool = False , error : bool = True
1766
1839
) -> dict [str , list [Hashable ]]:
@@ -1795,15 +1868,15 @@ def get_associated_variable_names(
1795
1868
"bounds" ,
1796
1869
"grid_mapping" ,
1797
1870
"grid" ,
1871
+ "geometry" ,
1798
1872
]
1799
1873
1800
1874
coords : dict [str , list [Hashable ]] = {k : [] for k in keys }
1801
1875
attrs_or_encoding = ChainMap (self ._obj [name ].attrs , self ._obj [name ].encoding )
1802
1876
1803
- coordinates = attrs_or_encoding .get ("coordinates" , None )
1804
1877
# Handles case where the coordinates attribute is None
1805
1878
# This is used to tell xarray to not write a coordinates attribute
1806
- if coordinates :
1879
+ if coordinates := attrs_or_encoding . get ( "coordinates" , None ) :
1807
1880
coords ["coordinates" ] = coordinates .split (" " )
1808
1881
1809
1882
if "cell_measures" in attrs_or_encoding :
@@ -1822,27 +1895,32 @@ def get_associated_variable_names(
1822
1895
)
1823
1896
coords ["cell_measures" ] = []
1824
1897
1825
- if (
1826
- isinstance (self ._obj , Dataset )
1827
- and "ancillary_variables" in attrs_or_encoding
1898
+ if isinstance (self ._obj , Dataset ) and (
1899
+ anc := attrs_or_encoding .get ("ancillary_variables" , None )
1828
1900
):
1829
- coords ["ancillary_variables" ] = attrs_or_encoding [
1830
- "ancillary_variables"
1831
- ].split (" " )
1901
+ coords ["ancillary_variables" ] = anc .split (" " )
1832
1902
1833
1903
if not skip_bounds :
1834
- if " bounds" in attrs_or_encoding :
1835
- coords ["bounds" ] = [attrs_or_encoding [ " bounds" ] ]
1904
+ if bounds := attrs_or_encoding . get ( "bounds" , None ) :
1905
+ coords ["bounds" ] = [bounds ]
1836
1906
for dim in self ._obj [name ].dims :
1837
- dbounds = self ._obj [dim ].attrs .get ("bounds" , None )
1838
- if dbounds :
1907
+ if dbounds := self ._obj [dim ].attrs .get ("bounds" , None ):
1839
1908
coords ["bounds" ].append (dbounds )
1840
1909
1841
- if "grid" in attrs_or_encoding :
1842
- coords ["grid" ] = [attrs_or_encoding ["grid" ]]
1910
+ for attrname in ["grid" , "grid_mapping" ]:
1911
+ if maybe := attrs_or_encoding .get (attrname , None ):
1912
+ coords [attrname ] = [maybe ]
1843
1913
1844
- if "grid_mapping" in attrs_or_encoding :
1845
- coords ["grid_mapping" ] = [attrs_or_encoding ["grid_mapping" ]]
1914
+ more = []
1915
+ if geometry_var := attrs_or_encoding .get ("geometry" , None ):
1916
+ coords ["geometry" ] = [geometry_var ]
1917
+ _attrs = ChainMap (
1918
+ self ._obj [geometry_var ].attrs , self ._obj [geometry_var ].encoding
1919
+ )
1920
+ more = _parse_related_geometry_vars (_attrs )
1921
+ elif "geometry_type" in attrs_or_encoding :
1922
+ more = _parse_related_geometry_vars (attrs_or_encoding )
1923
+ coords ["geometry" ].extend (more )
1846
1924
1847
1925
allvars = itertools .chain (* coords .values ())
1848
1926
missing = set (allvars ) - set (self ._maybe_to_dataset ()._variables )
0 commit comments