Skip to content

Commit

Permalink
Add email in request_metadata and in QoS rules, add request origin in…
Browse files Browse the repository at this point in the history
… dynamic_priority (#144)

* introduce origin in user_resource_used

* add priority info

* add metadata_contains_all and metadata_contains_any

* add get and user_data functions

* add default to get function

* do not crash if match crash

* fix

* rename adaptor to entry_point in TestRequest class

* qa
  • Loading branch information
francesconazzaro authored Jan 8, 2025
1 parent 3869ac6 commit 5c3ce08
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 44 deletions.
10 changes: 9 additions & 1 deletion cads_broker/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def get_users_queue_from_processing_time(
session: sa.orm.Session,
interval_stop: datetime.datetime,
interval: datetime.timedelta = datetime.timedelta(hours=24),
origin: str | None = None,
) -> dict[str, int]:
"""Build the queue of the users from the processing time."""
interval_start = interval_stop - interval
Expand All @@ -474,6 +475,8 @@ def get_users_queue_from_processing_time(
SystemRequest.started_at.is_not(None),
)
where_clause = sa.sql.or_(interval_clause, SystemRequest.status == "running")
if origin:
where_clause = sa.sql.and_(where_clause, SystemRequest.origin == origin)

statement = (
sa.sql.select(SystemRequest.user_uid, user_cost)
Expand All @@ -482,7 +485,7 @@ def get_users_queue_from_processing_time(
.order_by("user_cost")
)

running_user_costs = dict(session.execute(statement).all())
running_user_costs: dict[str, int] = dict(session.execute(statement).all())

queue_users = session.execute(
sa.select(SystemRequest.user_uid)
Expand Down Expand Up @@ -540,6 +543,7 @@ def user_resource_used(
user_uid: str,
session: sa.orm.Session,
interval: int,
origin: str | None = None,
) -> int:
"""Return the amount of resource used by a user."""
global QOS_FUNCTIONS_CACHE
Expand All @@ -550,6 +554,7 @@ def user_resource_used(
session=session,
interval_stop=datetime.datetime.now(),
interval=datetime.timedelta(hours=interval / 60 / 60),
origin=origin,
)
QOS_FUNCTIONS_CACHE["users_resources"] = users_resources

Expand Down Expand Up @@ -718,6 +723,7 @@ def set_request_status(
error_message: str | None = None,
error_reason: str | None = None,
resubmit: bool | None = None,
priority: float | None = None,
) -> SystemRequest:
"""Set the status of a request."""
statement = sa.select(SystemRequest).where(SystemRequest.request_uid == request_uid)
Expand All @@ -730,6 +736,8 @@ def set_request_status(
{"resubmit_number": request.request_metadata.get("resubmit_number", 0) + 1}
)
request.request_metadata = metadata
if priority is not None:
request.request_metadata["priority"] = priority
if status == "successful":
request.finished_at = sa.func.now()
elif status == "failed":
Expand Down
5 changes: 4 additions & 1 deletion cads_broker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,10 @@ def submit_request(
) -> None:
"""Submit the request to the dask scheduler and update the qos rules accordingly."""
request = db.set_request_status(
request_uid=request.request_uid, status="running", session=session
request_uid=request.request_uid,
status="running",
priority=priority,
session=session,
)
self.qos.notify_start_of_request(request, scheduler=self.internal_scheduler)
self.queue.pop(request.request_uid)
Expand Down
45 changes: 38 additions & 7 deletions cads_broker/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,42 @@ def tagged(context, value):
return True


def request_contains_all(context, key, values):
request_values = context.request.request_body.get("request").get(key)
def contains_all(request_values, values):
if not isinstance(request_values, (list, tuple)):
request_values = [request_values]
s1 = set(request_values)
s2 = set(values)
return len(s1 & s2) == len(s2)


def request_contains_any(context, key, values):
request_values = context.request.request_body.get("request").get(key)
def contains_any(request_values, values):
if not isinstance(request_values, (list, tuple)):
request_values = [request_values]
s1 = set(request_values)
s2 = set(values)
return len(s1 & s2) > 0


def request_contains_all(context, key, values):
request_values = context.request.request_body.get("request").get(key)
return contains_all(request_values, values)


def request_contains_any(context, column, key, values):
request_values = context.request.request_body.get("request").get(key)
return contains_any(request_values, values)


def metadata_contains_all(context, key, values):
metadata_values = context.request.request_metadata.get(key)
return contains_all(metadata_values, values)


def metadata_contains_any(context, key, values):
metadata_values = context.request.request_metadata.get(key)
return contains_any(metadata_values, values)


def register_functions():
expressions.FunctionFactory.FunctionFactory.register_function(
"dataset",
Expand Down Expand Up @@ -68,9 +86,10 @@ def register_functions():
)
expressions.FunctionFactory.FunctionFactory.register_function(
"user_resource_used",
lambda context, interval=24 * 60 * 60: database.user_resource_used(
lambda context, interval=24 * 60 * 60, origin=None: database.user_resource_used(
user_uid=context.request.user_uid,
interval=interval,
origin=origin,
session=context.environment.session,
),
)
Expand All @@ -83,8 +102,14 @@ def register_functions():
),
)
expressions.FunctionFactory.FunctionFactory.register_function(
"request_age",
lambda context: context.request.age
"request_age", lambda context: context.request.age
)
expressions.FunctionFactory.FunctionFactory.register_function(
"user_data",
lambda context: context.request.request_metadata.get("user_data", {}),
)
expressions.FunctionFactory.FunctionFactory.register_function(
"get", lambda context, object, key, default=None: object.get(key, default)
)

expressions.FunctionFactory.FunctionFactory.register_function("tagged", tagged)
Expand All @@ -94,3 +119,9 @@ def register_functions():
expressions.FunctionFactory.FunctionFactory.register_function(
"request_contains_any", request_contains_any
)
expressions.FunctionFactory.FunctionFactory.register_function(
"metadata_contains_all", metadata_contains_all
)
expressions.FunctionFactory.FunctionFactory.register_function(
"metadata_contains_any", metadata_contains_any
)
5 changes: 2 additions & 3 deletions cads_broker/qos/QoS.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# nor does it submit to any jurisdiction.
#

import collections
import threading
from functools import wraps

Expand Down Expand Up @@ -180,9 +179,9 @@ def dynamic_priority(self, request):
def priority(self, request):
"""Compute the priority of a request."""
# The priority of a request increases with time
return self._properties(
return self._properties(request).starting_priority + self.dynamic_priority(
request
).starting_priority + self.dynamic_priority(request)
)

def dump(self, out=print):
self.rules.dump(out)
Expand Down
14 changes: 12 additions & 2 deletions cads_broker/qos/Rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@ def evaluate(self, request):
return self.conclusion.evaluate(Context(request, self.environment))

def match(self, request):
return self.condition.evaluate(Context(request, self.environment))
try:
ret_value = self.condition.evaluate(Context(request, self.environment))
except Exception as e:
print(
f"Error evaluating condition {self.condition} for request {request.request_uid}"
)
print(e)
return False
return ret_value

def dump(self, out):
out(self)
Expand Down Expand Up @@ -191,7 +199,9 @@ def add_priority(self, environment, info, condition, conclusion):
self.priorities.append(Priority(environment, info, condition, conclusion))

def add_dynamic_priority(self, environment, info, condition, conclusion):
self.dynamic_priorities.append(DynamicPriority(environment, info, condition, conclusion))
self.dynamic_priorities.append(
DynamicPriority(environment, info, condition, conclusion)
)

def add_permission(self, environment, info, condition, conclusion):
self.permissions.append(Permission(environment, info, condition, conclusion))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_01_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class TestRequest:
user_uid = "david"
dataset = "dataset-1"
adaptor = "adaptor1"
entry_point = "adaptor1"
cost = (1024 * 1024, 60 * 60 * 24)


Expand Down
36 changes: 26 additions & 10 deletions tests/test_02_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def mock_system_request(
request_metadata: dict | None = None,
adaptor_properties_hash: str = "adaptor_properties_hash",
entry_point: str = "entry_point",
origin: str = "api",
) -> db.SystemRequest:
system_request = db.SystemRequest(
request_uid=request_uid or str(uuid.uuid4()),
Expand All @@ -62,6 +63,7 @@ def mock_system_request(
request_metadata=request_metadata or {},
adaptor_properties_hash=adaptor_properties_hash,
entry_point=entry_point,
origin=origin,
)
return system_request

Expand Down Expand Up @@ -707,50 +709,58 @@ def test_get_users_queue_from_processing_time(session_obj: sa.orm.sessionmaker)
status="successful",
adaptor_properties_hash=adaptor_properties.hash,
user_uid="user1",
origin="api",
started_at=datetime.datetime.now() - datetime.timedelta(hours=10),
finished_at=datetime.datetime.now() - datetime.timedelta(hours=5),
)
request_2 = mock_system_request(
status="successful",
adaptor_properties_hash=adaptor_properties.hash,
user_uid="user1",
origin="ui",
started_at=datetime.datetime.now() - datetime.timedelta(hours=20),
finished_at=datetime.datetime.now() - datetime.timedelta(hours=10),
)
request_3 = mock_system_request(
status="successful",
adaptor_properties_hash=adaptor_properties.hash,
user_uid="user2",
origin="api",
started_at=datetime.datetime.now() - datetime.timedelta(hours=20),
finished_at=datetime.datetime.now() - datetime.timedelta(hours=10),
)
request_4 = mock_system_request(
status="running",
adaptor_properties_hash=adaptor_properties.hash,
user_uid="user2",
origin="api",
started_at=datetime.datetime.now() - datetime.timedelta(hours=20),
)
request_5 = mock_system_request(
status="accepted",
adaptor_properties_hash=adaptor_properties.hash,
user_uid="user2",
origin="api",
)
request_6 = mock_system_request(
status="accepted",
adaptor_properties_hash=adaptor_properties.hash,
user_uid="user3",
origin="api",
)
request_7 = mock_system_request(
status="failed",
adaptor_properties_hash=adaptor_properties.hash,
user_uid="user3",
origin="api",
started_at=None,
finished_at=datetime.datetime.now() - datetime.timedelta(hours=10),
)
request_8 = mock_system_request(
status="deleted",
adaptor_properties_hash=adaptor_properties.hash,
user_uid="user2",
origin="api",
started_at=datetime.datetime.now() - datetime.timedelta(hours=15),
finished_at=datetime.datetime.now() - datetime.timedelta(hours=10),
)
Expand All @@ -769,9 +779,17 @@ def test_get_users_queue_from_processing_time(session_obj: sa.orm.sessionmaker)
users_cost = db.get_users_queue_from_processing_time(
session, interval_stop=datetime.datetime.now()
)
assert users_cost["user3"] == 0
assert users_cost["user1"] == 15 * 60 * 60
assert users_cost["user2"] == (10 + 20 + 5) * 60 * 60
users_cost_api = db.get_users_queue_from_processing_time(
session, interval_stop=datetime.datetime.now(), origin="api"
)
users_cost_ui = db.get_users_queue_from_processing_time(
session, interval_stop=datetime.datetime.now(), origin="ui"
)
assert users_cost["user3"] == users_cost_api["user3"] == users_cost_ui["user3"] == 0
assert users_cost["user1"] == (5 + 10) * 60 * 60
assert users_cost["user2"] == users_cost_api["user2"] == (10 + 20 + 5) * 60 * 60
assert users_cost_api["user1"] == 5 * 60 * 60
assert users_cost_ui["user1"] == 10 * 60 * 60


def test_users_last_finished_at(session_obj: sa.orm.sessionmaker) -> None:
Expand Down Expand Up @@ -807,7 +825,9 @@ def test_users_last_finished_at(session_obj: sa.orm.sessionmaker) -> None:
session.add(request_2)
session.add(request_3)
session.commit()
users_last_finished_at = db.users_last_finished_at(session=session, after=now - datetime.timedelta(hours=24))
users_last_finished_at = db.users_last_finished_at(
session=session, after=now - datetime.timedelta(hours=24)
)
assert finished_at == users_last_finished_at["user1"]
assert "user2" not in users_last_finished_at

Expand Down Expand Up @@ -843,19 +863,15 @@ def test_user_last_completed_request(session_obj: sa.orm.sessionmaker) -> None:
session.add(request_1)
session.add(request_3)
session.commit()
assert (
now - finished_at
).seconds == db.user_last_completed_request(
assert (now - finished_at).seconds == db.user_last_completed_request(
session=session, user_uid="user1", interval=60 * 60 * 24
)
assert 60 * 60 * 24 == db.user_last_completed_request(
session=session, user_uid="user2", interval=60 * 60 * 24
)
session.add(request_2)
session.commit()
assert (
now - finished_at
).seconds == db.user_last_completed_request(
assert (now - finished_at).seconds == db.user_last_completed_request(
session=session, user_uid="user1", interval=60 * 60 * 24
)
# invalidate cache
Expand Down
Loading

0 comments on commit 5c3ce08

Please sign in to comment.