Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add progress bar update capability to data transfers in iRODS #578

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,48 @@ will spawn a number of threads in order to optimize performance for
iRODS server versions 4.2.9+ and file sizes larger than a default
threshold value of 32 Megabytes.

Progress bars
-------------

The PRC now has support for progress bars which function on the basis of
an "update" callback function. In the case of a tqdm progress bar (see https://github.com/tqdm/tqdm), you can always just
pass the update method of the progress bar instance directly to the data_object
`put` or `get` method:

```python
pbar = tqdm.tqdm(total = data_obj.size)
session.data_objects.get(file_name, data_obj.path, updatables = pbar.update)
```

The updatables parameter can be a list or tuple of update-enabling objects and/or bound methods.

Alternatively, the tqdm progress bar object itself can be passed in, if an adapting
function such as the following is first registered:

```python
def adapt_tqdm(pbar, l = threading.Lock()):
def _update(n):
with l:
pbar.update(n)
return _update
irods.manager.data_objects_manager.register_update_type( adapt_tqdm )
session.data_objects.put( file, logical_path, updatables = [tqdm_1,tqdm_2] ) # update two tqdm's simultaneously
```

Other progress bars may be included in an updatables parameter, but may require more extensive adaptation.
For example, the ProgressBar object (from the progressbar module) also has an update method, but this one
takes an up-to-date cumulative byte-count, instead of the size of an individual transfer in bytes,
as its sole parameter. There can be other complications: e.g. a ProgressBar instance does not allow a weak
reference to itself to be formed, which interferes with the Python iRODS Client's internal scheme of accounting
for progress bar instances "still in progress" while also preventing resource leaks.

In such cases, it is probably best to implement a wrapper for the progress
bar in question, and submit the wrapper instance as the updatable parameter. Whether
a wrapper or the progress-bar object itself is thus employed, it is recommended that the user take steps to
ensure the lifetime of the updatable instance extends beyond the time needed for the transfer to complete.

See `irods/test/data_obj_test.py` for examples of these and other subtleties of progress-bar usage.

Working with collections
------------------------

Expand Down
102 changes: 88 additions & 14 deletions irods/manager/data_object_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from __future__ import absolute_import
import os
import ast
import collections
import io
import json
import logging
import os
import six
import weakref
from irods.models import DataObject, Collection
from irods.manager import Manager
from irods.manager._internal import _api_impl, _logical_path
Expand All @@ -17,11 +23,72 @@
import irods.keywords as kw
import irods.parallel as parallel
from irods.parallel import deferred_call
import six
import ast
import json
import logging

logger = logging.getLogger(__name__)

_update_types = []
_update_functions = weakref.WeakKeyDictionary()

def register_update_instance(object_, updater): # updater
_update_functions[object_] = updater

def register_update_type(type_, factory_):
"""
Create an entry corresponding to a type_ of instance to be allowed among updatables, with processing
based on the factory_ callable.

Parameters:
type_ : a type of instance to be allowed in the updatables parameter.
factory_ : a function accepting the instance passed in, and yielding an update callable.
If None, then remove the type from the list.
"""

# Delete if already present in list
z = tuple(zip(*_update_types))
if z and type_ in z[0]:
_update_types.pop(z[0].index(type_))
# Rewrite the list
# - with the new item introduced at the start of the list but otherwise in the same order, and
# - preserving only pairs that do not contain 'None' as the second member.
_update_types[:] = list((k,v) for k,v in collections.OrderedDict([(type_,factory_)] + _update_types).items() if v is not None)


def unregister_update_type(type_):
"""
Remove type_ from the listof recognized updatable types maintained by the PRC.
"""
register_update_type(type_, None)


def do_progress_updates(updatables, n, logging_function = logger.warning):
"""
Used internally by Python iRODS Client's data transfer routines (put, get) to iterate through updatables to be processed.
This, in turn, should cause the underlying corresponding progress bars or indicators to be updated.
"""
if not isinstance(updatables, (list,tuple)):
updatables = [updatables]

for object_ in updatables:
# If an updatable is directly callable, we set that up to be called without further ado.
if callable(object_):
update_func = object_
else:
# If not, we search for a registered type that matches object_ and register (or look up if previously registered) a factory-produced updater for that instance.
# Examine the unit tests for issue #574 in data_obj_test.py for factory examples.
update_func = _update_functions.get(object_)
if not update_func:
# search based on type
for class_,factory_ in _update_types:
if isinstance(object_,class_):
update_func = factory_(object_)
register_update_instance(object_, update_func)
break
else:
logging_function("Could not derive an update function for: %r",object_)
continue

# Do the update.
if update_func: update_func(n)


def call___del__if_exists(super_):
Expand Down Expand Up @@ -124,7 +191,7 @@ def should_parallelize_transfer( self,
open_options[kw.DATA_SIZE_KW] = size


def _download(self, obj, local_path, num_threads, **options):
def _download(self, obj, local_path, num_threads, updatables = (), **options):
"""Transfer the contents of a data object to a local file.

Called from get() when a local path is named.
Expand All @@ -145,14 +212,16 @@ def _download(self, obj, local_path, num_threads, **options):
f.close()
if not self.parallel_get( (obj,o), local_file, num_threads = num_threads,
target_resource_name = options.get(kw.RESC_NAME_KW,''),
data_open_returned_values = data_open_returned_values_):
data_open_returned_values = data_open_returned_values_,
updatables = updatables):
raise RuntimeError("parallel get failed")
else:
for chunk in chunks(o, self.READ_BUFFER_SIZE):
f.write(chunk)
do_progress_updates(updatables, len(chunk))


def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS, **options):
def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS, updatables = (), **options):
"""
Get a reference to the data object at the specified `path'.

Expand All @@ -163,7 +232,7 @@ def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS,

# TODO: optimize
if local_path:
self._download(path, local_path, num_threads = num_threads, **options)
self._download(path, local_path, num_threads = num_threads, updatables = updatables, **options)

query = self.sess.query(DataObject)\
.filter(DataObject.name == irods_basename(path))\
Expand All @@ -180,7 +249,7 @@ def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS,
return iRODSDataObject(self, parent, results)


def put(self, local_path, irods_path, return_data_object = False, num_threads = DEFAULT_NUMBER_OF_THREADS, **options):
def put(self, local_path, irods_path, return_data_object = False, num_threads = DEFAULT_NUMBER_OF_THREADS, updatables = (), **options):

if self.sess.collections.exists(irods_path):
obj = iRODSCollection.normalize_path(irods_path, os.path.basename(local_path))
Expand All @@ -195,7 +264,7 @@ def put(self, local_path, irods_path, return_data_object = False, num_threads =
if not self.parallel_put( local_path, (obj,o), total_bytes = sizelist[0], num_threads = num_threads,
target_resource_name = options.get(kw.RESC_NAME_KW,'') or
options.get(kw.DEST_RESC_NAME_KW,''),
open_options = options ):
open_options = options, updatables = updatables):
raise RuntimeError("parallel put failed")
else:
with self.open(obj, 'w', **options) as o:
Expand All @@ -204,6 +273,7 @@ def put(self, local_path, irods_path, return_data_object = False, num_threads =
options[kw.OPR_TYPE_KW] = 1 # PUT_OPR
for chunk in chunks(f, self.WRITE_BUFFER_SIZE):
o.write(chunk)
do_progress_updates(updatables, len(chunk))
if kw.ALL_KW in options:
repl_options = options.copy()
repl_options[kw.UPDATE_REPL_KW] = ''
Expand Down Expand Up @@ -259,7 +329,8 @@ def parallel_get(self,
num_threads = 0,
target_resource_name = '',
data_open_returned_values = None,
progressQueue = False):
progressQueue = False,
updatables = ()):
"""Call into the irods.parallel library for multi-1247 GET.

Called from a session.data_objects.get(...) (via the _download method) on
Expand All @@ -270,7 +341,8 @@ def parallel_get(self,
return parallel.io_main( self.sess, data_or_path_, parallel.Oper.GET | (parallel.Oper.NONBLOCKING if async_ else 0), file_,
num_threads = num_threads, target_resource_name = target_resource_name,
data_open_returned_values = data_open_returned_values,
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0))
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0),
updatables = updatables)

def parallel_put(self,
file_ ,
Expand All @@ -280,6 +352,7 @@ def parallel_put(self,
num_threads = 0,
target_resource_name = '',
open_options = {},
updatables = (),
progressQueue = False):
"""Call into the irods.parallel library for multi-1247 PUT.

Expand All @@ -290,7 +363,8 @@ def parallel_put(self,
return parallel.io_main( self.sess, data_or_path_, parallel.Oper.PUT | (parallel.Oper.NONBLOCKING if async_ else 0), file_,
num_threads = num_threads, total_bytes = total_bytes, target_resource_name = target_resource_name,
open_options = open_options,
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0)
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0),
updatables = updatables,
)


Expand Down
32 changes: 21 additions & 11 deletions irods/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,14 @@ def _io_send_bytes_progress (queueObject, item):

COPY_BUF_SIZE = (1024 ** 2) * 4

def _copy_part( src, dst, length, queueObject, debug_info, mgr):
def _copy_part( src, dst, length, queueObject, debug_info, mgr, updatables = ()):
"""
The work-horse for performing the copy between file and data object.

It also helps determine whether there has been a large enough increment of
bytes to inform the progress bar of a need to update.
"""
from irods.manager.data_object_manager import do_progress_updates
bytecount = 0
accum = 0
while True and bytecount < length:
Expand All @@ -240,6 +241,7 @@ def _copy_part( src, dst, length, queueObject, debug_info, mgr):
bytecount += buf_len
accum += buf_len
if queueObject and accum and _io_send_bytes_progress(queueObject,accum): accum = 0
do_progress_updates(updatables, buf_len)
if verboseConnection:
print ("("+debug_info+")",end='',file=sys.stderr)
sys.stderr.flush()
Expand Down Expand Up @@ -301,7 +303,7 @@ def finalize(self):
self.initial_io.close()


def _io_part (objHandle, range_, file_, opr_, mgr_, thread_debug_id = '', queueObject = None ):
def _io_part (objHandle, range_, file_, opr_, mgr_, thread_debug_id = '', queueObject = None, updatables = None):
"""
Runs in a separate thread to manage the transfer of a range of bytes within the data object.

Expand All @@ -315,8 +317,8 @@ def _io_part (objHandle, range_, file_, opr_, mgr_, thread_debug_id = '', queueO
file_.seek(offset)
if thread_debug_id == '': # for more succinct thread identifiers while debugging.
thread_debug_id = str(threading.currentThread().ident)
return ( _copy_part (file_, objHandle, length, queueObject, thread_debug_id, mgr_) if Operation.isPut()
else _copy_part (objHandle, file_, length, queueObject, thread_debug_id, mgr_) )
return ( _copy_part (file_, objHandle, length, queueObject, thread_debug_id, mgr_, updatables) if Operation.isPut()
else _copy_part (objHandle, file_, length, queueObject, thread_debug_id, mgr_, updatables) )


def _io_multipart_threaded(operation_ , dataObj_and_IO, replica_token, hier_str, session, fname,
Expand All @@ -342,9 +344,9 @@ def bytes_range_for_thread( i, num_threads, total_bytes, chunk ):

logger.info(u"num_threads = %s ; bytes_per_thread = %s", num_threads, bytes_per_thread)

_queueLength = extra_options.get('_queueLength',0)
if _queueLength > 0:
queueObject = Queue(_queueLength)
queueLength = extra_options.get('queueLength',0)
if queueLength > 0:
queueObject = Queue(queueLength)
else:
queueObject = None

Expand All @@ -355,6 +357,11 @@ def bytes_range_for_thread( i, num_threads, total_bytes, chunk ):
counter = 1
gen_file_handle = lambda: open(fname, Operation.disk_file_mode(initial_open = (counter == 1)))
File = gen_file_handle()

thread_opts = { 'updatables' : extra_options.get('updatables',()),
'queueObject' : queueObject
}

for byte_range in ranges:
if Io is None:
Io = session.data_objects.open( Data_object.path, Operation.data_object_mode(initial_open = False),
Expand All @@ -366,12 +373,14 @@ def bytes_range_for_thread( i, num_threads, total_bytes, chunk ):
mgr.add_io( Io )
logger.debug(u'target_host = %s', Io.raw.session.pool.account.host)
if File is None: File = gen_file_handle()
futures.append(executor.submit( _io_part, Io, byte_range, File, Operation, mgr, str(counter), queueObject))
futures.append(executor.submit(_io_part, Io, byte_range, File, Operation, mgr,
thread_debug_id = str(counter),
**thread_opts))
counter += 1
Io = File = None

if Operation.isNonBlocking():
if _queueLength:
if queueLength:
return futures, queueObject, mgr
else:
return futures
Expand All @@ -395,7 +404,6 @@ def io_main( session, Data, opr_, fname, R='', **kwopt):
Operation = Oper(opr_)
d_path = None
Io = None

if isinstance(Data,tuple):
(Data, Io) = Data[:2]

Expand Down Expand Up @@ -466,9 +474,11 @@ def io_main( session, Data, opr_, fname, R='', **kwopt):
(replica_token , resc_hier) = rawfile.replica_access_info()

queueLength = kwopt.get('queueLength',0)

pass_thru_options = ('updatables','queueLength')
retval = _io_multipart_threaded (Operation, (Data, Io), replica_token, resc_hier, session, fname, total_bytes,
num_threads = num_threads,
_queueLength = queueLength)
**{k:v for k,v in kwopt.items() if k in pass_thru_options})

# SessionObject.data_objects.parallel_{put,get} will return:
# - immediately with an AsyncNotify instance, if Oper.NONBLOCKING flag is used.
Expand Down
Loading