diff --git a/.gitignore b/.gitignore index 9f9928d8..07bec449 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ pydispatcher.egg-info .coverage* dbconfig.json .vscode/* +!.vscode/settings.json .idea */_version.py */_date.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..44e7fd69 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "ruff.lineLength": 88 +} diff --git a/dpdispatcher/machine.py b/dpdispatcher/machine.py index 01a51557..d6dbf785 100644 --- a/dpdispatcher/machine.py +++ b/dpdispatcher/machine.py @@ -82,6 +82,7 @@ def __init__( local_root=None, remote_root=None, remote_profile={}, + retry_count=3, *, context=None, ): @@ -96,6 +97,7 @@ def __init__( else: pass self.bind_context(context=context) + self.retry_count = retry_count def bind_context(self, context): self.context = context @@ -148,7 +150,8 @@ def load_from_dict(cls, machine_dict): base.check_value(machine_dict, strict=False) context = BaseContext.load_from_dict(machine_dict) - machine = machine_class(context=context) + retry_count = machine_dict.get("retry_count", 3) + machine = machine_class(context=context, retry_count=retry_count) return machine def serialize(self, if_empty_remote_profile=False): @@ -161,6 +164,7 @@ def serialize(self, if_empty_remote_profile=False): machine_dict["remote_profile"] = self.context.remote_profile else: machine_dict["remote_profile"] = {} + machine_dict["retry_count"] = self.retry_count # normalize the dict base = self.arginfo() machine_dict = base.normalize_value(machine_dict, trim_pattern="_*") @@ -396,6 +400,7 @@ def arginfo(cls): doc_clean_asynchronously = ( "Clean the remote directory asynchronously after the job finishes." ) + doc_retry_count = "Number of retries to resubmit failed jobs." machine_args = [ Argument("batch_type", str, optional=False, doc=doc_batch_type), @@ -413,6 +418,7 @@ def arginfo(cls): default=False, doc=doc_clean_asynchronously, ), + Argument("retry_count", int, optional=True, default=3, doc=doc_retry_count), ] context_variant = Variant( diff --git a/dpdispatcher/machines/dp_cloud_server.py b/dpdispatcher/machines/dp_cloud_server.py index b4719bfe..001a17fe 100644 --- a/dpdispatcher/machines/dp_cloud_server.py +++ b/dpdispatcher/machines/dp_cloud_server.py @@ -19,7 +19,8 @@ class Bohrium(Machine): alias = ("Lebesgue", "DpCloudServer") - def __init__(self, context): + def __init__(self, context, **kwargs): + super().__init__(context=context, **kwargs) self.context = context self.input_data = context.remote_profile["input_data"].copy() self.api_version = 2 @@ -32,7 +33,6 @@ def __init__(self, context): phone = context.remote_profile.get("phone", None) username = context.remote_profile.get("username", None) password = context.remote_profile.get("password", None) - self.retry_count = context.remote_profile.get("retry_count", 3) self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True) ticket = os.environ.get("BOHR_TICKET", None) diff --git a/dpdispatcher/machines/openapi.py b/dpdispatcher/machines/openapi.py index e5514dce..64c57c4c 100644 --- a/dpdispatcher/machines/openapi.py +++ b/dpdispatcher/machines/openapi.py @@ -29,7 +29,8 @@ def unzip_file(zip_file, out_dir="./"): class OpenAPI(Machine): - def __init__(self, context): + def __init__(self, context, **kwargs): + super().__init__(context=context, **kwargs) if not found_bohriumsdk: raise ModuleNotFoundError( "bohriumsdk not installed. Install dpdispatcher with `pip install dpdispatcher[bohrium]`" @@ -38,7 +39,6 @@ def __init__(self, context): self.remote_profile = context.remote_profile.copy() self.grouped = self.remote_profile.get("grouped", True) - self.retry_count = self.remote_profile.get("retry_count", 3) self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True) access_key = ( diff --git a/dpdispatcher/machines/pbs.py b/dpdispatcher/machines/pbs.py index 35ef4c44..9942cb89 100644 --- a/dpdispatcher/machines/pbs.py +++ b/dpdispatcher/machines/pbs.py @@ -17,6 +17,9 @@ class PBS(Machine): + def __init__(self, **kwargs): + super().__init__(**kwargs) + def gen_script(self, job): pbs_script = super().gen_script(job) return pbs_script @@ -188,24 +191,8 @@ def gen_script_header(self, job): class SGE(PBS): - def __init__( - self, - batch_type=None, - context_type=None, - local_root=None, - remote_root=None, - remote_profile={}, - *, - context=None, - ): - super(PBS, self).__init__( - batch_type, - context_type, - local_root, - remote_root, - remote_profile, - context=context, - ) + def __init__(self, **kwargs): + super().__init__(**kwargs) def gen_script_header(self, job): ### Ref:https://softpanorama.org/HPC/PBS_and_derivatives/Reference/pbs_command_vs_sge_commands.shtml diff --git a/tests/test_argcheck.py b/tests/test_argcheck.py index b87f39fc..637c5254 100644 --- a/tests/test_argcheck.py +++ b/tests/test_argcheck.py @@ -27,6 +27,7 @@ def test_machine_argcheck(self): "symlink": True, }, "clean_asynchronously": False, + "retry_count": 3, } self.assertDictEqual(norm_dict, expected_dict)