diff --git a/cads_broker/entry_points.py b/cads_broker/entry_points.py index 26081977..e52aace6 100644 --- a/cads_broker/entry_points.py +++ b/cads_broker/entry_points.py @@ -1,10 +1,11 @@ """Module for entry points.""" import datetime +import enum import os import random import uuid -from enum import Enum +from pathlib import Path from typing import Any, Optional import sqlalchemy as sa @@ -65,13 +66,7 @@ def add_dummy_requests( def requests_cleaner( connection_string: Optional[str] = None, older_than_days: Optional[int] = 365 ) -> None: - """Remove records from the system_requests table older than `older_than_days`. - - Parameters - ---------- - connection_string: something like 'postgresql://user:password@netloc:port/dbname' - older_than_days: minimum age (in days) to consider a record to be removed - """ + """Remove records from the system_requests table older than `older_than_days`.""" if not connection_string: dbsettings = config.ensure_settings(config.dbsettings) connection_string = dbsettings.connection_string @@ -112,7 +107,24 @@ def requests_cleaner( raise -class RequestStatus(str, Enum): +@app.command() +def list_request_uids( + query: str, + output_file: Annotated[Path, typer.Argument(file_okay=True, dir_okay=False)] + | None = Path("request_uids.txt"), +) -> None: + """List request_uids from the system_requests table.""" + with database.ensure_session_obj(None)() as session: + result = session.execute( + sa.text(f"select request_uid from system_requests where {query}") + ) + with output_file.open("w") as f: + for row in result: + f.write(str(row[0]) + "\n") + print(f"successfully wrote {result.rowcount} request_uids to {output_file}") + + +class RequestStatus(str, enum.Enum): """Enum for request status.""" running = "running" @@ -124,6 +136,10 @@ def delete_requests( status: RequestStatus = RequestStatus.running, user_uid: Optional[str] = None, request_uid: Optional[str] = None, + request_uids_file: Annotated[ + Path, typer.Argument(exists=True, file_okay=True, dir_okay=False) + ] + | None = None, connection_string: Optional[str] = None, minutes: float = 0, seconds: float = 0, @@ -132,12 +148,7 @@ def delete_requests( message: Optional[str] = "The request has been dismissed by the administrator.", skip_confirmation: Annotated[bool, typer.Option("--yes", "-y")] = False, ) -> None: - """Set the status of records in the system_requests table to 'dismissed'. - - Parameters - ---------- - connection_string: something like 'postgresql://user:password@netloc:port/dbname' - """ + """Set the status of records in the system_requests table to 'dismissed'.""" if not connection_string: dbsettings = config.ensure_settings(config.dbsettings) connection_string = dbsettings.connection_string @@ -145,11 +156,18 @@ def delete_requests( minutes=minutes, seconds=seconds, hours=hours, days=days ) with database.ensure_session_obj(None)() as session: - statement = ( - sa.update(database.SystemRequest) - .where(database.SystemRequest.status == status) - .where(database.SystemRequest.created_at < timestamp) - ) + if request_uids_file: + with request_uids_file.open() as f: + request_uids = f.read().splitlines() + statement = sa.update(database.SystemRequest).where( + database.SystemRequest.request_uid.in_(request_uids) + ) + else: + statement = ( + sa.update(database.SystemRequest) + .where(database.SystemRequest.status == status) + .where(database.SystemRequest.created_at < timestamp) + ) if user_uid: statement = statement.where(database.SystemRequest.user_uid == user_uid) if request_uid: @@ -169,7 +187,7 @@ def delete_requests( number_of_requests = session.execute(statement).rowcount if not skip_confirmation: if not typer.confirm( - f"Setting status to 'dismissed' for {number_of_requests} {status} requests. " + f"Setting status to 'dismissed' for {number_of_requests} requests. " "Do you want to continue?", abort=True, default=True,