Skip to content

Commit

Permalink
change tests after reworking
Browse files Browse the repository at this point in the history
  • Loading branch information
d-w-moore committed Jul 5, 2024
1 parent 64e99d6 commit e17636e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 37 deletions.
40 changes: 26 additions & 14 deletions irods/manager/data_object_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,57 @@

logger = logging.getLogger(__name__)

_update_types = [
#TODO : delete, as I'm pretty sure we no longer need this:
##(typing.Callable, (lambda _:_)) # bare callables in the updatables list should be left as they are
]

_update_types = []
_update_fns = weakref.WeakKeyDictionary()

def register_update_instance(obj_, updater): # updater
_update_fns[obj_] = updater

def register_update_type(type_, transform_): # transform is a factory for the updater func
def register_update_type(type_, factory_): # Create an entry corresponding to a type_ of instance to be allowed among updatables.
# The factory_, when called on an instance of type_, creates the updater for that instance.
"""
Parameters:
type_ : a type of instance to be allowed in the updatables parameter.
factory_ : a function accepting the instance passed in, and yielding an update callable.
If None, then remove the type from the list.
"""
global _update_types
_update_types.insert(0, (type_,transform_))
_update_types = list((k,v) for k,v in collections.OrderedDict(_update_types).items() if v is not None)

# Delete if already present in list
z = tuple(zip(*_update_types))
if z and type_ in z[0]:
_update_types.pop(z[0].index(type_))
# Rewrite the list, in the same order, from all pairs not containing 'None' as second member.
_update_types[:] = list((k,v) for k,v in collections.OrderedDict([(type_,factory_)] + _update_types).items() if v is not None)


def unregister_update_type(type_):
register_update_type(type_, None)


def do_progress_updates(updatables, n, logging_function = logger.warning):

if not isinstance(updatables, (list,tuple)):
updatables = [updatables]

for obj in updatables:
if callable(obj):
stored_func = obj
update_func = obj
else:
stored_func = _update_fns.get(obj)
if not stored_func:
update_func = _update_fns.get(obj)
if not update_func:
# search based on type
for cl,factory in _update_types:
if isinstance(obj,cl):
stored_func = factory(obj)
register_update_instance(obj, stored_func)
update_func = factory(obj)
register_update_instance(obj, update_func)
break
else:
logging_function("Could not derive an update function for: %r",obj)
continue

if update_func: update_func(n)

if stored_func: stored_func(n)

def call___del__if_exists(super_):
"""
Expand Down
56 changes: 33 additions & 23 deletions irods/test/data_obj_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,42 +2136,50 @@ def update(self,n):
def test_pbar_for_parallel_io_3(self):
from irods.manager.data_object_manager import (register_update_type, unregister_update_type)

class ProgressBar_wrapper:
def start(self,*x):return self.p.start(*x)
def update(self,*x):return self.p.update(*x)
@property
def currval(self): return self.p.currval
def __init__(self,*args,**kw):
self.p = progressbar.ProgressBar(*args,**kw)

def adapt_ProgressBar(pbar, total = weakref.WeakKeyDictionary(), # Any mutable parameters thus introduced to the function scope
l = threading.Lock()): # are one-time initialized & memoized for subsequent calls.
total.setdefault(pbar, 0) # Also effectively a one-time initialization
try:
progressbar_class = progressbar.ProgressBar
weakref.ref(progressbar_class(maxval = 10))
except TypeError:
class ProgressBar_wrapper:
def start(self,*x):return self.p.start(*x)
def update(self,*x):return self.p.update(*x)
@property
def currval(self): return self.p.currval
def __init__(self,*args,**kw):
self.p = progressbar.ProgressBar(*args,**kw)

progressbar_class = ProgressBar_wrapper

def adapt_ProgressBar(pbar):
total = [0]
l = threading.Lock()
pbar.start()
def _update(n):
# total was made a list-of-single-integer for the benefit of Python2 clients.
# If total were a bare integer we'd declare it by "nonlocal total" in Python3 and it would just work.
with l:
total[pbar] += n
pbar.update(total[pbar])
total[0] += n
pbar.update(total[0])
return _update

try:
actualClass = ProgressBar_wrapper # weakref-compatible wrapper for progressbar.ProgressBar
register_update_type(actualClass, adapt_ProgressBar)
register_update_type(progressbar_class, adapt_ProgressBar)

LEN = 1024**2*40
content = b'_'*LEN
ProgressBar_ = actualClass(maxval=len(content))
ProgressBar_.start()
self._run_pbars_for_parallel_io(content, [ProgressBar_])
self.assertEqual(ProgressBar_.currval, LEN)
pbar = progressbar_class(maxval=len(content))
self._run_pbars_for_parallel_io(content, [pbar])
self.assertEqual(pbar.currval, LEN)

finally:
unregister_update_type(actualClass)
unregister_update_type(progressbar_class)

@unittest.skipIf(tqdm is None, "tqdm is not installed")
def test_pbar_for_parallel_io_2(self):
from irods.manager.data_object_manager import (register_update_type, unregister_update_type)

def adapt_tqdm(pbar, l = threading.Lock()):
def adapt_tqdm(pbar):
l = threading.Lock()
def _update(n):
with l:
pbar.update(n)
Expand All @@ -2184,7 +2192,8 @@ def _update(n):
content = b'_'*LEN
tqdm_1 = tqdm.tqdm(total=len(content))
tqdm_2 = tqdm.tqdm(total=len(content))
self._run_pbars_for_parallel_io(content, [tqdm_1,tqdm_2.update])
self._run_pbars_for_parallel_io(content, [tqdm_1,tqdm_2]) # The bare the bound instance method itself, ie tqdm_2.update,
# could be passed in, if we knew it to be thread-safe.
self.assertEqual(tqdm_1.n, LEN)
self.assertEqual(tqdm_2.n, LEN)

Expand Down Expand Up @@ -2215,7 +2224,8 @@ def __init__(self, total):
def update(self,n):
self.i += n

def thread_safe(update_fn, l = threading.Lock()):
def thread_safe(update_fn):
l = threading.Lock()
def _update(n):
with l:
update_fn(n)
Expand Down

0 comments on commit e17636e

Please sign in to comment.