Skip to content

Commit

Permalink
Merge branch 'main' into COPDS-2157-origin-in-job-response
Browse files Browse the repository at this point in the history
  • Loading branch information
mcucchi9 committed Nov 12, 2024
2 parents 98e1848 + 2a0e394 commit 0da1153
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 58 deletions.
62 changes: 50 additions & 12 deletions cads_processing_api_service/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,54 @@ def get_accepted_licences(auth_header: tuple[str, str]) -> set[tuple[str, int]]:
return accepted_licences


def format_missing_licences_message(
request_url: str,
process_id: str,
portal: str | None = None,
missing_licences_message_template: str = SETTINGS.missing_licences_message,
dataset_licences_url_template: str = SETTINGS.dataset_licences_url,
) -> str:
"""Format a message for the user indicating that some licences are missing.
Parameters
----------
request_url : str
Request URL.
process_id : str
Process identifier.
portal : str | None, optional
Dataset portal identifier.
missing_licences_message_template : str, optional
Template for the missing licences message.
dataset_licences_url_template : str, optional
Template for the dataset licences URL.
Returns
-------
str
Formatted message.
"""
parsed_request_url = urllib.parse.urlparse(request_url)
request_netloc = parsed_request_url.netloc
portal_netloc = request_netloc
if portal is not None:
portal_netloc = SETTINGS.portals.get(portal, request_netloc)
base_url = f"{parsed_request_url.scheme}://{portal_netloc}"
dataset_licences_url = dataset_licences_url_template.format(
base_url=base_url, process_id=process_id
)
missing_licences_message = missing_licences_message_template.format(
dataset_licences_url=dataset_licences_url
)
return missing_licences_message


