From a22e978822ff7229f69b46d6a793aeef23448c8f Mon Sep 17 00:00:00 2001 From: Corvin-Petrut COBARZAN <127338708+ecmwf-cobarzan@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:09:19 +0300 Subject: [PATCH] Feature/cams parallel cache2 (#233) * dev/cams * Steps towards a first rate limiting solution * Adapt hardcoded mapping * Small bug fix (for testing code) * Call the secondary adaptors (via a cdsapi request) * Send all params to secondary adaptors * Call retrieve_* from CAMSEuropeAirQualityForecastsAdaptorFor*Data * Add mock create_result_file in retrieve_xxx * Log how many req_groups * Create random file names in retrieve_xxx (check if these are not already generated) * Create random file names in new_retrieve_subrequest (check if these are not already generated) * Log how many req_groups[0]['retrieved_files'] * Add process_grib_files * MockResultFile to str * MochResultFile to string (at the right time, for now) * Covertions (netCDF, zip) * Revert likely previously commented create_result_file * Debug info * Debug format * Debug format * Add create_temp_file to context * Rate limiting: OK * See URLs for local fields * Test with hardcoded object store URL * Rewrite temporary_location for object store * Upload (for temporary files for now) * Upload test (for both permanent and temporary) * Fix model in the hardcoded mapping * Read context.request['mapping'] from the config * Debug * Remove hardcoded mapping and use the one in the configuration * Use an adaptor specific mapping function to apply the direct mapping from the configuration * Debug * Send the mapping from the adaptor to new_cams_regional_fc (as it is consummed by a config.pop in a superclass) * Deal with no result cases under new_retrieve_subrequest * Return None from CAMSEuropeAirQualityForecastsAdaptorFor*Data:retrieve if no data retrieved * Intersect constraints * Add _in_adaptor_no_cache to the list of recognised keys * Pop _in_adaptor_no_cache after the cache is avoided (to not propagate it to sub-requests) * Remove _in_adaptor_no_cache from the recognised keys list (likely not needed anymore) * Debug * Pop _in_adaptor_no_cache from the requests resulted after intersection * All components working (polishing to do) * Switch to using the standard apply_mapping (facilitated by mapping.remap and mapping.rename being populated as expected) * Add tabulate as a dependency * Load CDSAPI client credentials from environment * Remove unused imports, move imports to methods * Move tabulate from the list of mandatory dependencies to the complete version * Assume that the dataset specific bucket exists * Support netcdf_zip in cams_regional_fc (done by Luke) * Solve small bug in reassign_missing_to_archive * Bug fix in cams_regional_fc * Roocs adaptor: set timeout to a generous 30 seconds * Roocs adaptor: set timeout back to 3 seconds * Roocs adaptor: read the URL download timeout from the dataset configuration * Log request IDs of sub-requests done via the CDSAPI * Cds 265 url area selector too small selection (#194) * handle too small area selection * fix (#197) * mapped_requests * Remove temporary function * URL adaptor * adaptors refactor * Request type * Request type * mapped_requests * mapped_requests * debug * recall normalise_request * pre mapping mods * handle formats * handle embargo exceptions * handle download_formats * preserve _pre_retrieve so that adaptors can be updated gradually * intersect_constraints_bool * remove cyclical calling of intersect_constraints * cads-obs to use normalise_requests * cads-obs to use normalise_requests * some tidying * TODO update * set_download_format, with nice error * set_download_format, with nice error * download format, remove bug and add assertion check * CamsSolarRad adaptor to use normalise_request * type checking * QA + typing * URL adaptor/tools: add support for "fail_on_timeout_for_any_part" * Short revert * Revert * schema is now called at start of normalise request to ensure that we duplicate all the testing and modifications * Small typo fix (cads_adaptors.tools.download_tools) * CAMS solar radiation adaptor: adjust to request["format"] now being a list * Revert mapped request to scalar values (a specific requirement for this adaptor, at least for now) * Linting * Download sub-request result in a loop (for the cases when it fails to download completely the first time) * URL tools: set default for fail_on_timeout_for_any_part to True * clean update * typo * Multi-threaded caching in CAMS regional adaptor * Multi-threaded caching in CAMS regional adaptor * Multi-threaded caching in CAMS regional adaptor * Auto-linting of cams_regional_fc/cacher.py * Auto-linting of cams_regional_fc/mem_safe_queue.py * Harmonise cads_adaptors/__init__.py, cads_adaptors/adaptors/__init__.py with main * Delete cams_regional_fc/remote_copier.py * Bug fix in cams_regional_fc/cacher.py * Small qa --------- Co-authored-by: Luke Jones Co-authored-by: Eddy Comyn-Platt <53045993+EddyCMWF@users.noreply.github.com> Co-authored-by: EddyCMWF --- .../adaptors/cams_regional_fc/cacher.py | 444 +++++++++--------- .../cams_regional_fc/mem_safe_queue.py | 97 ++++ .../cams_regional_fc/remote_copier.py | 212 --------- pyproject.toml | 1 + 4 files changed, 333 insertions(+), 421 deletions(-) create mode 100644 cads_adaptors/adaptors/cams_regional_fc/mem_safe_queue.py delete mode 100644 cads_adaptors/adaptors/cams_regional_fc/remote_copier.py diff --git a/cads_adaptors/adaptors/cams_regional_fc/cacher.py b/cads_adaptors/adaptors/cams_regional_fc/cacher.py index 1d875ecc..0a3ed9ba 100644 --- a/cads_adaptors/adaptors/cams_regional_fc/cacher.py +++ b/cads_adaptors/adaptors/cams_regional_fc/cacher.py @@ -1,131 +1,36 @@ +import concurrent.futures import io import os import re -import socket -import stat import threading import time -from os.path import dirname -from tempfile import NamedTemporaryFile import boto3 import jinja2 from cds_common.hcube_tools import count_fields, hcube_intdiff, hcubes_intdiff2 from cds_common.message_iterators import grib_bytes_iterator from cds_common.url2.caching import NotInCache -from eccodes import codes_get_message, codes_write +from eccodes import codes_get_message from .grib2request import grib2request +from .mem_safe_queue import MemSafeQueue -class Credentials: - def __init__(self, url, access, key): - self.url = url - self.access = access - self.key = key - - -DESTINATION_BUCKET = "cci2-cams-regional-fc" # "cci2-cams-regional-fc-test" -TRUST_THAT_BUCKET_EXISTS = True - - -def upload( - destination_credentials, destination_bucket, destination_filepath, data_to_transfer -): - client = boto3.client( - "s3", - aws_access_key_id=destination_credentials.access, - aws_secret_access_key=destination_credentials.key, - endpoint_url=destination_credentials.url, - ) - - if not TRUST_THAT_BUCKET_EXISTS: - resource = boto3.resource( - "s3", - aws_access_key_id=destination_credentials.access, - aws_secret_access_key=destination_credentials.key, - endpoint_url=destination_credentials.url, - ) - - _bucket = resource.Bucket(destination_bucket) - if not _bucket.creation_date: - _bucket = client.create_bucket(Bucket=destination_bucket) - retry = True - _n = 0 - t0 = time.time() - while retry: - try: - file_object = io.BytesIO(data_to_transfer) - - client.put_object( - Bucket=destination_bucket, - Key=destination_filepath, - Body=file_object.getvalue(), - ) - t1 = time.time() - retry = False - status = "uploaded" - except AssertionError: - status = "interrupted" - t1 = time.time() - break - except Exception as _err: - t1 = time.time() - print(_err) - _n += 1 - if _n >= 5: - retry = False - status = f"process ended in error: {_err}" - return {"status": status, "upload_time": t0 - t1, "upload_size": file_object.tell()} - - -class Cacher: - """Class to look after cache storage and retrieval.""" +class AbstractCacher: + """Abstract class for looking after cache storage and retrieval. This class + defines the interface. + """ def __init__(self, context, no_put=False): self.context = context - self.temp_cache_root = "/tmp/cams-europe-air-quality-forecasts/debug/" - os.makedirs(self.temp_cache_root, exist_ok=True) - self.remote_user = "cds" self.no_put = no_put - self.lock = threading.Lock() # Fields which should be cached permanently (on the datastore). All # other fields will be cached in temporary locations. self.permanent_fields = [{"model": ["ENS"], "level": ["0"]}] - # Get a list of the compute node names - self.compute_nodes = [] - with open("/etc/hosts") as f: - for x in [ - line.split()[1:] for line in f.readlines() if not line.startswith("#") - ]: - if x and x[0].startswith("compute-"): - self.compute_nodes.append(x[0].strip()) - self.compute_nodes = sorted(self.compute_nodes) - self.compute_dns = {n: n for n in self.compute_nodes} - - # For when testing/debugging on local desktop - if os.environ.get("CDS_UNIT_TESTING"): - self.compute_nodes = ["feldenak"] - self.compute_dns = {"feldenak": "feldenak.ecmwf.int:8080"} - self.temp_cache_root = os.environ["SCRATCH"] + "/test_ads_cacher" - self.remote_user = os.environ["USER"] - if not os.path.exists(self.temp_cache_root): - os.makedirs(self.temp_cache_root) - - # Compute node we're running on - self.host = socket.gethostname().split(".")[0] - - self._remote_copier = None - self.templates = {} - - context.debug("CACHER: host is " + self.host) - context.debug("CACHER: compute nodes are " + repr(self.compute_nodes)) - def done(self): - if self._remote_copier is not None: - self._remote_copier.done() + pass def __enter__(self): return self @@ -135,9 +40,6 @@ def __exit__(self, *args): def put(self, req): """Write grib fields from a request into the cache.""" - if self.no_put: - return - # Do not cache sub-area requests when the sub-area is done by the Meteo # France backend. With the current code below they would be cached # without the area specification in the path and would get confused for @@ -151,7 +53,12 @@ def put(self, req): data = req["data"].content() assert len(data) > 0 try: + count = 0 for msg in grib_bytes_iterator(data): + count += 1 + if count == 2: + break + # Figure out the request values that correspond to this field req1field = grib2request(msg) @@ -178,66 +85,20 @@ def put(self, req): if "no_cache" in req["req"]: req1field["no_cache"] = req["req"]["no_cache"] - # Write to cache - self._put_msg(msg, req1field) + # Convert the message to pure binary data and write to cache + self._write_field(codes_get_message(msg), req1field) + except Exception: # Temporary code for debugging - from datetime import datetime - from random import randint - - unique_string = datetime.now().strftime("%Y%m%d%H%M%S.") + str( - randint(0, 2**128) - ) - with open( - f"{self.temp_cache_root}/{unique_string}" ".actually_bad.grib", "wb" - ) as f: - f.write(data) + # from random import randint + # from datetime import datetime + # unique_string = datetime.now().strftime('%Y%m%d%H%M%S.') + \ + # str(randint(0,99999)) + # with open(f'{self.temp_cache_root}/{unique_string}' + # '.actually_bad.grib', 'wb') as f: + # f.write(data) raise - def _put_msg(self, msg, req1field): - """Write one grib message into the cache.""" - host, path, _ = self.cache_file_location(req1field) - - # It's easier if the cache host is the current host - if host == self.host: - self.context.debug("CACHER: writing to local file: " + path) - dname = dirname(path) - try: - os.makedirs(dname) - except FileExistsError: - pass - # Write to a temporary file and then atomically rename in case - # another process is trying to write the same file - with NamedTemporaryFile(dir=dname, delete=False) as tmpfile: - codes_write(msg, tmpfile) - os.chmod( - tmpfile.name, - stat.S_IRUSR - | stat.S_IWUSR - | stat.S_IRGRP - | stat.S_IWGRP - | stat.S_IROTH, - ) - os.rename(tmpfile.name, path) - else: - self.context.debug("CACHER: writing to remote file: " + repr((host, path))) - destination_url = os.environ["STORAGE_API_URL"] - destination_access = os.environ["STORAGE_ADMIN"] - destination_key = os.environ["STORAGE_PASSWORD"] - destination_credentials = Credentials( - destination_url, destination_access, destination_key - ) - - destination_bucket = DESTINATION_BUCKET - destination_filepath = path - - upload( - destination_credentials, - destination_bucket, - destination_filepath, - codes_get_message(msg), - ) - def get(self, req): """Get a file from the cache or raise NotInCache if it doesn't exist.""" # This is the method called by the URL2 code to see if the data is in @@ -248,8 +109,18 @@ def get(self, req): # that accesses the cache directly, so now we don't attempt it here. raise NotInCache() - def cache_file_location(self, field): - """Return the host, path and url of the cache file for the given field.""" + def cache_file_url(self, fieldinfo): + """Return the URL of the specified field in the cache.""" + raise Exception("Needs to be overloaded by a child class") + + def _write_field(self, msg, req1field): + """Write a field to the cache.""" + raise Exception("Needs to be overloaded by a child class") + + def _cache_permanently(self, field): + """Return True if this field should be put in the permanent cache, False + otherwise. + """ # Is this a field which should be stored in a permanent location? If # the field contains an area specification then it isn't because only # full-area fields are stored permanently. The "no_cache" key is set to @@ -262,33 +133,14 @@ def cache_file_location(self, field): else: permanent = [] - if permanent: - host, path, url = self.permanent_location(field) - else: - host, path, url = self.temporary_location(field) - - return (host, path, url) - - def permanent_location(self, field): - """Return the host, path and url of the permanent cache file for the given field.""" - host = "object-store.os-api.cci2.ecmwf.int" - bucket = DESTINATION_BUCKET - path = "permanent" + "/" + self.cache_field_path(field) - url = "https://" + host + "/" + bucket + "/" + path - - return (host, path, url) + return bool(permanent) - def temporary_location(self, field): - """Return the host, path and url of the temporary cache file for the given field.""" - host = "object-store.os-api.cci2.ecmwf.int" - bucket = DESTINATION_BUCKET - path = "temporary" + "/" + self.cache_field_path(field) - url = "https://" + host + "/" + bucket + "/" + path + def _cache_file_path(self, fieldinfo): + """Return a field-specific path or the given field. Can be used by a + child class to determine server-side cache location. + """ + dir = "permanent" if self._cache_permanently(fieldinfo) else "temporary" - return (host, path, url) - - def cache_field_path(self, field): - """Return the field-specific end part of the path of the cache file for the given field.""" # Set the order we'd like the keys to appear in the filename. Area # keys will be last. order1 = ["model", "type", "variable", "level", "time", "step"] @@ -303,38 +155,212 @@ def key_order(k): return k # Get a jinja2 template for these keys - keys = tuple(sorted(list(field.keys()))) - if keys not in self.templates: + keys = tuple(sorted(list(fieldinfo.keys()))) + if keys not in self._templates: # Form a Jinja2 template string for the cache files. "_backend" not # used; organised by date; area keys put at the end. - path_template = "{{ date }}/" + "_".join( - [ - "{k}={{{{ {k} }}}}".format(k=k) - for k in sorted(keys, key=key_order) - if k not in ["date", "_backend"] - ] + path_template = ( + dir + + "/{{ date }}/" + + "_".join( + [ + "{k}={{{{ {k} }}}}".format(k=k) + for k in sorted(keys, key=key_order) + if k not in ["date", "_backend"] + ] + ) ) - self.templates[keys] = jinja2.Template(path_template) + self._templates[keys] = jinja2.Template(path_template) # Safety check to make sure no dodgy characters end up in the filename regex = r"^[\w.:-]+$" - for k, v in field.items(): + for k, v in fieldinfo.items(): assert re.match(regex, k), "Bad characters in key: " + repr(k) + assert isinstance(v, (str, int)), f"Unexpected type for {k}: {type(v)}" assert re.match(regex, str(v)), ( "Bad characters in value for " + k + ": " + repr(v) ) - path = self.templates[keys].render(field) + return self._templates[keys].render(fieldinfo) + + +class AbstractAsyncCacher(AbstractCacher): + """Augment the AbstractCacher class to add asynchronous cache puts. This + class is still abstract since it does not do the actual data copy. It + can be sub-classed in order to give asynchronous, and optionally also + parallel, functionality to synchronous caching code. + """ + + def __init__( + self, + context, + *args, + nthreads=10, + max_mem=100000000, + tmpdir="/cache/tmp", + **kwargs, + ): + """The number of fields that will be written concurrently to the cache + is determined by nthreads. Note that even if nthreads=1 it will still + be the case that the fields will be cached asynchronously, even if + not concurrently, and so a cacher.put() will not hold up the thread + in which it is executed. + Fields will be buffered in memory while waiting to be written until + the memory usage exceeds max_mem bytes, at which point fields will be + temporarily written to disk (in tmpdir) to avoid excessive memory + usage. + """ + super().__init__(context, *args, **kwargs) + self.nthreads = nthreads + self._lock1 = threading.Lock() + self._lock2 = threading.Lock() + self._qclosed = False + self._templates = {} + self._futures = [] + self._start_time = None + self._queue = MemSafeQueue(max_mem, tmpdir, logger=context) + + def _start_copy_threads(self): + """Start the threads that will do the remote copies.""" + exr = concurrent.futures.ThreadPoolExecutor(max_workers=self.nthreads) + self._start_time = time.time() + self._futures = [exr.submit(self._copier) for _ in range(self.nthreads)] + exr.shutdown(wait=False) + + def done(self): + """Must be called once all files copied.""" + if self._futures: + # Close the queue + self._queue.put((b"", None)) + qclose_time = time.time() + + # Wait for each thread to complete and check if any raised an + # exception + for future in self._futures: + exc = future.exception(timeout=60) + if exc is not None: + raise exc from exc + + # Log a summary for performance monitoring + summary = self._queue.stats.copy() + iotime = summary.pop("iotime") + now = time.time() + summary["time_secs"] = { + "elapsed": now - self._start_time, + "drain": now - qclose_time, + "io": iotime, + } + self.context.info(f"MemSafeQueue summary: {summary!r}") - return path + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.done() + + def _write_field(self, data, fieldinfo): + """Asynchronously copy the bytes data to the specified file on + specified host. + """ + # Start the copying thread if not done already + with self._lock1: + if not self._futures: + self._start_copy_threads() + + self._queue.put((data, fieldinfo)) + + def _copier(self): + """Thread to actually copy the data.""" + while True: + # This lock is required so that only 1 thread marks the queue as + # closed + with self._lock2: + if self._qclosed: + break + data, fieldinfo = self._queue.get() + self._qclosed = fieldinfo is None + if self._qclosed: + break + self._write_field_sync(data, fieldinfo) + + n = self._queue.qsize() + if n > 0: + raise Exception(f"{n} unconsumed items in queue") + + +class CacherS3(AbstractAsyncCacher): + """Class to look after cache storage to, and retrieval from, an S3 + bucket. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._host = "object-store.os-api.cci2.ecmwf.int" + self._bucket = "cci2-cams-regional-fc" + self._credentials = dict( + endpoint_url=os.environ["STORAGE_API_URL"], + aws_access_key_id=os.environ["STORAGE_ADMIN"], + aws_secret_access_key=os.environ["STORAGE_PASSWORD"], + ) + self.client = boto3.client("s3", **self._credentials) + + def _write_field_sync(self, data, fieldinfo): + """Write the data described by fieldinfo to the appropriate cache + location. + """ + local_object = io.BytesIO(data) + remote_path = self._cache_file_path(fieldinfo) + + self.context.debug( + f"CACHER: copying data to " f"{self._host}:{self._bucket}:{remote_path}" + ) + + # Uncomment this code if it can't be trusted that the bucket already + # exists + # resource = boto3.resource('s3', **self._credentials) + # bkt = resource.Bucket(self._bucket) + # if not bkt.creation_date: + # bkt = self.client.create_bucket(Bucket=self._bucket) + + attempt = 0 + t0 = time.time() + while True: + attempt += 1 + try: + if not self.no_put: + self.client.put_object( + Bucket=self._bucket, + Key=remote_path, + Body=local_object.getvalue(), + ) + status = "uploaded" + break + except Exception as exc: + self.context.error( + "Failed to upload to S3 bucket (attempt " f"#{attempt}): {exc!r}" + ) + status = f"process ended in error: {exc!r}" + if attempt >= 5: + break + t1 = time.time() + + return { + "status": status, + "upload_time": t0 - t1, + "upload_size": local_object.tell(), + } + + def cache_file_url(self, fieldinfo): + """Return the URL of the specified field in the cache.""" + return f"https://{self._host}/{self._bucket}/" + self._cache_file_path( + fieldinfo + ) - def cache_file_url(self, field): - _, _, url = self.cache_file_location(field) - return url + def delete(self, fieldinfo): + """Only used for testing at the time of writing.""" + remote_path = self._cache_file_path(fieldinfo) + self.client.delete_object(Bucket=self._bucket, Key=remote_path) -def default_to(value, default): - if value is None: - return default - else: - return value +Cacher = CacherS3 diff --git a/cads_adaptors/adaptors/cams_regional_fc/mem_safe_queue.py b/cads_adaptors/adaptors/cams_regional_fc/mem_safe_queue.py new file mode 100644 index 00000000..e97a8a56 --- /dev/null +++ b/cads_adaptors/adaptors/cams_regional_fc/mem_safe_queue.py @@ -0,0 +1,97 @@ +import logging +import os +import queue +import threading +import time +from tempfile import NamedTemporaryFile + + +class MemSafeQueue(queue.Queue): + """Subclass of Queue that holds queued items in memory until the queue size + hits a limit and then starts temporarily storing them on file instead. It + means the queue memory usage will not grow out of control. + """ + + def __init__( + self, nbytes_max, tmpdir, *args, logger=logging.getLogger(__name__), **kwargs + ): + super().__init__(*args, **kwargs) + self.nbytes_max = nbytes_max + self.nbytes = 0 + self.tmpdir = tmpdir + self.logger = logger + self._lock = threading.Lock() + + self.stats = {} + for k1 in ["queue", "mem", "file"]: + self.stats[k1] = {} + for k2 in ["current", "total", "max"]: + self.stats[k1][k2] = 0 + self.stats["iotime"] = 0.0 + + def put(self, item, **kwargs): + """Put an item in the queue.""" + data, fieldinfo = item + self.stats["queue"]["total"] += 1 + self.stats["queue"]["max"] = max(self.stats["queue"]["max"], self.qsize()) + + self.logger.debug( + f'MemSafeQueue: Queue nbytes={self.nbytes}, ' + f'in-mem size={self.stats["mem"]["current"]}, ' + f'total size={self.qsize()}' + ) + + # Keep the item in memory or write to file and replace with the path? + self._lock.acquire() + if self.nbytes + len(data) <= self.nbytes_max: + self.nbytes += len(data) + self.stats["mem"]["total"] += 1 + self.stats["mem"]["current"] += 1 + self.stats["mem"]["max"] = max( + self.stats["mem"]["current"], self.stats["mem"]["max"] + ) + self._lock.release() + else: + self.stats["file"]["total"] += 1 + self.stats["file"]["current"] += 1 + self.stats["file"]["max"] = max( + self.stats["file"]["current"], self.stats["file"]["max"] + ) + self._lock.release() + self.logger.debug(f"MemSafeQueue: storing on disk: {fieldinfo!r}") + t = time.time() + with NamedTemporaryFile(dir=self.tmpdir, delete=False) as tmp: + tmp.write(data) + self.stats["iotime"] += time.time() - t + item = (tmp.name, fieldinfo) + + super().put(item, **kwargs) + + def put_nowait(self, item, **kwargs): + self.put(item, block=False) + + def get(self, **kwargs): + xx, fieldinfo = super().get(**kwargs) + + # Received data or a temporary file path? + if isinstance(xx, bytes): + data = xx + self.nbytes -= len(data) + self.stats["mem"]["current"] -= 1 + self.logger.debug( + f'MemSafeQueue: Queue nbytes={self.nbytes}, ' + f'in-mem size={self.stats["mem"]["current"]}, ' + f'total size={self.qsize()}' + ) + else: + self.stats["file"]["current"] -= 1 + t = time.time() + with open(xx, "rb") as tmp: + data = tmp.read() + os.remove(xx) + self.stats["iotime"] += time.time() - t + + return (data, fieldinfo) + + def get_nowait(self): + return self.get(block=False) diff --git a/cads_adaptors/adaptors/cams_regional_fc/remote_copier.py b/cads_adaptors/adaptors/cams_regional_fc/remote_copier.py deleted file mode 100644 index 2e20a886..00000000 --- a/cads_adaptors/adaptors/cams_regional_fc/remote_copier.py +++ /dev/null @@ -1,212 +0,0 @@ -import concurrent.futures -import io -import logging -import queue -import re -import socket -import subprocess -import tarfile -import threading -import time -from datetime import datetime -from os import environ, remove -from random import randint - - -class RemoteCopier: - """Class to allow fast asynchronous copying of lots of small files to remote - hosts from threaded applications. Transfers the files sequentially but - re-uses the same ssh connection for each. Useful when, if an scp or rsync - were used for each file, the total time spent establishing connections - would greater than any saving made by transferring in parallel. - """ - - def __init__(self, logger=logging.getLogger(__name__)): - self.executor = concurrent.futures.ThreadPoolExecutor(1) - self.queue = queue.Queue() - self.lock = threading.Lock() - self._local_host = socket.gethostname().split(".")[0] - self._dirs = set() - self._logger = logger - rand = randint(0, 10000000) - self._unique_string = datetime.now().strftime("%s%f") + "." + str(rand) - self._ssh_opts = [ - "-o", - "BatchMode=yes", - "-o", - "StrictHostKeyChecking=no", - "-o", - "ConnectTimeout=2", - "-o", - "ControlPath=" + "~/.ssh/master-%C." + str(rand), - "-o", - "ControlMaster=auto", - "-o", - "ControlPersist=60", - "-o", - "LogLevel=ERROR", - ] - self._tf = {} - self._tf_count = {} - - # Temporary locations for tar file storage - self._temp_dirs = { - "local": "/cache/tmp", - "compute-.*": "/cache/tmp", - "datastore": "/scratch/tmp", - } - # For when testing/debugging on local desktop - if environ.get("CDS_UNIT_TESTING"): - self._temp_dirs["local"] = "/var/tmp/nal/DATA" - self._temp_dirs["feldenak.ecmwf.int:8080"] = environ["SCRATCH"] - - # Start the thread that will copy the tar files - self.future = self.executor.submit(self._copier) - self.executor.shutdown(wait=False) - - def done(self): - """Must be called once all files copied.""" - self.queue.put(None) - exc = self.future.exception() - if exc is not None: - raise exc from exc - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.done() - - def copy(self, data, user, host, remote_path): - """Asynchronously copy the bytes data to the specified file on specified host.""" - with self.lock: - # Make an object to hold tar member metadata - ti = tarfile.TarInfo(remote_path) - ti.mtime = time.time() - ti.size = len(data) - - # Write the data to the tar file - tf = self._get_tar_file(user, host) - with io.BytesIO(data) as f: - tf["fileobj"].addfile(ti, fileobj=f) - tf["size"] += len(data) - tf["nfiles"] += 1 - - # Copy to host and untar if above a certain size - if tf["size"] > 100000000: - uhost = user + "@" + host - self.queue.put(uhost) - - def _get_tar_file(self, user, host): - """Return a dict representing the tar file to write to.""" - uhost = user + "@" + host - - # Need to open a new tar file for this destination user and host? - if uhost not in self._tf: - # Get the remote temp directory for this host - for regex, tmpdir in self._temp_dirs.items(): - if re.match(regex, host): - remote_tmpdir = tmpdir - break - else: - raise Exception("Do not know tmpdir for " + host) - - # A new tar file for this uhost may be created while the previous - # one is still being copied. Use a count to prevent filename - # clashes - self._tf_count[uhost] = self._tf_count.get(uhost, 0) + 1 - name = "regional_fc.{}.{}.{}".format( - uhost, self._unique_string, self._tf_count[uhost] - ) - tf = { - "path": self._temp_dirs["local"] + f"/{name}.l.tar", - "remote_path": remote_tmpdir + f"/{name}.r.tar", - "size": 0, - "nfiles": 0, - } - tf["fileobj"] = tarfile.open(tf["path"], "w") - self._tf[uhost] = tf - else: - tf = self._tf[uhost] - - return tf - - def _copier(self): - """Thread to copy tar files to remote hosts and untar.""" - DO_COPY = False - # Copy any tar files that exceed the max size - while True: - uhost = self.queue.get() - if uhost is None: - break - if DO_COPY: - self._copy_tar(uhost) - # self._logger.debug('REMCOP Expecting no more copy calls. Remaining ' - # 'files are ' + - # ', '.join(v['path'] for v in self._tf.values())) - - # Copy any remaining tar files - for uhost in list(self._tf.keys()): - if DO_COPY: - self._copy_tar(uhost) - - def _copy_tar(self, uhost): - """Copy tar file to remote host and untar.""" - with self.lock: - tf = self._tf.pop(uhost, None) - if tf is None: - # The file has already been copied - return - tf["fileobj"].close() - self._logger.info( - f'Copying tar file containing {tf["nfiles"]}' f' members to {uhost}' - ) - self._exec( - ["scp"] + self._ssh_opts + [tf["path"], f'{uhost}:{tf["remote_path"]}'] - ) - self._ssh(uhost, ["tar", "xPf", tf["remote_path"]]) - self._ssh(uhost, ["rm", tf["remote_path"]]) - remove(tf["path"]) - - def _ssh(self, host, cmd, **kwargs): - """Execute an ssh command in a way that re-uses an existing connection, - if available. - """ - self._exec(["ssh"] + self._ssh_opts + [host] + cmd, **kwargs) - - def _exec(self, cmd, **kwargs): - # self._logger.info('Running command: ' + ' '.join(cmd)) - proc = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - encoding="utf-8", - **kwargs, - ) - if proc.stdout: - self._logger.info("stdout from " + repr(cmd) + ": " + str(proc.stdout)) - if proc.stderr: - self._logger.warning("stderr from " + repr(cmd) + ": " + str(proc.stderr)) - if proc.returncode != 0: - self._logger.error("Command failed: " + " ".join(cmd)) - - -if __name__ == "__main__": - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - t0 = time.time() - with RemoteCopier() as r: - for i in range(100): - file = "foo." + str(i) - print("Writing " + file) - with open(file, "w") as fw: - fw.write("This is file " + file + "\n") - fw.write("0" * 800000) - with open(file, "rb") as fr: - r.copy(fr.read(), "cds", "compute-0001", "/cache/downloads/" + file) - print("Waiting for copying to finish") - telapsed = time.time() - t0 - print("Copying took " + str(telapsed) + "s") diff --git a/pyproject.toml b/pyproject.toml index 93b91e9b..f508f229 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ readme = "README.md" complete = [ # Additional dependencies required by the worker image "aiohttp", + "boto3", "cacholote", "cads-mars-server@git+https://github.com/ecmwf-projects/cads-mars-server.git", "cdsapi",