From 74764f6ce9a13417ca769914a55b2feef1b431d7 Mon Sep 17 00:00:00 2001 From: Nolan Conaway Date: Sun, 10 Nov 2024 11:29:09 -0500 Subject: [PATCH] Refactor SubwayFeed API key handling and update related tests --- readme.md | 15 +-------------- src/underground/cli/feed.py | 9 +-------- src/underground/cli/stops.py | 10 ++-------- src/underground/feed.py | 30 ++++++------------------------ src/underground/models.py | 18 ++++-------------- test/test_cli.py | 19 +++++++------------ test/test_feed.py | 17 +++++------------ test/test_models.py | 2 +- 8 files changed, 27 insertions(+), 93 deletions(-) diff --git a/readme.md b/readme.md index 4264b07..420d7bd 100644 --- a/readme.md +++ b/readme.md @@ -20,23 +20,16 @@ Or if you'd like to live dangerously: pip install git+https://github.com/nolanbconaway/underground.git#egg=underground ``` -To request data from the MTA, you'll also need a free API key. -[Register here](https://api.mta.info/). - ## Python API -Once you have your API key, use the Python API like: +Use the Python API like: ``` python import os from underground import metadata, SubwayFeed -API_KEY = os.getenv('MTA_API_KEY') ROUTE = 'Q' -feed = SubwayFeed.get(ROUTE, api_key=API_KEY) - -# request will read from $MTA_API_KEY if a key is not provided feed = SubwayFeed.get(ROUTE) # under the hood, the Q route is mapped to a URL. This call is equivalent: @@ -93,9 +86,6 @@ Usage: underground feed [OPTIONS] ROUTE_OR_URL underground feed $URL --json > feed_nrqw.json Options: - --api-key TEXT MTA API key. Will be read from $MTA_API_KEY if not - provided. - --json Option to output the feed data as JSON. Otherwise output will be bytes. @@ -120,8 +110,6 @@ Options: unix timestamp. -r, --retries INTEGER Retry attempts in case of API connection failure. Default 100. - --api-key TEXT MTA API key. Will be read from $MTA_API_KEY if not - provided. -t, --timezone TEXT Output timezone. Ignored if --epoch. Default to NYC time. -s, --stalled-timeout INTEGER Number of seconds between the last movement @@ -135,7 +123,6 @@ Options: Stops are printed to stdout in the format `stop_id t1 t2 ... tn` . ``` sh -$ export MTA_API_KEY='...' $ underground stops Q | tail -2 Q05S 19:01 19:09 19:16 19:25 19:34 19:44 19:51 19:58 Q04S 19:03 19:11 19:18 19:27 19:36 19:46 19:53 20:00 diff --git a/src/underground/cli/feed.py b/src/underground/cli/feed.py index 317caf4..d169e6a 100644 --- a/src/underground/cli/feed.py +++ b/src/underground/cli/feed.py @@ -9,12 +9,6 @@ @click.command() @click.argument("route_or_url") -@click.option( - "--api-key", - "api_key", - default=None, - help="MTA API key. Will be read from $MTA_API_KEY if not provided.", -) @click.option( "--json", "output_json", @@ -29,7 +23,7 @@ type=int, help="Retry attempts in case of API connection failure. Default 100.", ) -def main(route_or_url, api_key, output_json, retries): +def main(route_or_url, output_json, retries): """Request an MTA feed via a route or URL. ROUTE_OR_URL may be either a feed URL or a route (which will be used to look up @@ -53,7 +47,6 @@ def main(route_or_url, api_key, output_json, retries): data = feed.request_robust( route_or_url=route_or_url, retries=retries, - api_key=api_key, return_dict=output_json, ) diff --git a/src/underground/cli/stops.py b/src/underground/cli/stops.py index d83e0ca..86f6bca 100644 --- a/src/underground/cli/stops.py +++ b/src/underground/cli/stops.py @@ -32,12 +32,6 @@ def datetime_to_epoch(dttm: datetime.datetime) -> int: type=int, help="Retry attempts in case of API connection failure. Default 100.", ) -@click.option( - "--api-key", - "api_key", - default=None, - help="MTA API key. Will be read from $MTA_API_KEY if not provided.", -) @click.option( "-t", "--timezone", @@ -54,10 +48,10 @@ def datetime_to_epoch(dttm: datetime.datetime) -> int: " update before considering a train stalled. Default is 90 as recommended" " by the MTA. Numbers less than 1 disable this check.", ) -def main(route, fmt, retries, api_key, timezone, stalled_timeout): +def main(route, fmt, retries, timezone, stalled_timeout): """Print out train departure times for all stops on a subway line.""" stops = ( - SubwayFeed.get(api_key=api_key, route_or_url=route, retries=retries) + SubwayFeed.get(route_or_url=route, retries=retries) .extract_stop_dict(timezone=timezone, stalled_timeout=stalled_timeout) .get(route, dict()) ) diff --git a/src/underground/feed.py b/src/underground/feed.py index 42c09b1..6a9202c 100644 --- a/src/underground/feed.py +++ b/src/underground/feed.py @@ -1,6 +1,5 @@ """Interact with the MTA GTFS api.""" -import os import time import typing @@ -22,7 +21,7 @@ def load_protobuf(protobuf_bytes: bytes) -> dict: Parameters ---------- protobuf_bytes : bytes - Protobuuf data, as returned from the raw request. + Protobuf data, as returned from the raw request. Returns ------- @@ -38,7 +37,7 @@ def load_protobuf(protobuf_bytes: bytes) -> dict: return feed_dict -def request(route_or_url: str, api_key: typing.Optional[str] = None) -> bytes: +def request(route_or_url: str) -> bytes: """Send a HTTP GET request to the MTA for realtime feed data. Occassionally a feed is requested as the MTA is writing updated data to the file, @@ -49,9 +48,6 @@ def request(route_or_url: str, api_key: typing.Optional[str] = None) -> bytes: ---------- route_or_url : str Route ID or feed url (per ``https://api.mta.info/#/subwayRealTimeFeeds``). - api_key : str - MTA API key. If not provided, it will be read from the $MTA_API_KEY env - variable. Returns ------- @@ -62,26 +58,15 @@ def request(route_or_url: str, api_key: typing.Optional[str] = None) -> bytes: # check feed url = metadata.resolve_url(route_or_url) - # get the API key. - api_key = api_key or os.getenv("MTA_API_KEY", None) - if api_key is None: - raise ValueError( - "No API key. pass to the called function " - "or set the $MTA_API_KEY environment variable." - ) - # make the request - res = requests.get(url, headers={"x-api-key": api_key}) + res = requests.get(url) res.raise_for_status() return res.content def request_robust( - route_or_url: str, - retries: int = 100, - api_key: typing.Optional[str] = None, - return_dict: bool = False, + route_or_url: str, retries: int = 100, return_dict: bool = False ) -> typing.Union[dict, bytes]: """Request feed data with validations and retries. @@ -97,9 +82,6 @@ def request_robust( retries : int Number of retry attempts, with 1 second timeout between attempts. Set to -1 for unlimited. Default 100. - api_key : str - MTA API key. If not provided, it will be read from the $MTA_API_KEY env - variable. return_dict : bool Option to return the process data as a dict rather than as raw protobuf data. This is equivalent to running ``load_protobuf(request_robust(...))``. @@ -112,7 +94,7 @@ def request_robust( """ # get protobuf bytes - protobuf_data = request(route_or_url=route_or_url, api_key=api_key) + protobuf_data = request(route_or_url=route_or_url) for attempt in range(retries + 1): try: feed_dict = load_protobuf(protobuf_data) @@ -125,6 +107,6 @@ def request_robust( # wait 1 second and then make new protobuf data time.sleep(1) # be cool to the MTA - protobuf_data = request(route_or_url=route_or_url, api_key=api_key) + protobuf_data = request(route_or_url=route_or_url) return feed_dict if return_dict else protobuf_data diff --git a/src/underground/models.py b/src/underground/models.py index 835bc5d..5b2b00d 100644 --- a/src/underground/models.py +++ b/src/underground/models.py @@ -178,10 +178,8 @@ class SubwayFeed(pydantic.BaseModel): header: FeedHeader entity: list[Entity] - @staticmethod - def get( - route_or_url: str, retries: int = 100, api_key: typing.Optional[str] = None - ) -> "SubwayFeed": + @classmethod + def get(cls, route_or_url: str, retries: int = 100) -> "SubwayFeed": """Request feed data from the MTA. Parameters @@ -193,9 +191,6 @@ def get( retries : int Number of retry attempts, with 1 second timeout between attempts. Set to -1 for unlimited. Default 100. - api_key : str - MTA API key. If not provided, it will be read from the $MTA_API_KEY env - variable. Returns ------- @@ -203,13 +198,8 @@ def get( An instance of the SubwayFeed class with the requested data. """ - return SubwayFeed( - **feed.request_robust( - route_or_url=route_or_url, - retries=retries, - api_key=api_key, - return_dict=True, - ) + return cls( + **feed.request_robust(route_or_url=route_or_url, retries=retries, return_dict=True) ) def extract_stop_dict( diff --git a/test/test_cli.py b/test/test_cli.py index bcdc1aa..0264153 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -1,4 +1,5 @@ """Test the CLI.""" + import io import json import os @@ -52,9 +53,7 @@ def test_stops_epoch(monkeypatch): }, ], } - monkeypatch.setattr( - "underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data) - ) + monkeypatch.setattr("underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data)) runner = CliRunner() result = runner.invoke(stops_cli.main, ["1", "-f", "epoch"]) assert result.exit_code == 0 @@ -87,9 +86,7 @@ def test_stops_format(monkeypatch): ], } - monkeypatch.setattr( - "underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data) - ) + monkeypatch.setattr("underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data)) runner = CliRunner() # year @@ -130,9 +127,7 @@ def test_stops_timezone(monkeypatch): ], } - monkeypatch.setattr( - "underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data) - ) + monkeypatch.setattr("underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data)) runner = CliRunner() # in hong kong time @@ -155,7 +150,7 @@ def test_feed_bytes(requests_mock, filename): requests_mock.get(requests_mock_any, content=file.read()) runner = CliRunner() - result = runner.invoke(feed_cli.main, ["1", "--api-key", "FAKE"]) + result = runner.invoke(feed_cli.main, ["1"]) assert result.exit_code == 0 assert "entity" in load_protobuf(result.stdout_bytes) @@ -167,7 +162,7 @@ def test_feed_json(requests_mock, filename): requests_mock.get(requests_mock_any, content=file.read()) runner = CliRunner() - result = runner.invoke(feed_cli.main, ["1", "--json", "--api-key", "FAKE"]) + result = runner.invoke(feed_cli.main, ["1", "--json"]) assert result.exit_code == 0 assert "entity" in json.loads(result.output) @@ -200,7 +195,7 @@ def test_stopstxt_json(monkeypatch, args): lambda: content, ) runner = CliRunner() - result = runner.invoke(findstops_cli.main, args + ["--json"]) + result = runner.invoke(findstops_cli.main, [*args, "--json"]) assert result.exit_code == 0 for stop in json.loads(result.output): diff --git a/test/test_feed.py b/test/test_feed.py index da18ec6..169e101 100644 --- a/test/test_feed.py +++ b/test/test_feed.py @@ -1,4 +1,5 @@ """Test the feed submodule.""" + import os import time @@ -37,7 +38,7 @@ def mock_load_protobuf(*a): time_1 = time.time() with pytest.raises(feed.EmptyFeedError): - feed.request_robust("1", retries=retries, api_key="FAKE") + feed.request_robust("1", retries=retries) elapsed = time.time() - time_1 assert elapsed >= retries @@ -62,21 +63,13 @@ def test_request_invalid_feed(): feed.request("NOT REAL") -def test_request_no_api_key(monkeypatch): - """Test that request raises value error when no api key is available.""" - monkeypatch.delenv("MTA_API_KEY", raising=False) - - with pytest.raises(ValueError): - feed.request(next(iter(metadata.VALID_FEED_URLS))) - - @pytest.mark.parametrize("ret_code", [200, 500]) def test_request_raise_status(requests_mock, ret_code): """Test the request raise status conditional.""" feed_url = next(iter(metadata.VALID_FEED_URLS)) - requests_mock.get(requests_mock_any, content="".encode(), status_code=ret_code) + requests_mock.get(requests_mock_any, content=b"", status_code=ret_code) if ret_code != 200: with pytest.raises(requests.HTTPError): - feed.request(feed_url, api_key="FAKE") + feed.request(feed_url) else: - feed.request(feed_url, api_key="FAKE") + feed.request(feed_url) diff --git a/test/test_models.py b/test/test_models.py index 32e6814..81b5fa0 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -88,7 +88,7 @@ def test_get(requests_mock, filename): return_value = file.read() requests_mock.get(requests_mock_any, content=return_value) - feed = SubwayFeed.get("1", api_key="FAKE") ## valid route but not used at all + feed = SubwayFeed.get("1") ## valid route but not used at all assert isinstance(feed, SubwayFeed) assert isinstance(feed.extract_stop_dict(), dict)