Skip to content

Commit

Permalink
Move methods out of main file
Browse files Browse the repository at this point in the history
Also seperated the rate limit and etag header logic out of the do_request
  • Loading branch information
tjorim committed Jan 2, 2025
1 parent 97bb682 commit a840397
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 44 deletions.
56 changes: 56 additions & 0 deletions pyrail/api_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Dict, List, Optional, Any

class Endpoint:
def __init__(self, name: str, required_params: Optional[List[str]] = None, optional_params: Optional[List[str]] = None, xor_params: Optional[List[str]] = None):
self.name = name
self.required_params = required_params or []
self.optional_params = optional_params or []
self.xor_params = xor_params or []

def validate(self, args: Dict[str, Any]) -> bool:
"""Validate required parameters and XOR conditions."""

# Check required parameters
for param in self.required_params:
if param not in args or args[param] is None:
return False

# Check XOR logic (only one of xor_params can be set)
if self.xor_params:
xor_values = [args.get(param) is not None for param in self.xor_params]
if sum(xor_values) != 1: # Exactly one must be true
return False

return True

# Define endpoints
endpoints = {
'stations': Endpoint(
name='stations',
optional_params=['format', 'lang']
),
'liveboard': Endpoint(
name='liveboard',
xor_params=['station', 'id'], # XOR parameters
optional_params=['date', 'time', 'arrdep', 'alerts', 'format', 'lang']
),
'connections': Endpoint(
name='connections',
required_params=['from', 'to'],
optional_params=['date', 'time', 'timesel', 'typeOfTransport', 'alerts', 'results', 'format', 'lang']
),
'vehicle': Endpoint(
name='vehicle',
required_params=['id'],
optional_params=['date', 'alerts', 'format', 'lang']
),
'composition': Endpoint(
name='composition',
required_params=['id'],
optional_params=['data', 'format', 'lang']
),
'disturbances': Endpoint(
name='disturbances',
optional_params=['lineBreakCharacter', 'format', 'lang']
)
}
81 changes: 37 additions & 44 deletions pyrail/irail.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from threading import Lock
import time
from .api_endpoints import endpoints

import requests

Expand All @@ -11,14 +12,6 @@

base_url = 'https://api.irail.be/{}/'

methods = {
'stations': [],
'liveboard': ['id', 'station', 'date', 'time', 'arrdep', 'alerts'],
'connections': ['from', 'to', 'date', 'time', 'timesel', 'typeOfTransport', 'alerts', 'results'],
'vehicle': ['id', 'date', 'alerts'],
'disturbances': []
}

headers = {'user-agent': 'pyRail (tielemans.jorim@gmail.com)'}

"""
Expand Down Expand Up @@ -79,49 +72,52 @@ def _refill_tokens(self):
elapsed = current_time - self.last_request_time
self.last_request_time = current_time

# Refill tokens based on elapsed time
self.tokens += elapsed * 3 # 3 tokens per second
if self.tokens > 3:
self.tokens = 3
# Refill tokens, 3 tokens per second, cap tokens at 3
self.tokens = min(3, self.tokens + elapsed * 3)
# Refill burst tokens, 3 burst tokens per second, cap burst tokens at 5
self.burst_tokens = min(5, self.burst_tokens + elapsed * 3) # Cap burst tokens at 5

def _handle_rate_limit(self):
"""Handles rate limiting by refilling tokens or waiting."""
self._refill_tokens()

if self.tokens < 1:
if self.burst_tokens >= 1:
self.burst_tokens -= 1
else:
logger.warning("Rate limiting active, waiting for tokens")
time.sleep(1 - (time.time() - self.last_request_time))
self._refill_tokens()
self.tokens -= 1
else:
self.tokens -= 1

# Refill burst tokens
self.burst_tokens += elapsed * 3 # 3 burst tokens per second
if self.burst_tokens > 5:
self.burst_tokens = 5
def _add_etag_header(self, method):
"""Adds ETag header if a cached ETag exists."""
headers = {}
if method in self.etag_cache:
logger.debug("Adding If-None-Match header with value: %s", self.etag_cache[method])
headers['If-None-Match'] = self.etag_cache[method]
return headers

def do_request(self, method, args=None):
logger.info("Starting request to endpoint: %s", method)
with self.lock:
self._refill_tokens()

if self.tokens < 1:
if self.burst_tokens >= 1:
self.burst_tokens -= 1
else:
logger.warning("Rate limiting, waiting for tokens")
time.sleep(1 - (time.time() - self.last_request_time))
self._refill_tokens()
self.tokens -= 1
else:
self.tokens -= 1
self._handle_rate_limit()

if method in methods:
if method in endpoints:
url = base_url.format(method)
params = {'format': self.format, 'lang': self.lang}
if args:
params.update(args)
headers = {}

# Add If-None-Match header if we have a cached ETag
if method in self.etag_cache:
logger.debug("Adding If-None-Match header with value: %s", self.etag_cache[method])
headers['If-None-Match'] = self.etag_cache[method]
request_headers = self._add_etag_header(method)

try:
response = session.get(url, params=params, headers=headers)
response = session.get(url, params=params, headers=request_headers)
if response.status_code == 429:
logger.warning("Rate limited, waiting for retry-after header")
retry_after = int(response.headers.get("Retry-After", 1))
logger.warning("Rate limited, retrying after %d seconds", retry_after)
time.sleep(retry_after)
return self.do_request(method, args)
if response.status_code == 200:
Expand All @@ -132,6 +128,7 @@ def do_request(self, method, args=None):
json_data = response.json()
return json_data
except ValueError:
logger.error("Failed to parse JSON response")
return -1
elif response.status_code == 304:
logger.info("Data not modified, using cached data")
Expand All @@ -152,23 +149,19 @@ def do_request(self, method, args=None):

def get_stations(self):
"""Retrieve a list of all stations."""
json_data = self.do_request('stations')
return json_data
return self.do_request('stations')

def get_liveboard(self, station=None, id=None):
if bool(station) ^ bool(id):
extra_params = {'station': station, 'id': id}
json_data = self.do_request('liveboard', extra_params)
return json_data
return self.do_request('liveboard', extra_params)

def get_connections(self, from_station=None, to_station=None):
if from_station and to_station:
extra_params = {'from': from_station, 'to': to_station}
json_data = self.do_request('connections', extra_params)
return json_data
return self.do_request('connections', extra_params)

def get_vehicle(self, id=None):
if id:
extra_params = {'id': id}
json_data = self.do_request('vehicle', extra_params)
return json_data
return self.do_request('vehicle', extra_params)

0 comments on commit a840397

Please sign in to comment.