Skip to content

Commit cfc92e3

Browse files
committed
wip
1 parent abca00d commit cfc92e3

File tree

9 files changed

+959
-9
lines changed

9 files changed

+959
-9
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""Example showing how to use SQLSpec session backend with Litestar."""
2+
3+
from litestar import Litestar, get, post
4+
from litestar.config.session import SessionConfig
5+
from litestar.datastructures import State
6+
7+
from sqlspec.adapters.sqlite.config import SqliteConfig
8+
from sqlspec.extensions.litestar import SQLSpec, SQLSpecSessionBackend
9+
10+
# Configure SQLSpec with SQLite database
11+
sqlite_config = SqliteConfig(
12+
pool_config={"database": "sessions.db"},
13+
migration_config={"script_location": "migrations", "version_table_name": "sqlspec_migrations"},
14+
)
15+
16+
# Create SQLSpec plugin
17+
sqlspec_plugin = SQLSpec(sqlite_config)
18+
19+
# Create session backend using SQLSpec
20+
session_backend = SQLSpecSessionBackend(
21+
config=sqlite_config,
22+
table_name="user_sessions",
23+
session_lifetime=3600, # 1 hour
24+
)
25+
26+
# Configure session middleware
27+
session_config = SessionConfig(
28+
backend=session_backend,
29+
cookie_https_only=False, # Set to True in production
30+
cookie_secure=False, # Set to True in production with HTTPS
31+
cookie_domain="localhost",
32+
cookie_path="/",
33+
cookie_max_age=3600,
34+
cookie_same_site="lax",
35+
cookie_http_only=True,
36+
session_cookie_name="sqlspec_session",
37+
)
38+
39+
40+
@get("/")
41+
async def index() -> dict[str, str]:
42+
"""Homepage route."""
43+
return {"message": "SQLSpec Session Example"}
44+
45+
46+
@get("/login")
47+
async def login_form() -> str:
48+
"""Simple login form."""
49+
return """
50+
<html>
51+
<body>
52+
<h2>Login</h2>
53+
<form method="post" action="/login">
54+
<input type="text" name="username" placeholder="Username" required>
55+
<input type="password" name="password" placeholder="Password" required>
56+
<button type="submit">Login</button>
57+
</form>
58+
</body>
59+
</html>
60+
"""
61+
62+
63+
@post("/login")
64+
async def login(data: dict[str, str], request) -> dict[str, str]:
65+
"""Handle login and create session."""
66+
username = data.get("username")
67+
password = data.get("password")
68+
69+
# Simple authentication (use proper auth in production)
70+
if username == "admin" and password == "secret":
71+
# Store user data in session
72+
request.set_session(
73+
{"user_id": 1, "username": username, "login_time": "2024-01-01T12:00:00Z", "roles": ["admin", "user"]}
74+
)
75+
return {"message": f"Welcome, {username}!"}
76+
77+
return {"error": "Invalid credentials"}
78+
79+
80+
@get("/profile")
81+
async def profile(request) -> dict[str, str]:
82+
"""User profile route - requires session."""
83+
session_data = request.session
84+
85+
if not session_data or "user_id" not in session_data:
86+
return {"error": "Not logged in"}
87+
88+
return {
89+
"user_id": session_data["user_id"],
90+
"username": session_data["username"],
91+
"login_time": session_data["login_time"],
92+
"roles": session_data["roles"],
93+
}
94+
95+
96+
@post("/logout")
97+
async def logout(request) -> dict[str, str]:
98+
"""Logout and clear session."""
99+
request.clear_session()
100+
return {"message": "Logged out successfully"}
101+
102+
103+
@get("/admin/sessions")
104+
async def admin_sessions(request, state: State) -> dict[str, any]:
105+
"""Admin route to view all active sessions."""
106+
session_data = request.session
107+
108+
if not session_data or "admin" not in session_data.get("roles", []):
109+
return {"error": "Admin access required"}
110+
111+
# Get session backend from state
112+
backend = session_backend
113+
session_ids = await backend.get_all_session_ids()
114+
115+
return {
116+
"active_sessions": len(session_ids),
117+
"session_ids": session_ids[:10], # Limit to first 10 for display
118+
}
119+
120+
121+
@post("/admin/cleanup")
122+
async def cleanup_sessions(request, state: State) -> dict[str, str]:
123+
"""Admin route to clean up expired sessions."""
124+
session_data = request.session
125+
126+
if not session_data or "admin" not in session_data.get("roles", []):
127+
return {"error": "Admin access required"}
128+
129+
# Clean up expired sessions
130+
backend = session_backend
131+
await backend.delete_expired_sessions()
132+
133+
return {"message": "Expired sessions cleaned up"}
134+
135+
136+
# Create Litestar application
137+
app = Litestar(
138+
route_handlers=[index, login_form, login, profile, logout, admin_sessions, cleanup_sessions],
139+
plugins=[sqlspec_plugin],
140+
session_config=session_config,
141+
debug=True,
142+
)
143+
144+
145+
if __name__ == "__main__":
146+
import uvicorn
147+
148+
print("Starting SQLSpec Session Example...")
149+
print("Visit http://localhost:8000 to view the application")
150+
print("Login with username 'admin' and password 'secret'")
151+
152+
uvicorn.run(app, host="0.0.0.0", port=8000)

sqlspec/_typing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,8 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
608608
FSSPEC_INSTALLED = bool(find_spec("fsspec"))
609609
OBSTORE_INSTALLED = bool(find_spec("obstore"))
610610
PGVECTOR_INSTALLED = bool(find_spec("pgvector"))
611-
611+
UUID_UTILS_INSTALLED = bool(find_spec("uuid_utils"))
612+
NANOID_INSTALLED = bool(find_spec("fastnanoid"))
612613

613614
__all__ = (
614615
"AIOSQL_INSTALLED",
@@ -617,6 +618,7 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
617618
"FSSPEC_INSTALLED",
618619
"LITESTAR_INSTALLED",
619620
"MSGSPEC_INSTALLED",
621+
"NANOID_INSTALLED",
620622
"OBSTORE_INSTALLED",
621623
"OPENTELEMETRY_INSTALLED",
622624
"PGVECTOR_INSTALLED",
@@ -625,6 +627,7 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
625627
"PYDANTIC_INSTALLED",
626628
"UNSET",
627629
"UNSET_STUB",
630+
"UUID_UTILS_INSTALLED",
628631
"AiosqlAsyncProtocol",
629632
"AiosqlParamType",
630633
"AiosqlProtocol",

sqlspec/adapters/oracledb/driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sqlspec.core.statement import StatementConfig
1313
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
1414
from sqlspec.exceptions import SQLParsingError, SQLSpecError
15+
from sqlspec.utils.serializers import to_json
1516

1617
if TYPE_CHECKING:
1718
from contextlib import AbstractAsyncContextManager, AbstractContextManager
@@ -38,7 +39,7 @@
3839
supported_parameter_styles={ParameterStyle.NAMED_COLON, ParameterStyle.POSITIONAL_COLON, ParameterStyle.QMARK},
3940
default_execution_parameter_style=ParameterStyle.POSITIONAL_COLON,
4041
supported_execution_parameter_styles={ParameterStyle.NAMED_COLON, ParameterStyle.POSITIONAL_COLON},
41-
type_coercion_map={},
42+
type_coercion_map={dict: to_json, list: to_json},
4243
has_native_list_expansion=False,
4344
needs_static_script_compilation=True,
4445
preserve_parameter_format=True,

sqlspec/adapters/oracledb/migrations.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class OracleMigrationTrackerMixin:
2626
__slots__ = ()
2727

2828
version_table: str
29+
_table_initialized: bool
2930

3031
def _get_create_table_sql(self) -> CreateTable:
3132
"""Get Oracle-specific SQL builder for creating the tracking table.
@@ -52,16 +53,28 @@ def _get_create_table_sql(self) -> CreateTable:
5253
class OracleSyncMigrationTracker(OracleMigrationTrackerMixin, BaseMigrationTracker["SyncDriverAdapterBase"]):
5354
"""Oracle-specific sync migration tracker."""
5455

