From 22652d12a0e276a513bf673dad99989c0b64d9bd Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Sun, 12 Jun 2022 13:26:41 +0200 Subject: [PATCH] Upgrade FAB to 4.1.1 The Flask Application Builder have been updated recently to support a number of newer dependencies. This PR is the attempt to migrate FAB to newer version. This includes: * update setup.py and setup.cfg upper and lower bounds to account for proper version of dependencies that FAB < 4.0.0 was blocking from upgrade * added typed Flask application retrieval with a custom application fields available for MyPy typing checks. * fix typing to account for typing hints added in multiple upgraded libraries optional values and content of request returned as Mapping * switch to PyJWT 2.* by using non-deprecated "required" claim as list rather than separate fields * add possibiliyt to install providers without constraints so that we could avoid errors on conflicting constraints when upgrade-to-newer-dependencies is used * add pre-commit to check that 2.4+ only get_airflow_app is not used in providers * avoid Bad Request in case the request sent to Flask 2.0 is not JSon content type * switch imports of internal classes to direct packages where classes are available rather than from "airflow.models" to satisfy MyPY * synchronize changes of FAB Security Manager 4.1.1 with our copy of the Security Manager. * add error handling for a few "None" cases detected by MyPY * corrected test cases that were broken by immutability of Flask 2 objects and better escaping done by Flask 2 * updated test cases to account for redirection to "path" rather than full URL by Flask2 Fixes: #22397 --- .github/workflows/ci.yml | 4 + Dockerfile.ci | 36 +++-- airflow/api/auth/backend/basic_auth.py | 5 +- .../api_connexion/endpoints/dag_endpoint.py | 9 +- .../endpoints/dag_run_endpoint.py | 26 ++-- .../endpoints/extra_link_endpoint.py | 4 +- .../api_connexion/endpoints/log_endpoint.py | 11 +- .../endpoints/mapping_from_request.py | 24 +++ .../api_connexion/endpoints/pool_endpoint.py | 16 +- .../endpoints/role_and_permission_endpoint.py | 15 +- .../api_connexion/endpoints/task_endpoint.py | 7 +- .../endpoints/task_instance_endpoint.py | 17 ++- .../api_connexion/endpoints/user_endpoint.py | 13 +- .../endpoints/variable_endpoint.py | 7 +- .../api_connexion/endpoints/xcom_endpoint.py | 5 +- airflow/api_connexion/schemas/dag_schema.py | 3 +- .../schemas/task_instance_schema.py | 2 +- airflow/api_connexion/security.py | 7 +- airflow/models/abstractoperator.py | 1 - airflow/operators/trigger_dagrun.py | 5 +- .../common/auth_backend/google_openid.py | 2 +- airflow/sensors/external_task.py | 6 +- airflow/utils/airflow_flask_app.py | 37 +++++ airflow/utils/jwt_signer.py | 4 +- airflow/www/api/experimental/endpoints.py | 3 +- airflow/www/auth.py | 6 +- .../www/extensions/init_wsgi_middlewares.py | 2 +- airflow/www/fab_security/manager.py | 37 +++-- airflow/www/views.py | 115 +++++++------- .../commands/release_management_commands.py | 8 + .../src/airflow_breeze/params/shell_params.py | 1 + .../utils/docker_command_utils.py | 1 + images/breeze/output-commands-hash.txt | 2 +- .../output-verify-provider-packages.svg | 140 +++++++++--------- scripts/ci/docker-compose/_docker.env | 1 + scripts/ci/docker-compose/base.yml | 1 + scripts/ci/docker-compose/devcontainer.env | 1 + .../pre_commit_check_2_2_compatibility.py | 40 +++-- scripts/docker/entrypoint_ci.sh | 36 +++-- scripts/in_container/_in_container_utils.sh | 29 ++-- setup.cfg | 52 ++----- setup.py | 10 -- .../endpoints/test_dag_endpoint.py | 3 +- .../endpoints/test_xcom_endpoint.py | 6 +- .../api_connexion/schemas/test_dag_schema.py | 3 +- .../remote_user_api_auth_backend.py | 6 +- tests/utils/test_serve_logs.py | 8 +- tests/www/views/test_views.py | 25 ++-- tests/www/views/test_views_decorators.py | 6 +- tests/www/views/test_views_log.py | 2 +- tests/www/views/test_views_mount.py | 4 +- 51 files changed, 489 insertions(+), 325 deletions(-) create mode 100644 airflow/api_connexion/endpoints/mapping_from_request.py create mode 100644 airflow/utils/airflow_flask_app.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c112a6da727a7..96211320211d0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -801,6 +801,8 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" run: > breeze verify-provider-packages --use-airflow-version wheel --use-packages-from-dist --package-format wheel + env: + SKIP_CONSTRAINTS: "${{ needs.build-info.outputs.upgradeToNewerDependencies }}" - name: "Remove airflow package and replace providers with 2.2-compliant versions" run: | rm -vf dist/apache_airflow-*.whl \ @@ -878,6 +880,8 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" run: > breeze verify-provider-packages --use-airflow-version sdist --use-packages-from-dist --package-format sdist + env: + SKIP_CONSTRAINTS: "${{ needs.build-info.outputs.upgradeToNewerDependencies }}" - name: "Fix ownership" run: breeze fix-ownership if: always() diff --git a/Dockerfile.ci b/Dockerfile.ci index 7bf6257f5343c..e9f3f4b428a09 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -686,9 +686,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo "${COLOR_BLUE}Uninstalling airflow and providers" echo uninstall_airflow_and_providers - echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers elif [[ ${USE_AIRFLOW_VERSION} == "sdist" ]]; then echo @@ -696,9 +702,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers else echo @@ -706,9 +718,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "none" + else + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi fi if [[ ${USE_PACKAGES_FROM_DIST=} == "true" ]]; then echo diff --git a/airflow/api/auth/backend/basic_auth.py b/airflow/api/auth/backend/basic_auth.py index 397a722a98cf2..12f00b435fe11 100644 --- a/airflow/api/auth/backend/basic_auth.py +++ b/airflow/api/auth/backend/basic_auth.py @@ -18,10 +18,11 @@ from functools import wraps from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast -from flask import Response, current_app, request +from flask import Response, request from flask_appbuilder.const import AUTH_LDAP from flask_login import login_user +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.www.fab_security.sqla.models import User CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None @@ -40,7 +41,7 @@ def auth_current_user() -> Optional[User]: if auth is None or not auth.username or not auth.password: return None - ab_security_manager = current_app.appbuilder.sm + ab_security_manager = get_airflow_app().appbuilder.sm user = None if ab_security_manager.auth_type == AUTH_LDAP: user = ab_security_manager.auth_user_ldap(auth.username, auth.password) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 0505f864ee333..7940a25c8f9fb 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -19,7 +19,7 @@ from typing import Collection, Optional from connexion import NoContent -from flask import current_app, g, request +from flask import g, request from marshmallow import ValidationError from sqlalchemy.orm import Session from sqlalchemy.sql.expression import or_ @@ -38,6 +38,7 @@ from airflow.exceptions import AirflowException, DagNotFound from airflow.models.dag import DagModel, DagTag from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -56,7 +57,7 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) def get_dag_details(*, dag_id: str) -> APIResponse: """Get details of DAG.""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found") return dag_detail_schema.dump(dag) @@ -83,7 +84,7 @@ def get_dags( if dag_id_pattern: dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) - readable_dags = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags)) if tags: @@ -143,7 +144,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat if dag_id_pattern == '~': dag_id_pattern = '%' dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) - editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user) + editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user) dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags)) if tags: diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index e510126534b12..38159b0454285 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -19,7 +19,7 @@ import pendulum from connexion import NoContent -from flask import current_app, g, request +from flask import g from marshmallow import ValidationError from sqlalchemy import or_ from sqlalchemy.orm import Query, Session @@ -30,6 +30,7 @@ set_dag_run_state_to_success, ) from airflow.api_connexion import security +from airflow.api_connexion.endpoints.mapping_from_request import get_mapping_from_request from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters from airflow.api_connexion.schemas.dag_run_schema import ( @@ -47,6 +48,7 @@ from airflow.api_connexion.types import APIResponse from airflow.models import DagModel, DagRun from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -167,7 +169,7 @@ def get_dag_runs( # This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs. if dag_id == "~": - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder query = query.filter(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user))) else: query = query.filter(DagRun.dag_id == dag_id) @@ -199,13 +201,13 @@ def get_dag_runs( @provide_session def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: """Get list of DAG Runs""" - body = request.get_json() + body = get_mapping_from_request() try: data = dagruns_batch_form_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) query = session.query(DagRun) if data.get("dag_ids"): @@ -252,7 +254,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: detail=f"DAG with dag_id: '{dag_id}' has import errors", ) try: - post_body = dagrun_schema.load(request.json, session=session) + post_body = dagrun_schema.load(get_mapping_from_request(), session=session) except ValidationError as err: raise BadRequest(detail=str(err)) @@ -268,7 +270,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: ) if not dagrun_instance: try: - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_run = dag.create_dagrun( run_type=DagRunType.MANUAL, run_id=run_id, @@ -277,7 +279,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: state=DagRunState.QUEUED, conf=post_body.get("conf"), external_trigger=True, - dag_hash=current_app.dag_bag.dags_hash.get(dag_id), + dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id), ) return dagrun_schema.dump(dag_run) except ValueError as ve: @@ -310,12 +312,12 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) try: - post_body = set_dagrun_state_form_schema.load(request.json) + post_body = set_dagrun_state_form_schema.load(get_mapping_from_request()) except ValidationError as err: raise BadRequest(detail=str(err)) state = post_body['state'] - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if state == DagRunState.SUCCESS: set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True) elif state == DagRunState.QUEUED: @@ -339,15 +341,15 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() ) if dag_run is None: - error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' + error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) try: - post_body = clear_dagrun_form_schema.load(request.json) + post_body = clear_dagrun_form_schema.load(get_mapping_from_request()) except ValidationError as err: raise BadRequest(detail=str(err)) dry_run = post_body.get('dry_run', False) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) start_date = dag_run.logical_date end_date = dag_run.logical_date diff --git a/airflow/api_connexion/endpoints/extra_link_endpoint.py b/airflow/api_connexion/endpoints/extra_link_endpoint.py index 3e9535603bda3..94b36928bfd0c 100644 --- a/airflow/api_connexion/endpoints/extra_link_endpoint.py +++ b/airflow/api_connexion/endpoints/extra_link_endpoint.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from flask import current_app from sqlalchemy.orm.session import Session from airflow import DAG @@ -25,6 +24,7 @@ from airflow.exceptions import TaskNotFound from airflow.models.dagbag import DagBag from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -46,7 +46,7 @@ def get_extra_links( """Get extra links for task instance""" from airflow.models.taskinstance import TaskInstance - dagbag: DagBag = current_app.dag_bag + dagbag: DagBag = get_airflow_app().dag_bag dag: DAG = dagbag.get_dag(dag_id) if not dag: raise NotFound("DAG not found", detail=f'DAG with ID = "{dag_id}" not found') diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index f1335fe527451..171cacb076e7c 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -14,10 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from typing import Any, Optional -from flask import Response, current_app, request +from flask import Response, request from itsdangerous.exc import BadSignature from itsdangerous.url_safe import URLSafeSerializer from sqlalchemy.orm.session import Session @@ -29,6 +28,7 @@ from airflow.exceptions import TaskNotFound from airflow.models import TaskInstance from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.session import NEW_SESSION, provide_session @@ -52,7 +52,7 @@ def get_log( session: Session = NEW_SESSION, ) -> APIResponse: """Get logs for specific task instance""" - key = current_app.config["SECRET_KEY"] + key = get_airflow_app().config["SECRET_KEY"] if not token: metadata = {} else: @@ -87,7 +87,7 @@ def get_log( metadata['end_of_log'] = True raise NotFound(title="TaskInstance not found") - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if dag: try: ti.task = dag.get_task(ti.task_id) @@ -101,7 +101,8 @@ def get_log( if return_type == 'application/json' or return_type is None: # default logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata) logs = logs[0] if task_try_number is not None else logs - token = URLSafeSerializer(key).dumps(metadata) + # we must have token here, so we can safely ignore it + token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment] return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs)) # text/plain. Stream logs = task_log_reader.read_log_stream(ti, task_try_number, metadata) diff --git a/airflow/api_connexion/endpoints/mapping_from_request.py b/airflow/api_connexion/endpoints/mapping_from_request.py new file mode 100644 index 0000000000000..72ba6729c5ef4 --- /dev/null +++ b/airflow/api_connexion/endpoints/mapping_from_request.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Mapping, cast + + +def get_mapping_from_request() -> Mapping[str, Any]: + from flask import request + + return cast(Mapping[str, Any], request.get_json()) diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index 1d24fea63d756..992b37df73f5d 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -18,13 +18,14 @@ from http import HTTPStatus from typing import Optional -from flask import Response, request +from flask import Response from marshmallow import ValidationError from sqlalchemy import func from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from airflow.api_connexion import security +from airflow.api_connexion.endpoints.mapping_from_request import get_mapping_from_request from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.pool_schema import PoolCollection, pool_collection_schema, pool_schema @@ -87,7 +88,10 @@ def patch_pool( """Update a pool""" # Only slots can be modified in 'default_pool' try: - if pool_name == Pool.DEFAULT_POOL_NAME and request.json["name"] != Pool.DEFAULT_POOL_NAME: + if ( + pool_name == Pool.DEFAULT_POOL_NAME + and get_mapping_from_request()["name"] != Pool.DEFAULT_POOL_NAME + ): if update_mask and len(update_mask) == 1 and update_mask[0].strip() == "slots": pass else: @@ -100,7 +104,7 @@ def patch_pool( raise NotFound(detail=f"Pool with name:'{pool_name}' not found") try: - patch_body = pool_schema.load(request.json) + patch_body = pool_schema.load(get_mapping_from_request()) except ValidationError as err: raise BadRequest(detail=str(err.messages)) @@ -121,7 +125,7 @@ def patch_pool( else: required_fields = {"name", "slots"} - fields_diff = required_fields - set(request.json.keys()) + fields_diff = required_fields - set(get_mapping_from_request().keys()) if fields_diff: raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}") @@ -136,12 +140,12 @@ def patch_pool( def post_pool(*, session: Session = NEW_SESSION) -> APIResponse: """Create a pool""" required_fields = {"name", "slots"} # Pool would require both fields in the post request - fields_diff = required_fields - set(request.json.keys()) + fields_diff = required_fields - set(get_mapping_from_request().keys()) if fields_diff: raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}") try: - post_body = pool_schema.load(request.json, session=session) + post_body = pool_schema.load(get_mapping_from_request(), session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) diff --git a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py index 88a68341c129e..1b25769af7737 100644 --- a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py +++ b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py @@ -19,7 +19,7 @@ from typing import List, Optional, Tuple from connexion import NoContent -from flask import current_app, request +from flask import request from marshmallow import ValidationError from sqlalchemy import asc, desc, func @@ -35,6 +35,7 @@ ) from airflow.api_connexion.types import APIResponse, UpdateMask from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.www.fab_security.sqla.models import Action, Role from airflow.www.security import AirflowSecurityManager @@ -55,7 +56,7 @@ def _check_action_and_resource(sm: AirflowSecurityManager, perms: List[Tuple[str @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE)]) def get_role(*, role_name: str) -> APIResponse: """Get role""" - ab_security_manager = current_app.appbuilder.sm + ab_security_manager = get_airflow_app().appbuilder.sm role = ab_security_manager.find_role(name=role_name) if not role: raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found") @@ -66,7 +67,7 @@ def get_role(*, role_name: str) -> APIResponse: @format_parameters({"limit": check_limit}) def get_roles(*, order_by: str = "name", limit: int, offset: Optional[int] = None) -> APIResponse: """Get roles""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder session = appbuilder.get_session total_entries = session.query(func.count(Role.id)).scalar() direction = desc if order_by.startswith("-") else asc @@ -90,7 +91,7 @@ def get_roles(*, order_by: str = "name", limit: int, offset: Optional[int] = Non @format_parameters({'limit': check_limit}) def get_permissions(*, limit: int, offset: Optional[int] = None) -> APIResponse: """Get permissions""" - session = current_app.appbuilder.get_session + session = get_airflow_app().appbuilder.get_session total_entries = session.query(func.count(Action.id)).scalar() query = session.query(Action) actions = query.offset(offset).limit(limit).all() @@ -100,7 +101,7 @@ def get_permissions(*, limit: int, offset: Optional[int] = None) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ROLE)]) def delete_role(*, role_name: str) -> APIResponse: """Delete a role""" - ab_security_manager = current_app.appbuilder.sm + ab_security_manager = get_airflow_app().appbuilder.sm role = ab_security_manager.find_role(name=role_name) if not role: raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found") @@ -111,7 +112,7 @@ def delete_role(*, role_name: str) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE)]) def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse: """Update a role""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder security_manager = appbuilder.sm body = request.json try: @@ -145,7 +146,7 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse @security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ROLE)]) def post_role() -> APIResponse: """Create a new role""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder security_manager = appbuilder.sm body = request.json try: diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py index 28c39b000c28d..74b6e7e9ee8ed 100644 --- a/airflow/api_connexion/endpoints/task_endpoint.py +++ b/airflow/api_connexion/endpoints/task_endpoint.py @@ -16,8 +16,6 @@ # under the License. from operator import attrgetter -from flask import current_app - from airflow import DAG from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound @@ -25,6 +23,7 @@ from airflow.api_connexion.types import APIResponse from airflow.exceptions import TaskNotFound from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app @security.requires_access( @@ -35,7 +34,7 @@ ) def get_task(*, dag_id: str, task_id: str) -> APIResponse: """Get simplified representation of a task.""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found") @@ -54,7 +53,7 @@ def get_task(*, dag_id: str, task_id: str) -> APIResponse: ) def get_tasks(*, dag_id: str, order_by: str = "task_id") -> APIResponse: """Get tasks for DAG""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found") tasks = dag.tasks diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index c2416ab0d9d44..2c6c08fb9cfce 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -16,7 +16,6 @@ # under the License. from typing import Any, Iterable, List, Optional, Tuple, TypeVar -from flask import current_app, request from marshmallow import ValidationError from sqlalchemy import and_, func, or_ from sqlalchemy.exc import MultipleResultsFound @@ -25,6 +24,7 @@ from sqlalchemy.sql import ClauseElement from airflow.api_connexion import security +from airflow.api_connexion.endpoints.mapping_from_request import get_mapping_from_request from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import format_datetime, format_parameters from airflow.api_connexion.schemas.task_instance_schema import ( @@ -42,6 +42,7 @@ from airflow.models.dagrun import DagRun as DR from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState, State @@ -188,7 +189,7 @@ def get_mapped_task_instances( # 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404 if base_query.with_entities(func.count('*')).scalar() == 0: - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"DAG {dag_id} not found" raise NotFound(error_message) @@ -364,7 +365,7 @@ def get_task_instances( @provide_session def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: """Get list of task instances.""" - body = request.get_json() + body = get_mapping_from_request() try: data = task_instance_batch_form.load(body) except ValidationError as err: @@ -423,20 +424,20 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: @provide_session def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" - body = request.get_json() + body = get_mapping_from_request() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! - task_instances = dag.clear(dry_run=True, dag_bag=current_app.dag_bag, **data) + task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, **data) if not dry_run: clear_task_instances( task_instances.all(), @@ -460,14 +461,14 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> @provide_session def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of task instances.""" - body = request.get_json() + body = get_mapping_from_request() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) diff --git a/airflow/api_connexion/endpoints/user_endpoint.py b/airflow/api_connexion/endpoints/user_endpoint.py index 6b4e984a69559..3ab476e219cb9 100644 --- a/airflow/api_connexion/endpoints/user_endpoint.py +++ b/airflow/api_connexion/endpoints/user_endpoint.py @@ -18,7 +18,7 @@ from typing import List, Optional from connexion import NoContent -from flask import current_app, request +from flask import request from marshmallow import ValidationError from sqlalchemy import asc, desc, func from werkzeug.security import generate_password_hash @@ -34,13 +34,14 @@ ) from airflow.api_connexion.types import APIResponse, UpdateMask from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.www.fab_security.sqla.models import Role, User @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_USER)]) def get_user(*, username: str) -> APIResponse: """Get a user""" - ab_security_manager = current_app.appbuilder.sm + ab_security_manager = get_airflow_app().appbuilder.sm user = ab_security_manager.find_user(username=username) if not user: raise NotFound(title="User not found", detail=f"The User with username `{username}` was not found") @@ -51,7 +52,7 @@ def get_user(*, username: str) -> APIResponse: @format_parameters({"limit": check_limit}) def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None) -> APIResponse: """Get users""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder session = appbuilder.get_session total_entries = session.query(func.count(User.id)).scalar() direction = desc if order_by.startswith("-") else asc @@ -87,7 +88,7 @@ def post_user() -> APIResponse: except ValidationError as e: raise BadRequest(detail=str(e.messages)) - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm username = data["username"] email = data["email"] @@ -130,7 +131,7 @@ def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: except ValidationError as e: raise BadRequest(detail=str(e.messages)) - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(username=username) if user is None: @@ -194,7 +195,7 @@ def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_USER)]) def delete_user(*, username: str) -> APIResponse: """Delete a user""" - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(username=username) if user is None: diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index 487d2cc486c83..a751c02ec2bbb 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -17,12 +17,13 @@ from http import HTTPStatus from typing import Optional -from flask import Response, request +from flask import Response from marshmallow import ValidationError from sqlalchemy import func from sqlalchemy.orm import Session from airflow.api_connexion import security +from airflow.api_connexion.endpoints.mapping_from_request import get_mapping_from_request from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.variable_schema import variable_collection_schema, variable_schema @@ -79,7 +80,7 @@ def get_variables( def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Response: """Update a variable by key""" try: - data = variable_schema.load(request.json) + data = variable_schema.load(get_mapping_from_request()) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) @@ -100,7 +101,7 @@ def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Resp def post_variables() -> Response: """Create a variable""" try: - data = variable_schema.load(request.json) + data = variable_schema.load(get_mapping_from_request()) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index 9cc6b6d79a933..62c7262f7ed2c 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -16,7 +16,7 @@ # under the License. from typing import Optional -from flask import current_app, g +from flask import g from sqlalchemy import and_ from sqlalchemy.orm import Session @@ -27,6 +27,7 @@ from airflow.api_connexion.types import APIResponse from airflow.models import DagRun as DR, XCom from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -52,7 +53,7 @@ def get_xcom_entries( """Get all XCom values""" query = session.query(XCom) if dag_id == '~': - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) query = query.filter(XCom.dag_id.in_(readable_dag_ids)) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) diff --git a/airflow/api_connexion/schemas/dag_schema.py b/airflow/api_connexion/schemas/dag_schema.py index 2f369113290d9..a75fda8dd0441 100644 --- a/airflow/api_connexion/schemas/dag_schema.py +++ b/airflow/api_connexion/schemas/dag_schema.py @@ -83,7 +83,8 @@ def get_owners(obj: DagModel): @staticmethod def get_token(obj: DagModel): """Return file token""" - serializer = URLSafeSerializer(conf.get('webserver', 'secret_key')) + # the secret key is always available, so we can ignore Optional here + serializer = URLSafeSerializer(conf.get('webserver', 'secret_key')) # type: ignore[arg-type] return serializer.dumps(obj.fileloc) diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py index 37005256f6cdc..74824dbaf87c6 100644 --- a/airflow/api_connexion/schemas/task_instance_schema.py +++ b/airflow/api_connexion/schemas/task_instance_schema.py @@ -60,7 +60,7 @@ class Meta: pid = auto_field() executor_config = auto_field() sla_miss = fields.Nested(SlaMissSchema, dump_default=None) - rendered_fields = JsonObjectField(default={}) + rendered_fields = JsonObjectField(dump_default={}) def get_attribute(self, obj, attr, default): if attr == "sla_miss": diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 3562c98eb4b35..6c84181f91bd3 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -18,16 +18,17 @@ from functools import wraps from typing import Callable, Optional, Sequence, Tuple, TypeVar, cast -from flask import Response, current_app +from flask import Response from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated +from airflow.utils.airflow_flask_app import get_airflow_app T = TypeVar("T", bound=Callable) def check_authentication() -> None: """Checks that the request has valid authorization information.""" - for auth in current_app.api_auth: + for auth in get_airflow_app().api_auth: response = auth.requires_authentication(Response)() if response.status_code == 200: return @@ -38,7 +39,7 @@ def check_authentication() -> None: def requires_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Callable[[T], T]: """Factory for decorator that checks current user's permissions against required permissions.""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder appbuilder.sm.sync_resource_permissions(permissions) def requires_access_decorator(func: T): diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index cb566ea6f9557..0645066b0a3aa 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -307,7 +307,6 @@ def get_extra_links(self, ti: "TaskInstance", link_name: str) -> Optional[str]: return link.get_link(self, ti_key=ti.key) else: return link.get_link(self, ti.dag_run.logical_date) # type: ignore[misc] - return None def render_template_fields( self, diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 0689f14c56261..4578fd2df818b 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -23,7 +23,10 @@ from airflow.api.common.trigger_dag import trigger_dag from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists -from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun +from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagBag +from airflow.models.dagrun import DagRun from airflow.models.xcom import XCom from airflow.utils import timezone from airflow.utils.context import Context diff --git a/airflow/providers/google/common/auth_backend/google_openid.py b/airflow/providers/google/common/auth_backend/google_openid.py index 496ac29616686..a267c0e63a1ca 100644 --- a/airflow/providers/google/common/auth_backend/google_openid.py +++ b/airflow/providers/google/common/auth_backend/google_openid.py @@ -88,7 +88,7 @@ def _verify_id_token(id_token: str) -> Optional[str]: def _lookup_user(user_email: str): - security_manager = current_app.appbuilder.sm + security_manager = current_app.appbuilder.sm # type: ignore[attr-defined] user = security_manager.find_user(email=user_email) if not user: diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 9bb074d47c7fc..327fed6db0929 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -25,7 +25,11 @@ from sqlalchemy import func from airflow.exceptions import AirflowException -from airflow.models import BaseOperatorLink, DagBag, DagModel, DagRun, TaskInstance +from airflow.models.baseoperator import BaseOperatorLink +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagBag +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator from airflow.sensors.base import BaseSensorOperator from airflow.utils.helpers import build_airflow_url_with_query diff --git a/airflow/utils/airflow_flask_app.py b/airflow/utils/airflow_flask_app.py new file mode 100644 index 0000000000000..a14ff99398d21 --- /dev/null +++ b/airflow/utils/airflow_flask_app.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, List, cast + +from flask import Flask + +from airflow.models.dagbag import DagBag +from airflow.www.extensions.init_appbuilder import AirflowAppBuilder + + +class AirflowApp(Flask): + """Airflow Flask Application""" + + appbuilder: AirflowAppBuilder + dag_bag: DagBag + api_auth: List[Any] + + +def get_airflow_app() -> AirflowApp: + from flask import current_app + + return cast(AirflowApp, current_app) diff --git a/airflow/utils/jwt_signer.py b/airflow/utils/jwt_signer.py index 941a3d05981ce..e767997ebeb78 100644 --- a/airflow/utils/jwt_signer.py +++ b/airflow/utils/jwt_signer.py @@ -73,9 +73,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: algorithms=[self._algorithm], options={ "verify_signature": True, - "require_exp": True, - "require_iat": True, - "require_nbf": True, + "require": ["exp", "iat", "nbf"], }, audience=self._audience, ) diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index 898988db81c50..d317bc1f2eef1 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -70,7 +70,8 @@ def add_deprecation_headers(response: Response): return response -api_experimental.after_request(add_deprecation_headers) +# This is really experimental. We do not care too much about typing here +api_experimental.after_request(add_deprecation_headers) # type: ignore[arg-type] @api_experimental.route('/dags//dag_runs', methods=['POST']) diff --git a/airflow/www/auth.py b/airflow/www/auth.py index e95b4f323fda2..1dec47b43832c 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -37,7 +37,11 @@ def decorated(*args, **kwargs): appbuilder = current_app.appbuilder dag_id = ( - request.args.get("dag_id") or request.form.get("dag_id") or (request.json or {}).get("dag_id") + request.args.get("dag_id") + or request.form.get("dag_id") + or request.is_json + and request.json.get("dag_id") + or None ) if appbuilder.sm.check_authorization(permissions, dag_id): return func(*args, **kwargs) diff --git a/airflow/www/extensions/init_wsgi_middlewares.py b/airflow/www/extensions/init_wsgi_middlewares.py index 0ed78073e92f5..00c04006ff68e 100644 --- a/airflow/www/extensions/init_wsgi_middlewares.py +++ b/airflow/www/extensions/init_wsgi_middlewares.py @@ -37,7 +37,7 @@ def init_wsgi_middleware(flask_app: Flask): base_url = "" if base_url: flask_app.wsgi_app = DispatcherMiddleware( # type: ignore - _root_app, mounts={base_url: flask_app.wsgi_app} + _root_app, mounts={base_url: flask_app.wsgi_app} # type: ignore ) # Apply ProxyFix middleware diff --git a/airflow/www/fab_security/manager.py b/airflow/www/fab_security/manager.py index 50e36cfa99cc1..8399f11df367e 100644 --- a/airflow/www/fab_security/manager.py +++ b/airflow/www/fab_security/manager.py @@ -291,7 +291,7 @@ def create_jwt_manager(self, app) -> JWTManager: """ jwt_manager = JWTManager() jwt_manager.init_app(app) - jwt_manager.user_loader_callback_loader(self.load_user_jwt) + jwt_manager.user_lookup_loader(self.load_user_jwt) return jwt_manager def create_builtin_roles(self): @@ -654,6 +654,18 @@ def get_oauth_user_info(self, provider, resp): "email": data.get("email", ""), "role_keys": data.get("groups", []), } + # for Keycloak + if provider in ["keycloak", "keycloak_before_17"]: + me = self.appbuilder.sm.oauth_remotes[provider].get("openid-connect/userinfo") + me.raise_for_status() + data = me.json() + log.debug("User info from Keycloak: %s", data) + return { + "username": data.get("preferred_username", ""), + "first_name": data.get("given_name", ""), + "last_name": data.get("family_name", ""), + "email": data.get("email", ""), + } else: return {} @@ -1028,12 +1040,6 @@ def auth_user_ldap(self, username, password): try: # LDAP certificate settings - if self.auth_ldap_allow_self_signed: - ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) - ldap.set_option(ldap.OPT_X_TLS_NEWCTX, 0) - elif self.auth_ldap_tls_demand: - ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) - ldap.set_option(ldap.OPT_X_TLS_NEWCTX, 0) if self.auth_ldap_tls_cacertdir: ldap.set_option(ldap.OPT_X_TLS_CACERTDIR, self.auth_ldap_tls_cacertdir) if self.auth_ldap_tls_cacertfile: @@ -1042,6 +1048,12 @@ def auth_user_ldap(self, username, password): ldap.set_option(ldap.OPT_X_TLS_CERTFILE, self.auth_ldap_tls_certfile) if self.auth_ldap_tls_keyfile: ldap.set_option(ldap.OPT_X_TLS_KEYFILE, self.auth_ldap_tls_keyfile) + if self.auth_ldap_allow_self_signed: + ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW) + ldap.set_option(ldap.OPT_X_TLS_NEWCTX, 0) + elif self.auth_ldap_tls_demand: + ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND) + ldap.set_option(ldap.OPT_X_TLS_NEWCTX, 0) # Initialise LDAP connection con = ldap.initialize(self.auth_ldap_server) @@ -1355,7 +1367,10 @@ def get_user_menu_access(self, menu_names: Optional[List[str]] = None) -> Set[st return self._get_user_permission_resources(g.user, "menu_access", resource_names=menu_names) elif current_user_jwt: return self._get_user_permission_resources( - current_user_jwt, "menu_access", resource_names=menu_names + # the current_user_jwt is a lazy proxy, so we need to ignore type checking + current_user_jwt, # type: ignore[arg-type] + "menu_access", + resource_names=menu_names, ) else: return self._get_user_permission_resources(None, "menu_access", resource_names=menu_names) @@ -1661,9 +1676,9 @@ def load_user(self, user_id): """Load user by ID""" return self.get_user_by_id(int(user_id)) - def load_user_jwt(self, user_id): - """Load user JWT""" - user = self.load_user(user_id) + def load_user_jwt(self, _jwt_header, jwt_data): + identity = jwt_data["sub"] + user = self.load_user(identity) # Set flask g.user to JWT user, we can't do it on before request g.user = user return user diff --git a/airflow/www/views.py b/airflow/www/views.py index c780b08fdd077..aca9152a35e21 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -44,7 +44,6 @@ Response, abort, before_render_template, - current_app, flash, g, jsonify, @@ -120,6 +119,7 @@ from airflow.timetables.base import DataInterval, TimeRestriction from airflow.timetables.interval import CronDataIntervalTimetable from airflow.utils import json as utils_json, timezone, yaml +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.dates import infer_time_unit, scale_time_units from airflow.utils.docs import get_doc_url_for_provider, get_docs_url from airflow.utils.helpers import alchemy_to_dict @@ -723,13 +723,13 @@ def add_user_permissions_to_dag(sender, template, context, **extra): """ if 'dag' in context: dag = context['dag'] - can_create_dag_run = current_app.appbuilder.sm.has_access( + can_create_dag_run = get_airflow_app().appbuilder.sm.has_access( permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN ) - dag.can_edit = current_app.appbuilder.sm.can_edit_dag(dag.dag_id) + dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id) dag.can_trigger = dag.can_edit and can_create_dag_run - dag.can_delete = current_app.appbuilder.sm.can_delete_dag(dag.dag_id) + dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id) context['dag'] = dag @@ -816,7 +816,7 @@ def index(self): end = start + dags_per_page # Get all the dag id the user could access - filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) with create_session() as session: # read orm_dags from the db @@ -925,7 +925,7 @@ def index(self): ) dashboard_alerts = [ - fm for fm in settings.DASHBOARD_UIALERTS if fm.should_show(current_app.appbuilder.sm) + fm for fm in settings.DASHBOARD_UIALERTS if fm.should_show(get_airflow_app().appbuilder.sm) ] def _iter_parsed_moved_data_table_names(): @@ -1006,7 +1006,7 @@ def dag_stats(self, session=None): """Dag statistics.""" dr = models.DagRun - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) dag_state_stats = session.query(dr.dag_id, dr.state, sqla.func.count(dr.state)).group_by( dr.dag_id, dr.state @@ -1051,7 +1051,7 @@ def dag_stats(self, session=None): @provide_session def task_stats(self, session=None): """Task Statistics""" - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) if not allowed_dag_ids: return wwwutils.json_response({}) @@ -1160,7 +1160,7 @@ def task_stats(self, session=None): @provide_session def last_dagruns(self, session=None): """Last DAG runs""" - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id} @@ -1284,7 +1284,7 @@ def legacy_dag_details(self): @provide_session def dag_details(self, dag_id, session=None): """Get Dag details.""" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id, session=session) title = "DAG Details" @@ -1360,7 +1360,7 @@ def rendered_templates(self, session): root = request.args.get('root', '') logging.info("Retrieving rendered templates.") - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) dag_run = dag.get_dagrun(execution_date=dttm, session=session) raw_task = dag.get_task(task_id).prepare_for_execution() @@ -1463,7 +1463,10 @@ def rendered_k8s(self, session: Session = NEW_SESSION): map_index = request.args.get('map_index', -1, type=int) logging.info("Retrieving rendered templates.") - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) + if task_id is None: + logging.warning("Task id not passed in the request") + abort(400) task = dag.get_task(task_id) dag_run = dag.get_dagrun(execution_date=dttm, session=session) ti = dag_run.get_task_instance(task_id=task.task_id, map_index=map_index, session=session) @@ -1568,7 +1571,7 @@ def get_logs_with_metadata(self, session=None): ) try: - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if dag: ti.task = dag.get_task(ti.task_id) @@ -1699,7 +1702,7 @@ def task(self, session): map_index = request.args.get('map_index', -1, type=int) form = DateTimeForm(data={'execution_date': dttm}) root = request.args.get('root', '') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag or task_id not in dag.task_ids: flash(f"Task [{dag_id}.{task_id}] doesn't seem to exist at the moment", "error") @@ -1878,7 +1881,7 @@ def run(self, session=None): dag_run_id = request.form.get('dag_run_id') map_index = request.args.get('map_index', -1, type=int) origin = get_safe_url(request.form.get('origin')) - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) task = dag.get_task(task_id) ignore_all_deps = request.form.get('ignore_all_deps') == "true" @@ -1979,7 +1982,7 @@ def trigger(self, session=None): request_conf = request.values.get('conf') request_execution_date = request.values.get('execution_date', default=timezone.utcnow().isoformat()) is_dag_run_conf_overrides_params = conf.getboolean('core', 'dag_run_conf_overrides_params') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_orm = session.query(models.DagModel).filter(models.DagModel.dag_id == dag_id).first() if not dag_orm: flash(f"Cannot find dag {dag_id}") @@ -2080,7 +2083,7 @@ def trigger(self, session=None): state=State.QUEUED, conf=run_conf, external_trigger=True, - dag_hash=current_app.dag_bag.dags_hash.get(dag_id), + dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id), run_id=run_id, ) except (ValueError, ParamValidationError) as ve: @@ -2162,7 +2165,7 @@ def clear(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') origin = get_safe_url(request.form.get('origin')) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if 'map_index' not in request.form: map_indexes: Optional[List[int]] = None @@ -2223,7 +2226,7 @@ def dagrun_clear(self): dag_run_id = request.form.get('dag_run_id') confirmed = request.form.get('confirmed') == "true" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dr = dag.get_dagrun(run_id=dag_run_id) start_date = dr.logical_date end_date = dr.logical_date @@ -2247,7 +2250,7 @@ def dagrun_clear(self): @provide_session def blocked(self, session=None): """Mark Dag Blocked.""" - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id} @@ -2270,7 +2273,7 @@ def blocked(self, session=None): payload = [] for dag_id, active_dag_runs in dags: max_active_runs = 0 - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if dag: # TODO: Make max_active_runs a column so we can query for it directly max_active_runs = dag.max_active_runs @@ -2287,7 +2290,7 @@ def _mark_dagrun_state_as_failed(self, dag_id, dag_run_id, confirmed): if not dag_run_id: return {'status': 'error', 'message': 'Invalid dag_run_id'} - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: return {'status': 'error', 'message': f'Cannot find DAG: {dag_id}'} @@ -2305,7 +2308,7 @@ def _mark_dagrun_state_as_success(self, dag_id, dag_run_id, confirmed): if not dag_run_id: return {'status': 'error', 'message': 'Invalid dag_run_id'} - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: return {'status': 'error', 'message': f'Cannot find DAG: {dag_id}'} @@ -2323,7 +2326,7 @@ def _mark_dagrun_state_as_queued(self, dag_id: str, dag_run_id: str, confirmed: if not dag_run_id: return {'status': 'error', 'message': 'Invalid dag_run_id'} - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: return {'status': 'error', 'message': f'Cannot find DAG: {dag_id}'} @@ -2397,7 +2400,7 @@ def dagrun_details(self, session=None): dag_id = request.args.get("dag_id") run_id = request.args.get("run_id") - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_run: Optional[DagRun] = ( session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one_or_none() ) @@ -2448,7 +2451,7 @@ def _mark_task_instance_state( past: bool, state: TaskInstanceState, ): - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not run_id: flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error") @@ -2496,7 +2499,7 @@ def confirm(self): past = to_boolean(args.get('past')) origin = origin or url_for('Airflow.index') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: msg = f'DAG {dag_id} not found' return redirect_or_json(origin, msg, status='error', status_code=404) @@ -2685,7 +2688,7 @@ def tree(self): @provide_session def grid(self, dag_id, session=None): """Get Dag's grid view.""" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) if not dag: flash(f'DAG "{dag_id}" seems to be missing from DagBag.', "error") @@ -2794,7 +2797,7 @@ def _convert_to_date(session, column): else: return func.date(column) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) if not dag: flash(f'DAG "{dag_id}" seems to be missing from DagBag.', "error") @@ -2918,7 +2921,7 @@ def legacy_graph(self): @provide_session def graph(self, dag_id, session=None): """Get DAG as Graph.""" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) if not dag: flash(f'DAG "{dag_id}" seems to be missing.', "error") @@ -3030,7 +3033,7 @@ def duration(self, dag_id, session=None): default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') dag_model = DagModel.get_dagmodel(dag_id) - dag: Optional[DAG] = current_app.dag_bag.get_dag(dag_id) + dag: Optional[DAG] = get_airflow_app().dag_bag.get_dag(dag_id) if dag is None: flash(f'DAG "{dag_id}" seems to be missing.', "error") return redirect(url_for('Airflow.index')) @@ -3182,7 +3185,7 @@ def legacy_tries(self): def tries(self, dag_id, session=None): """Shows all tries.""" default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) base_date = request.args.get('base_date') num_runs = request.args.get('num_runs', default=default_dag_run, type=int) @@ -3272,7 +3275,7 @@ def legacy_landing_times(self): def landing_times(self, dag_id, session=None): """Shows landing times.""" default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) base_date = request.args.get('base_date') num_runs = request.args.get('num_runs', default=default_dag_run, type=int) @@ -3389,7 +3392,7 @@ def legacy_gantt(self): @provide_session def gantt(self, dag_id, session=None): """Show GANTT chart.""" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_model = DagModel.get_dagmodel(dag_id) root = request.args.get('root') @@ -3517,7 +3520,7 @@ def extra_links(self, session: "Session" = NEW_SESSION): execution_date = request.args.get('execution_date') link_name = request.args.get('link_name') dttm = _safe_parse_datetime(execution_date) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag or task_id not in dag.task_ids: response = jsonify( @@ -3541,6 +3544,10 @@ def extra_links(self, session: "Session" = NEW_SESSION): response = jsonify({'url': None, 'error': 'Task Instances not found'}) response.status_code = 404 return response + if link_name is None: + response = jsonify({'url': None, 'error': 'Link name not passed'}) + response.status_code = 400 + return response try: url = task.get_extra_links(ti, link_name) except ValueError as err: @@ -3567,7 +3574,7 @@ def extra_links(self, session: "Session" = NEW_SESSION): def task_instances(self): """Shows task instances.""" dag_id = request.args.get('dag_id') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dttm = request.args.get('execution_date') if dttm: @@ -3595,7 +3602,7 @@ def task_instances(self): def grid_data(self): """Returns grid data""" dag_id = request.args.get('dag_id') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: response = jsonify({'error': f"can't find dag {dag_id}"}) @@ -3649,7 +3656,7 @@ def robots(self): of the risk associated with exposing Airflow to the public internet, however it does not address the real security risks associated with such a deployment. """ - return send_from_directory(current_app.static_folder, 'robots.txt') + return send_from_directory(get_airflow_app().static_folder, 'robots.txt') @expose('/audit_log') @auth.has_access( @@ -3661,7 +3668,7 @@ def robots(self): @provide_session def audit_log(self, session=None): dag_id = request.args.get('dag_id') - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) included_events = conf.get('webserver', 'audit_view_included_events', fallback=None) excluded_events = conf.get('webserver', 'audit_view_excluded_events', fallback=None) @@ -3765,9 +3772,9 @@ class DagFilter(BaseFilter): """Filter using DagIDs""" def apply(self, query, func): - if current_app.appbuilder.sm.has_all_dags_access(g.user): + if get_airflow_app().appbuilder.sm.has_all_dags_access(g.user): return query - filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) return query.filter(self.model.dag_id.in_(filter_dag_ids)) @@ -3790,7 +3797,7 @@ class AirflowPrivilegeVerifierModelView(AirflowModelView): @staticmethod def validate_dag_edit_access(item: Union[DagRun, TaskInstance]): """Validates whether the user has 'can_edit' access for this specific DAG.""" - if not current_app.appbuilder.sm.can_edit_dag(item.dag_id): + if not get_airflow_app().appbuilder.sm.can_edit_dag(item.dag_id): raise AirflowException(f"Access denied for dag_id {item.dag_id}") def pre_add(self, item: Union[DagRun, TaskInstance]): @@ -3821,7 +3828,7 @@ def check_dag_edit_acl_for_actions( items: Optional[Union[List[TaskInstance], List[DagRun], TaskInstance, DagRun]], *args, **kwargs, - ) -> None: + ) -> Callable: if items is None: dag_ids: Set[str] = set() elif isinstance(items, list): @@ -3836,7 +3843,7 @@ def check_dag_edit_acl_for_actions( ) for dag_id in dag_ids: - if not current_app.appbuilder.sm.can_edit_dag(dag_id): + if not get_airflow_app().appbuilder.sm.can_edit_dag(dag_id): flash(f"Access denied for dag_id {dag_id}", "danger") logging.warning("User %s tried to modify %s without having access.", g.user.username, dag_id) return redirect(self.get_default_url()) @@ -4439,7 +4446,9 @@ def fqueued_slots(self): def _can_create_variable() -> bool: - return current_app.appbuilder.sm.has_access(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE) + return get_airflow_app().appbuilder.sm.has_access( + permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE + ) class VariableModelView(AirflowModelView): @@ -4783,7 +4792,10 @@ def action_set_failed(self, drs: List[DagRun], session=None): for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all(): count += 1 altered_tis += set_dag_run_state_to_failed( - dag=current_app.dag_bag.get_dag(dr.dag_id), run_id=dr.run_id, commit=True, session=session + dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), + run_id=dr.run_id, + commit=True, + session=session, ) altered_ti_count = len(altered_tis) flash(f"{count} dag runs and {altered_ti_count} task instances were set to failed") @@ -4808,7 +4820,10 @@ def action_set_success(self, drs: List[DagRun], session=None): for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all(): count += 1 altered_tis += set_dag_run_state_to_success( - dag=current_app.dag_bag.get_dag(dr.dag_id), run_id=dr.run_id, commit=True, session=session + dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), + run_id=dr.run_id, + commit=True, + session=session, ) altered_ti_count = len(altered_tis) flash(f"{count} dag runs and {altered_ti_count} task instances were set to success") @@ -4828,7 +4843,7 @@ def action_clear(self, drs: List[DagRun], session=None): dag_to_tis: Dict[DAG, List[TaskInstance]] = {} for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id for dagrun in drs])).all(): count += 1 - dag = current_app.dag_bag.get_dag(dr.dag_id) + dag = get_airflow_app().dag_bag.get_dag(dr.dag_id) tis_to_clear = dag_to_tis.setdefault(dag, []) tis_to_clear += dr.get_task_instances() @@ -5121,7 +5136,7 @@ def action_clear(self, task_instances, session=None): dag_to_tis = collections.defaultdict(list) for ti in task_instances: - dag = current_app.dag_bag.get_dag(ti.dag_id) + dag = get_airflow_app().dag_bag.get_dag(ti.dag_id) dag_to_tis[dag].append(ti) for dag, task_instances_list in dag_to_tis.items(): @@ -5237,7 +5252,7 @@ def autocomplete(self, session=None): dag_ids_query = dag_ids_query.filter(DagModel.is_paused) owners_query = owners_query.filter(DagModel.is_paused) - filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids)) owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids)) diff --git a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py index a5333a08ab0b6..b649e814b9194 100644 --- a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py @@ -511,6 +511,12 @@ def generate_constraints( @option_use_airflow_version @option_airflow_extras @option_airflow_constraints_reference +@click.option( + "--skip-constraints", + is_flag=True, + help="Do not use constraints when installing providers.", + envvar='SKIP_CONSTRAINTS', +) @option_use_packages_from_dist @option_installation_package_format @option_verbose @@ -522,6 +528,7 @@ def verify_provider_packages( dry_run: bool, use_airflow_version: Optional[str], airflow_constraints_reference: str, + skip_constraints: bool, airflow_extras: str, use_packages_from_dist: bool, debug: bool, @@ -538,6 +545,7 @@ def verify_provider_packages( airflow_extras=airflow_extras, airflow_constraints_reference=airflow_constraints_reference, use_packages_from_dist=use_packages_from_dist, + skip_constraints=skip_constraints, package_format=package_format, ) rebuild_or_pull_ci_image_if_needed(command_params=shell_params, dry_run=dry_run, verbose=verbose) diff --git a/dev/breeze/src/airflow_breeze/params/shell_params.py b/dev/breeze/src/airflow_breeze/params/shell_params.py index 8b908aefa77d8..1b5d6925dee9c 100644 --- a/dev/breeze/src/airflow_breeze/params/shell_params.py +++ b/dev/breeze/src/airflow_breeze/params/shell_params.py @@ -81,6 +81,7 @@ class ShellParams: postgres_version: str = ALLOWED_POSTGRES_VERSIONS[0] python: str = ALLOWED_PYTHON_MAJOR_MINOR_VERSIONS[0] skip_environment_initialization: bool = False + skip_constraints: bool = False start_airflow: str = "false" use_airflow_version: Optional[str] = None use_packages_from_dist: bool = False diff --git a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py index 45fd17b8580ef..aca7a4219d744 100644 --- a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py +++ b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py @@ -572,6 +572,7 @@ def update_expected_environment_variables(env: Dict[str, str]) -> None: "POSTGRES_VERSION": "postgres_version", "SQLITE_URL": "sqlite_url", "START_AIRFLOW": "start_airflow", + "SKIP_CONSTRAINTS": "skip_constraints", "SKIP_ENVIRONMENT_INITIALIZATION": "skip_environment_initialization", "USE_AIRFLOW_VERSION": "use_airflow_version", "USE_PACKAGES_FROM_DIST": "use_packages_from_dist", diff --git a/images/breeze/output-commands-hash.txt b/images/breeze/output-commands-hash.txt index 70ac9b305de2f..e9dc92c91c72f 100644 --- a/images/breeze/output-commands-hash.txt +++ b/images/breeze/output-commands-hash.txt @@ -1 +1 @@ -8b4116c1808c84d491961283a4ddbec2 +1a99b4b0bb09b4214384971a1121124a diff --git a/images/breeze/output-verify-provider-packages.svg b/images/breeze/output-verify-provider-packages.svg index 12853b46a203f..7b9afca2b4c53 100644 --- a/images/breeze/output-verify-provider-packages.svg +++ b/images/breeze/output-verify-provider-packages.svg @@ -1,4 +1,4 @@ - + - - + + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + - Command: verify-provider-packages + Command: verify-provider-packages - + - - -Usage: breeze verify-provider-packages [OPTIONS] - -Verifies if all provider code is following expectations for providers. - -╭─ Provider verification flags ────────────────────────────────────────────────────────────────────────────────────────╮ ---use-airflow-versionUse (reinstall at entry) Airflow version from PyPI. It can also be `none`,        -`wheel`, or `sdist` if Airflow should be removed, installed from wheel packages   -or sdist packages available in dist folder respectively. Implies --mount-sources -`remove`.                                                                         -(none | wheel | sdist | <airflow_version>)                                        ---airflow-constraints-referenceConstraint reference to use. Useful with --use-airflow-version parameter to       -specify constraints for the installed version and to find newer dependencies      -(TEXT)                                                                            ---airflow-extrasAirflow extras to install when --use-airflow-version is used(TEXT) ---use-packages-from-distInstall all found packages (--package-format determines type) from 'dist' folder  -when entering breeze.                                                             ---package-formatFormat of packages that should be installed from dist.(wheel | sdist) -[default: wheel]                                       ---debugDrop user in shell instead of running the command. Useful for debugging. -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---verbose-vPrint verbose information about performed steps. ---dry-run-DIf dry-run is set, commands are only printed, not executed. ---github-repository-gGitHub repository used to pull, push run images.(TEXT)[default: apache/airflow] ---help-hShow this message and exit. -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + + +Usage: breeze verify-provider-packages [OPTIONS] + +Verifies if all provider code is following expectations for providers. + +╭─ Provider verification flags ────────────────────────────────────────────────────────────────────────────────────────╮ +--use-airflow-versionUse (reinstall at entry) Airflow version from PyPI. It can also be `none`,        +`wheel`, or `sdist` if Airflow should be removed, installed from wheel packages   +or sdist packages available in dist folder respectively. Implies --mount-sources +`remove`.                                                                         +(none | wheel | sdist | <airflow_version>)                                        +--airflow-constraints-referenceConstraint reference to use. Useful with --use-airflow-version parameter to       +specify constraints for the installed version and to find newer dependencies      +(TEXT)                                                                            +--airflow-extrasAirflow extras to install when --use-airflow-version is used(TEXT) +--use-packages-from-distInstall all found packages (--package-format determines type) from 'dist' folder  +when entering breeze.                                                             +--package-formatFormat of packages that should be installed from dist.(wheel | sdist) +[default: wheel]                                       +--debugDrop user in shell instead of running the command. Useful for debugging. +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +--skip-constraintsDo not use constraints when installing providers. +--verbose-vPrint verbose information about performed steps. +--dry-run-DIf dry-run is set, commands are only printed, not executed. +--github-repository-gGitHub repository used to pull, push run images.(TEXT)[default: apache/airflow] +--help-hShow this message and exit. +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ diff --git a/scripts/ci/docker-compose/_docker.env b/scripts/ci/docker-compose/_docker.env index 4edc849b57b93..b33cfea3602e5 100644 --- a/scripts/ci/docker-compose/_docker.env +++ b/scripts/ci/docker-compose/_docker.env @@ -59,6 +59,7 @@ RUN_TESTS LIST_OF_INTEGRATION_TESTS_TO_RUN RUN_SYSTEM_TESTS START_AIRFLOW +SKIP_CONSTRAINTS SKIP_ENVIRONMENT_INITIALIZATION SKIP_SSH_SETUP TEST_TYPE diff --git a/scripts/ci/docker-compose/base.yml b/scripts/ci/docker-compose/base.yml index 48e4d3df9606e..c1285eda88193 100644 --- a/scripts/ci/docker-compose/base.yml +++ b/scripts/ci/docker-compose/base.yml @@ -72,6 +72,7 @@ services: - LIST_OF_INTEGRATION_TESTS_TO_RUN=${LIST_OF_INTEGRATION_TESTS_TO_RUN} - RUN_SYSTEM_TESTS=${RUN_SYSTEM_TESTS} - START_AIRFLOW=${START_AIRFLOW} + - SKIP_CONSTRAINTS=${SKIP_CONSTRAINTS} - SKIP_ENVIRONMENT_INITIALIZATION=${SKIP_ENVIRONMENT_INITIALIZATION} - SKIP_SSH_SETUP=${SKIP_SSH_SETUP} - TEST_TYPE=${TEST_TYPE} diff --git a/scripts/ci/docker-compose/devcontainer.env b/scripts/ci/docker-compose/devcontainer.env index 1c4b27b36af67..ae51b204436ac 100644 --- a/scripts/ci/docker-compose/devcontainer.env +++ b/scripts/ci/docker-compose/devcontainer.env @@ -57,6 +57,7 @@ RUN_TESTS="false" LIST_OF_INTEGRATION_TESTS_TO_RUN="" RUN_SYSTEM_TESTS="" START_AIRFLOW="false" +SKIP_CONSTRAINTS="false" SKIP_SSH_SETUP="true" SKIP_ENVIRONMENT_INITIALIZATION="false" TEST_TYPE= diff --git a/scripts/ci/pre_commit/pre_commit_check_2_2_compatibility.py b/scripts/ci/pre_commit/pre_commit_check_2_2_compatibility.py index 8d72c251d54ae..d9562029078f6 100755 --- a/scripts/ci/pre_commit/pre_commit_check_2_2_compatibility.py +++ b/scripts/ci/pre_commit/pre_commit_check_2_2_compatibility.py @@ -35,6 +35,7 @@ TRY_NUM_MATCHER = re.compile(r".*context.*\[[\"']try_number[\"']].*") GET_MANDATORY_MATCHER = re.compile(r".*conf\.get_mandatory_value") +GET_AIRFLOW_APP_MATCHER = re.compile(r".*get_airflow_app\(\)") def _check_file(_file: Path): @@ -45,39 +46,48 @@ def _check_file(_file: Path): if "if ti_key is not None:" not in lines[index - 1]: errors.append( f"[red]In {_file}:{index} there is a forbidden construct " - f"(Airflow 2.3.0 only):[/]\n\n" + "(Airflow 2.3.0 only):[/]\n\n" f"{lines[index-1]}\n{lines[index]}\n\n" - f"[yellow]When you use XCom.get_value( in providers, it should be in the form:[/]\n\n" - f"if ti_key is not None:\n" - f" value = XCom.get_value(...., ti_key=ti_key)\n\n" - f"See: https://airflow.apache.org/docs/apache-airflow-providers/" - f"howto/create-update-providers.html#using-providers-with-dynamic-task-mapping\n" + "[yellow]When you use XCom.get_value( in providers, it should be in the form:[/]\n\n" + "if ti_key is not None:\n" + " value = XCom.get_value(...., ti_key=ti_key)\n\n" + "See: https://airflow.apache.org/docs/apache-airflow-providers/" + "howto/create-update-providers.html#using-providers-with-dynamic-task-mapping\n" ) if "ti.map_index" in line: errors.append( f"[red]In {_file}:{index} there is a forbidden construct " - f"(Airflow 2.3+ only):[/]\n\n" + "(Airflow 2.3+ only):[/]\n\n" f"{lines[index]}\n\n" - f"[yellow]You should not use map_index field in providers " - f"as it is not available in Airflow 2.2[/]" + "[yellow]You should not use map_index field in providers " + "as it is only available in Airflow 2.3+[/]" ) if TRY_NUM_MATCHER.match(line): errors.append( f"[red]In {_file}:{index} there is a forbidden construct " - f"(Airflow 2.3+ only):[/]\n\n" + "(Airflow 2.3+ only):[/]\n\n" f"{lines[index]}\n\n" - f"[yellow]You should not expect try_number field for context in providers " - f"as it is not available in Airflow 2.2[/]" + "[yellow]You should not expect try_number field for context in providers " + "as it is only available in Airflow 2.3+[/]" ) if GET_MANDATORY_MATCHER.match(line): errors.append( f"[red]In {_file}:{index} there is a forbidden construct " - f"(Airflow 2.3+ only):[/]\n\n" + "(Airflow 2.3+ only):[/]\n\n" f"{lines[index]}\n\n" - f"[yellow]You should not use conf.get_mandatory_value " - f"as it is not available in Airflow 2.2[/]" + "[yellow]You should not use conf.get_mandatory_value in providers " + "as it is only available in Airflow 2.3+[/]" + ) + + if GET_AIRFLOW_APP_MATCHER.match(line): + errors.append( + f"[red]In {_file}:{index} there is a forbidden construct " + "(Airflow 2.4+ only):[/]\n\n" + f"{lines[index]}\n\n" + "[yellow]You should not use airflow.utils.airflow_flask_app.get_airflow_app() in providers " + "as it is not available in Airflow 2.4+. Use current_app instead.[/]" ) diff --git a/scripts/docker/entrypoint_ci.sh b/scripts/docker/entrypoint_ci.sh index f5198a556c21c..2994604fc9274 100755 --- a/scripts/docker/entrypoint_ci.sh +++ b/scripts/docker/entrypoint_ci.sh @@ -94,9 +94,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo "${COLOR_BLUE}Uninstalling airflow and providers" echo uninstall_airflow_and_providers - echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers elif [[ ${USE_AIRFLOW_VERSION} == "sdist" ]]; then echo @@ -104,9 +110,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers else echo @@ -114,9 +126,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "none" + else + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi fi if [[ ${USE_PACKAGES_FROM_DIST=} == "true" ]]; then echo diff --git a/scripts/in_container/_in_container_utils.sh b/scripts/in_container/_in_container_utils.sh index d6a637e5c348b..66f2e6b083499 100644 --- a/scripts/in_container/_in_container_utils.sh +++ b/scripts/in_container/_in_container_utils.sh @@ -224,8 +224,12 @@ function install_airflow_from_wheel() { >&2 echo exit 4 fi - pip install "${airflow_package}${extras}" --constraint \ - "https://raw.githubusercontent.com/apache/airflow/${constraints_reference}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + if [[ ${constraints_reference} == "none" ]]; then + pip install "${airflow_package}${extras}" + else + pip install "${airflow_package}${extras}" --constraint \ + "https://raw.githubusercontent.com/apache/airflow/${constraints_reference}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + fi } function install_airflow_from_sdist() { @@ -250,8 +254,12 @@ function install_airflow_from_sdist() { >&2 echo exit 4 fi - pip install "${airflow_package}${extras}" --constraint \ - "https://raw.githubusercontent.com/apache/airflow/${constraints_reference}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + if [[ ${constraints_reference} == "none" ]]; then + pip install "${airflow_package}${extras}" + else + pip install "${airflow_package}${extras}" --constraint \ + "https://raw.githubusercontent.com/apache/airflow/${constraints_reference}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + fi } function uninstall_airflow() { @@ -278,17 +286,20 @@ function uninstall_airflow_and_providers() { function install_released_airflow_version() { local version="${1}" - echo - echo "Installing released ${version} version of airflow with extras: ${AIRFLOW_EXTRAS} and constraints constraints-${version}" - echo + local constraints_reference + constraints_reference="${2:-}" rm -rf "${AIRFLOW_SOURCES}"/*.egg-info if [[ ${AIRFLOW_EXTRAS} != "" ]]; then BRACKETED_AIRFLOW_EXTRAS="[${AIRFLOW_EXTRAS}]" else BRACKETED_AIRFLOW_EXTRAS="" fi - pip install "apache-airflow${BRACKETED_AIRFLOW_EXTRAS}==${version}" \ - --constraint "https://raw.githubusercontent.com/${CONSTRAINTS_GITHUB_REPOSITORY}/constraints-${version}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + if [[ ${constraints_reference} == "none" ]]; then + pip install "${airflow_package}${extras}" + else + pip install "apache-airflow${BRACKETED_AIRFLOW_EXTRAS}==${version}" \ + --constraint "https://raw.githubusercontent.com/${CONSTRAINTS_GITHUB_REPOSITORY}/constraints-${version}/constraints-${PYTHON_MAJOR_MINOR_VERSION}.txt" + fi } function install_local_airflow_with_eager_upgrade() { diff --git a/setup.cfg b/setup.cfg index e0976f9ba3b9b..c8dc97a5b6bc3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -101,48 +101,24 @@ install_requires = cryptography>=0.9.3 deprecated>=1.2.13 dill>=0.2.2 - # Flask and all related libraries are limited to below 2.0.0 because we expect it to introduce - # Serious breaking changes. Flask 2.0 has been introduced in May 2021 and 2.0.2 version is available - # now (Feb 2022): TODO: we should attempt to migrate to Flask 2 and all below flask libraries soon. - flask>=1.1.0, <2.0 - # We are tightly coupled with FAB version because we vendored in part of FAB code related to security manager - # This is done as part of preparation to removing FAB as dependency, but we are not ready for it yet - # Every time we update FAB version here, please make sure that you review the classes and models in - # `airflow/www/fab_security` with their upstream counterparts. In particular, make sure any breaking changes, - # for example any new methods, are accounted for. - flask-appbuilder==3.4.5 - flask-caching>=1.5.0, <2.0.0 - flask-login>=0.3, <0.5 - # Strict upper-bound on the latest release of flask-session, - # as any schema changes will require a migration. - flask-session>=0.3.1, <=0.4.0 - flask-wtf>=0.14.3, <0.15 + flask>=2.0 + flask-appbuilder==4.1.1 + flask-caching>=1.5.0 + flask-login>=0.5 + flask-session>=0.4.0 + flask-wtf>=0.14.3 graphviz>=0.12 gunicorn>=20.1.0 httpx importlib_metadata>=1.7;python_version<"3.9" importlib_resources>=5.2;python_version<"3.9" - # Logging is broken with itsdangerous > 2 - likely due to changed serializing support - # https://itsdangerous.palletsprojects.com/en/2.0.x/changes/#version-2-0-0 - # itsdangerous 2 has been released in May 2020 - # TODO: we should attempt to upgrade to line 2 of itsdangerous - itsdangerous>=1.1.0, <2.0 - # Jinja2 3.1 will remove the 'autoescape' and 'with' extensions, which would - # break Flask 1.x, so we limit this for future compatibility. Remove this - # when bumping Flask to >=2. - jinja2>=2.10.1,<3.1 - # Because connexion upper-bound is 5.0.0 and we depend on connexion, - # we pin to the same upper-bound as connexion. - jsonschema>=3.2.0, <5.0 + itsdangerous>=2.0 + jinja2>=2.10.1 + jsonschema>=3.2.0 lazy-object-proxy lockfile>=0.12.2 markdown>=3.0 - # Markupsafe 2.1.0 breaks with error: import name 'soft_unicode' from 'markupsafe'. - # This should be removed when either this issue is closed: - # https://github.com/pallets/markupsafe/issues/284 - # or when we will be able to upgrade JINJA to newer version (currently limited due to Flask and - # Flask Application Builder) - markupsafe>=1.1.1,<2.1.0 + markupsafe>=1.1.1 marshmallow-oneofschema>=2.0.1 packaging>=14.0 pathspec~=0.9.0 @@ -150,8 +126,7 @@ install_requires = pluggy>=1.0 psutil>=4.2.0 pygments>=2.0.1 - # python daemon crashes with 'socket operation on non-socket' for python 3.8+ in version < 2.2.4 - # https://pagure.io/python-daemon/issue/34 + pyjwt>=2.0.0 python-daemon>=2.2.4 python-dateutil>=2.3 python-nvd3>=0.15.0 @@ -169,10 +144,7 @@ install_requires = termcolor>=1.1.0 typing-extensions>=3.7.4 unicodecsv>=0.14.1 - # Werkzeug is known to cause breaking changes and it is very closely tied with FlaskAppBuilder and other - # Flask dependencies and the limit to 1.* line should be reviewed when we upgrade Flask and remove - # FlaskAppBuilder. - werkzeug~=1.0, >=1.0.1 + werkzeug>=2.0 [options.packages.find] include = diff --git a/setup.py b/setup.py index 9d3243f2b9dfd..0089f921a0ab1 100644 --- a/setup.py +++ b/setup.py @@ -618,16 +618,6 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'flake8-implicit-str-concat', 'flaky', 'freezegun', - # Github3 version 3.1.2 requires PyJWT>=2.3.0 which clashes with Flask App Builder where PyJWT is <2.0.0 - # Actually GitHub3.1.0 already introduced PyJWT>=2.3.0 but so far `pip` was able to resolve it without - # getting into a long backtracking loop and figure out that github3 3.0.0 version is the right version - # similarly limiting it to 3.1.2 causes pip not to enter the backtracking loop. Apparently when there - # are 3 versions with PyJWT>=2.3.0 (3.1.0, 3.1.1 an 3.1.2) pip enters into backtrack loop and fails - # to resolve that github3 3.0.0 is the right version to use. - # This limitation could be removed if PyJWT limitation < 2.0.0 is dropped from FAB or when - # pip resolution is improved to handle the case. The issue which describes this PIP behaviour - # and hopefully allowing to improve it is tracked in https://github.com/pypa/pip/issues/10924 - 'github3.py<3.1.0', 'gitpython', 'ipdb', 'jira', diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index d95d4c38549df..9a98a7f61bf9f 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -34,7 +34,8 @@ from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags -SERIALIZER = URLSafeSerializer(conf.get('webserver', 'secret_key')) +# the secret key is always available, so we can ignore Optional here +SERIALIZER = URLSafeSerializer(conf.get('webserver', 'secret_key')) # type: ignore[arg-type] FILE_TOKEN = SERIALIZER.dumps(__file__) DAG_ID = "test_dag" TASK_ID = "op1" diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index 7c4452dc2bb63..efcba3271188d 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -185,7 +185,7 @@ def test_should_respond_200(self): response_data = response.json for xcom_entry in response_data['xcom_entries']: xcom_entry['timestamp'] = "TIMESTAMP" - assert response.json == { + assert response_data == { 'xcom_entries': [ { 'dag_id': dag_id, @@ -227,7 +227,7 @@ def test_should_respond_200_with_tilde_and_access_to_all_dags(self): response_data = response.json for xcom_entry in response_data['xcom_entries']: xcom_entry['timestamp'] = "TIMESTAMP" - assert response.json == { + assert response_data == { 'xcom_entries': [ { 'dag_id': dag_id_1, @@ -283,7 +283,7 @@ def test_should_respond_200_with_tilde_and_granular_dag_access(self): response_data = response.json for xcom_entry in response_data['xcom_entries']: xcom_entry['timestamp'] = "TIMESTAMP" - assert response.json == { + assert response_data == { 'xcom_entries': [ { 'dag_id': dag_id_1, diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index ca7e04ae89226..21f8e84d554f4 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -30,7 +30,8 @@ from airflow.configuration import conf from airflow.models import DagModel, DagTag -SERIALIZER = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY')) +# the secret key is always available, so we can ignore Optional here +SERIALIZER = URLSafeSerializer(conf.get('webserver', 'secret_key')) # type: ignore[arg-type] class TestDagSchema(unittest.TestCase): diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py index 1e6a0c70adf6d..187f57a7fd114 100644 --- a/tests/test_utils/remote_user_api_auth_backend.py +++ b/tests/test_utils/remote_user_api_auth_backend.py @@ -20,10 +20,12 @@ from functools import wraps from typing import Callable, Optional, Tuple, TypeVar, Union, cast -from flask import Response, current_app, request +from flask import Response, request from flask_login import login_user from requests.auth import AuthBase +from airflow.utils.airflow_flask_app import get_airflow_app + log = logging.getLogger(__name__) CLIENT_AUTH: Optional[Union[Tuple[str, str], AuthBase]] = None @@ -37,7 +39,7 @@ def init_app(_): def _lookup_user(user_email_or_username: str): - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( username=user_email_or_username ) diff --git a/tests/utils/test_serve_logs.py b/tests/utils/test_serve_logs.py index f8d38817592b8..db0b849c6a26e 100644 --- a/tests/utils/test_serve_logs.py +++ b/tests/utils/test_serve_logs.py @@ -31,6 +31,8 @@ LOG_DATA = "Airflow log data" * 20 +secret_key: str = conf.get('webserver', 'secret_key') # type: ignore + @pytest.fixture def client(tmpdir): @@ -51,7 +53,7 @@ def sample_log(tmpdir): @pytest.fixture def signer(): return JWTSigner( - secret_key=conf.get('webserver', 'secret_key'), + secret_key=secret_key, expiration_time_in_seconds=30, audience="task-instance-logs", ) @@ -60,7 +62,7 @@ def signer(): @pytest.fixture def different_audience(): return JWTSigner( - secret_key=conf.get('webserver', 'secret_key'), + secret_key=secret_key, expiration_time_in_seconds=30, audience="different-audience", ) @@ -191,7 +193,7 @@ def test_missing_claims(self, claim_to_remove: str, client: "FlaskClient"): jwt_dict.update({"filename": 'sample.log'}) token = jwt.encode( jwt_dict, - conf.get('webserver', 'secret_key'), + secret_key, algorithm="HS512", ) assert ( diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py index 887bd4898a0a6..fa79e145cba6c 100644 --- a/tests/www/views/test_views.py +++ b/tests/www/views/test_views.py @@ -375,52 +375,55 @@ def test_get_task_stats_from_query(): assert data == expected_data +INVALID_DATETIME_RESPONSE = "Invalid datetime: 'invalid'" + + @pytest.mark.parametrize( "url, content", [ ( '/rendered-templates?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( '/log?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( '/redirect_to_external_log?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( '/task?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/graph?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/graph?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/duration?base_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/tries?base_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/landing-times?base_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'dags/example_bash_operator/gantt?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ( 'extra_links?execution_date=invalid', - "Invalid datetime: 'invalid'", + INVALID_DATETIME_RESPONSE, ), ], ) diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py index 0e4fc12857a8f..1de80c1214a28 100644 --- a/tests/www/views/test_views_decorators.py +++ b/tests/www/views/test_views_decorators.py @@ -213,9 +213,9 @@ def test_action_has_dag_edit_access(create_task_instance, class_type, no_instanc else: test_items = tis if class_type == TaskInstance else [ti.get_dagrun() for ti in tis] test_items = test_items[0] if len(test_items) == 1 else test_items - - with app.create_app(testing=True).app_context(): - with mock.patch("airflow.www.views.current_app.appbuilder.sm.can_edit_dag") as mocked_can_edit: + application = app.create_app(testing=True) + with application.app_context(): + with mock.patch.object(application.appbuilder.sm, "can_edit_dag") as mocked_can_edit: mocked_can_edit.return_value = True assert not isinstance(test_items, list) or len(test_items) == no_instances assert some_view_action_which_requires_dag_edit_access(None, test_items) is True diff --git a/tests/www/views/test_views_log.py b/tests/www/views/test_views_log.py index 82b30f9d218e5..fd136351cf1ad 100644 --- a/tests/www/views/test_views_log.py +++ b/tests/www/views/test_views_log.py @@ -461,7 +461,7 @@ def test_redirect_to_external_log_with_local_log_handler(log_admin_client, task_ ) response = log_admin_client.get(url) assert 302 == response.status_code - assert 'http://localhost/home' == response.headers['Location'] + assert '/home' == response.headers['Location'] class _ExternalHandler(ExternalLoggingMixin): diff --git a/tests/www/views/test_views_mount.py b/tests/www/views/test_views_mount.py index a9fb8746657df..3f504e9b0f168 100644 --- a/tests/www/views/test_views_mount.py +++ b/tests/www/views/test_views_mount.py @@ -36,7 +36,7 @@ def factory(): @pytest.fixture() def client(app): - return werkzeug.test.Client(app, werkzeug.wrappers.BaseResponse) + return werkzeug.test.Client(app, werkzeug.wrappers.response.Response) def test_mount(client): @@ -54,4 +54,4 @@ def test_not_found(client): def test_index(client): resp = client.get('/test/') assert resp.status_code == 302 - assert resp.headers['Location'] == 'http://localhost/test/home' + assert resp.headers['Location'] == '/test/home'