diff --git a/languages/python/sqlalchemy-oso/sqlalchemy_oso/compat.py b/languages/python/sqlalchemy-oso/sqlalchemy_oso/compat.py index 58ef06168c..8b1d0b4f72 100644 --- a/languages/python/sqlalchemy-oso/sqlalchemy_oso/compat.py +++ b/languages/python/sqlalchemy-oso/sqlalchemy_oso/compat.py @@ -3,6 +3,7 @@ Keep us compatible with multiple SQLAlchemy versions by implementing wrappers when needed here. """ + import sqlalchemy from packaging.version import parse diff --git a/languages/python/sqlalchemy-oso/sqlalchemy_oso/session.py b/languages/python/sqlalchemy-oso/sqlalchemy_oso/session.py index 61986d12fb..b0636adcbf 100644 --- a/languages/python/sqlalchemy-oso/sqlalchemy_oso/session.py +++ b/languages/python/sqlalchemy-oso/sqlalchemy_oso/session.py @@ -1,4 +1,5 @@ """SQLAlchemy session classes and factories for oso.""" + import logging from typing import Any, Callable, Dict, Optional, Type diff --git a/languages/python/sqlalchemy-oso/sqlalchemy_oso/sqlalchemy_utils.py b/languages/python/sqlalchemy-oso/sqlalchemy_oso/sqlalchemy_utils.py index 776dc56fba..ff9c9f9e53 100644 --- a/languages/python/sqlalchemy-oso/sqlalchemy_oso/sqlalchemy_utils.py +++ b/languages/python/sqlalchemy-oso/sqlalchemy_oso/sqlalchemy_utils.py @@ -5,6 +5,7 @@ We must detect all entities properly to apply authorization. """ + import sqlalchemy from sqlalchemy import inspect from sqlalchemy.orm.util import AliasedClass, AliasedInsp @@ -20,10 +21,12 @@ def to_class(entity): else: return entity + if USING_SQLAlchemy_v1_3: # unsupported for <= 1.3 def all_entities_in_statement(statement): raise NotImplementedError("Unsupported on SQLAlchemy < 1.4") + else: if USING_SQLAlchemy_v2_0: @@ -78,6 +81,7 @@ def get_joinedload_entities(stmt): entities.add(loadopt.path[-1].entity) return entities + else: # Start POC code from @zzzeek (Mike Bayer) # TODO: Still needs to be generalized & support other options. @@ -123,7 +127,10 @@ def get_joinedload_entities(stmt): elif hasattr(opt, "context"): # these options are called Load for key, loadopt in opt.context.items(): - if key[0] == "loader" and ("lazy", "joined") in loadopt.strategy: + if ( + key[0] == "loader" + and ("lazy", "joined") in loadopt.strategy + ): # the "path" is a tuple showing the entity/relationships # being targeted @@ -218,4 +225,3 @@ class A(Base): default_entities.add(rel.mapper) return default_entities - diff --git a/languages/python/sqlalchemy-oso/tests/models.py b/languages/python/sqlalchemy-oso/tests/models.py index a4102d87d0..3e76780a66 100644 --- a/languages/python/sqlalchemy-oso/tests/models.py +++ b/languages/python/sqlalchemy-oso/tests/models.py @@ -2,6 +2,7 @@ from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String from sqlalchemy.orm import relationship + if USING_SQLAlchemy_v1_3: from sqlalchemy.ext.declarative import declarative_base else: diff --git a/languages/python/sqlalchemy-oso/tests/test_advanced_queries_14.py b/languages/python/sqlalchemy-oso/tests/test_advanced_queries_14.py index 4cb8e1acc4..4f8f5ec1a4 100644 --- a/languages/python/sqlalchemy-oso/tests/test_advanced_queries_14.py +++ b/languages/python/sqlalchemy-oso/tests/test_advanced_queries_14.py @@ -134,7 +134,6 @@ def test_get_joinedload_entities(stmt, o): assert set(map(to_class, get_joinedload_entities(stmt))) == o - def test_default_loader_strategies_all_entities_in_statement(): """Test that all_entitites_in_statement finds default "joined" entities.""" Base2 = declarative_base() diff --git a/languages/python/sqlalchemy-oso/tests/test_post_relationship.py b/languages/python/sqlalchemy-oso/tests/test_post_relationship.py index f450cf5c38..a2910cd58e 100644 --- a/languages/python/sqlalchemy-oso/tests/test_post_relationship.py +++ b/languages/python/sqlalchemy-oso/tests/test_post_relationship.py @@ -3,6 +3,7 @@ Tests come from the relationship document & operations laid out there https://www.notion.so/osohq/Relationships-621b884edbc6423f93d29e6066e58d16. """ + import pytest from sqlalchemy_oso.auth import authorize_model @@ -183,7 +184,7 @@ def tag_test_fixture(session): # HACK! objects = {} - for (name, local) in locals().items(): + for name, local in locals().items(): if name != "session" and name != "objects": session.add(local) @@ -259,7 +260,7 @@ def tag_nested_test_fixture(session): # HACK! objects = {} - for (name, local) in locals().items(): + for name, local in locals().items(): if name != "session" and name != "objects": session.add(local) @@ -357,7 +358,7 @@ def tag_nested_many_many_test_fixture(session): # HACK! objects = {} - for (name, local) in locals().items(): + for name, local in locals().items(): if name != "session" and name != "objects": session.add(local) diff --git a/languages/python/sqlalchemy-oso/tests/test_sqlalchemy.py b/languages/python/sqlalchemy-oso/tests/test_sqlalchemy.py index 132b9bf1a6..e5de07a39e 100644 --- a/languages/python/sqlalchemy-oso/tests/test_sqlalchemy.py +++ b/languages/python/sqlalchemy-oso/tests/test_sqlalchemy.py @@ -1,4 +1,5 @@ """Test hooks & SQLAlchemy API integrations.""" + import pytest from sqlalchemy.orm import aliased, joinedload