diff --git a/src/refiners/conversion/utils.py b/src/refiners/conversion/utils.py index 86352e37..dcc581e0 100644 --- a/src/refiners/conversion/utils.py +++ b/src/refiners/conversion/utils.py @@ -26,9 +26,10 @@ def download_file_url(url: str, destination: Path) -> None: logging.debug(f"Downloading {url} to {destination}") # get the size of the file - response = requests.get(url, stream=True) + response = requests.get(url=url, stream=True) response.raise_for_status() total = int(response.headers.get("content-length", 0)) + chunk_size = 1024 * 1000 # 1 MiB # create a progress bar bar = tqdm( @@ -45,7 +46,7 @@ def download_file_url(url: str, destination: Path) -> None: with destination.open("wb") as f: with requests.get(url, stream=True) as r: r.raise_for_status() - for chunk in r.iter_content(chunk_size=1024 * 1000): + for chunk in r.iter_content(chunk_size=chunk_size): size = f.write(chunk) bar.update(size) bar.close() @@ -63,8 +64,8 @@ def __init__( self, repo_id: str, filename: str, - expected_sha256: str, revision: str = "main", + expected_sha256: str | None = None, download_url: str | None = None, ) -> None: """Initialize the HubPath. @@ -73,14 +74,14 @@ def __init__( repo_id: The repository identifier on the hub. filename: The filename of the file in the repository. revision: The revision of the file on the hf hub. - expected_sha256: The sha256 hash of the file. + expected_sha256: The sha256 hash of the file, to optionally (but strongly recommended) check against the local or remote hash. download_url: The url to download the file from, if not from the huggingface hub. """ self.repo_id = repo_id self.filename = filename self.revision = revision - self.expected_sha256 = expected_sha256.lower() - self.override_download_url = download_url + self.expected_sha256 = expected_sha256.lower() if expected_sha256 is not None else None + self.download_url = download_url @staticmethod def hub_location(): @@ -90,16 +91,22 @@ def hub_location(): @property def hf_url(self) -> str: """Return the url to the file on the hf hub.""" - assert self.override_download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub" + assert self.download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub" return hf_hub_url( repo_id=self.repo_id, filename=self.filename, revision=self.revision, ) + @property + def hf_metadata(self) -> HfFileMetadata: + """Return the metadata of the file on the hf hub.""" + return get_hf_file_metadata(self.hf_url) + @property def hf_cache_path(self) -> Path: """Download the file from the hf hub and return its path in the local hf cache.""" + assert self.download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub" return Path( hf_hub_download( repo_id=self.repo_id, @@ -108,11 +115,6 @@ def hf_cache_path(self) -> Path: ), ) - @property - def hf_metadata(self) -> HfFileMetadata: - """Return the metadata of the file on the hf hub.""" - return get_hf_file_metadata(self.hf_url) - @property def hf_sha256_hash(self) -> str: """Return the sha256 hash of the file on the hf hub.""" @@ -127,7 +129,7 @@ def local_path(self) -> Path: return self.hub_location() / self.repo_id / self.filename @property - def local_hash(self) -> str: + def local_sha256_hash(self) -> str: """Return the sha256 hash of the file in the local hub.""" assert self.local_path.is_file(), f"{self.local_path} does not exist" # TODO: use https://docs.python.org/3/library/hashlib.html#hashlib.file_digest when support python >= 3.11 @@ -135,16 +137,24 @@ def local_hash(self) -> str: def check_local_hash(self) -> bool: """Check if the sha256 hash of the file in the local hub is correct.""" - if self.expected_sha256 != self.local_hash: - logging.warning(f"{self.local_path} local sha256 mismatch, {self.local_hash} != {self.expected_sha256}") + if self.expected_sha256 is None: + logging.warning(f"{self.repo_id}/{self.filename} has no expected sha256 hash, skipping check") + return True + elif self.expected_sha256 != self.local_sha256_hash: + logging.warning( + f"{self.local_path} local sha256 mismatch, {self.local_sha256_hash} != {self.expected_sha256}" + ) return False else: - logging.debug(f"{self.local_path} local sha256 is correct ({self.local_hash})") + logging.debug(f"{self.local_path} local sha256 is correct ({self.local_sha256_hash})") return True def check_remote_hash(self) -> bool: """Check if the sha256 hash of the file on the hf hub is correct.""" - if self.expected_sha256 != self.hf_sha256_hash: + if self.expected_sha256 is None: + logging.warning(f"{self.repo_id}/{self.filename} has no expected sha256 hash, skipping check") + return True + elif self.expected_sha256 != self.hf_sha256_hash: logging.warning( f"{self.local_path} remote sha256 mismatch, {self.hf_sha256_hash} != {self.expected_sha256}" ) @@ -154,14 +164,14 @@ def check_remote_hash(self) -> bool: return True def download(self) -> None: - """Download the file from the hf hub or from the override download url.""" - self.local_path.parent.mkdir(parents=True, exist_ok=True) + """Download the file from the hf hub or from the override download url, and save it to the local hub.""" if self.local_path.is_file(): logging.warning(f"{self.local_path} already exists") - elif self.override_download_url is not None: - download_file_url(url=self.override_download_url, destination=self.local_path) + elif self.download_url is not None: + self.local_path.parent.mkdir(parents=True, exist_ok=True) + download_file_url(url=self.download_url, destination=self.local_path) else: - # TODO: pas assez de message de log quand local_path existe pas et que ça vient du hf cache + self.local_path.parent.mkdir(parents=True, exist_ok=True) self.local_path.symlink_to(self.hf_cache_path) assert self.check_local_hash() diff --git a/tests/weight_paths.py b/tests/weight_paths.py index 88cd12e8..3f0bd17f 100644 --- a/tests/weight_paths.py +++ b/tests/weight_paths.py @@ -31,7 +31,7 @@ def get_path(hub: Hub, use_local_weights: bool) -> Path: if use_local_weights: path = hub.local_path else: - if hub.override_download_url is not None: + if hub.download_url is not None: pytest.skip(f"{hub.filename} is not available on Hugging Face Hub") try: