Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SubwayFeed API key handling and update related tests #39

Merged
merged 1 commit into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,16 @@ Or if you'd like to live dangerously:
pip install git+https://github.com/nolanbconaway/underground.git#egg=underground
```

To request data from the MTA, you'll also need a free API key.
[Register here](https://api.mta.info/).

## Python API

Once you have your API key, use the Python API like:
Use the Python API like:

``` python
import os

from underground import metadata, SubwayFeed

API_KEY = os.getenv('MTA_API_KEY')
ROUTE = 'Q'
feed = SubwayFeed.get(ROUTE, api_key=API_KEY)

# request will read from $MTA_API_KEY if a key is not provided
feed = SubwayFeed.get(ROUTE)

# under the hood, the Q route is mapped to a URL. This call is equivalent:
Expand Down Expand Up @@ -93,9 +86,6 @@ Usage: underground feed [OPTIONS] ROUTE_OR_URL
underground feed $URL --json > feed_nrqw.json

Options:
--api-key TEXT MTA API key. Will be read from $MTA_API_KEY if not
provided.

--json Option to output the feed data as JSON. Otherwise
output will be bytes.

Expand All @@ -120,8 +110,6 @@ Options:
unix timestamp.
-r, --retries INTEGER Retry attempts in case of API connection failure.
Default 100.
--api-key TEXT MTA API key. Will be read from $MTA_API_KEY if not
provided.
-t, --timezone TEXT Output timezone. Ignored if --epoch. Default to NYC
time.
-s, --stalled-timeout INTEGER Number of seconds between the last movement
Expand All @@ -135,7 +123,6 @@ Options:
Stops are printed to stdout in the format `stop_id t1 t2 ... tn` .

