@@ -1707,6 +1707,36 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl
1707
1707
self .fields = [field .replace_table (current_table , new_table ) for field in self .fields ]
1708
1708
1709
1709
1710
+ class ForeignKey :
1711
+ """Represents a foreign key constraint."""
1712
+
1713
+ def __init__ (
1714
+ self ,
1715
+ columns : List [Column ],
1716
+ reference_table : Union [str , Table ],
1717
+ reference_columns : List [Column ],
1718
+ on_delete : ReferenceOption = None ,
1719
+ on_update : ReferenceOption = None ,
1720
+ ) -> None :
1721
+ self .columns = columns
1722
+ self .reference_table = reference_table
1723
+ self .reference_columns = reference_columns
1724
+ self .on_delete = on_delete
1725
+ self .on_update = on_update
1726
+
1727
+ def get_sql (self , ** kwargs : Any ) -> str :
1728
+ foreign_key_sql = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})" .format (
1729
+ columns = "," .join (column .get_name_sql (** kwargs ) for column in self .columns ),
1730
+ table_name = self .reference_table .get_sql (** kwargs ),
1731
+ reference_columns = "," .join (column .get_name_sql (** kwargs ) for column in self .reference_columns ),
1732
+ )
1733
+ if self .on_delete :
1734
+ foreign_key_sql += " ON DELETE " + self .on_delete .value
1735
+ if self .on_update :
1736
+ foreign_key_sql += " ON UPDATE " + self .on_update .value
1737
+ return foreign_key_sql
1738
+
1739
+
1710
1740
class CreateQueryBuilder :
1711
1741
"""
1712
1742
Query builder used to build CREATE queries.
@@ -1729,11 +1759,7 @@ def __init__(self, dialect: Optional[Dialects] = None) -> None:
1729
1759
self ._uniques = []
1730
1760
self ._if_not_exists = False
1731
1761
self .dialect = dialect
1732
- self ._foreign_key = None
1733
- self ._foreign_key_reference_table = None
1734
- self ._foreign_key_reference = None
1735
- self ._foreign_key_on_update : ReferenceOption = None
1736
- self ._foreign_key_on_delete : ReferenceOption = None
1762
+ self ._foreign_keys = []
1737
1763
1738
1764
def _set_kwargs_defaults (self , kwargs : dict ) -> None :
1739
1765
kwargs .setdefault ("quote_char" , self .QUOTE_CHAR )
@@ -1908,19 +1934,19 @@ def foreign_key(
1908
1934
1909
1935
Update option.
1910
1936
1911
- :raises AttributeError:
1912
- If the foreign key is already defined.
1913
-
1914
1937
:return:
1915
1938
CreateQueryBuilder.
1916
1939
"""
1917
- if self ._foreign_key :
1918
- raise AttributeError ("'Query' object already has attribute foreign_key" )
1919
- self ._foreign_key = self ._prepare_columns_input (columns )
1920
- self ._foreign_key_reference_table = reference_table
1921
- self ._foreign_key_reference = self ._prepare_columns_input (reference_columns )
1922
- self ._foreign_key_on_delete = on_delete
1923
- self ._foreign_key_on_update = on_update
1940
+
1941
+ self ._foreign_keys .append (
1942
+ ForeignKey (
1943
+ columns = self ._prepare_columns_input (columns ),
1944
+ reference_table = reference_table ,
1945
+ reference_columns = self ._prepare_columns_input (reference_columns ),
1946
+ on_delete = on_delete ,
1947
+ on_update = on_update ,
1948
+ )
1949
+ )
1924
1950
1925
1951
@builder
1926
1952
def as_select (self , query_builder : QueryBuilder ) -> "CreateQueryBuilder" :
@@ -2017,28 +2043,17 @@ def _primary_key_clause(self, **kwargs) -> str:
2017
2043
columns = "," .join (column .get_name_sql (** kwargs ) for column in self ._primary_key )
2018
2044
)
2019
2045
2020
- def _foreign_key_clause (self , ** kwargs ) -> str :
2021
- clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})" .format (
2022
- columns = "," .join (column .get_name_sql (** kwargs ) for column in self ._foreign_key ),
2023
- table_name = self ._foreign_key_reference_table .get_sql (** kwargs ),
2024
- reference_columns = "," .join (column .get_name_sql (** kwargs ) for column in self ._foreign_key_reference ),
2025
- )
2026
- if self ._foreign_key_on_delete :
2027
- clause += " ON DELETE " + self ._foreign_key_on_delete .value
2028
- if self ._foreign_key_on_update :
2029
- clause += " ON UPDATE " + self ._foreign_key_on_update .value
2030
-
2031
- return clause
2046
+ def _foreign_key_clauses (self , ** kwargs ) -> str :
2047
+ return [foreign_key .get_sql (** kwargs ) for foreign_key in self ._foreign_keys ]
2032
2048
2033
2049
def _body_sql (self , ** kwargs ) -> str :
2034
2050
clauses = self ._column_clauses (** kwargs )
2035
2051
clauses += self ._period_for_clauses (** kwargs )
2036
2052
clauses += self ._unique_key_clauses (** kwargs )
2053
+ clauses += self ._foreign_key_clauses (** kwargs )
2037
2054
2038
2055
if self ._primary_key :
2039
2056
clauses .append (self ._primary_key_clause (** kwargs ))
2040
- if self ._foreign_key :
2041
- clauses .append (self ._foreign_key_clause (** kwargs ))
2042
2057
2043
2058
return "," .join (clauses )
2044
2059
0 commit comments