Skip to content

Commit

Permalink
Merge pull request #178 from ecmwf-projects/COPDS-1563-cost-based-pro…
Browse files Browse the repository at this point in the history
…cess-blocking

Implement cost-based request blocking
  • Loading branch information
mcucchi9 authored Mar 6, 2024
2 parents 3466ff3 + 0b84468 commit 7376895
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 24 deletions.
8 changes: 6 additions & 2 deletions cads_processing_api_service/adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,13 @@ def make_system_job_kwargs(


def instantiate_adaptor(
dataset: cads_catalogue.database.Resource,
dataset: cads_catalogue.database.Resource | None = None,
adaptor_properties: dict[str, Any] | None = None,
) -> cads_adaptors.AbstractAdaptor:
adaptor_properties = get_adaptor_properties(dataset)
if not adaptor_properties:
if dataset is None:
raise ValueError("Either adaptor_properties or dataset must be provided")
adaptor_properties = get_adaptor_properties(dataset)
adaptor_class = cads_adaptors.get_adaptor_class(
entry_point=adaptor_properties["entry_point"],
setup_code=adaptor_properties["setup_code"],
Expand Down
31 changes: 30 additions & 1 deletion cads_processing_api_service/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import fastapi
import requests

from . import config, exceptions
from . import config, costing, exceptions

VERIFICATION_ENDPOINT = {
"PRIVATE-TOKEN": "/account/verification/pat",
Expand Down Expand Up @@ -267,3 +267,32 @@ def verify_if_disabled(disabled_reason: str | None, user_role: str | None) -> No
)
else:
return


def verify_cost(request: dict[str, Any], adaptor_properties: dict[str, Any]) -> None:
"""Verify if the cost of a process execution request is within the allowed limits.
Parameters
----------
request : dict[str, Any]
Process execution request.
adaptor_properties : dict[str, Any]
Adaptor properties.
Raises
------
exceptions.PermissionDenied
Raised if the cost of the process execution request exceeds the allowed limits.
"""
costing_info = costing.compute_costing(request, adaptor_properties)
max_costs_exceeded = costing_info.max_costs_exceeded
if max_costs_exceeded:
raise exceptions.PermissionDenied(
title="cost limits exceeded",
detail=(
"the cost of the submitted request exceeds the allowed limits; "
f"the following limits have been exceeded: {max_costs_exceeded}"
),
)
else:
return
20 changes: 11 additions & 9 deletions cads_processing_api_service/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,25 +194,27 @@ def post_process_execution(
user_uid, user_role = auth.authenticate_user(auth_header, portal_header)
structlog.contextvars.bind_contextvars(user_uid=user_uid)
accepted_licences = auth.get_accepted_licences(auth_header)
execution_content = execution_content.model_dump()
request = execution_content.model_dump()
catalogue_sessionmaker = db_utils.get_catalogue_sessionmaker(
db_utils.ConnectionMode.read
)
with catalogue_sessionmaker() as catalogue_session:
resource: cads_catalogue.database.Resource = utils.lookup_resource_by_id(
dataset: cads_catalogue.database.Resource = utils.lookup_resource_by_id(
resource_id=process_id,
table=self.process_table,
session=catalogue_session,
load_messages=True,
)
auth.verify_if_disabled(resource.disabled_reason, user_role)
adaptor = adaptors.instantiate_adaptor(resource)
licences = adaptor.get_licences(execution_content)
auth.verify_if_disabled(dataset.disabled_reason, user_role)
adaptor_properties = adaptors.get_adaptor_properties(dataset)
auth.verify_cost(request, adaptor_properties)
adaptor = adaptors.instantiate_adaptor(adaptor_properties=adaptor_properties)
licences = adaptor.get_licences(request)
auth.validate_licences(accepted_licences, licences)
job_id = str(uuid.uuid4())
structlog.contextvars.bind_contextvars(job_id=job_id)
job_kwargs = adaptors.make_system_job_kwargs(
resource, execution_content, adaptor.resources
dataset, request, adaptor.resources
)
compute_sessionmaker = db_utils.get_compute_sessionmaker(
mode=db_utils.ConnectionMode.write
Expand All @@ -224,8 +226,8 @@ def post_process_execution(
origin=auth.REQUEST_ORIGIN[auth_header[0]],
user_uid=user_uid,
process_id=process_id,
portal=resource.portal,
qos_tags=resource.qos_tags,
portal=dataset.portal,
qos_tags=dataset.qos_tags,
**job_kwargs,
)
dataset_messages = [
Expand All @@ -234,7 +236,7 @@ def post_process_execution(
severity=message.severity,
content=message.content,
)
for message in resource.messages
for message in dataset.messages
]
status_info = utils.make_status_info(
job, dataset_metadata={"messages": dataset_messages}
Expand Down
46 changes: 35 additions & 11 deletions cads_processing_api_service/costing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
"""Requests' cost estimation endpoint."""

# Copyright 2022, European Union.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

from typing import Any

import cads_adaptors
import cads_adaptors.constraints
import cads_catalogue
import fastapi
import ogc_api_processes_fastapi.models

from . import adaptors, db_utils, models, utils
from . import adaptors, costing, db_utils, models, utils


def estimate_costs(
Expand All @@ -21,18 +36,27 @@ def estimate_costs(
dataset = utils.lookup_resource_by_id(
resource_id=process_id, table=table, session=catalogue_session
)
adaptor_configuration: dict[
str, Any
] = dataset.resource_data.adaptor_configuration # type: ignore
costing_config: dict[str, Any] = adaptor_configuration.get("costing", {})
adaptor_properties = adaptors.get_adaptor_properties(dataset)
costing_info = costing.compute_costing(request.model_dump(), adaptor_properties)
return costing_info


def compute_costing(
request: dict[str, Any],
adaptor_properties: dict[str, Any],
) -> models.Costing:
adaptor: cads_adaptors.AbstractAdaptor = adaptors.instantiate_adaptor(
adaptor_properties=adaptor_properties
)
costs: dict[str, float] = adaptor.estimate_costs(request=request)
costing_config: dict[str, Any] = adaptor_properties["config"].get("costing", {})
max_costs: dict[str, Any] = costing_config.get("max_costs", {})
adaptor: cads_adaptors.AbstractAdaptor = adaptors.instantiate_adaptor(dataset)
costs: dict[str, float] = adaptor.estimate_costs(request=request.model_dump())
max_costs_exceeded = {}
for max_cost_id, max_cost_value in max_costs.items():
if max_cost_id in costs.keys():
if costs[max_cost_id] > max_cost_value:
max_costs_exceeded[max_cost_id] = max_cost_value
costing = models.Costing(costs=costs, max_costs_exceeded=max_costs_exceeded)

return costing
costing_info = models.Costing(
costs=costs, max_costs=max_costs, max_costs_exceeded=max_costs_exceeded
)
return costing_info
1 change: 1 addition & 0 deletions cads_processing_api_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@ class Exception(ogc_api_processes_fastapi.models.Exception):

class Costing(pydantic.BaseModel):
costs: dict[str, float] | None = None
max_costs: dict[str, float] | None = None
max_costs_exceeded: dict[str, float] | None = None
16 changes: 16 additions & 0 deletions tests/test_20_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import cads_adaptors.adaptors.url
import cads_catalogue.database
import pytest

from cads_processing_api_service import adaptors

Expand Down Expand Up @@ -98,5 +99,20 @@ def test_instantiate_adaptor() -> None:
),
)
adaptor = adaptors.instantiate_adaptor(dataset)
assert isinstance(adaptor, cads_adaptors.adaptors.url.UrlCdsAdaptor)

adaptor_properties = {
"entry_point": "cads_adaptors:UrlCdsAdaptor",
"setup_code": None,
"form": form_data,
"config": {
"constraints": constraints_data,
"mapping": mapping,
"licences": licences,
},
}
adaptor = adaptors.instantiate_adaptor(adaptor_properties=adaptor_properties)
assert isinstance(adaptor, cads_adaptors.adaptors.url.UrlCdsAdaptor)

with pytest.raises(ValueError):
adaptors.instantiate_adaptor(adaptor_properties={})
22 changes: 21 additions & 1 deletion tests/test_30_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

# mypy: ignore-errors

import unittest.mock

import cads_broker
import pytest

from cads_processing_api_service import auth, exceptions
from cads_processing_api_service import auth, exceptions, models


def test_check_licences() -> None:
Expand Down Expand Up @@ -58,3 +60,21 @@ def test_verify_if_disabled() -> None:
test_disabled_reason = None
test_user_role = None
auth.verify_if_disabled(test_disabled_reason, test_user_role)


def test_verify_cost() -> None:
with unittest.mock.patch(
"cads_processing_api_service.costing.compute_costing"
) as mock_compute_costing:
mock_compute_costing.return_value = models.Costing(
costs={"cost_1": 1.0, "cost_2": 2.0},
max_costs_exceeded={},
)
auth.verify_cost({}, {})

mock_compute_costing.return_value = models.Costing(
costs={"cost_1": 1.0, "cost_2": 2.0},
max_costs_exceeded={"cost_1": 0},
)
with pytest.raises(exceptions.PermissionDenied):
auth.verify_cost({}, {})

0 comments on commit 7376895

Please sign in to comment.