``` sh
$ export MTA_API_KEY='...'
$ underground stops Q | tail -2
Q05S 19:01 19:09 19:16 19:25 19:34 19:44 19:51 19:58
Q04S 19:03 19:11 19:18 19:27 19:36 19:46 19:53 20:00
Expand Down
9 changes: 1 addition & 8 deletions src/underground/cli/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@

@click.command()
@click.argument("route_or_url")
@click.option(
"--api-key",
"api_key",
default=None,
help="MTA API key. Will be read from $MTA_API_KEY if not provided.",
)
@click.option(
"--json",
"output_json",
Expand All @@ -29,7 +23,7 @@
type=int,
help="Retry attempts in case of API connection failure. Default 100.",
)
def main(route_or_url, api_key, output_json, retries):
def main(route_or_url, output_json, retries):
"""Request an MTA feed via a route or URL.

ROUTE_OR_URL may be either a feed URL or a route (which will be used to look up
Expand All @@ -53,7 +47,6 @@ def main(route_or_url, api_key, output_json, retries):
data = feed.request_robust(
route_or_url=route_or_url,
retries=retries,
api_key=api_key,
return_dict=output_json,
)

Expand Down
10 changes: 2 additions & 8 deletions src/underground/cli/stops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ def datetime_to_epoch(dttm: datetime.datetime) -> int:
type=int,
help="Retry attempts in case of API connection failure. Default 100.",
)
@click.option(
"--api-key",
"api_key",
default=None,
help="MTA API key. Will be read from $MTA_API_KEY if not provided.",
)
@click.option(
"-t",
"--timezone",
Expand All @@ -54,10 +48,10 @@ def datetime_to_epoch(dttm: datetime.datetime) -> int:
" update before considering a train stalled. Default is 90 as recommended"
" by the MTA. Numbers less than 1 disable this check.",
)
def main(route, fmt, retries, api_key, timezone, stalled_timeout):
def main(route, fmt, retries, timezone, stalled_timeout):
"""Print out train departure times for all stops on a subway line."""
stops = (
SubwayFeed.get(api_key=api_key, route_or_url=route, retries=retries)
SubwayFeed.get(route_or_url=route, retries=retries)
.extract_stop_dict(timezone=timezone, stalled_timeout=stalled_timeout)
.get(route, dict())
)
Expand Down
30 changes: 6 additions & 24 deletions src/underground/feed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Interact with the MTA GTFS api."""

import os
import time
import typing

Expand All @@ -22,7 +21,7 @@ def load_protobuf(protobuf_bytes: bytes) -> dict:
Parameters
----------
protobuf_bytes : bytes
Protobuuf data, as returned from the raw request.
Protobuf data, as returned from the raw request.

Returns
-------
Expand All @@ -38,7 +37,7 @@ def load_protobuf(protobuf_bytes: bytes) -> dict:
return feed_dict


def request(route_or_url: str, api_key: typing.Optional[str] = None) -> bytes:
def request(route_or_url: str) -> bytes:
"""Send a HTTP GET request to the MTA for realtime feed data.

Occassionally a feed is requested as the MTA is writing updated data to the file,
Expand All @@ -49,9 +48,6 @@ def request(route_or_url: str, api_key: typing.Optional[str] = None) -> bytes:
----------
route_or_url : str
Route ID or feed url (per ``https://api.mta.info/#/subwayRealTimeFeeds``).
api_key : str
MTA API key. If not provided, it will be read from the $MTA_API_KEY env
variable.

Returns
-------
Expand All @@ -62,26 +58,15 @@ def request(route_or_url: str, api_key: typing.Optional[str] = None) -> bytes:
# check feed
url = metadata.resolve_url(route_or_url)

# get the API key.
api_key = api_key or os.getenv("MTA_API_KEY", None)
if api_key is None:
raise ValueError(
"No API key. pass to the called function "
"or set the $MTA_API_KEY environment variable."
)

# make the request
res = requests.get(url, headers={"x-api-key": api_key})
res = requests.get(url)
res.raise_for_status()

return res.content


def request_robust(
route_or_url: str,
retries: int = 100,
api_key: typing.Optional[str] = None,
return_dict: bool = False,
route_or_url: str, retries: int = 100, return_dict: bool = False
) -> typing.Union[dict, bytes]:
"""Request feed data with validations and retries.

Expand All @@ -97,9 +82,6 @@ def request_robust(
retries : int
Number of retry attempts, with 1 second timeout between attempts.
Set to -1 for unlimited. Default 100.
api_key : str
MTA API key. If not provided, it will be read from the $MTA_API_KEY env
variable.
return_dict : bool
Option to return the process data as a dict rather than as raw protobuf data.
This is equivalent to running ``load_protobuf(request_robust(...))``.
Expand All @@ -112,7 +94,7 @@ def request_robust(

"""
# get protobuf bytes
protobuf_data = request(route_or_url=route_or_url, api_key=api_key)
protobuf_data = request(route_or_url=route_or_url)
for attempt in range(retries + 1):
try:
feed_dict = load_protobuf(protobuf_data)
Expand All @@ -125,6 +107,6 @@ def request_robust(

# wait 1 second and then make new protobuf data
time.sleep(1) # be cool to the MTA
protobuf_data = request(route_or_url=route_or_url, api_key=api_key)
protobuf_data = request(route_or_url=route_or_url)

return feed_dict if return_dict else protobuf_data
18 changes: 4 additions & 14 deletions src/underground/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,8 @@ class SubwayFeed(pydantic.BaseModel):
header: FeedHeader
entity: list[Entity]

@staticmethod
def get(
route_or_url: str, retries: int = 100, api_key: typing.Optional[str] = None
) -> "SubwayFeed":
@classmethod
def get(cls, route_or_url: str, retries: int = 100) -> "SubwayFeed":
"""Request feed data from the MTA.

Parameters
Expand All @@ -193,23 +191,15 @@ def get(
retries : int
Number of retry attempts, with 1 second timeout between attempts.
Set to -1 for unlimited. Default 100.
api_key : str
MTA API key. If not provided, it will be read from the $MTA_API_KEY env
variable.

Returns
-------
SubwayFeed
An instance of the SubwayFeed class with the requested data.

"""
return SubwayFeed(
**feed.request_robust(
route_or_url=route_or_url,
retries=retries,
api_key=api_key,
return_dict=True,
)
return cls(
**feed.request_robust(route_or_url=route_or_url, retries=retries, return_dict=True)
)

def extract_stop_dict(
Expand Down
19 changes: 7 additions & 12 deletions test/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the CLI."""

import io
import json
import os
Expand Down Expand Up @@ -52,9 +53,7 @@ def test_stops_epoch(monkeypatch):
},
],
}
monkeypatch.setattr(
"underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data)
)
monkeypatch.setattr("underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data))
runner = CliRunner()
result = runner.invoke(stops_cli.main, ["1", "-f", "epoch"])
assert result.exit_code == 0
Expand Down Expand Up @@ -87,9 +86,7 @@ def test_stops_format(monkeypatch):
],
}

monkeypatch.setattr(
"underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data)
)
monkeypatch.setattr("underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data))
runner = CliRunner()

# year
Expand Down Expand Up @@ -130,9 +127,7 @@ def test_stops_timezone(monkeypatch):
],
}

monkeypatch.setattr(
"underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data)
)
monkeypatch.setattr("underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data))
runner = CliRunner()

# in hong kong time
Expand All @@ -155,7 +150,7 @@ def test_feed_bytes(requests_mock, filename):
requests_mock.get(requests_mock_any, content=file.read())

runner = CliRunner()
result = runner.invoke(feed_cli.main, ["1", "--api-key", "FAKE"])
result = runner.invoke(feed_cli.main, ["1"])
assert result.exit_code == 0
assert "entity" in load_protobuf(result.stdout_bytes)

Expand All @@ -167,7 +162,7 @@ def test_feed_json(requests_mock, filename):
requests_mock.get(requests_mock_any, content=file.read())

runner = CliRunner()
result = runner.invoke(feed_cli.main, ["1", "--json", "--api-key", "FAKE"])
result = runner.invoke(feed_cli.main, ["1", "--json"])
assert result.exit_code == 0
assert "entity" in json.loads(result.output)

Expand Down Expand Up @@ -200,7 +195,7 @@ def test_stopstxt_json(monkeypatch, args):
lambda: content,
)
runner = CliRunner()
result = runner.invoke(findstops_cli.main, args + ["--json"])
result = runner.invoke(findstops_cli.main, [*args, "--json"])
assert result.exit_code == 0

for stop in json.loads(result.output):
Expand Down
17 changes: 5 additions & 12 deletions test/test_feed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the feed submodule."""

import os
import time

Expand Down Expand Up @@ -37,7 +38,7 @@ def mock_load_protobuf(*a):

time_1 = time.time()
with pytest.raises(feed.EmptyFeedError):
feed.request_robust("1", retries=retries, api_key="FAKE")
feed.request_robust("1", retries=retries)
elapsed = time.time() - time_1

assert elapsed >= retries
Expand All @@ -62,21 +63,13 @@ def test_request_invalid_feed():
feed.request("NOT REAL")


def test_request_no_api_key(monkeypatch):
"""Test that request raises value error when no api key is available."""
monkeypatch.delenv("MTA_API_KEY", raising=False)

with pytest.raises(ValueError):
feed.request(next(iter(metadata.VALID_FEED_URLS)))


@pytest.mark.parametrize("ret_code", [200, 500])
def test_request_raise_status(requests_mock, ret_code):
"""Test the request raise status conditional."""
feed_url = next(iter(metadata.VALID_FEED_URLS))
requests_mock.get(requests_mock_any, content="".encode(), status_code=ret_code)
requests_mock.get(requests_mock_any, content=b"", status_code=ret_code)
if ret_code != 200:
with pytest.raises(requests.HTTPError):
feed.request(feed_url, api_key="FAKE")
feed.request(feed_url)
else:
feed.request(feed_url, api_key="FAKE")
feed.request(feed_url)
2 changes: 1 addition & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_get(requests_mock, filename):
return_value = file.read()

requests_mock.get(requests_mock_any, content=return_value)
feed = SubwayFeed.get("1", api_key="FAKE") ## valid route but not used at all
feed = SubwayFeed.get("1") ## valid route but not used at all

assert isinstance(feed, SubwayFeed)
assert isinstance(feed.extract_stop_dict(), dict)
Expand Down
Loading