Skip to content

Commit

Permalink
repo-review round 2 (#485)
Browse files Browse the repository at this point in the history
* repo-review 2

* fix

* fixes

* more fix
  • Loading branch information
dcherian authored Dec 5, 2023
1 parent 2a1e4e1 commit f573ed7
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 74 deletions.
86 changes: 44 additions & 42 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import inspect
import itertools
import re
import warnings
from collections import ChainMap, namedtuple
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
from datetime import datetime
from typing import (
Any,
Callable,
Literal,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -48,6 +48,7 @@
_get_version,
_is_datetime_like,
always_iterable,
emit_user_level_warning,
invert_mappings,
parse_cell_methods_attr,
parse_cf_standard_name_table,
Expand Down Expand Up @@ -107,7 +108,7 @@ def apply_mapper(
"""

if not isinstance(key, Hashable):
if default is None:
if default is None: # type: ignore[unreachable]
raise ValueError(
"`default` must be provided when `key` is not not a valid DataArray name (of hashable type)."
)
Expand Down Expand Up @@ -224,7 +225,7 @@ def _get_custom_criteria(
try:
from regex import match as regex_match
except ImportError:
from re import match as regex_match # type: ignore
from re import match as regex_match # type: ignore[no-redef]

if isinstance(obj, DataArray):
obj = obj._to_temp_dataset()
Expand Down Expand Up @@ -363,8 +364,6 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
if key in measures:
results.update([measures[key]])

if isinstance(results, str):
return [results]
return list(results)


Expand Down Expand Up @@ -471,7 +470,7 @@ def _get_all(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
"""
all_mappers: tuple[Mapper] = (
_get_custom_criteria,
functools.partial(_get_custom_criteria, criteria=cf_role_criteria), # type: ignore
functools.partial(_get_custom_criteria, criteria=cf_role_criteria), # type: ignore[assignment]
functools.partial(_get_custom_criteria, criteria=grid_mapping_var_criteria),
_get_axis_coord,
_get_measure,
Expand Down Expand Up @@ -653,10 +652,10 @@ def _getattr(
):
raise AttributeError(
f"{obj.__class__.__name__+'.cf'!r} object has no attribute {attr!r}"
)
) from None
raise AttributeError(
f"{attr!r} is not a valid attribute on the underlying xarray object."
)
) from None

if isinstance(attribute, Mapping):
if not attribute:
Expand All @@ -680,7 +679,7 @@ def _getattr(
newmap.update(dict.fromkeys(inverted[key], value))
newmap.update({key: attribute[key] for key in unused_keys})

skip: dict[str, list[Hashable] | None] = {
skip: dict[str, list[Literal["coords", "measures"]] | None] = {
"data_vars": ["coords"],
"coords": None,
}
Expand All @@ -689,7 +688,7 @@ def _getattr(
newmap[key] = _getitem(accessor, key, skip=skip[attr])
return newmap

elif isinstance(attribute, Callable): # type: ignore
elif isinstance(attribute, Callable): # type: ignore[arg-type]
func: Callable = attribute

else:
Expand Down Expand Up @@ -721,7 +720,7 @@ def wrapper(*args, **kwargs):
def _getitem(
accessor: CFAccessor,
key: Hashable,
skip: list[Hashable] | None = None,
skip: list[Literal["coords", "measures"]] | None = None,
) -> DataArray:
...

Expand All @@ -730,15 +729,15 @@ def _getitem(
def _getitem(
accessor: CFAccessor,
key: Iterable[Hashable],
skip: list[Hashable] | None = None,
skip: list[Literal["coords", "measures"]] | None = None,
) -> Dataset:
...


def _getitem(
accessor,
key,
skip=None,
accessor: CFAccessor,
key: Hashable | Iterable[Hashable],
skip: list[Literal["coords", "measures"]] | None = None,
):
"""
Index into obj using key. Attaches CF associated variables.
Expand Down Expand Up @@ -789,7 +788,7 @@ def check_results(names, key):
measures = accessor._get_all_cell_measures()
except ValueError:
measures = []
warnings.warn("Ignoring bad cell_measures attribute.", UserWarning)
emit_user_level_warning("Ignoring bad cell_measures attribute.", UserWarning)

if isinstance(obj, Dataset):
grid_mapping_names = list(accessor.grid_mapping_names)
Expand Down Expand Up @@ -852,6 +851,7 @@ def check_results(names, key):
)
coords.extend(itertools.chain(*extravars.values()))

ds: Dataset
if isinstance(obj, DataArray):
ds = obj._to_temp_dataset()
else:
Expand All @@ -860,7 +860,7 @@ def check_results(names, key):
if scalar_key:
if len(allnames) == 1:
(name,) = allnames
da: DataArray = ds.reset_coords()[name] # type: ignore
da: DataArray = ds.reset_coords()[name]
if name in coords:
coords.remove(name)
for k1 in coords:
Expand All @@ -877,26 +877,27 @@ def check_results(names, key):

ds = ds.reset_coords()[varnames + coords]
if isinstance(obj, DataArray):
if scalar_key and len(ds.variables) == 1:
# single dimension coordinates
assert coords
assert not varnames
if scalar_key:
if len(ds.variables) == 1: # type: ignore[unreachable]
# single dimension coordinates
assert coords
assert not varnames

return ds[coords[0]]
return ds[coords[0]]

elif scalar_key and len(ds.variables) > 1:
raise NotImplementedError(
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
"Please open an issue."
)
else:
raise NotImplementedError(
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
"Please open an issue."
)

return ds.set_coords(coords)

except KeyError:
raise KeyError(
f"{kind}.cf does not understand the key {k!r}. "
f"Use 'repr({kind}.cf)' (or '{kind}.cf' in a Jupyter environment) to see a list of key names that can be interpreted."
)
) from None


def _possible_x_y_plot(obj, key, skip=None):
Expand Down Expand Up @@ -1135,7 +1136,7 @@ def _assert_valid_other_comparison(self, other):
)
return flag_dict

def __eq__(self, other) -> DataArray: # type: ignore
def __eq__(self, other) -> DataArray: # type: ignore[override]
"""
Compare flag values against ``other``.
Expand All @@ -1155,7 +1156,7 @@ def __eq__(self, other) -> DataArray: # type: ignore
"""
return self._extract_flags([other])[other].rename(self._obj.name)

def __ne__(self, other) -> DataArray: # type: ignore
def __ne__(self, other) -> DataArray: # type: ignore[override]
"""
Compare flag values against ``other``.
Expand Down Expand Up @@ -1328,7 +1329,7 @@ def curvefit(
coords_iter = coords
coords = [
apply_mapper(
[_single(_get_coords)], self._obj, v, error=False, default=[v] # type: ignore
[_single(_get_coords)], self._obj, v, error=False, default=[v] # type: ignore[arg-type]
)[0]
for v in coords_iter
]
Expand All @@ -1339,7 +1340,7 @@ def curvefit(
reduce_dims_iter = list(reduce_dims)
reduce_dims = [
apply_mapper(
[_single(_get_dims)], self._obj, v, error=False, default=[v] # type: ignore
[_single(_get_dims)], self._obj, v, error=False, default=[v] # type: ignore[arg-type]
)[0]
for v in reduce_dims_iter
]
Expand Down Expand Up @@ -1435,7 +1436,7 @@ def _rewrite_values(

# allow multiple return values here.
# these are valid for .sel, .isel, .coarsen
all_mappers = ChainMap( # type: ignore
all_mappers = ChainMap( # type: ignore[misc]
key_mappers,
dict.fromkeys(var_kws, (_get_all,)),
)
Expand Down Expand Up @@ -1531,7 +1532,7 @@ def describe(self):
Print a string repr to screen.
"""

warnings.warn(
emit_user_level_warning(
"'obj.cf.describe()' will be removed in a future version. "
"Use instead 'repr(obj.cf)' or 'obj.cf' in a Jupyter environment.",
DeprecationWarning,
Expand Down Expand Up @@ -1695,10 +1696,9 @@ def cell_measures(self) -> dict[str, list[Hashable]]:
bad_vars = list(
as_dataset.filter_by_attrs(cell_measures=attr).data_vars.keys()
)
warnings.warn(
emit_user_level_warning(
f"Ignoring bad cell_measures attribute: {attr} on {bad_vars}.",
UserWarning,
stacklevel=2,
)
measures = {
key: self._drop_missing_variables(_get_all(self._obj, key)) for key in keys
Expand Down Expand Up @@ -1816,9 +1816,9 @@ def get_associated_variable_names(
except ValueError as e:
if error:
msg = e.args[0] + " Ignore this error by passing 'error=False'"
raise ValueError(msg)
raise ValueError(msg) from None
else:
warnings.warn(
emit_user_level_warning(
f"Ignoring bad cell_measures attribute: {attrs_or_encoding['cell_measures']}",
UserWarning,
)
Expand Down Expand Up @@ -1850,7 +1850,7 @@ def get_associated_variable_names(
missing = set(allvars) - set(self._maybe_to_dataset()._variables)
if missing:
if OPTIONS["warn_on_missing_variables"]:
warnings.warn(
emit_user_level_warning(
f"Variables {missing!r} not found in object but are referred to in the CF attributes.",
UserWarning,
)
Expand Down Expand Up @@ -1963,7 +1963,7 @@ def get_renamer_and_conflicts(keydict):

# Rename and warn
if conflicts:
warnings.warn(
emit_user_level_warning(
"Conflicting variables skipped:\n"
+ "\n".join(
[
Expand Down Expand Up @@ -2684,10 +2684,12 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None):
try:
zname = outnames[dim]
except KeyError:
raise KeyError("Your `outnames` need to include a key of `dim`.")
raise KeyError(
"Your `outnames` need to include a key of `dim`."
) from None

else:
warnings.warn(
emit_user_level_warning(
"`prefix` is being deprecated; use `outnames` instead.",
DeprecationWarning,
)
Expand Down
4 changes: 2 additions & 2 deletions cf_xarray/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
try:
import regex as re
except ImportError:
import re # type: ignore
import re # type: ignore[no-redef]

from collections.abc import Mapping, MutableMapping
from typing import Any
Expand Down Expand Up @@ -128,7 +128,7 @@
coordinate_criteria["time"] = coordinate_criteria["T"]

# "long_name" and "standard_name" criteria are the same. For convenience.
for coord, attrs in coordinate_criteria.items():
for coord in coordinate_criteria:
coordinate_criteria[coord]["long_name"] = coordinate_criteria[coord][
"standard_name"
]
Expand Down
24 changes: 9 additions & 15 deletions cf_xarray/formatting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from collections.abc import Hashable, Iterable
from functools import partial
Expand All @@ -10,7 +12,7 @@
try:
from rich.table import Table
except ImportError:
Table = None # type: ignore
Table = None # type: ignore[assignment, misc]


def _format_missing_row(row: str, rich: bool) -> str:
Expand Down Expand Up @@ -41,7 +43,7 @@ def _format_cf_name(name: str, rich: bool) -> str:
def make_text_section(
accessor,
subtitle: str,
attr: str,
attr: str | dict,
dims=None,
valid_keys=None,
valid_values=None,
Expand Down Expand Up @@ -140,10 +142,10 @@ def _maybe_panel(textgen, title: str, rich: bool):
width=100,
)
if isinstance(textgen, Table):
return Panel(textgen, padding=(0, 20), **kwargs) # type: ignore
return Panel(textgen, padding=(0, 20), **kwargs) # type: ignore[arg-type]
else:
text = "".join(textgen)
return Panel(f"[color(241)]{text.rstrip()}[/color(241)]", **kwargs) # type: ignore
return Panel(f"[color(241)]{text.rstrip()}[/color(241)]", **kwargs) # type: ignore[arg-type]
else:
text = "".join(textgen)
return title + ":\n" + text
Expand Down Expand Up @@ -220,22 +222,14 @@ def _format_flags(accessor, rich):
table.add_column("Value", justify="right")
table.add_column("Bits", justify="center")

for val, bit, (key, (mask, value)) in zip(
value_text, bit_text, flag_dict.items()
):
table.add_row(
_format_cf_name(key, rich),
val,
bit,
)
for val, bit, key in zip(value_text, bit_text, flag_dict):
table.add_row(_format_cf_name(key, rich), val, bit)

return table

else:
rows = []
for val, bit, (key, (mask, value)) in zip(
value_text, bit_text, flag_dict.items()
):
for val, bit, key in zip(value_text, bit_text, flag_dict):
rows.append(f"{TAB}{_format_cf_name(key, rich)}: {TAB} {val} {bit}")
return _print_rows("Flag Meanings", rows, rich)

Expand Down
4 changes: 2 additions & 2 deletions cf_xarray/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class set_options: # numpydoc ignore=PR01,PR02

def __init__(self, **kwargs):
self.old = {}
for k, v in kwargs.items():
for k in kwargs:
if k not in OPTIONS:
raise ValueError(
f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}"
Expand All @@ -58,7 +58,7 @@ def __init__(self, **kwargs):

def _apply_update(self, options_dict):
options_dict = copy.deepcopy(options_dict)
for k, v in options_dict.items():
for k in options_dict:
if k == "custom_criteria":
options_dict["custom_criteria"] = always_iterable(
options_dict["custom_criteria"], allowed=(tuple, list)
Expand Down
Loading

0 comments on commit f573ed7

Please sign in to comment.