Skip to content

Commit

Permalink
Allow masking without primary keys (#5575)
Browse files Browse the repository at this point in the history
  • Loading branch information
galvana authored Jan 10, 2025
1 parent a009dbb commit 53944d5
Show file tree
Hide file tree
Showing 27 changed files with 619 additions and 538 deletions.
552 changes: 148 additions & 404 deletions data/dataset/bigquery_enterprise_test_dataset.yml

Large diffs are not rendered by default.

18 changes: 0 additions & 18 deletions data/dataset/bigquery_example_test_dataset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ dataset:
data_categories: [user.contact.address.street]
- name: id
data_categories: [system.operations]
fides_meta:
primary_key: True
- name: state
data_categories: [user.contact.address.state]
- name: street
Expand Down Expand Up @@ -53,8 +51,6 @@ dataset:
data_type: string
- name: id
data_categories: [user.unique_id]
fides_meta:
primary_key: True
- name: name
data_categories: [user.name]
fides_meta:
Expand All @@ -80,8 +76,6 @@ dataset:
data_type: string
- name: id
data_categories: [user.unique_id]
fides_meta:
primary_key: True
- name: name
data_categories: [user.name]
fides_meta:
Expand All @@ -98,8 +92,6 @@ dataset:
direction: from
- name: id
data_categories: [system.operations]
fides_meta:
primary_key: True
- name: time
data_categories: [user.sensor]

Expand All @@ -114,8 +106,6 @@ dataset:
direction: from
- name: id
data_categories: [system.operations]
fides_meta:
primary_key: True
- name: shipping_address_id
data_categories: [system.operations]
fides_meta:
Expand Down Expand Up @@ -166,8 +156,6 @@ dataset:
direction: from
- name: id
data_categories: [system.operations]
fides_meta:
primary_key: True
- name: name
data_categories: [user.financial]
- name: preferred
Expand All @@ -177,8 +165,6 @@ dataset:
fields:
- name: id
data_categories: [system.operations]
fides_meta:
primary_key: True
- name: name
data_categories: [system.operations]
- name: price
Expand All @@ -193,8 +179,6 @@ dataset:
data_type: string
- name: id
data_categories: [system.operations]
fides_meta:
primary_key: True
- name: month
data_categories: [system.operations]
- name: name
Expand Down Expand Up @@ -227,8 +211,6 @@ dataset:
direction: from
- name: id
data_categories: [system.operations]
fides_meta:
primary_key: True
- name: opened
data_categories: [system.operations]

Expand Down
11 changes: 11 additions & 0 deletions src/fides/api/service/connectors/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,14 @@ def execute_standalone_retrieval_query(
raise NotImplementedError(
"execute_standalone_retrieval_query must be implemented in a concrete subclass"
)

@property
def requires_primary_keys(self) -> bool:
"""
Indicates if datasets linked to this connector require primary keys for erasures.
Defaults to True.
"""

# Defaulting to true for now so we can keep the default behavior and
# incrementally determine the need for primary keys across all connectors
return True
5 changes: 5 additions & 0 deletions src/fides/api/service/connectors/bigquery_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ class BigQueryConnector(SQLConnector):

secrets_schema = BigQuerySchema

@property
def requires_primary_keys(self) -> bool:
"""BigQuery does not have the concept of primary keys so they're not required for erasures."""
return False

# Overrides BaseConnector.build_uri
def build_uri(self) -> str:
"""Build URI of format"""
Expand Down
5 changes: 5 additions & 0 deletions src/fides/api/service/connectors/postgres_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class PostgreSQLConnector(SQLConnector):

secrets_schema = PostgreSQLSchema

@property
def requires_primary_keys(self) -> bool:
"""Postgres allows arbitrary columns in the WHERE clause for updates so primary keys are not required."""
return False

def build_uri(self) -> str:
"""Build URI of format postgresql://[user[:password]@][netloc][:port][/dbname]"""
config = self.secrets_schema(**self.configuration.secrets or {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ def generate_update(
TODO: DRY up this method and `generate_delete` a bit
"""
update_value_map: Dict[str, Any] = self.update_value_map(row, policy, request)
non_empty_primary_keys: Dict[str, Field] = filter_nonempty_values(
non_empty_reference_field_keys: Dict[str, Field] = filter_nonempty_values(
{
fpath.string_path: fld.cast(row[fpath.string_path])
for fpath, fld in self.primary_key_field_paths.items()
for fpath, fld in self.reference_field_paths.items()
if fpath.string_path in row
}
)

valid = len(non_empty_primary_keys) > 0 and update_value_map
valid = len(non_empty_reference_field_keys) > 0 and update_value_map
if not valid:
logger.warning(
"There is not enough data to generate a valid update statement for {}",
Expand All @@ -140,8 +140,8 @@ def generate_update(
return []

table = Table(self._generate_table_name(), MetaData(bind=client), autoload=True)
pk_clauses: List[ColumnElement] = [
getattr(table.c, k) == v for k, v in non_empty_primary_keys.items()
where_clauses: List[ColumnElement] = [
getattr(table.c, k) == v for k, v in non_empty_reference_field_keys.items()
]

if self.partitioning:
Expand All @@ -153,13 +153,13 @@ def generate_update(
for partition_clause in partition_clauses:
partitioned_queries.append(
table.update()
.where(*(pk_clauses + [text(partition_clause)]))
.where(*(where_clauses + [text(partition_clause)]))
.values(**update_value_map)
)

return partitioned_queries

return [table.update().where(*pk_clauses).values(**update_value_map)]
return [table.update().where(*where_clauses).values(**update_value_map)]

def generate_delete(self, row: Row, client: Engine) -> List[Delete]:
"""Returns a List of SQLAlchemy DELETE statements for BigQuery. Does not actually execute the delete statement.
Expand All @@ -172,15 +172,15 @@ def generate_delete(self, row: Row, client: Engine) -> List[Delete]:
TODO: DRY up this method and `generate_update` a bit
"""

non_empty_primary_keys: Dict[str, Field] = filter_nonempty_values(
non_empty_reference_field_keys: Dict[str, Field] = filter_nonempty_values(
{
fpath.string_path: fld.cast(row[fpath.string_path])
for fpath, fld in self.primary_key_field_paths.items()
for fpath, fld in self.reference_field_paths.items()
if fpath.string_path in row
}
)

valid = len(non_empty_primary_keys) > 0
valid = len(non_empty_reference_field_keys) > 0
if not valid:
logger.warning(
"There is not enough data to generate a valid DELETE statement for {}",
Expand All @@ -189,8 +189,8 @@ def generate_delete(self, row: Row, client: Engine) -> List[Delete]:
return []

table = Table(self._generate_table_name(), MetaData(bind=client), autoload=True)
pk_clauses: List[ColumnElement] = [
getattr(table.c, k) == v for k, v in non_empty_primary_keys.items()
where_clauses: List[ColumnElement] = [
getattr(table.c, k) == v for k, v in non_empty_reference_field_keys.items()
]

if self.partitioning:
Expand All @@ -202,9 +202,9 @@ def generate_delete(self, row: Row, client: Engine) -> List[Delete]:

for partition_clause in partition_clauses:
partitioned_queries.append(
table.delete().where(*(pk_clauses + [text(partition_clause)]))
table.delete().where(*(where_clauses + [text(partition_clause)]))
)

return partitioned_queries

return [table.delete().where(*pk_clauses)]
return [table.delete().where(*where_clauses)]
69 changes: 46 additions & 23 deletions src/fides/api/service/connectors/query_configs/query_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ def primary_key_field_paths(self) -> Dict[FieldPath, Field]:
if field.primary_key
}

@property
def reference_field_paths(self) -> Dict[FieldPath, Field]:
"""Mapping of FieldPaths to Fields that have incoming identity or dataset references"""
return {
field_path: field
for field_path, field in self.field_map().items()
if field_path in {edge.f2.field_path for edge in self.node.incoming_edges}
}

def query_sources(self) -> Dict[str, List[CollectionAddress]]:
"""Display the input collection(s) for each query key for display purposes.
Expand Down Expand Up @@ -412,14 +421,16 @@ def generate_query_without_tuples( # pylint: disable=R0914
def get_update_stmt(
self,
update_clauses: List[str],
pk_clauses: List[str],
where_clauses: List[str],
) -> str:
"""Returns a SQL UPDATE statement to fit SQL syntax."""
return f"UPDATE {self.node.address.collection} SET {', '.join(update_clauses)} WHERE {' AND '.join(pk_clauses)}"
return f"UPDATE {self.node.address.collection} SET {', '.join(update_clauses)} WHERE {' AND '.join(where_clauses)}"

@abstractmethod
def get_update_clauses(
self, update_value_map: Dict[str, Any], non_empty_primary_keys: Dict[str, Field]
self,
update_value_map: Dict[str, Any],
where_clause_fields: Dict[str, Field],
) -> List[str]:
"""Returns a list of update clauses for the update statement."""

Expand All @@ -428,46 +439,57 @@ def format_query_stmt(self, query_str: str, update_value_map: Dict[str, Any]) ->
"""Returns a formatted update statement in the appropriate dialect."""

@abstractmethod
def format_key_map_for_update_stmt(self, fields: List[str]) -> List[str]:
def format_key_map_for_update_stmt(self, param_map: Dict[str, Any]) -> List[str]:
"""Adds the appropriate formatting for update statements in this datastore."""

def generate_update_stmt(
self, row: Row, policy: Policy, request: PrivacyRequest
) -> Optional[T]:
"""Returns an update statement in generic SQL-ish dialect."""
update_value_map: Dict[str, Any] = self.update_value_map(row, policy, request)
non_empty_primary_keys: Dict[str, Field] = filter_nonempty_values(

non_empty_primary_key_fields: Dict[str, Field] = filter_nonempty_values(
{
fpath.string_path: fld.cast(row[fpath.string_path])
for fpath, fld in self.primary_key_field_paths.items()
if fpath.string_path in row
}
)

non_empty_reference_fields: Dict[str, Field] = filter_nonempty_values(
{
fpath.string_path: fld.cast(row[fpath.string_path])
for fpath, fld in self.reference_field_paths.items()
if fpath.string_path in row
}
)

# Create parameter mappings with masked_ prefix for SET values
param_map = {
**{f"masked_{k}": v for k, v in update_value_map.items()},
**non_empty_primary_key_fields,
**non_empty_reference_fields,
}

update_clauses = self.get_update_clauses(
update_value_map, non_empty_primary_keys
{k: f"masked_{k}" for k in update_value_map},
non_empty_primary_key_fields or non_empty_reference_fields,
)
pk_clauses = self.format_key_map_for_update_stmt(
list(non_empty_primary_keys.keys())
where_clauses = self.format_key_map_for_update_stmt(
{k: k for k in non_empty_primary_key_fields or non_empty_reference_fields}
)

for k, v in non_empty_primary_keys.items():
update_value_map[k] = v

valid = len(pk_clauses) > 0 and len(update_clauses) > 0
valid = len(where_clauses) > 0 and len(update_clauses) > 0
if not valid:
logger.warning(
"There is not enough data to generate a valid update statement for {}",
self.node.address,
)
return None

query_str = self.get_update_stmt(
update_clauses,
pk_clauses,
)
logger.info("query = {}, params = {}", Pii(query_str), Pii(update_value_map))
return self.format_query_stmt(query_str, update_value_map)
query_str = self.get_update_stmt(update_clauses, where_clauses)
logger.info("query = {}, params = {}", Pii(query_str), Pii(param_map))
return self.format_query_stmt(query_str, param_map)


class SQLQueryConfig(SQLLikeQueryConfig[Executable]):
Expand Down Expand Up @@ -538,16 +560,17 @@ def generate_query(
)
return None

def format_key_map_for_update_stmt(self, fields: List[str]) -> List[str]:
def format_key_map_for_update_stmt(self, param_map: Dict[str, Any]) -> List[str]:
"""Adds the appropriate formatting for update statements in this datastore."""
fields.sort()
return [f"{k} = :{k}" for k in fields]
return [f"{k} = :{v}" for k, v in sorted(param_map.items())]

def get_update_clauses(
self, update_value_map: Dict[str, Any], non_empty_primary_keys: Dict[str, Field]
self,
update_value_map: Dict[str, Any],
where_clause_fields: Dict[str, Field],
) -> List[str]:
"""Returns a list of update clauses for the update statement."""
return self.format_key_map_for_update_stmt(list(update_value_map.keys()))
return self.format_key_map_for_update_stmt(update_value_map)

def format_query_stmt(
self, query_str: str, update_value_map: Dict[str, Any]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,14 @@ def get_formatted_query_string(
"""Returns a query string with double quotation mark formatting as required by Snowflake syntax."""
return f'SELECT {field_list} FROM {self._generate_table_name()} WHERE ({" OR ".join(clauses)})'

def format_key_map_for_update_stmt(self, fields: List[str]) -> List[str]:
def format_key_map_for_update_stmt(self, param_map: Dict[str, Any]) -> List[str]:
"""Adds the appropriate formatting for update statements in this datastore."""
fields.sort()
return [f'"{k}" = :{k}' for k in fields]
return [f'"{k}" = :{v}' for k, v in sorted(param_map.items())]

def get_update_stmt(
self,
update_clauses: List[str],
pk_clauses: List[str],
where_clauses: List[str],
) -> str:
"""Returns a parameterized update statement in Snowflake dialect."""
return f'UPDATE {self._generate_table_name()} SET {", ".join(update_clauses)} WHERE {" AND ".join(pk_clauses)}'
return f'UPDATE {self._generate_table_name()} SET {", ".join(update_clauses)} WHERE {" AND ".join(where_clauses)}'
5 changes: 5 additions & 0 deletions src/fides/api/service/connectors/scylla_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class ScyllaConnectorMissingKeyspace(Exception):
class ScyllaConnector(BaseConnector[Cluster]):
"""Scylla Connector"""

@property
def requires_primary_keys(self) -> bool:
"""ScyllaDB requires primary keys for erasures."""
return True

def build_uri(self) -> str:
"""
Builds URI - Not yet implemented
Expand Down
Loading

0 comments on commit 53944d5

Please sign in to comment.