Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for multiple foreign_key in CreateQueryBuilder #762

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 44 additions & 29 deletions pypika/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,6 +1707,36 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl
self.fields = [field.replace_table(current_table, new_table) for field in self.fields]


class ForeignKey:
"""Represents a foreign key constraint."""

def __init__(
self,
columns: List[Column],
reference_table: Union[str, Table],
reference_columns: List[Column],
on_delete: ReferenceOption = None,
on_update: ReferenceOption = None,
) -> None:
self.columns = columns
self.reference_table = reference_table
self.reference_columns = reference_columns
self.on_delete = on_delete
self.on_update = on_update

def get_sql(self, **kwargs: Any) -> str:
foreign_key_sql = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format(
columns=",".join(column.get_name_sql(**kwargs) for column in self.columns),
table_name=self.reference_table.get_sql(**kwargs),
reference_columns=",".join(column.get_name_sql(**kwargs) for column in self.reference_columns),
)
if self.on_delete:
foreign_key_sql += " ON DELETE " + self.on_delete.value
if self.on_update:
foreign_key_sql += " ON UPDATE " + self.on_update.value
return foreign_key_sql


class CreateQueryBuilder:
"""
Query builder used to build CREATE queries.
Expand All @@ -1729,11 +1759,7 @@ def __init__(self, dialect: Optional[Dialects] = None) -> None:
self._uniques = []
self._if_not_exists = False
self.dialect = dialect
self._foreign_key = None
self._foreign_key_reference_table = None
self._foreign_key_reference = None
self._foreign_key_on_update: ReferenceOption = None
self._foreign_key_on_delete: ReferenceOption = None
self._foreign_keys = []

def _set_kwargs_defaults(self, kwargs: dict) -> None:
kwargs.setdefault("quote_char", self.QUOTE_CHAR)
Expand Down Expand Up @@ -1908,19 +1934,19 @@ def foreign_key(

Update option.

:raises AttributeError:
If the foreign key is already defined.

:return:
CreateQueryBuilder.
"""
if self._foreign_key:
raise AttributeError("'Query' object already has attribute foreign_key")
self._foreign_key = self._prepare_columns_input(columns)
self._foreign_key_reference_table = reference_table
self._foreign_key_reference = self._prepare_columns_input(reference_columns)
self._foreign_key_on_delete = on_delete
self._foreign_key_on_update = on_update

self._foreign_keys.append(
ForeignKey(
columns=self._prepare_columns_input(columns),
reference_table=reference_table,
reference_columns=self._prepare_columns_input(reference_columns),
on_delete=on_delete,
on_update=on_update,
)
)

@builder
def as_select(self, query_builder: QueryBuilder) -> "CreateQueryBuilder":
Expand Down Expand Up @@ -2017,28 +2043,17 @@ def _primary_key_clause(self, **kwargs) -> str:
columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key)
)

def _foreign_key_clause(self, **kwargs) -> str:
clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format(
columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key),
table_name=self._foreign_key_reference_table.get_sql(**kwargs),
reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference),
)
if self._foreign_key_on_delete:
clause += " ON DELETE " + self._foreign_key_on_delete.value
if self._foreign_key_on_update:
clause += " ON UPDATE " + self._foreign_key_on_update.value

return clause
def _foreign_key_clauses(self, **kwargs) -> str:
return [foreign_key.get_sql(**kwargs) for foreign_key in self._foreign_keys]

def _body_sql(self, **kwargs) -> str:
clauses = self._column_clauses(**kwargs)
clauses += self._period_for_clauses(**kwargs)
clauses += self._unique_key_clauses(**kwargs)
clauses += self._foreign_key_clauses(**kwargs)

if self._primary_key:
clauses.append(self._primary_key_clause(**kwargs))
if self._foreign_key:
clauses.append(self._foreign_key_clause(**kwargs))

return ",".join(clauses)

Expand Down
31 changes: 25 additions & 6 deletions pypika/tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,31 @@ def test_create_table_with_columns(self):
str(q),
)

with self.subTest("with multiple foreign key constrains"):
secondary_table = Table("secondary_table")
cref, dref = Columns(("c", "INT"), ("d", "VARCHAR(100)"))
q = (
Query.create_table(self.new_table)
.columns(self.foo, self.bar)
.foreign_key([self.foo], self.existing_table, [cref])
.foreign_key(
[self.bar],
secondary_table,
[dref],
on_delete=ReferenceOption.cascade,
on_update=ReferenceOption.restrict,
)
)

self.assertEqual(
'CREATE TABLE "abc" ('
'"a" INT,'
'"b" VARCHAR(100),'
'FOREIGN KEY ("a") REFERENCES "efg" ("c"),'
'FOREIGN KEY ("b") REFERENCES "secondary_table" ("d") ON DELETE CASCADE ON UPDATE RESTRICT)',
str(q),
)

with self.subTest("with unique keys"):
q = (
Query.create_table(self.new_table)
Expand Down Expand Up @@ -156,12 +181,6 @@ def test_create_table_with_select_and_columns_fails(self):
with self.assertRaises(AttributeError):
Query.create_table(self.new_table).as_select(select).columns(self.foo, self.bar)

with self.subTest("repeated foreign key"):
with self.assertRaises(AttributeError):
Query.create_table(self.new_table).foreign_key([self.foo], self.existing_table, [self.bar]).foreign_key(
[self.foo], self.existing_table, [self.bar]
)

def test_create_table_as_select_not_query_raises_error(self):
with self.assertRaises(TypeError):
Query.create_table(self.new_table).as_select("abc")
Expand Down