def verify_licences(
accepted_licences: set[tuple[str, int]] | list[tuple[str, int]],
required_licences: set[tuple[str, int]] | list[tuple[str, int]],
api_request_url: str,
request_url: str,
process_id: str,
dataset_portal: str | None = None,
) -> set[tuple[str, int]]:
"""
Verify if all the licences required for the process submission have been accepted.
Expand All @@ -233,10 +276,12 @@ def verify_licences(
Licences accepted by a user stored in the Extended Profiles database.
required_licences : set[tuple[str, int]] | list[tuple[str, int]],
Licences bound to the required process/dataset.
api_request_url : str
API request URL, required to generate the URL to the dataset licences page.
request_url : str
Request URL, required to generate the URL to the dataset licences page.
process_id : str
Process identifier, required to generate the URL to the dataset licences page.
dataset_portal : str | None, optional
Dataset portal identifier.
Returns
-------
Expand All @@ -254,15 +299,8 @@ def verify_licences(
required_licences = set(required_licences)
missing_licences = required_licences - accepted_licences
if not len(missing_licences) == 0:
missing_licences_message_template = SETTINGS.missing_licences_message
dataset_licences_url_template = SETTINGS.dataset_licences_url
parsed_api_request_url = urllib.parse.urlparse(api_request_url)
base_url = f"{parsed_api_request_url.scheme}://{parsed_api_request_url.netloc}"
dataset_licences_url = dataset_licences_url_template.format(
base_url=base_url, process_id=process_id
)
missing_licences_message = missing_licences_message_template.format(
dataset_licences_url=dataset_licences_url
missing_licences_message = format_missing_licences_message(
request_url, process_id, dataset_portal
)
raise exceptions.PermissionDenied(
title="required licences not accepted", detail=missing_licences_message
Expand Down
6 changes: 5 additions & 1 deletion cads_processing_api_service/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,11 @@ def post_process_execution(
accepted_licences = auth.get_accepted_licences(auth_info.auth_header)
request_url = str(request.url)
_ = auth.verify_licences(
accepted_licences, required_licences, request_url, process_id
accepted_licences,
required_licences,
request_url,
process_id,
dataset.portal,
)
job_message = None
else:
Expand Down
26 changes: 23 additions & 3 deletions cads_processing_api_service/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
"to accept the required licence(s)."
)

DATASET_LICENCES_URL = "{base_url}/datasets/{process_id}?tab=download#manage-licences"

RATE_LIMITS_STORAGE = limits.storage.MemoryStorage()
RATE_LIMITS_LIMITER = limits.strategies.FixedWindowRateLimiter(RATE_LIMITS_STORAGE)

Expand Down Expand Up @@ -153,6 +155,18 @@ def load_rate_limits(rate_limits_file: str | None) -> RateLimitsConfig:
return rate_limits


def load_portals(portals_file: str | None) -> dict[str, str]:
portals = {}
if portals_file is not None:
try:
with open(portals_file, "r") as file:
loaded_portals = yaml.safe_load(file)
portals = loaded_portals
except OSError:
logger.exception("Failed to read portals file", portals_file=portals_file)
return portals


class Settings(pydantic_settings.BaseSettings):
"""General API settings."""

Expand Down Expand Up @@ -181,18 +195,24 @@ def profiles_api_url(self) -> str:
anonymous_licences_message: str = ANONYMOUS_LICENCES_MESSAGE
deprecation_warning_message: str = DEPRECATION_WARNING_MESSAGE
missing_licences_message: str = MISSING_LICENCES_MESSAGE
dataset_licences_url: str = (
"{base_url}/datasets/{process_id}?tab=download#manage-licences"
)
dataset_licences_url: str = DATASET_LICENCES_URL

rate_limits_file: str | None = None
rate_limits: RateLimitsConfig = pydantic.Field(default=RateLimitsConfig())

portals_file: str | None = None
portals: dict[str, str] = pydantic.Field(default={})

@pydantic.model_validator(mode="after") # type: ignore
def load_rate_limits(self) -> pydantic_settings.BaseSettings:
self.rate_limits: RateLimitsConfig = load_rate_limits(self.rate_limits_file)
return self

@pydantic.model_validator(mode="after") # type: ignore
def load_portals(self) -> pydantic_settings.BaseSettings:
self.portals: dict[str, str] = load_portals(self.portals_file)
return self


settings = Settings()

Expand Down
1 change: 1 addition & 0 deletions ci/environment-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- sphinx-autoapi
# DO NOT EDIT ABOVE THIS LINE, ADD DEPENDENCIES BELOW
- pip
- pytest-mock
- mypy != 1.11.0
- mypy != 1.11.1
- types-cachetools
Expand Down
76 changes: 59 additions & 17 deletions tests/test_30_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,68 @@

# mypy: ignore-errors

import unittest.mock

import cads_broker
import pytest

from cads_processing_api_service import auth, exceptions, models


def test_format_missing_licences_message(mocker) -> None:
request_url = "http://base_url/api/v1/processes/process_id/execution"
process_id = "test_process_id"
missing_licences_message_template = "{dataset_licences_url}"
res = auth.format_missing_licences_message(
request_url,
process_id,
missing_licences_message_template=missing_licences_message_template,
)
exp = "http://base_url/datasets/test_process_id?tab=download#manage-licences"
assert res == exp

request_url = "https://base_url/api/v1/processes/process_id/execution"
process_id = "test_process_id"
missing_licences_message_template = "{dataset_licences_url}"
portal_id = "missing_test_portal_id"
mocker.patch(
"cads_processing_api_service.auth.SETTINGS.portals",
{"test_portal_id": "test_portal_netloc"},
)
res = auth.format_missing_licences_message(
request_url, process_id, portal_id, missing_licences_message_template
)
exp = "https://base_url/datasets/test_process_id?tab=download#manage-licences"
assert res == exp

request_url = "https://base_url/api/v1/processes/process_id/execution"
process_id = "test_process_id"
missing_licences_message_template = "{dataset_licences_url}"
portal_id = "test_portal_id"
mocker.patch(
"cads_processing_api_service.auth.SETTINGS.portals",
{"test_portal_id": "test_portal_netloc"},
)
res = auth.format_missing_licences_message(
request_url, process_id, portal_id, missing_licences_message_template
)
exp = "https://test_portal_netloc/datasets/test_process_id?tab=download#manage-licences"
assert res == exp


def test_verify_licences() -> None:
accepted_licences = {("licence_1", 1), ("licence_2", 2), ("licence_3", 3)}
required_licences = {("licence_1", 1), ("licence_2", 2)}
api_request_url = "http://base_url/api/v1/processes/process_id/execution"
request_url = "http://base_url/api/v1/processes/process_id/execution"
process_id = "process_id"
missing_licences = auth.verify_licences(
accepted_licences, required_licences, api_request_url, process_id
accepted_licences, required_licences, request_url, process_id
)
assert len(missing_licences) == 0

accepted_licences = {("licence_1", 1), ("licence_2", 1)}
required_licences = {("licence_1", 1), ("licence_2", 2)}
with pytest.raises(exceptions.PermissionDenied):
missing_licences = auth.verify_licences(
accepted_licences, required_licences, api_request_url, process_id
accepted_licences, required_licences, request_url, process_id
)


Expand Down Expand Up @@ -68,20 +107,23 @@ def test_verify_if_disabled() -> None:
auth.verify_if_disabled(test_disabled_reason, test_user_role)


def test_verify_cost() -> None:
with unittest.mock.patch(
"cads_processing_api_service.costing.compute_costing"
) as mock_compute_costing:
mock_compute_costing.return_value = models.CostingInfo(
def test_verify_cost(mocker) -> None:
mocker.patch(
"cads_processing_api_service.costing.compute_costing",
return_value=models.CostingInfo(
costs={"cost_id_1": 10.0, "cost_id_2": 10.0},
limits={"cost_id_1": 20.0, "cost_id_2": 20.0},
)
costs = auth.verify_cost({}, {}, "api")
assert costs == {"cost_id_1": 10.0, "cost_id_2": 10.0}
),
)
costs = auth.verify_cost({}, {}, "api")
assert costs == {"cost_id_1": 10.0, "cost_id_2": 10.0}

mock_compute_costing.return_value = models.CostingInfo(
mocker.patch(
"cads_processing_api_service.costing.compute_costing",
return_value=models.CostingInfo(
costs={"cost_id_1": 10.0, "cost_id_2": 10.0},
limits={"cost_id_1": 5.0, "cost_id_2": 20.0},
)
with pytest.raises(exceptions.PermissionDenied):
auth.verify_cost({}, {}, "api")
),
)
with pytest.raises(exceptions.PermissionDenied):
auth.verify_cost({}, {}, "api")
46 changes: 21 additions & 25 deletions tests/test_30_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# mypy: ignore-errors
import datetime
import unittest.mock
import uuid
from typing import Any

Expand Down Expand Up @@ -220,26 +219,23 @@ def test_dictify_job() -> None:
assert all([key in res_job and res_job[key] == exp_job[key] for key in exp_job])


def test_get_job_from_broker_db() -> None:
def test_get_job_from_broker_db(mocker) -> None:
test_job_id = "1234"
mock_session = unittest.mock.Mock(spec=sqlalchemy.orm.Session)
with unittest.mock.patch("cads_broker.database.get_request") as mock_get_request:
mock_get_request.return_value = cads_broker.database.SystemRequest(
request_uid=test_job_id
)
job = utils.get_job_from_broker_db(test_job_id, session=mock_session)
mock_session = mocker.Mock(spec=sqlalchemy.orm.Session)
mocker.patch(
"cads_broker.database.get_request",
return_value=cads_broker.database.SystemRequest(request_uid=test_job_id),
)
job = utils.get_job_from_broker_db(test_job_id, session=mock_session)
assert isinstance(job, cads_broker.SystemRequest)
assert job.request_uid == test_job_id

with unittest.mock.patch("cads_broker.database.get_request") as mock_get_request:
mock_get_request.side_effect = cads_broker.database.NoResultFound()
with pytest.raises(ogc_api_processes_fastapi.exceptions.NoSuchJob):
job = utils.get_job_from_broker_db(test_job_id, session=mock_session)

with unittest.mock.patch("cads_broker.database.get_request") as mock_get_request:
mock_get_request.side_effect = cads_broker.database.NoResultFound()
with pytest.raises(ogc_api_processes_fastapi.exceptions.NoSuchJob):
job = utils.get_job_from_broker_db("1234", session=mock_session)
mocker.patch(
"cads_broker.database.get_request",
side_effect=cads_broker.database.NoResultFound(),
)
with pytest.raises(ogc_api_processes_fastapi.exceptions.NoSuchJob):
job = utils.get_job_from_broker_db(test_job_id, session=mock_session)


def test_update_results_href() -> None:
Expand All @@ -256,8 +252,8 @@ def test_update_results_href() -> None:
assert updated_href == exp_updated_href


def test_get_results_from_job(prepare_env_for_download_nodes) -> None:
mock_session = unittest.mock.Mock(spec=sqlalchemy.orm.Session)
def test_get_results_from_job(prepare_env_for_download_nodes, mocker) -> None:
mock_session = mocker.Mock(spec=sqlalchemy.orm.Session)
job = cads_broker.SystemRequest(
**{
"status": "successful",
Expand Down Expand Up @@ -293,14 +289,14 @@ def test_get_results_from_job(prepare_env_for_download_nodes) -> None:
}
)
with pytest.raises(ogc_api_processes_fastapi.exceptions.JobResultsFailed) as exc:
with unittest.mock.patch(
"cads_processing_api_service.utils.get_job_events"
) as mock_get_job_events:
mock_get_job_events.return_value = [
mocker.patch(
"cads_processing_api_service.utils.get_job_events",
return_value=[
"2024-01-01T16:20:12.175021",
"error message",
]
results = utils.get_results_from_job(job, session=mock_session)
],
)
results = utils.get_results_from_job(job, session=mock_session)
assert exc.value.traceback == "error message"

job = cads_broker.SystemRequest(**{"status": "accepted", "request_uid": "1234"})
Expand Down

0 comments on commit 0da1153

Please sign in to comment.