3
3
from typing import TYPE_CHECKING , Any , Optional , Union
4
4
5
5
from litestar .middleware .session .base import BaseSessionBackend
6
+ from litestar .types import Scopes
6
7
7
- from sqlspec .extensions .litestar .store import SessionStore
8
+ from sqlspec .extensions .litestar .store import SQLSpecSessionStore
8
9
from sqlspec .utils .logging import get_logger
9
10
10
11
if TYPE_CHECKING :
11
12
from litestar .connection import ASGIConnection
13
+ from litestar .types import Message , ScopeSession
12
14
13
15
from sqlspec .config import AsyncConfigT , DatabaseConfigProtocol , SyncConfigT
14
16
15
17
logger = get_logger ("extensions.litestar.session" )
16
18
17
- __all__ = ("SQLSpecSessionBackend" ,)
19
+ __all__ = ("SQLSpecSessionBackend" , "SQLSpecSessionConfig" )
20
+
21
+
22
+ class SQLSpecSessionConfig :
23
+ """Configuration for SQLSpec session backend."""
24
+
25
+ def __init__ (
26
+ self ,
27
+ key : str = "session" ,
28
+ max_age : int = 1209600 , # 14 days
29
+ path : str = "/" ,
30
+ domain : Optional [str ] = None ,
31
+ secure : bool = False ,
32
+ httponly : bool = True ,
33
+ samesite : str = "lax" ,
34
+ exclude : Optional [Union [str , list [str ]]] = None ,
35
+ exclude_opt_key : str = "skip_session" ,
36
+ scopes : Scopes = frozenset ({"http" , "websocket" }),
37
+ ) -> None :
38
+ """Initialize session configuration.
39
+
40
+ Args:
41
+ key: Cookie key name
42
+ max_age: Cookie max age in seconds
43
+ path: Cookie path
44
+ domain: Cookie domain
45
+ secure: Require HTTPS for cookie
46
+ httponly: Make cookie HTTP-only
47
+ samesite: SameSite policy for cookie
48
+ exclude: Patterns to exclude from session middleware
49
+ exclude_opt_key: Key to opt out of session middleware
50
+ scopes: Scopes where session middleware applies
51
+ """
52
+ self .key = key
53
+ self .max_age = max_age
54
+ self .path = path
55
+ self .domain = domain
56
+ self .secure = secure
57
+ self .httponly = httponly
58
+ self .samesite = samesite
59
+ self .exclude = exclude
60
+ self .exclude_opt_key = exclude_opt_key
61
+ self .scopes = scopes
18
62
19
63
20
64
class SQLSpecSessionBackend (BaseSessionBackend ):
@@ -24,7 +68,7 @@ class SQLSpecSessionBackend(BaseSessionBackend):
24
68
middleware, providing transparent session management with database persistence.
25
69
"""
26
70
27
- __slots__ = ("_session_id_generator" , "_session_lifetime" , "_store" )
71
+ __slots__ = ("_session_id_generator" , "_session_lifetime" , "_store" , "config" )
28
72
29
73
def __init__ (
30
74
self ,
@@ -36,6 +80,7 @@ def __init__(
36
80
expires_at_column : str = "expires_at" ,
37
81
created_at_column : str = "created_at" ,
38
82
session_lifetime : int = 24 * 60 * 60 , # 24 hours
83
+ session_config : Optional [SQLSpecSessionConfig ] = None ,
39
84
) -> None :
40
85
"""Initialize the session backend.
41
86
@@ -47,17 +92,19 @@ def __init__(
47
92
expires_at_column: Name of the expires at column
48
93
created_at_column: Name of the created at column
49
94
session_lifetime: Default session lifetime in seconds
95
+ session_config: Session configuration for middleware
50
96
"""
51
- self ._store = SessionStore (
97
+ self ._store = SQLSpecSessionStore (
52
98
config ,
53
99
table_name = table_name ,
54
100
session_id_column = session_id_column ,
55
101
data_column = data_column ,
56
102
expires_at_column = expires_at_column ,
57
103
created_at_column = created_at_column ,
58
104
)
59
- self ._session_id_generator = SessionStore .generate_session_id
105
+ self ._session_id_generator = SQLSpecSessionStore .generate_session_id
60
106
self ._session_lifetime = session_lifetime
107
+ self .config = session_config or SQLSpecSessionConfig ()
61
108
62
109
async def load_from_connection (self , connection : "ASGIConnection[Any, Any, Any, Any]" ) -> dict [str , Any ]:
63
110
"""Load session data from the connection.
@@ -110,9 +157,112 @@ def get_session_id(self, connection: "ASGIConnection[Any, Any, Any, Any]") -> Op
110
157
Returns:
111
158
Session identifier if found
112
159
"""
113
- # Look for session ID in cookies
114
- session_cookie_name = getattr (connection .app .session_config , "session_cookie_name" , "session" ) # type: ignore[union-attr]
115
- return connection .cookies .get (session_cookie_name )
160
+ # Try to get session ID from cookies using the config key
161
+ session_id = connection .cookies .get (self .config .key )
162
+ if session_id and session_id != "null" :
163
+ return session_id
164
+
165
+ # Fallback to getting session ID from connection state
166
+ session_id = connection .get_session_id ()
167
+ if session_id :
168
+ return session_id
169
+
170
+ return None
171
+
172
+ async def store_in_message (
173
+ self , scope_session : "ScopeSession" , message : "Message" , connection : "ASGIConnection[Any, Any, Any, Any]"
174
+ ) -> None :
175
+ """Store session information in the outgoing message.
176
+
177
+ For server-side sessions, this method sets a cookie containing the session ID.
178
+ If the session is empty, a null-cookie will be set to clear any existing session.
179
+
180
+ Args:
181
+ scope_session: Current session data to store
182
+ message: Outgoing ASGI message to modify
183
+ connection: ASGI connection instance
184
+ """
185
+ if message ["type" ] != "http.response.start" :
186
+ return
187
+
188
+ cookie_key = self .config .key
189
+
190
+ # If session is empty, set a null cookie to clear any existing session
191
+ if not scope_session :
192
+ cookie_value = self ._build_cookie_value (
193
+ key = cookie_key ,
194
+ value = "null" ,
195
+ max_age = 0 ,
196
+ path = self .config .path ,
197
+ domain = self .config .domain ,
198
+ secure = self .config .secure ,
199
+ httponly = self .config .httponly ,
200
+ samesite = self .config .samesite ,
201
+ )
202
+ self ._add_cookie_to_message (message , cookie_value )
203
+ return
204
+
205
+ # Get or generate session ID
206
+ session_id = self .get_session_id (connection )
207
+ if not session_id :
208
+ session_id = self ._session_id_generator ()
209
+
210
+ # Store session data in the backend
211
+ try :
212
+ await self ._store .set (session_id , scope_session , expires_in = self ._session_lifetime )
213
+ except Exception :
214
+ logger .exception ("Failed to store session data for session %s" , session_id )
215
+ # Don't set the cookie if we failed to store the data
216
+ return
217
+
218
+ # Set the session ID cookie
219
+ cookie_value = self ._build_cookie_value (
220
+ key = cookie_key ,
221
+ value = session_id ,
222
+ max_age = self .config .max_age ,
223
+ path = self .config .path ,
224
+ domain = self .config .domain ,
225
+ secure = self .config .secure ,
226
+ httponly = self .config .httponly ,
227
+ samesite = self .config .samesite ,
228
+ )
229
+ self ._add_cookie_to_message (message , cookie_value )
230
+
231
+ def _build_cookie_value (
232
+ self ,
233
+ key : str ,
234
+ value : str ,
235
+ max_age : Optional [int ] = None ,
236
+ path : Optional [str ] = None ,
237
+ domain : Optional [str ] = None ,
238
+ secure : bool = False ,
239
+ httponly : bool = False ,
240
+ samesite : Optional [str ] = None ,
241
+ ) -> str :
242
+ """Build a cookie value string with attributes."""
243
+ cookie_parts = [f"{ key } ={ value } " ]
244
+
245
+ if path :
246
+ cookie_parts .append (f"Path={ path } " )
247
+ if domain :
248
+ cookie_parts .append (f"Domain={ domain } " )
249
+ if max_age is not None :
250
+ cookie_parts .append (f"Max-Age={ max_age } " )
251
+ if secure :
252
+ cookie_parts .append ("Secure" )
253
+ if httponly :
254
+ cookie_parts .append ("HttpOnly" )
255
+ if samesite :
256
+ cookie_parts .append (f"SameSite={ samesite } " )
257
+
258
+ return "; " .join (cookie_parts )
259
+
260
+ def _add_cookie_to_message (self , message : "Message" , cookie_value : str ) -> None :
261
+ """Add a Set-Cookie header to the ASGI message."""
262
+ if message ["type" ] == "http.response.start" :
263
+ headers = list (message .get ("headers" , []))
264
+ headers .append ([b"set-cookie" , cookie_value .encode ()])
265
+ message ["headers" ] = headers
116
266
117
267
async def delete_session (self , session_id : str ) -> None :
118
268
"""Delete a session.
@@ -152,7 +302,7 @@ async def get_all_session_ids(self) -> list[str]:
152
302
return session_ids
153
303
154
304
@property
155
- def store (self ) -> SessionStore :
305
+ def store (self ) -> SQLSpecSessionStore :
156
306
"""Get the underlying session store.
157
307
158
308
Returns:
0 commit comments