Skip to content

Commit 243e3b9

Browse files
committed
Allow for multiple foreign_key in CreateQueryBuilder
1 parent ded17eb commit 243e3b9

File tree

2 files changed

+69
-35
lines changed

2 files changed

+69
-35
lines changed

pypika/queries.py

+44-29
Original file line numberDiff line numberDiff line change
@@ -1707,6 +1707,36 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl
17071707
self.fields = [field.replace_table(current_table, new_table) for field in self.fields]
17081708

17091709

1710+
class ForeignKey:
1711+
"""Represents a foreign key constraint."""
1712+
1713+
def __init__(
1714+
self,
1715+
columns: List[Column],
1716+
reference_table: Union[str, Table],
1717+
reference_columns: List[Column],
1718+
on_delete: ReferenceOption = None,
1719+
on_update: ReferenceOption = None,
1720+
) -> None:
1721+
self.columns = columns
1722+
self.reference_table = reference_table
1723+
self.reference_columns = reference_columns
1724+
self.on_delete = on_delete
1725+
self.on_update = on_update
1726+
1727+
def get_sql(self, **kwargs: Any) -> str:
1728+
foreign_key_sql = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format(
1729+
columns=",".join(column.get_name_sql(**kwargs) for column in self.columns),
1730+
table_name=self.reference_table.get_sql(**kwargs),
1731+
reference_columns=",".join(column.get_name_sql(**kwargs) for column in self.reference_columns),
1732+
)
1733+
if self.on_delete:
1734+
foreign_key_sql += " ON DELETE " + self.on_delete.value
1735+
if self.on_update:
1736+
foreign_key_sql += " ON UPDATE " + self.on_update.value
1737+
return foreign_key_sql
1738+
1739+
17101740
class CreateQueryBuilder:
17111741
"""
17121742
Query builder used to build CREATE queries.
@@ -1729,11 +1759,7 @@ def __init__(self, dialect: Optional[Dialects] = None) -> None:
17291759
self._uniques = []
17301760
self._if_not_exists = False
17311761
self.dialect = dialect
1732-
self._foreign_key = None
1733-
self._foreign_key_reference_table = None
1734-
self._foreign_key_reference = None
1735-
self._foreign_key_on_update: ReferenceOption = None
1736-
self._foreign_key_on_delete: ReferenceOption = None
1762+
self._foreign_keys = []
17371763

17381764
def _set_kwargs_defaults(self, kwargs: dict) -> None:
17391765
kwargs.setdefault("quote_char", self.QUOTE_CHAR)
@@ -1908,19 +1934,19 @@ def foreign_key(
19081934
19091935
Update option.
19101936
1911-
:raises AttributeError:
1912-
If the foreign key is already defined.
1913-
19141937
:return:
19151938
CreateQueryBuilder.
19161939
"""
1917-
if self._foreign_key:
1918-
raise AttributeError("'Query' object already has attribute foreign_key")
1919-
self._foreign_key = self._prepare_columns_input(columns)
1920-
self._foreign_key_reference_table = reference_table
1921-
self._foreign_key_reference = self._prepare_columns_input(reference_columns)
1922-
self._foreign_key_on_delete = on_delete
1923-
self._foreign_key_on_update = on_update
1940+
1941+
self._foreign_keys.append(
1942+
ForeignKey(
1943+
columns=self._prepare_columns_input(columns),
1944+
reference_table=reference_table,
1945+
reference_columns=self._prepare_columns_input(reference_columns),
1946+
on_delete=on_delete,
1947+
on_update=on_update,
1948+
)
1949+
)
19241950

19251951
@builder
19261952
def as_select(self, query_builder: QueryBuilder) -> "CreateQueryBuilder":
@@ -2017,28 +2043,17 @@ def _primary_key_clause(self, **kwargs) -> str:
20172043
columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key)
20182044
)
20192045

2020-
def _foreign_key_clause(self, **kwargs) -> str:
2021-
clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format(
2022-
columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key),
2023-
table_name=self._foreign_key_reference_table.get_sql(**kwargs),
2024-
reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference),
2025-
)
2026-
if self._foreign_key_on_delete:
2027-
clause += " ON DELETE " + self._foreign_key_on_delete.value
2028-
if self._foreign_key_on_update:
2029-
clause += " ON UPDATE " + self._foreign_key_on_update.value
2030-
2031-
return clause
2046+
def _foreign_key_clauses(self, **kwargs) -> str:
2047+
return [foreign_key.get_sql(**kwargs) for foreign_key in self._foreign_keys]
20322048

20332049
def _body_sql(self, **kwargs) -> str:
20342050
clauses = self._column_clauses(**kwargs)
20352051
clauses += self._period_for_clauses(**kwargs)
20362052
clauses += self._unique_key_clauses(**kwargs)
2053+
clauses += self._foreign_key_clauses(**kwargs)
20372054

20382055
if self._primary_key:
20392056
clauses.append(self._primary_key_clause(**kwargs))
2040-
if self._foreign_key:
2041-
clauses.append(self._foreign_key_clause(**kwargs))
20422057

20432058
return ",".join(clauses)
20442059

pypika/tests/test_create.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,31 @@ def test_create_table_with_columns(self):
9999
str(q),
100100
)
101101

102+
with self.subTest("with multiple foreign key constrains"):
103+
secondary_table = Table("secondary_table")
104+
cref, dref, eref = Columns(("c", "INT"), ("d", "VARCHAR(100)"), ("e", "INT"))
105+
q = (
106+
Query.create_table(self.new_table)
107+
.columns(self.foo, self.bar)
108+
.foreign_key([self.foo], self.existing_table, [cref])
109+
.foreign_key(
110+
[self.bar],
111+
secondary_table,
112+
[eref],
113+
on_delete=ReferenceOption.cascade,
114+
on_update=ReferenceOption.restrict,
115+
)
116+
)
117+
118+
self.assertEqual(
119+
'CREATE TABLE "abc" ('
120+
'"a" INT,'
121+
'"b" VARCHAR(100),'
122+
'FOREIGN KEY ("a") REFERENCES "efg" ("c"),'
123+
'FOREIGN KEY ("b") REFERENCES "secondary_table" ("e") ON DELETE CASCADE ON UPDATE RESTRICT)',
124+
str(q),
125+
)
126+
102127
with self.subTest("with unique keys"):
103128
q = (
104129
Query.create_table(self.new_table)
@@ -156,12 +181,6 @@ def test_create_table_with_select_and_columns_fails(self):
156181
with self.assertRaises(AttributeError):
157182
Query.create_table(self.new_table).as_select(select).columns(self.foo, self.bar)
158183

159-
with self.subTest("repeated foreign key"):
160-
with self.assertRaises(AttributeError):
161-
Query.create_table(self.new_table).foreign_key([self.foo], self.existing_table, [self.bar]).foreign_key(
162-
[self.foo], self.existing_table, [self.bar]
163-
)
164-
165184
def test_create_table_as_select_not_query_raises_error(self):
166185
with self.assertRaises(TypeError):
167186
Query.create_table(self.new_table).as_select("abc")

0 commit comments

Comments
 (0)