Skip to content

Commit

Permalink
Add app command cleanup, abstract methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Lxstr committed Feb 7, 2024
1 parent cb36bde commit ec78cb6
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 47 deletions.
6 changes: 4 additions & 2 deletions src/flask_session/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class Defaults:
SESSION_PERMANENT = True
SESSION_SID_LENGTH = 3

# Clean up settings for non TTL backends (SQL, PostgreSQL, etc.)
SESSION_CLEANUP_N_REQUESTS = None
SESSION_CLEANUP_N_SECONDS = None

# Redis settings
SESSION_REDIS = None

Expand All @@ -31,5 +35,3 @@ class Defaults:
SESSION_SQLALCHEMY_SEQUENCE = None
SESSION_SQLALCHEMY_SCHEMA = None
SESSION_SQLALCHEMY_BIND_KEY = None
SESSION_CLEANUP_N_REQUESTS = None
SESSION_CLEANUP_N_SECONDS = None
115 changes: 70 additions & 45 deletions src/flask_session/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,25 @@ def __init__(
use_signer: bool = Defaults.SESSION_USE_SIGNER,
permanent: bool = Defaults.SESSION_PERMANENT,
sid_length: int = Defaults.SESSION_SID_LENGTH,
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
cleanup_n_seconds: Optional[int] = Defaults.SESSION_CLEANUP_N_SECONDS,
):
self.app = app
self.key_prefix = key_prefix
self.use_signer = use_signer
self.permanent = permanent
self.sid_length = sid_length
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
self.cleanup_n_requests = cleanup_n_requests
self.cleanup_n_seconds = cleanup_n_seconds

# Cleanup settings for non-TTL databases only
if self.ttl:
self._register_cleanup_app_command()
if self.cleanup_n_seconds:
self._start_cleanup_thread(self.cleanup_n_seconds)
if self.cleanup_n_requests:
self.app.before_request(self._cleanup_per_requests)

def save_session(
self, app: Flask, session: ServerSideSession, response: Response
Expand Down Expand Up @@ -179,6 +191,42 @@ def open_session(self, app: Flask, request: Request) -> ServerSideSession:
sid = self._generate_sid(self.sid_length)
return self.session_class(sid=sid, permanent=self.permanent)

# CLEANUP METHODS FOR NON TTL DATABASES

def _register_cleanup_app_command(self):
"""
Register a custom Flask CLI command for cleaning up expired sessions.
Run the command with `flask session_cleanup`. Run with a cron job
or scheduler such as Heroku Scheduler to automatically clean up expired sessions.
"""

@self.app.cli.command("session_cleanup")
def session_cleanup():
with self.app.app_context():
self._delete_expired_sessions()

def _cleanup_n_requests(self) -> None:
"""Delete expired sessions approximately every N requests."""
if self.cleanup_n_seconds or (
self.cleanup_n_requests and random.randint(0, self.cleanup_n_requests) == 0
):
self._delete_expired_sessions()

def _start_cleanup_thread(self, cleanup_n_seconds: int) -> None:
"""Start a background thread to delete expired sessions approximately every N seconds."""

def cleanup():
with self.app.app_context():
while True:
self._delete_expired_sessions()
time.sleep(cleanup_n_seconds)

thread = Thread(target=cleanup, daemon=True)
thread.start()

# METHODS TO BE IMPLEMENTED BY SUBCLASSES

def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
raise NotImplementedError()

Expand All @@ -190,6 +238,10 @@ def _upsert_session(
) -> None:
raise NotImplementedError()

def _delete_expired_sessions(self) -> None:
"""Delete expired sessions from the backend storage. Only required for non-TTL databases."""
pass


class RedisSessionInterface(ServerSideSessionInterface):
"""Uses the Redis key-value store as a session backend. (`redis-py` required)
Expand All @@ -209,6 +261,7 @@ class RedisSessionInterface(ServerSideSessionInterface):

serializer = pickle
session_class = RedisSession
ttl = True

def __init__(
self,
Expand Down Expand Up @@ -275,6 +328,7 @@ class MemcachedSessionInterface(ServerSideSessionInterface):

serializer = pickle
session_class = MemcachedSession
ttl = True

def __init__(
self,
Expand Down Expand Up @@ -367,6 +421,8 @@ class FileSystemSessionInterface(ServerSideSessionInterface):
"""

session_class = FileSystemSession
serializer = None
ttl = True

def __init__(
self,
Expand Down Expand Up @@ -427,6 +483,7 @@ class MongoDBSessionInterface(ServerSideSessionInterface):

serializer = pickle
session_class = MongoDBSession
ttl = True

def __init__(
self,
Expand Down Expand Up @@ -532,6 +589,7 @@ class SqlAlchemySessionInterface(ServerSideSessionInterface):

serializer = pickle
session_class = SqlAlchemySession
non_ttl = True

def __init__(
self,
Expand All @@ -548,6 +606,7 @@ def __init__(
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
cleanup_n_seconds: Optional[int] = Defaults.SESSION_CLEANUP_N_SECONDS,
):
self.app = app
if db is None:
from flask_sqlalchemy import SQLAlchemy

Expand All @@ -556,8 +615,6 @@ def __init__(
self.sequence = sequence
self.schema = schema
self.bind_key = bind_key
self.cleanup_n_requests = cleanup_n_requests
self.cleanup_n_seconds = cleanup_n_seconds
super().__init__(app, key_prefix, use_signer, permanent, sid_length)

# Create the Session database model
Expand Down Expand Up @@ -596,51 +653,19 @@ def __repr__(self):

self.sql_session_model = Session

# Start the cleanup thread
if self.cleanup_n_seconds:
self._start_cleanup_thread(cleanup_n_seconds)

def _clean_up(self) -> None:
# Delete expired sessions approximately every N requests
if self.cleanup_n_seconds or (
self.cleanup_n_requests and random.randint(0, self.cleanup_n_requests) == 0
):
self.app.logger.info("Deleting expired sessions")
try:
self.db.session.query(self.sql_session_model).filter(
self.sql_session_model.expiry <= datetime.utcnow()
).delete(synchronize_session=False)
self.db.session.commit()
except Exception as e:
self.app.logger.exception(
e, "Failed to delete expired sessions. Skipping..."
)

def _start_cleanup_thread(self, cleanup_n_seconds: int) -> None:
def cleanup():
with self.app.app_context():
while True:
try:
self.app.logger.info("Deleting expired sessions")
self.db.session.query(self.sql_session_model).filter(
self.sql_session_model.expiry <= datetime.utcnow()
).delete(synchronize_session=False)
self.db.session.commit()
except Exception as e:
self.app.logger.exception(
e, "Failed to delete expired sessions. Skipping..."
)
# Wait for a specified interval (e.g., 3600 seconds = 1 hour) before the next cleanup
time.sleep(cleanup_n_seconds)

# Create and start the cleanup thread
thread = Thread(target=cleanup, daemon=True)
thread.start()
def _delete_expired_sessions(self) -> None:
try:
self.db.session.query(self.sql_session_model).filter(
self.sql_session_model.expiry <= datetime.utcnow()
).delete(synchronize_session=False)
self.db.session.commit()
self.app.logger.info("Deleted expired sessions")
except Exception as e:
self.app.logger.exception(
e, "Failed to delete expired sessions. Skipping..."
)

def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
if self.cleanup_n_requests:
self._clean_up()

# Get the saved session (record) from the database
record = self.sql_session_model.query.filter_by(session_id=store_id).first()

Expand Down

0 comments on commit ec78cb6

Please sign in to comment.