Skip to content

Commit 3d3d20b

Browse files
committed
typing improvements
1 parent 1890be9 commit 3d3d20b

File tree

11 files changed

+86
-66
lines changed

11 files changed

+86
-66
lines changed

lint.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
result=0
33
flake8 shub_workflow/ tests/ --application-import-names=shub_workflow --import-order-style=pep8
44
result=$(($result | $?))
5-
mypy --ignore-missing-imports --disable-error-code=method-assign shub_workflow/ tests/
5+
mypy --ignore-missing-imports --disable-error-code=method-assign --check-untyped-defs shub_workflow/ tests/
66
result=$(($result | $?))
77
exit $result

shub_workflow/clone_job.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def add_argparser_options(self):
128128

129129
def run(self):
130130
if self.args.key:
131-
keys = filter(lambda x: not self.is_cloned_by_jobkey(x), self.args.key)
131+
keys = list(filter(lambda x: not self.is_cloned_by_jobkey(x), self.args.key))
132132
elif self.args.tag_spider:
133133
keys = []
134134
project_id, tag, spider = self.args.tag_spider.split("/")

shub_workflow/crawl.py

+3
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def __init__(self):
240240
self.__next_job_seq = 1
241241
self._jobuids = self.create_dupe_filter()
242242

243+
def get_delayed_jobs(self) -> List[FullJobParams]:
244+
return deepcopy(self.__delayed_jobs)
245+
243246
@classmethod
244247
def create_dupe_filter(cls) -> DupesFilterProtocol:
245248
return BloomFilter(max_elements=1e6, error_rate=1e-6)

shub_workflow/deliver/futils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,13 @@ def upload_file(path, dest, aws_key=None, aws_secret=None, aws_token=None, **kwa
180180
gcstorage.upload_file(path, dest)
181181

182182

183-
def get_glob(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs):
183+
def get_glob(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs) -> List[str]:
184184
region = kwargs.pop("region", None)
185185
if check_s3_path(path):
186186
fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs)
187187
fp = [_S3_ATTRIBUTE + p for p in fs.glob(s3_path(path))]
188188
else:
189-
fp = iglob(path)
189+
fp = list(iglob(path))
190190

191191
return fp
192192

shub_workflow/deliver/gcstorage.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ def set_credential_file_environ(module, resource, check_exists=True):
2626
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credfile
2727

2828

29-
def upload_file(src_path, dest_path):
29+
def upload_file(src_path: str, dest_path: str):
3030
storage_client = storage.Client()
31-
try:
32-
bucket_name, destination_blob_name = _GS_FOLDER_RE.match(dest_path).groups()
33-
except AttributeError:
31+
m = _GS_FOLDER_RE.match(dest_path)
32+
if m is None:
3433
raise ValueError(f"Invalid destination {dest_path}")
34+
bucket_name, destination_blob_name = m.groups()
3535
bucket = storage_client.bucket(bucket_name)
3636
blob = bucket.blob(destination_blob_name)
3737
blob.upload_from_filename(src_path, retry=storage.retry.DEFAULT_RETRY)

shub_workflow/graph/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def run_job(self, job: TaskId, is_retry=False) -> Optional[JobKey]:
264264
if task is not None:
265265
idx = jobconf["index"]
266266
return task.run(self, is_retry, index=idx)
267-
return None
267+
raise RuntimeError(f"Failed to run task {job}")
268268

269269
def _must_wait_time(self, job: TaskId) -> bool:
270270
status = self.__pending_jobs[job]
@@ -300,6 +300,7 @@ def run_pending_jobs(self):
300300
if job_can_run:
301301
try:
302302
jobid = self.run_job(task_id, status["is_retry"])
303+
assert jobid is not None, f"Failed to run task {task_id}"
303304
except Exception:
304305
self._release_resources(task_id)
305306
raise
@@ -330,6 +331,7 @@ def run_pending_jobs(self):
330331
if job_can_run:
331332
try:
332333
jobid = self.run_job(task_id, status["is_retry"])
334+
assert jobid is not None, f"Failed to run task {task_id}"
333335
except Exception:
334336
self._release_resources(task_id)
335337
raise

shub_workflow/graph/task.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class JobGraphDict(TypedDict):
4343
origin: NotRequired[TaskId]
4444
index: NotRequired[int]
4545

46+
spider: NotRequired[str]
47+
spider_args: NotRequired[Dict[str, str]]
48+
4649

4750
class BaseTask(abc.ABC):
4851
def __init__(
@@ -283,7 +286,7 @@ def get_spider_args(self):
283286
spider_args.update({"job_settings": self.__job_settings})
284287
return spider_args
285288

286-
def as_jobgraph_dict(self):
289+
def as_jobgraph_dict(self) -> JobGraphDict:
287290
jdict = super().as_jobgraph_dict()
288291
jdict.update({"spider": self.spider, "spider_args": self.get_spider_args()})
289292
return jdict

shub_workflow/script.py

+8
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,14 @@ def _run_loops(self) -> Generator[bool, None, None]:
516516
def base_loop_tasks(self):
517517
...
518518

519+
@abc.abstractmethod
520+
def _on_start(self):
521+
...
522+
523+
@abc.abstractmethod
524+
def _close(self):
525+
...
526+
519527

520528
class BaseLoopScript(BaseScript, BaseLoopScriptProtocol):
521529

shub_workflow/utils/sesemail.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import boto3
1212
from botocore.client import Config
1313

14+
from shub_workflow.script import BaseScriptProtocol
15+
1416
logger = logging.getLogger(__name__)
1517

1618

@@ -100,7 +102,7 @@ def build_email_message(
100102
return msg
101103

102104

103-
class SESMailSenderMixin:
105+
class SESMailSenderMixin(BaseScriptProtocol):
104106
"""Use this mixin for enabling ses email sending capabilities on your script class"""
105107

106108
def __init__(self):

tests/test_base_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def workflow_loop(self):
5555
self.assertEqual(manager.name, "my_fantasy_name")
5656

5757
manager._on_start()
58-
self.assertFalse(manager._check_resume_workflow.called)
58+
self.assertFalse(mocked_check_resume_workflow.called)
5959

6060
@patch("shub_workflow.base.WorkFlowManager._check_resume_workflow")
6161
def test_check_resume_workflow_called(
@@ -72,7 +72,7 @@ def workflow_loop(self):
7272
self.assertEqual(manager.name, "my_fantasy_name")
7373

7474
manager._on_start()
75-
self.assertTrue(manager._check_resume_workflow.called)
75+
self.assertTrue(mocked_check_resume_workflow.called)
7676

7777
def test_project_id_override(self, mocked_update_metadata, mocked_get_job_tags):
7878
class TestManager(WorkFlowManager):

0 commit comments

Comments
 (0)