55-
__slots__ = ()
56+
__slots__ = ("_table_initialized",)
57+
58+
def __init__(self, version_table_name: str = "ddl_migrations") -> None:
59+
"""Initialize the Oracle sync migration tracker.
60+
61+
Args:
62+
version_table_name: Name of the table to track migrations.
63+
"""
64+
super().__init__(version_table_name)
65+
self._table_initialized = False
5666

5767
def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None:
5868
"""Create the migration tracking table if it doesn't exist.
5969
60-
Oracle doesn't support IF NOT EXISTS, so we check for table existence first.
70+
Uses caching to avoid repeated database queries for table existence.
71+
This is critical for performance in ASGI frameworks where this might be called on every request.
6172
6273
Args:
6374
driver: The database driver to use.
6475
"""
76+
if self._table_initialized:
77+
return
6578

6679
check_sql = (
6780
sql.select(sql.count().as_("table_count"))
@@ -74,6 +87,8 @@ def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None:
7487
driver.execute(self._get_create_table_sql())
7588
self._safe_commit(driver)
7689

90+
self._table_initialized = True
91+
7792
def get_current_version(self, driver: "SyncDriverAdapterBase") -> "Optional[str]":
7893
"""Get the latest applied migration version.
7994
@@ -156,16 +171,28 @@ def _safe_commit(self, driver: "SyncDriverAdapterBase") -> None:
156171
class OracleAsyncMigrationTracker(OracleMigrationTrackerMixin, BaseMigrationTracker["AsyncDriverAdapterBase"]):
157172
"""Oracle-specific async migration tracker."""
158173

159-
__slots__ = ()
174+
__slots__ = ("_table_initialized",)
175+
176+
def __init__(self, version_table_name: str = "ddl_migrations") -> None:
177+
"""Initialize the Oracle async migration tracker.
178+
179+
Args:
180+
version_table_name: Name of the table to track migrations.
181+
"""
182+
super().__init__(version_table_name)
183+
self._table_initialized = False
160184

161185
async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None:
162186
"""Create the migration tracking table if it doesn't exist.
163187
164-
Oracle doesn't support IF NOT EXISTS, so we check for table existence first.
188+
Uses caching to avoid repeated database queries for table existence.
189+
This is critical for performance in ASGI frameworks where this might be called on every request.
165190
166191
Args:
167192
driver: The database driver to use.
168193
"""
194+
if self._table_initialized:
195+
return
169196

170197
check_sql = (
171198
sql.select(sql.count().as_("table_count"))
@@ -178,6 +205,8 @@ async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None:
178205
await driver.execute(self._get_create_table_sql())
179206
await self._safe_commit_async(driver)
180207

208+
self._table_initialized = True
209+
181210
async def get_current_version(self, driver: "AsyncDriverAdapterBase") -> "Optional[str]":
182211
"""Get the latest applied migration version.
183212

sqlspec/extensions/litestar/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,16 @@
22
from sqlspec.extensions.litestar.cli import database_group
33
from sqlspec.extensions.litestar.config import DatabaseConfig
44
from sqlspec.extensions.litestar.plugin import SQLSpec
5+
from sqlspec.extensions.litestar.session import SQLSpecSessionBackend
6+
from sqlspec.extensions.litestar.store import SQLSpecSessionStore, SQLSpecSessionStoreError
57

6-
__all__ = ("DatabaseConfig", "SQLSpec", "database_group", "handlers", "providers")
8+
__all__ = (
9+
"DatabaseConfig",
10+
"SQLSpec",
11+
"SQLSpecSessionBackend",
12+
"SQLSpecSessionStore",
13+
"SQLSpecSessionStoreError",
14+
"database_group",
15+
"handlers",
16+
"providers",
17+
)

0 commit comments

Comments
 (0)