From ec1e508dd497dedb9ff11068c7f0e875b6e36329 Mon Sep 17 00:00:00 2001 From: d-w-moore Date: Fri, 5 Jul 2024 02:17:59 -0400 Subject: [PATCH] change tests after reworking --- irods/manager/data_object_manager.py | 40 +++++++++++++------- irods/test/data_obj_test.py | 56 ++++++++++++++++------------ 2 files changed, 59 insertions(+), 37 deletions(-) diff --git a/irods/manager/data_object_manager.py b/irods/manager/data_object_manager.py index 41caff05..49a97d18 100644 --- a/irods/manager/data_object_manager.py +++ b/irods/manager/data_object_manager.py @@ -27,24 +27,34 @@ logger = logging.getLogger(__name__) -_update_types = [ - #TODO : delete, as I'm pretty sure we no longer need this: - ##(typing.Callable, (lambda _:_)) # bare callables in the updatables list should be left as they are -] - +_update_types = [] _update_fns = weakref.WeakKeyDictionary() def register_update_instance(obj_, updater): # updater _update_fns[obj_] = updater -def register_update_type(type_, transform_): # transform is a factory for the updater func +def register_update_type(type_, factory_): # Create an entry corresponding to a type_ of instance to be allowed among updatables. + # The factory_, when called on an instance of type_, creates the updater for that instance. + """ + Parameters: + type_ : a type of instance to be allowed in the updatables parameter. + factory_ : a function accepting the instance passed in, and yielding an update callable. + If None, then remove the type from the list. + """ global _update_types - _update_types.insert(0, (type_,transform_)) - _update_types = list((k,v) for k,v in collections.OrderedDict(_update_types).items() if v is not None) + + # Delete if already present in list + z = tuple(zip(*_update_types)) + if z and type_ in z[0]: + _update_types.pop(z[0].index(type_)) + # Rewrite the list, in the same order, from all pairs not containing 'None' as second member. + _update_types[:] = list((k,v) for k,v in collections.OrderedDict([(type_,factory_)] + _update_types).items() if v is not None) + def unregister_update_type(type_): register_update_type(type_, None) + def do_progress_updates(updatables, n, logging_function = logger.warning): if not isinstance(updatables, (list,tuple)): @@ -52,20 +62,22 @@ def do_progress_updates(updatables, n, logging_function = logger.warning): for obj in updatables: if callable(obj): - stored_func = obj + update_func = obj else: - stored_func = _update_fns.get(obj) - if not stored_func: + update_func = _update_fns.get(obj) + if not update_func: # search based on type for cl,factory in _update_types: if isinstance(obj,cl): - stored_func = factory(obj) - register_update_instance(obj, stored_func) + update_func = factory(obj) + register_update_instance(obj, update_func) break else: logging_function("Could not derive an update function for: %r",obj) + continue + + if update_func: update_func(n) - if stored_func: stored_func(n) def call___del__if_exists(super_): """ diff --git a/irods/test/data_obj_test.py b/irods/test/data_obj_test.py index 888c1128..32d640ec 100644 --- a/irods/test/data_obj_test.py +++ b/irods/test/data_obj_test.py @@ -2136,42 +2136,50 @@ def update(self,n): def test_pbar_for_parallel_io_3(self): from irods.manager.data_object_manager import (register_update_type, unregister_update_type) - class ProgressBar_wrapper: - def start(self,*x):return self.p.start(*x) - def update(self,*x):return self.p.update(*x) - @property - def currval(self): return self.p.currval - def __init__(self,*args,**kw): - self.p = progressbar.ProgressBar(*args,**kw) - - def adapt_ProgressBar(pbar, total = weakref.WeakKeyDictionary(), # Any mutable parameters thus introduced to the function scope - l = threading.Lock()): # are one-time initialized & memoized for subsequent calls. - total.setdefault(pbar, 0) # Also effectively a one-time initialization + try: + progressbar_class = progressbar.ProgressBar + weakref.ref(progressbar_class(maxval = 10)) + except TypeError: + class ProgressBar_wrapper: + def start(self,*x):return self.p.start(*x) + def update(self,*x):return self.p.update(*x) + @property + def currval(self): return self.p.currval + def __init__(self,*args,**kw): + self.p = progressbar.ProgressBar(*args,**kw) + + progressbar_class = ProgressBar_wrapper + + def adapt_ProgressBar(pbar): + total = [0] + l = threading.Lock() + pbar.start() def _update(n): + # total was made a list-of-single-integer for the benefit of Python2 clients. + # If total were a bare integer we'd declare it by "nonlocal total" in Python3 and it would just work. with l: - total[pbar] += n - pbar.update(total[pbar]) + total[0] += n + pbar.update(total[0]) return _update try: - actualClass = ProgressBar_wrapper # weakref-compatible wrapper for progressbar.ProgressBar - register_update_type(actualClass, adapt_ProgressBar) + register_update_type(progressbar_class, adapt_ProgressBar) LEN = 1024**2*40 content = b'_'*LEN - ProgressBar_ = actualClass(maxval=len(content)) - ProgressBar_.start() - self._run_pbars_for_parallel_io(content, [ProgressBar_]) - self.assertEqual(ProgressBar_.currval, LEN) + pbar = progressbar_class(maxval=len(content)) + self._run_pbars_for_parallel_io(content, [pbar]) + self.assertEqual(pbar.currval, LEN) finally: - unregister_update_type(actualClass) + unregister_update_type(progressbar_class) @unittest.skipIf(tqdm is None, "tqdm is not installed") def test_pbar_for_parallel_io_2(self): from irods.manager.data_object_manager import (register_update_type, unregister_update_type) - def adapt_tqdm(pbar, l = threading.Lock()): + def adapt_tqdm(pbar): + l = threading.Lock() def _update(n): with l: pbar.update(n) @@ -2184,7 +2192,8 @@ def _update(n): content = b'_'*LEN tqdm_1 = tqdm.tqdm(total=len(content)) tqdm_2 = tqdm.tqdm(total=len(content)) - self._run_pbars_for_parallel_io(content, [tqdm_1,tqdm_2.update]) + self._run_pbars_for_parallel_io(content, [tqdm_1,tqdm_2]) # The bare the bound instance method itself, ie tqdm_2.update, + # could be passed in, if we knew it to be thread-safe. self.assertEqual(tqdm_1.n, LEN) self.assertEqual(tqdm_2.n, LEN) @@ -2215,7 +2224,8 @@ def __init__(self, total): def update(self,n): self.i += n - def thread_safe(update_fn, l = threading.Lock()): + def thread_safe(update_fn): + l = threading.Lock() def _update(n): with l: update_fn(n)