Skip to content

Commit

Permalink
Old CDS-ADS fashion constraints (#72)
Browse files Browse the repository at this point in the history
* Use old CDS-ADS fashion constraints
* Add the DateTimeRange dependency
  • Loading branch information
ecmwf-cobarzan authored Dec 12, 2023
1 parent 1508b4e commit 9fb29b9
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .cruft.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"template": "https://github.com/ecmwf-projects/cookiecutter-conda-package",
"commit": "ce9afbb8510935c0206746d26f05f6b80e9d0087",
"commit": "c6665306749b5dd3b4ec0fdcf1cb31d18fe23511",
"checkout": null,
"context": {
"cookiecutter": {
Expand Down
186 changes: 174 additions & 12 deletions cads_adaptors/constraints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Main module of the request-constraints API."""

import copy
import re
from typing import Any

from datetimerange import DateTimeRange

from . import translators


Expand Down Expand Up @@ -101,19 +105,20 @@ def apply_constraints(
:return: a dictionary containing all values that should be left
active for selection, in JSON format
"""
always_valid = get_always_valid_params(form, constraints)
constraint_keys = get_keys(constraints)
always_valid = get_always_valid_params(form, constraint_keys)

form = copy.deepcopy(form)
selection = copy.deepcopy(selection)
for key, value in form.copy().items():
if key not in get_keys(constraints):
for key in form.copy():
if key not in constraint_keys:
form.pop(key, None)
selection.pop(key, None)

result = get_form_state(form, selection, constraints)
result.update(always_valid)
result = apply_constraints_in_old_cds_fashion(form, selection, constraints)
result.update(format_to_json(always_valid))

return format_to_json(result)
return result


def get_possible_values(
Expand Down Expand Up @@ -185,6 +190,126 @@ def get_possible_values(
return result


def apply_constraints_in_old_cds_fashion(
form: dict[str, set[Any]],
selection: dict[str, set[Any]],
constraints: list[dict[str, set[Any]]],
) -> dict[str, list[Any]]:
result: dict[str, set[Any]] = {}

# if the selection is empty, return the entire form
if len(selection) == 0:
return format_to_json(form)

for constraint in constraints:
# the per-selected-widget result is the union of:
# - all constraints containing the selected widget with at least one
# value/option in common with the selected values/options (Category 1)
# - all constraints NOT containing the selected widget (Category 2)

# loop over the widgets in the selection
# as a general rule, a widget cannot decide for itself (but only for others)
# only other widgets can enable/disable options/values in the "current" widget
per_constraint_result: dict[str, dict[str, set[Any]]] = {}
for selected_widget_name, selected_widget_options in selection.items():
if selected_widget_name in constraint:
constraint_is_intersected = False
if selected_widget_name == "date_range":
assert (
len(selected_widget_options) == 1
), "More than one selected date range!"
selected_range = gen_time_range_from_string(
next(iter(selected_widget_options))
)
valid_ranges = [
gen_time_range_from_string(valid_range)
for valid_range in constraint[selected_widget_name]
]
if temporal_intersection_between(selected_range, valid_ranges):
constraint_is_intersected = True
else:
constraint_selection_intersection = (
selected_widget_options & constraint[selected_widget_name]
)
if len(constraint_selection_intersection):
constraint_is_intersected = True
if constraint_is_intersected:
# factoring in Category 1 constraints
if selected_widget_name not in per_constraint_result:
per_constraint_result[selected_widget_name] = {}
for widget_name in form:
if widget_name != selected_widget_name:
per_constraint_result[selected_widget_name][
widget_name
] = set()
for widget_name, widget_options in constraint.items():
if widget_name != selected_widget_name:
if (
widget_name
in per_constraint_result[selected_widget_name]
):
per_constraint_result[selected_widget_name][
widget_name
] |= set(widget_options)
else:
per_constraint_result[selected_widget_name][
widget_name
] = set(widget_options)
else:
# factoring in Category 2 constraints
if selected_widget_name not in per_constraint_result:
per_constraint_result[selected_widget_name] = {}
for widget_name in form:
if widget_name != selected_widget_name:
per_constraint_result[selected_widget_name][
widget_name
] = set()
for widget_name, widget_options in constraint.items():
if widget_name in per_constraint_result[selected_widget_name]:
per_constraint_result[selected_widget_name][widget_name] |= set(
widget_options
)
else:
per_constraint_result[selected_widget_name][widget_name] = set(
widget_options
)

for widget_name in form:
per_constraint_result_agg: set[Any] = set()
for selected_widget_name in selection:
if widget_name != selected_widget_name:
if selected_widget_name in per_constraint_result:
if per_constraint_result_agg:
per_constraint_result_agg &= per_constraint_result[
selected_widget_name
][widget_name]
else:
per_constraint_result_agg = per_constraint_result[
selected_widget_name
][widget_name]
else:
per_constraint_result_agg = set()
break
if widget_name in result:
result[widget_name] |= per_constraint_result_agg
else:
result[widget_name] = per_constraint_result_agg

for widget_name in form:
if widget_name not in result:
result[widget_name] = set()

# as a general rule, a widget cannot decide for itself (but only for others)
# only other widgets can enable/disable options/values in the "current" widget
# when the selection contains only one widget, we need to enable all options for that widget
# (as an exception from the general rule)
if len(selection) == 1:
only_widget_in_selection = next(iter(selection))
result[only_widget_in_selection] = form[only_widget_in_selection]

return format_to_json(result)


def format_to_json(result: dict[str, set[Any]]) -> dict[str, list[Any]]:
"""
Convert dict[str, set[Any]] into dict[str, list[Any]].
Expand Down Expand Up @@ -253,7 +378,7 @@ def get_form_state(

def get_always_valid_params(
form: dict[str, set[Any]],
constraints: list[dict[str, set[Any]]],
constraint_keys: set[str],
) -> dict[str, set[Any]]:
"""
Get always valid field and values.
Expand All @@ -267,8 +392,7 @@ def get_always_valid_params(
}
:type: dict[str, set[Any]]:
:param constraints: a list of dictionaries representing
all constraints for a specific dataset
:param constraint_keys: a set of strings representing all constraints keys for a specific dataset
e.g. constraints = [
{"level": {"500"}, "param": {"Z", "T"}, "step": {"24", "36", "48"}},
{"level": {"1000"}, "param": {"Z"}, "step": {"24", "48"}},
Expand All @@ -282,7 +406,7 @@ def get_always_valid_params(
"""
result: dict[str, set[Any]] = {}
for field_name, field_values in form.items():
if field_name not in get_keys(constraints):
if field_name not in constraint_keys:
result.setdefault(field_name, field_values)
return result

Expand All @@ -308,8 +432,24 @@ def parse_form(raw_form: list[Any] | dict[str, Any] | None) -> dict[str, set[Any
ogc_form[field_name]["schema_"]["items"]["enum"]
)
else:
# FIXME: temporarely fix for making constraints working from UI
form[field_name] = [] # type: ignore
handled = False
if ogc_form[field_name]["schema_"].get("default", None):
if ogc_form[field_name]["schema_"]["default"].get(
"defaultStart", None
) and ogc_form[field_name]["schema_"]["default"].get(
"defaultEnd", None
):
defaultStart = ogc_form[field_name]["schema_"]["default"][
"defaultStart"
]
defaultEnd = ogc_form[field_name]["schema_"]["default"][
"defaultEnd"
]
form[field_name] = set([f"{defaultStart}/{defaultEnd}"])
handled = True
if not handled:
# FIXME: temporarely fix for making constraints working from UI
form[field_name] = set() # type: ignore
else:
form[field_name] = set(ogc_form[field_name]["schema_"]["enum"])
except KeyError:
Expand All @@ -336,3 +476,25 @@ def get_keys(constraints: list[dict[str, Any]]) -> set[str]:
for constraint in constraints:
keys |= set(constraint.keys())
return keys


def temporal_intersection_between(
selected: DateTimeRange, ranges: list[DateTimeRange]
) -> bool:
for valid in ranges:
if selected.intersection(valid).is_valid_timerange():
return True
return False


def gen_time_range_from_string(string: str) -> DateTimeRange:
dates = re.split("[;/]", string)
if len(dates) == 1:
dates *= 2
time_range = DateTimeRange(dates[0], dates[1])
time_range.start_time_format = "%Y-%m-%d"
time_range.end_time_format = "%Y-%m-%d"
if time_range.is_valid_timerange():
return time_range
else:
raise ValueError("Start date must be before end date")
3 changes: 2 additions & 1 deletion cads_adaptors/costing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def estimate_granules(
_constraints: list[dict[str, set[str]]],
safe: bool = True,
) -> int:
always_valid = constraints.get_always_valid_params(form, _constraints)
constraint_keys = constraints.get_keys(_constraints)
always_valid = constraints.get_always_valid_params(form, constraint_keys)
selected_but_always_valid = {
k: v for k, v in selection.items() if k in always_valid.keys()
}
Expand Down
15 changes: 15 additions & 0 deletions cads_adaptors/translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,26 @@ def translate_geographic_extent_map(input_cds_schema: dict[str, Any]) -> dict[st
return input_ogc_schema


def translate_date_range(input_cds_schema: dict[str, Any]) -> dict[str, Any]:
input_ogc_schema = {
"type": "array",
"minItems": 2,
"maxItems": 2,
"items": {"type": "string"},
"default": {
"defaultStart": input_cds_schema["details"].get("defaultStart", None),
"defaultEnd": input_cds_schema["details"].get("defaultEnd", None),
},
}
return input_ogc_schema


SCHEMA_TRANSLATORS = {
"StringListWidget": translate_string_list,
"StringListArrayWidget": translate_string_list_array,
"StringChoiceWidget": translate_string_choice,
"GeographicExtentWidget": translate_geographic_extent_map,
"DateRangeWidget": translate_date_range,
}


Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ dependencies:
- wget
- multiurl>=0.2.3.2
- pyyaml
- DateTimeRange
- pip:
- rooki
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ dependencies = [
"multiurl>=0.2.3.2",
"rooki",
"python-dateutil",
"pyyaml"
"pyyaml",
"DateTimeRange"
]
description = "CADS data retrieve utilities to be used by adaptors"
dynamic = ["version"]
Expand Down

0 comments on commit 9fb29b9

Please sign in to comment.