From d56ccf782a48f09c408517d98acc50451d84b26a Mon Sep 17 00:00:00 2001 From: GPH Date: Tue, 24 Oct 2023 09:55:00 +0800 Subject: [PATCH] Update notebook_helpers.py Fix the download_models function to download models in the correct directories and names. --- notebook_helpers.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/notebook_helpers.py b/notebook_helpers.py index 5d0ebd7e..1d1cf5f8 100644 --- a/notebook_helpers.py +++ b/notebook_helpers.py @@ -23,14 +23,14 @@ def download_models(mode): url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' - path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml' - path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt' + path_conf, name_conf = os.path.split('logs/diffusion/superresolution_bsr/configs/project.yaml') + path_ckpt, name_ckpt = os.path.split('logs/diffusion/superresolution_bsr/checkpoints/last.ckpt') - download_url(url_conf, path_conf) - download_url(url_ckpt, path_ckpt) + download_url(url_conf, path_conf, filename=name_conf) + download_url(url_ckpt, path_ckpt, filename=name_ckpt) - path_conf = path_conf + '/?dl=1' # fix it - path_ckpt = path_ckpt + '/?dl=1' # fix it + path_conf = os.path.join(path_conf, name_conf) + path_ckpt = os.path.join(path_ckpt, name_ckpt) return path_conf, path_ckpt else: @@ -267,4 +267,4 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e log["sample"] = x_sample log["time"] = t1 - t0 - return log \ No newline at end of file + return log