Skip to content

Commit

Permalink
Merge pull request #672 from kierandidi/main
Browse files Browse the repository at this point in the history
MMSeqsGPU support
  • Loading branch information
martin-steinegger authored Jan 3, 2025
2 parents 00de5b4 + 24e79a5 commit acc0bf7
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 31 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,41 @@ In some cases using precomputed database can still be useful. For the following

If no index was created (`MMSEQS_NO_INDEX=1` was set), then `--db-load-mode` does not do anything and can be ignored.

### Generating MSAs on the GPU

Recently [GPU-accelerated search for MMSeqs](https://www.biorxiv.org/content/10.1101/2024.11.13.623350v1) was introduced and is now supported in ColabFold. To leverage it, you will need to ajdust the database setup and how you run ⁠`colabfold_search`⁠.

#### GPU database setup

To setup the GPU databases, you will need to run the ⁠`setup_databases.sh`⁠ command with ⁠`GPU=1`⁠:

```shell
GPU=1 ./setup_databases.sh /path/to/db_folder
```

This will download and setup the GPU databases in the specified folder. Note that here we do not pass ⁠`MMSEQS_NO_INDEX=1`⁠ as an argument since the indices are useful in the GPU search since we will keep them in the GPU memory.

#### GPU search with ⁠ colabfold_search ⁠

To run the MSA search on the GPU, it is recommended (although not required) to start a GPU server before running the search; this server will keep the indices in the GPU memory and will be used to accelerate the search. To start a GPU server, run:

```shell
mmseqs gpuserver /path/to/db_folder/colabfold_envdb_202108_db --max-seqs 10000 --db-load-mode 0 --prefilter-mode 1 &
PID1=$!
mmseqs gpuserver /path/to/db_folder/uniref30_2302 --max-seqs 10000 --db-load-mode 0 --prefilter-mode 1 &
PID2=$!
```

By default, this server will use all available GPUs and split the database up evenly across them. If you want to restrict the numbers of GPU used, you can set the environment variable ⁠`CUDA_VISIBLE_DEVICES`⁠ to a specific GPU or set of GPUs, e.g., ⁠`CUDA_VISIBLE_DEVICES=0,1`⁠. You can control how many sequences are loaded onto the GPU with the ⁠`--max-seqs`⁠ option. If your database is larger than the available GPU memory, the GPU server will efficiently swap the required data in and out of the GPU memory, overlapping data transfer and computation. The GPU server will be started in the background and will continue to run until you stop it explicitly via killing the process via ⁠`kill $PID1`⁠ and ⁠`kill $PID2`⁠.

You can then run ⁠ colabfold_search ⁠ with the ⁠`--gpu`⁠ and ⁠`--gpu-server`⁠ option enabled:

```shell
colabfold_search --mmseqs /path/to/bin/mmseqs --gpu 1 --gpu-server 1 input_sequences.fasta /path/to/db_folder msas
```

You can also run the search only with the ⁠`--gpu`⁠ option enabled if you do not want to start a GPU server, but the GPU server option is generally faster. Similarly to the GPU server, you can control with GPUs are used for the search via the ⁠`CUDA_VISIBLE_DEVICES` environment variable.

### Tutorials & Presentations
- ColabFold Tutorial presented at the Boston Protein Design and Modeling Club. [[video]](https://www.youtube.com/watch?v=Rfw7thgGTwI) [[slides]](https://docs.google.com/presentation/d/1mnffk23ev2QMDzGZ5w1skXEadTe54l8-Uei6ACce8eI).

Expand Down
42 changes: 34 additions & 8 deletions colabfold/mmseqs/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def mmseqs_search_monomer(
s: float = 8,
db_load_mode: int = 2,
threads: int = 32,
gpu: int = 0,
gpu_server: int = 0,
unpack: bool = True,
):
"""Run mmseqs with a local colabfold database set
Expand Down Expand Up @@ -106,11 +108,16 @@ def mmseqs_search_monomer(
dbSuffix3 = ".idx"

search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000"]
search_param += ["--prefilter-mode", str(prefilter_mode)]
if s is not None:
search_param += ["-s", "{:.1f}".format(s)]
if gpu:
search_param += ["--gpu", str(gpu), "--prefilter-mode", "1"] # gpu version only supports ungapped prefilter currently
else:
search_param += ["--k-score", "'seq:96,prof:80'"]
search_param += ["--prefilter-mode", str(prefilter_mode)]
if s is not None: # sensitivy can only be set for non-gpu version, gpu version runs at max sensitivity
search_param += ["-s", "{:.1f}".format(s)]
else:
search_param += ["--k-score", "'seq:96,prof:80'"]
if gpu_server:
search_param += ["--gpu-server", str(gpu_server)]

filter_param = ["--filter-msa", str(filter), "--filter-min-enable", "1000", "--diff", str(diff), "--qid", "0.0,0.2,0.4,0.6,0.8,1.0", "--qsc", "0", "--max-seq-id", "0.95",]
expand_param = ["--expansion-mode", "0", "-e", str(expand_eval), "--expand-filter-clusters", str(filter), "--max-seq-id", "0.95",]
Expand Down Expand Up @@ -207,6 +214,8 @@ def mmseqs_search_pair(
prefilter_mode: int = 0,
s: float = 8,
threads: int = 64,
gpu: bool = False,
gpu_server: bool = False,
db_load_mode: int = 2,
pairing_strategy: int = 0,
unpack: bool = True,
Expand Down Expand Up @@ -238,11 +247,16 @@ def mmseqs_search_pair(
# fmt: off
# @formatter:off
search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000",]
search_param += ["--prefilter-mode", str(prefilter_mode)]
if s is not None:
search_param += ["-s", "{:.1f}".format(s)]
if gpu:
search_param += ["--gpu", str(gpu), "--prefilter-mode", "1"] # gpu version only supports ungapped prefilter currently
else:
search_param += ["--k-score", "'seq:96,prof:80'"]
search_param += ["--prefilter-mode", str(prefilter_mode)]
if s is not None: # sensitivy can only be set for non-gpu version, gpu version runs at max sensitivity
search_param += ["-s", "{:.1f}".format(s)]
else:
search_param += ["--k-score", "'seq:96,prof:80'"]
if gpu_server:
search_param += ["--gpu-server", str(gpu_server)]
expand_param = ["--expansion-mode", "0", "-e", "inf", "--expand-filter-clusters", "0", "--max-seq-id", "0.95",]
run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads),] + search_param,)
run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads),] + expand_param,)
Expand Down Expand Up @@ -373,6 +387,12 @@ def main():
parser.add_argument(
"--threads", type=int, default=64, help="Number of threads to use."
)
parser.add_argument(
"--gpu", type=int, default=0, choices=[0, 1], help="Whether to use GPU (1) or not (0). Control number of GPUs with CUDA_VISIBLE_DEVICES env var."
)
parser.add_argument(
"--gpu-server", type=int, default=0, choices=[0, 1], help="Whether to use GPU server (1) or not (0)"
)
args = parser.parse_args()

logging.basicConfig(level = logging.INFO)
Expand Down Expand Up @@ -446,6 +466,8 @@ def main():
s=args.s,
db_load_mode=args.db_load_mode,
threads=args.threads,
gpu=args.gpu,
gpu_server=args.gpu_server,
unpack=args.unpack,
)
if is_complex is True:
Expand All @@ -458,6 +480,8 @@ def main():
s=args.s,
db_load_mode=args.db_load_mode,
threads=args.threads,
gpu=args.gpu,
gpu_server=args.gpu_server,
pairing_strategy=args.pairing_strategy,
pair_env=False,
unpack=args.unpack,
Expand All @@ -473,6 +497,8 @@ def main():
s=args.s,
db_load_mode=args.db_load_mode,
threads=args.threads,
gpu=args.gpu,
gpu_server=args.gpu_server,
pairing_strategy=args.pairing_strategy,
pair_env=True,
unpack=args.unpack,
Expand Down
75 changes: 52 additions & 23 deletions setup_databases.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ PDB_AWS_SNAPSHOT="20240101"

UNIREF30DB="uniref30_2302"
MMSEQS_NO_INDEX=${MMSEQS_NO_INDEX:-}
DOWNLOADS_ONLY=${DOWNLOADS_ONLY:-}
GPU=${GPU:-}
mkdir -p -- "${WORKDIR}"
cd "${WORKDIR}"

hasCommand () {
Expand Down Expand Up @@ -56,15 +59,51 @@ downloadFile() {
fail "Could not download $URL to $OUTPUT"
}

if [ ! -f DOWNLOADS_READY ]; then
downloadFile "https://wwwuser.gwdg.de/~compbiol/colabfold/${UNIREF30DB}.tar.gz" "${UNIREF30DB}.tar.gz"
downloadFile "https://wwwuser.gwdg.de/~compbiol/colabfold/colabfold_envdb_202108.tar.gz" "colabfold_envdb_202108.tar.gz"
downloadFile "https://wwwuser.gwdg.de/~compbiol/colabfold/pdb100_230517.fasta.gz" "pdb100_230517.fasta.gz"
downloadFile "https://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/pdb100_foldseek_230517.tar.gz" "pdb100_foldseek_230517.tar.gz"
touch DOWNLOADS_READY
fi

if [ ! -f PDB_MMCIF_READY ]; then
mkdir -p pdb/divided
mkdir -p pdb/obsolete
if [ -n "${PDB_AWS_DOWNLOAD}" ]; then
aws s3 cp --no-sign-request --recursive s3://pdbsnapshots/${PDB_AWS_SNAPSHOT}/pub/pdb/data/structures/divided/mmCIF/ pdb/divided/
aws s3 cp --no-sign-request --recursive s3://pdbsnapshots/${PDB_AWS_SNAPSHOT}/pub/pdb/data/structures/obsolete/mmCIF/ pdb/obsolete/
fi
rsync -rlpt -v -z --delete --port=${PDB_PORT} ${PDB_SERVER}/data/structures/divided/mmCIF/ pdb/divided
rsync -rlpt -v -z --delete --port=${PDB_PORT} ${PDB_SERVER}/data/structures/obsolete/mmCIF/ pdb/obsolete
touch PDB_MMCIF_READY
fi

if [ -n "$DOWNLOADS_ONLY" ]; then
exit 0
fi


# Make MMseqs2 merge the databases to avoid spamming the folder with files
export MMSEQS_FORCE_MERGE=1

GPU_PAR=""
GPU_INDEX_PAR=""
if [ -n "${GPU}" ]; then
GPU_PAR="--gpu 1"
GPU_INDEX_PAR=" --split 1 --index-subset 2"

if ! mmseqs --help | grep -q 'gpuserver'; then
echo "The installed MMseqs2 has no GPU support, update to at least release 16"
exit 1
fi
fi

if [ ! -f UNIREF30_READY ]; then
downloadFile "https://wwwuser.gwdg.de/~compbiol/colabfold/${UNIREF30DB}.tar.gz" "${UNIREF30DB}.tar.gz"
tar xzvf "${UNIREF30DB}.tar.gz"
mmseqs tsv2exprofiledb "${UNIREF30DB}" "${UNIREF30DB}_db"
mmseqs tsv2exprofiledb "${UNIREF30DB}" "${UNIREF30DB}_db" ${GPU_PAR}
if [ -z "$MMSEQS_NO_INDEX" ]; then
mmseqs createindex "${UNIREF30DB}_db" tmp1 --remove-tmp-files 1
mmseqs createindex "${UNIREF30DB}_db" tmp1 --remove-tmp-files 1 ${GPU_INDEX_PAR}
fi
if [ -e ${UNIREF30DB}_db_mapping ]; then
ln -sf ${UNIREF30DB}_db_mapping ${UNIREF30DB}_db.idx_mapping
Expand All @@ -76,40 +115,30 @@ if [ ! -f UNIREF30_READY ]; then
fi

if [ ! -f COLABDB_READY ]; then
downloadFile "https://wwwuser.gwdg.de/~compbiol/colabfold/colabfold_envdb_202108.tar.gz" "colabfold_envdb_202108.tar.gz"
tar xzvf "colabfold_envdb_202108.tar.gz"
mmseqs tsv2exprofiledb "colabfold_envdb_202108" "colabfold_envdb_202108_db"
mmseqs tsv2exprofiledb "colabfold_envdb_202108" "colabfold_envdb_202108_db" ${GPU_PAR}
# TODO: split memory value for createindex?
if [ -z "$MMSEQS_NO_INDEX" ]; then
mmseqs createindex "colabfold_envdb_202108_db" tmp2 --remove-tmp-files 1
mmseqs createindex "colabfold_envdb_202108_db" tmp2 --remove-tmp-files 1 ${GPU_INDEX_PAR}
fi
touch COLABDB_READY
fi

if [ ! -f PDB_READY ]; then
downloadFile "https://wwwuser.gwdg.de/~compbiol/colabfold/pdb100_230517.fasta.gz" "pdb100_230517.fasta.gz"
mmseqs createdb pdb100_230517.fasta.gz pdb100_230517
if [ -n "${GPU}" ]; then
mmseqs createdb pdb100_230517.fasta.gz pdb100_230517_tmp
mmseqs makepaddedseqdb pdb100_230517_tmp pdb100_230517
mmseqs rmdb pdb100_230517_tmp
else
mmseqs createdb pdb100_230517.fasta.gz pdb100_230517
fi
if [ -z "$MMSEQS_NO_INDEX" ]; then
mmseqs createindex pdb100_230517 tmp3 --remove-tmp-files 1
mmseqs createindex pdb100_230517 tmp3 --remove-tmp-files 1 ${GPU_INDEX_PAR}
fi
touch PDB_READY
fi


if [ ! -f PDB100_READY ]; then
downloadFile "https://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/pdb100_foldseek_230517.tar.gz" "pdb100_foldseek_230517.tar.gz"
tar xzvf pdb100_foldseek_230517.tar.gz pdb100_a3m.ffdata pdb100_a3m.ffindex
touch PDB100_READY
fi

if [ ! -f PDB_MMCIF_READY ]; then
mkdir -p pdb/divided
mkdir -p pdb/obsolete
if [ -n "${PDB_AWS_DOWNLOAD}" ]; then
aws s3 cp --no-sign-request --recursive s3://pdbsnapshots/${PDB_AWS_SNAPSHOT}/pub/pdb/data/structures/divided/mmCIF/ pdb/divided/
aws s3 cp --no-sign-request --recursive s3://pdbsnapshots/${PDB_AWS_SNAPSHOT}/pub/pdb/data/structures/obsolete/mmCIF/ pdb/obsolete/
fi
rsync -rlpt -v -z --delete --port=${PDB_PORT} ${PDB_SERVER}/data/structures/divided/mmCIF/ pdb/divided
rsync -rlpt -v -z --delete --port=${PDB_PORT} ${PDB_SERVER}/data/structures/obsolete/mmCIF/ pdb/obsolete
touch PDB_MMCIF_READY
fi

0 comments on commit acc0bf7

Please sign in to comment.