diff --git a/acto/engine.py b/acto/engine.py index 5ad5021346..bc78aa851b 100644 --- a/acto/engine.py +++ b/acto/engine.py @@ -9,6 +9,8 @@ import tempfile import threading import time +import threading +import signal from copy import deepcopy from typing import Callable, Optional @@ -216,6 +218,8 @@ class TrialRunner: """Test driver of Acto. One TrialRunner is one worker for Acto, and each TrialRunner is independant. """ + alarms_before_termination = 10 ** 10 # large placeholder value + alarm_decrement_lock = threading.Lock() def __init__( self, @@ -599,7 +603,6 @@ def run_testcases( # field_node.get_testcases().pop() # finish testcase testcase_patches.append((group, testcase_with_path, patch)) - run_result = TrialRunner.run_and_check( runner=runner, checker=checker, @@ -686,6 +689,17 @@ def run_and_check( is_revert=revert, ) run_result.dump(trial_dir=runner.trial_dir) + is_alarm = ( + not run_result.is_invalid_input() + and run_result.oracle_result.is_error() + ) + if is_alarm: + with TrialRunner.alarm_decrement_lock: + TrialRunner.alarms_before_termination -= 1 + if TrialRunner.alarms_before_termination <= 0: + msg = "\nAlarm Limit Reached!" + print(f"\033[1m{msg}\033[0m""") + os._exit(1) return run_result def run_recovery( @@ -765,9 +779,10 @@ def __init__( mount: Optional[list] = None, focus_fields: Optional[list] = None, acto_namespace: int = 0, + time_limit: int = 0, + alarm_limit: int = 0 ) -> None: logger = get_thread_logger(with_prefix=False) - try: with open( operator_config.seed_custom_resource, "r", encoding="utf-8" @@ -857,6 +872,16 @@ def __init__( encoding="utf-8", ) as plan_file: json.dump(self.test_plan, plan_file, cls=ActoEncoder, indent=4) + if time_limit: + print("Triggering timeout after %d seconds" % time_limit) + def signal_handler(signum, frame): + msg = "\nTime limit reached!" + print(f"\033[1m{msg}\033[0m""") + os._exit(1) + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(time_limit) # Trigger after time_limit seconds + if alarm_limit: + TrialRunner.alarms_before_termination = alarm_limit def __learn(self, context_file, helper_crd, analysis_only=False): logger = get_thread_logger(with_prefix=False) diff --git a/acto/input/input.py b/acto/input/input.py index f7271f28f0..335f13e5f8 100644 --- a/acto/input/input.py +++ b/acto/input/input.py @@ -8,6 +8,7 @@ import threading from functools import reduce from typing import List, Optional, Tuple +from tqdm import tqdm import pydantic import yaml @@ -167,6 +168,8 @@ def __init__( for example_doc in example_docs: self.root_schema.load_examples(example_doc) + self.p_bar_intialized = False + self.p_bar = None self.num_workers = num_workers self.num_cases = num_cases # number of test cases to run at a time @@ -405,7 +408,13 @@ def next_test( Tuple of (new value, if this is a setup) """ logger = get_thread_logger(with_prefix=True) - + if self.p_bar_intialized: + self.p_bar.update(1) + self.p_bar.refresh() + else: + self.p_bar = tqdm(total=self.metadata.number_of_run_test_cases, initial=1, position=0, leave=True) + self.p_bar_intialized = True + logger.info("Progress [%d] cases left", len(self.thread_vars.test_plan)) selected_group: TestGroup = self.thread_vars.test_plan.next_group()