Skip to content

Commit 082185b

Browse files
committed
wip
1 parent cfc92e3 commit 082185b

File tree

42 files changed

+9605
-37
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+9605
-37
lines changed

sqlspec/extensions/litestar/session.py

Lines changed: 159 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,62 @@
33
from typing import TYPE_CHECKING, Any, Optional, Union
44

55
from litestar.middleware.session.base import BaseSessionBackend
6+
from litestar.types import Scopes
67

7-
from sqlspec.extensions.litestar.store import SessionStore
8+
from sqlspec.extensions.litestar.store import SQLSpecSessionStore
89
from sqlspec.utils.logging import get_logger
910

1011
if TYPE_CHECKING:
1112
from litestar.connection import ASGIConnection
13+
from litestar.types import Message, ScopeSession
1214

1315
from sqlspec.config import AsyncConfigT, DatabaseConfigProtocol, SyncConfigT
1416

1517
logger = get_logger("extensions.litestar.session")
1618

17-
__all__ = ("SQLSpecSessionBackend",)
19+
__all__ = ("SQLSpecSessionBackend", "SQLSpecSessionConfig")
20+
21+
22+
class SQLSpecSessionConfig:
23+
"""Configuration for SQLSpec session backend."""
24+
25+
def __init__(
26+
self,
27+
key: str = "session",
28+
max_age: int = 1209600, # 14 days
29+
path: str = "/",
30+
domain: Optional[str] = None,
31+
secure: bool = False,
32+
httponly: bool = True,
33+
samesite: str = "lax",
34+
exclude: Optional[Union[str, list[str]]] = None,
35+
exclude_opt_key: str = "skip_session",
36+
scopes: Scopes = frozenset({"http", "websocket"}),
37+
) -> None:
38+
"""Initialize session configuration.
39+
40+
Args:
41+
key: Cookie key name
42+
max_age: Cookie max age in seconds
43+
path: Cookie path
44+
domain: Cookie domain
45+
secure: Require HTTPS for cookie
46+
httponly: Make cookie HTTP-only
47+
samesite: SameSite policy for cookie
48+
exclude: Patterns to exclude from session middleware
49+
exclude_opt_key: Key to opt out of session middleware
50+
scopes: Scopes where session middleware applies
51+
"""
52+
self.key = key
53+
self.max_age = max_age
54+
self.path = path
55+
self.domain = domain
56+
self.secure = secure
57+
self.httponly = httponly
58+
self.samesite = samesite
59+
self.exclude = exclude
60+
self.exclude_opt_key = exclude_opt_key
61+
self.scopes = scopes
1862

1963

