Skip to content

Commit

Permalink
[Fix] Fetching the Git-LFS tokenizer files (#1954)
Browse files Browse the repository at this point in the history
Prior to this PR, when running commands like
```shell
python3 -m mlc_chat chat HF://mlc-ai/gemma-7b-it-q4f16_2-MLC
```
only the binary weight files are downloaded, among all the Git LFS
files.

For models like Gemma whose tokenizer is large and also in Git LFS
file, the tokenizer files are not effectively downloaded automatically.
For example, the cloned Gemma `tokenizer.json` file has content
```
version https://git-lfs.github.com/spec/v1
oid sha256:05e97791a5e007260de1db7e1692e53150e08cea481e2bf25435553380c147ee
size 17477929
```
and this content is never realized to the actual tokenizer. This will
lead to the issue of #1913.

This PR fixes the issue by pulling all the Git LFS files that are not
binary files.
  • Loading branch information
MasterJH5574 authored Mar 14, 2024
1 parent 8d192ef commit c0b2ccd
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/mlc_llm/support/download.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common utilities for downloading files from HuggingFace or other URLs online."""

import concurrent.futures as cf
import hashlib
import json
Expand All @@ -7,7 +8,7 @@
import subprocess
import tempfile
from pathlib import Path
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import requests # pylint: disable=import-error

Expand Down Expand Up @@ -56,7 +57,7 @@ def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None:
) from error


def git_lfs_pull(repo_dir: Path) -> None:
def git_lfs_pull(repo_dir: Path, ignore_extensions: Optional[List[str]] = None) -> None:
"""Pull files with Git LFS."""
filenames = (
subprocess.check_output(
Expand All @@ -66,6 +67,12 @@ def git_lfs_pull(repo_dir: Path) -> None:
.decode("utf-8")
.splitlines()
)
if ignore_extensions is not None:
filenames = [
filename
for filename in filenames
if not any(filename.endswith(extension) for extension in ignore_extensions)
]
logger.info("[Git LFS] Downloading %d files with Git LFS: %s", len(filenames), filenames)
with tqdm.redirect():
for file in tqdm.tqdm(filenames):
Expand Down Expand Up @@ -127,6 +134,7 @@ def download_mlc_weights( # pylint: disable=too-many-locals
tmp_dir = Path(tmp_dir_prefix) / "tmp"
git_url = git_url_template.format(user=user, repo=repo)
git_clone(git_url, tmp_dir, ignore_lfs=True)
git_lfs_pull(tmp_dir, ignore_extensions=[".bin"])
shutil.rmtree(tmp_dir / ".git", ignore_errors=True)
with (tmp_dir / "ndarray-cache.json").open(encoding="utf-8") as in_file:
param_metadata = json.load(in_file)["records"]
Expand Down

0 comments on commit c0b2ccd

Please sign in to comment.