diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py index 8f13b3c..da5fdb8 100644 --- a/src/nrtk_explorer/app/core.py +++ b/src/nrtk_explorer/app/core.py @@ -6,6 +6,8 @@ from nrtk_explorer.library.filtering import FilterProtocol from nrtk_explorer.library.dataset import get_dataset, expand_hugging_face_datasets from nrtk_explorer.library.debounce import debounce +from nrtk_explorer.library.app_config import process_config + from nrtk_explorer.app.images.images import Images from nrtk_explorer.app.embeddings import EmbeddingsApp @@ -38,28 +40,34 @@ # Engine class # --------------------------------------------------------- +config_options = { + "dataset": { + "flags": ["--dataset"], + "params": { + "nargs": "+", + "default": DEFAULT_DATASETS, + "help": "Path to the JSON file describing the image dataset", + }, + }, + "download": { + "flags": ["--download"], + "params": { + "action": "store_true", + "default": False, + "help": "Download Hugging Face Hub datasets instead of streaming them", + }, + }, +} + class Engine(Applet): - def __init__(self, server=None): + def __init__(self, server=None, **kwargs): super().__init__(server) - self.server.cli.add_argument( - "--dataset", - nargs="+", - default=DEFAULT_DATASETS, - help="Path to the JSON file describing the image dataset", - ) - - self.server.cli.add_argument( - "--download", - action="store_true", - default=False, - help="Download Hugging Face Hub datasets instead of streaming them", - ) + config = process_config(self.server.cli, config_options, **kwargs) - known_args, _ = self.server.cli.parse_known_args() dataset_identifiers = expand_hugging_face_datasets( - known_args.dataset, not known_args.download + config["dataset"], not config["download"] ) self.input_paths = dataset_identifiers self.state.current_dataset = self.input_paths[0] @@ -67,7 +75,7 @@ def __init__(self, server=None): images = Images(server=self.server) self._transforms_app = TransformsApp( - server=self.server.create_child_server(), images=images + server=self.server.create_child_server(), images=images, **kwargs ) self._embeddings_app = EmbeddingsApp( diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index e16633b..e0aabd5 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -11,14 +11,15 @@ import nrtk_explorer.library.nrtk_transforms as nrtk_trans import nrtk_explorer.library.yaml_transforms as nrtk_yaml from nrtk_explorer.library import object_detector -from nrtk_explorer.app.applet import Applet -from nrtk_explorer.app.parameters import ParametersApp -from nrtk_explorer.app.images.image_meta import update_image_meta, dataset_id_to_meta +from nrtk_explorer.library.app_config import process_config from nrtk_explorer.library.scoring import ( compute_score, ) -from nrtk_explorer.app.trame_utils import change_checker, delete_state +from nrtk_explorer.app.applet import Applet +from nrtk_explorer.app.parameters import ParametersApp +from nrtk_explorer.app.images.image_meta import update_image_meta, dataset_id_to_meta +from nrtk_explorer.app.trame_utils import change_checker, delete_state from nrtk_explorer.app.images.image_ids import ( dataset_id_to_image_id, dataset_id_to_transformed_image_id, @@ -86,6 +87,18 @@ def on_change_feature_enabled(self, **kwargs): self.enabled_callback() +config_options = { + "models": { + "flags": ["--models"], + "params": { + "nargs": "+", + "default": INFERENCE_MODELS_DEFAULT, + "help": "Space separated list of inference models", + }, + }, +} + + class TransformsApp(Applet): def __init__( self, @@ -94,18 +107,12 @@ def __init__( ground_truth_annotations=None, original_detection_annotations=None, transformed_detection_annotations=None, + **kwargs, ): super().__init__(server) - self.server.cli.add_argument( - "--models", - nargs="+", - default=INFERENCE_MODELS_DEFAULT, - help="Space separated list of inference models", - ) - - known_args, _ = self.server.cli.parse_known_args() - self.state.inference_models = known_args.models + config = process_config(self.server.cli, config_options, **kwargs) + self.state.inference_models = config["models"] self.state.inference_model = self.state.inference_models[0] self.state.setdefault("image_list_ids", []) self.state.setdefault("dataset_ids", []) diff --git a/src/nrtk_explorer/library/app_config.py b/src/nrtk_explorer/library/app_config.py new file mode 100644 index 0000000..5dcef77 --- /dev/null +++ b/src/nrtk_explorer/library/app_config.py @@ -0,0 +1,12 @@ +def process_config(cli, config_options, **kwargs): + for opt in config_options.values(): + cli.add_argument(*opt["flags"], **opt["params"]) + known_args, _ = cli.parse_known_args() + + config = {} + for name in config_options: + if name in kwargs: + config[name] = kwargs[name] + else: + config[name] = getattr(known_args, name) + return config