2064
class SQLSpecSessionBackend(BaseSessionBackend):
@@ -24,7 +68,7 @@ class SQLSpecSessionBackend(BaseSessionBackend):
2468
middleware, providing transparent session management with database persistence.
2569
"""
2670

27-
__slots__ = ("_session_id_generator", "_session_lifetime", "_store")
71+
__slots__ = ("_session_id_generator", "_session_lifetime", "_store", "config")
2872

2973
def __init__(
3074
self,
@@ -36,6 +80,7 @@ def __init__(
3680
expires_at_column: str = "expires_at",
3781
created_at_column: str = "created_at",
3882
session_lifetime: int = 24 * 60 * 60, # 24 hours
83+
session_config: Optional[SQLSpecSessionConfig] = None,
3984
) -> None:
4085
"""Initialize the session backend.
4186
@@ -47,17 +92,19 @@ def __init__(
4792
expires_at_column: Name of the expires at column
4893
created_at_column: Name of the created at column
4994
session_lifetime: Default session lifetime in seconds
95+
session_config: Session configuration for middleware
5096
"""
51-
self._store = SessionStore(
97+
self._store = SQLSpecSessionStore(
5298
config,
5399
table_name=table_name,
54100
session_id_column=session_id_column,
55101
data_column=data_column,
56102
expires_at_column=expires_at_column,
57103
created_at_column=created_at_column,
58104
)
59-
self._session_id_generator = SessionStore.generate_session_id
105+
self._session_id_generator = SQLSpecSessionStore.generate_session_id
60106
self._session_lifetime = session_lifetime
107+
self.config = session_config or SQLSpecSessionConfig()
61108

62109
async def load_from_connection(self, connection: "ASGIConnection[Any, Any, Any, Any]") -> dict[str, Any]:
63110
"""Load session data from the connection.
@@ -110,9 +157,112 @@ def get_session_id(self, connection: "ASGIConnection[Any, Any, Any, Any]") -> Op
110157
Returns:
111158
Session identifier if found
112159
"""
113-
# Look for session ID in cookies
114-
session_cookie_name = getattr(connection.app.session_config, "session_cookie_name", "session") # type: ignore[union-attr]
115-
return connection.cookies.get(session_cookie_name)
160+
# Try to get session ID from cookies using the config key
161+
session_id = connection.cookies.get(self.config.key)
162+
if session_id and session_id != "null":
163+
return session_id
164+
165+
# Fallback to getting session ID from connection state
166+
session_id = connection.get_session_id()
167+
if session_id:
168+
return session_id
169+
170+
return None
171+
172+
async def store_in_message(
173+
self, scope_session: "ScopeSession", message: "Message", connection: "ASGIConnection[Any, Any, Any, Any]"
174+
) -> None:
175+
"""Store session information in the outgoing message.
176+
177+
For server-side sessions, this method sets a cookie containing the session ID.
178+
If the session is empty, a null-cookie will be set to clear any existing session.
179+
180+
Args:
181+
scope_session: Current session data to store
182+
message: Outgoing ASGI message to modify
183+
connection: ASGI connection instance
184+
"""
185+
if message["type"] != "http.response.start":
186+
return
187+
188+
cookie_key = self.config.key
189+
190+
# If session is empty, set a null cookie to clear any existing session
191+
if not scope_session:
192+
cookie_value = self._build_cookie_value(
193+
key=cookie_key,
194+
value="null",
195+
max_age=0,
196+
path=self.config.path,
197+
domain=self.config.domain,
198+
secure=self.config.secure,
199+
httponly=self.config.httponly,
200+
samesite=self.config.samesite,
201+
)
202+
self._add_cookie_to_message(message, cookie_value)
203+
return
204+
205+
# Get or generate session ID
206+
session_id = self.get_session_id(connection)
207+
if not session_id:
208+
session_id = self._session_id_generator()
209+
210+
# Store session data in the backend
211+
try:
212+
await self._store.set(session_id, scope_session, expires_in=self._session_lifetime)
213+
except Exception:
214+
logger.exception("Failed to store session data for session %s", session_id)
215+
# Don't set the cookie if we failed to store the data
216+
return
217+
218+
# Set the session ID cookie
219+
cookie_value = self._build_cookie_value(
220+
key=cookie_key,
221+
value=session_id,
222+
max_age=self.config.max_age,
223+
path=self.config.path,
224+
domain=self.config.domain,
225+
secure=self.config.secure,
226+
httponly=self.config.httponly,
227+
samesite=self.config.samesite,
228+
)
229+
self._add_cookie_to_message(message, cookie_value)
230+
231+
def _build_cookie_value(
232+
self,
233+
key: str,
234+
value: str,
235+
max_age: Optional[int] = None,
236+
path: Optional[str] = None,
237+
domain: Optional[str] = None,
238+
secure: bool = False,
239+
httponly: bool = False,
240+
samesite: Optional[str] = None,
241+
) -> str:
242+
"""Build a cookie value string with attributes."""
243+
cookie_parts = [f"{key}={value}"]
244+
245+
if path:
246+
cookie_parts.append(f"Path={path}")
247+
if domain:
248+
cookie_parts.append(f"Domain={domain}")
249+
if max_age is not None:
250+
cookie_parts.append(f"Max-Age={max_age}")
251+
if secure:
252+
cookie_parts.append("Secure")
253+
if httponly:
254+
cookie_parts.append("HttpOnly")
255+
if samesite:
256+
cookie_parts.append(f"SameSite={samesite}")
257+
258+
return "; ".join(cookie_parts)
259+
260+
def _add_cookie_to_message(self, message: "Message", cookie_value: str) -> None:
261+
"""Add a Set-Cookie header to the ASGI message."""
262+
if message["type"] == "http.response.start":
263+
headers = list(message.get("headers", []))
264+
headers.append([b"set-cookie", cookie_value.encode()])
265+
message["headers"] = headers
116266

117267
async def delete_session(self, session_id: str) -> None:
118268
"""Delete a session.
@@ -152,7 +302,7 @@ async def get_all_session_ids(self) -> list[str]:
152302
return session_ids
153303

154304
@property
155-
def store(self) -> SessionStore:
305+
def store(self) -> SQLSpecSessionStore:
156306
"""Get the underlying session store.
157307
158308
Returns:

sqlspec/extensions/litestar/store.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222

2323
logger = get_logger("extensions.litestar.store")
2424

25-
__all__ = ("SessionStore", "SessionStoreError")
25+
__all__ = ("SQLSpecSessionStore", "SQLSpecSessionStoreError")
2626

2727

28-
class SessionStoreError(SQLSpecError):
28+
class SQLSpecSessionStoreError(SQLSpecError):
2929
"""Exception raised by session store operations."""
3030

3131

32-
class SessionStore(Store):
32+
class SQLSpecSessionStore(Store):
3333
"""SQLSpec-based session store for Litestar.
3434
3535
This store uses SQLSpec's builder API to create dialect-aware SQL operations
@@ -129,7 +129,7 @@ async def _ensure_table_exists(self, driver: Union[SyncDriverAdapterBase, AsyncD
129129
except Exception as e:
130130
msg = f"Failed to create session table: {e}"
131131
logger.exception("Failed to create session table %s", self._table_name)
132-
raise SessionStoreError(msg) from e
132+
raise SQLSpecSessionStoreError(msg) from e
133133

134134
def _get_dialect_upsert_sql(self, dialect: str, session_id: str, data: str, expires_at: datetime) -> Any:
135135
"""Generate dialect-specific upsert SQL using SQL builder API.
@@ -152,12 +152,10 @@ def _get_dialect_upsert_sql(self, dialect: str, session_id: str, data: str, expi
152152
.columns(self._session_id_column, self._data_column, self._expires_at_column, self._created_at_column)
153153
.values(session_id, data, expires_at, current_time)
154154
.on_conflict(self._session_id_column)
155-
.do_update(
156-
**{
157-
self._data_column: sql.raw("EXCLUDED." + self._data_column),
158-
self._expires_at_column: sql.raw("EXCLUDED." + self._expires_at_column),
159-
}
160-
)
155+
.do_update(**{
156+
self._data_column: sql.raw("EXCLUDED." + self._data_column),
157+
self._expires_at_column: sql.raw("EXCLUDED." + self._expires_at_column),
158+
})
161159
)
162160

163161
if dialect in {"mysql", "mariadb"}:
@@ -166,12 +164,10 @@ def _get_dialect_upsert_sql(self, dialect: str, session_id: str, data: str, expi
166164
sql.insert(self._table_name)
167165
.columns(self._session_id_column, self._data_column, self._expires_at_column, self._created_at_column)
168166
.values(session_id, data, expires_at, current_time)
169-
.on_duplicate_key_update(
170-
**{
171-
self._data_column: sql.raw(f"VALUES({self._data_column})"),
172-
self._expires_at_column: sql.raw(f"VALUES({self._expires_at_column})"),
173-
}
174-
)
167+
.on_duplicate_key_update(**{
168+
self._data_column: sql.raw(f"VALUES({self._data_column})"),
169+
self._expires_at_column: sql.raw(f"VALUES({self._expires_at_column})"),
170+
})
175171
)
176172

177173
if dialect == "sqlite":
@@ -181,12 +177,10 @@ def _get_dialect_upsert_sql(self, dialect: str, session_id: str, data: str, expi
181177
.columns(self._session_id_column, self._data_column, self._expires_at_column, self._created_at_column)
182178
.values(session_id, data, expires_at, current_time)
183179
.on_conflict(self._session_id_column)
184-
.do_update(
185-
**{
186-
self._data_column: sql.raw("EXCLUDED." + self._data_column),
187-
self._expires_at_column: sql.raw("EXCLUDED." + self._expires_at_column),
188-
}
189-
)
180+
.do_update(**{
181+
self._data_column: sql.raw("EXCLUDED." + self._data_column),
182+
self._expires_at_column: sql.raw("EXCLUDED." + self._expires_at_column),
183+
})
190184
)
191185

192186
if dialect == "oracle":
@@ -363,7 +357,7 @@ async def _set_session_data(
363357
except Exception as e:
364358
msg = f"Failed to store session: {e}"
365359
logger.exception("Failed to store session %s", key)
366-
raise SessionStoreError(msg) from e
360+
raise SQLSpecSessionStoreError(msg) from e
367361

368362
async def delete(self, key: str) -> None:
369363
"""Delete session data.
@@ -392,7 +386,7 @@ async def _delete_session_data(
392386
except Exception as e:
393387
msg = f"Failed to delete session: {e}"
394388
logger.exception("Failed to delete session %s", key)
395-
raise SessionStoreError(msg) from e
389+
raise SQLSpecSessionStoreError(msg) from e
396390

397391
async def exists(self, key: str) -> bool:
398392
"""Check if a session exists and is not expired.
@@ -469,11 +463,9 @@ async def expires_in(self, key: str) -> int:
469463
delta = expires_at - current_time
470464
return max(0, int(delta.total_seconds()))
471465

472-
return 0
473-
474466
except Exception:
475467
logger.exception("Failed to get expires_in for session %s", key)
476-
return 0
468+
return 0
477469

478470
async def delete_all(self, pattern: str = "*") -> None:
479471
"""Delete all sessions matching pattern.
@@ -499,7 +491,7 @@ async def _delete_all_sessions(self, driver: Union[SyncDriverAdapterBase, AsyncD
499491
except Exception as e:
500492
msg = f"Failed to delete all sessions: {e}"
501493
logger.exception("Failed to delete all sessions")
502-
raise SessionStoreError(msg) from e
494+
raise SQLSpecSessionStoreError(msg) from e
503495

504496
async def delete_expired(self) -> None:
505497
"""Delete expired sessions."""
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import pytest
2+
3+
pytestmark = [pytest.mark.adbc, pytest.mark.postgres]

0 commit comments

Comments
 (0)