Skip to content

Commit 53143f5

Browse files
committed
init commit
1 parent d5f3875 commit 53143f5

File tree

7 files changed

+46
-20
lines changed

7 files changed

+46
-20
lines changed

README.md

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo
1313
1. Excellent kernel performance on GPU platform
1414
2. Supporting Dynamic Axial Parallelism(DAP)
1515
* Break the memory limit of single GPU and reduce the overall training time
16-
* Distributed inference can significantly speed up inference and make extremely long sequence inference possible
16+
* DAP can significantly speed up inference and make ultra-long sequence inference possible
1717
3. Ease of use
18-
* Replace a few lines and you can use FastFold in your project
18+
* Huge performance gains with a few lines changes
1919
* You don't need to care about how the parallel part is implemented
2020

2121
## Installation
@@ -38,6 +38,24 @@ cd FastFold
3838
python setup.py install --cuda_ext
3939
```
4040

41+
## Usage
42+
43+
You can use `Evoformer` as `nn.Module` in your project after `from fastfold.model import Evoformer`:
44+
45+
```python
46+
from fastfold.model import Evoformer
47+
evoformer_layer = Evoformer()
48+
```
49+
50+
If you want to use Dynamic Axial Parallelism, add a line of initialize with `fastfold.distributed.init_dap` after `torch.distributed.init_process_group`.
51+
52+
```python
53+
from fastfold.distributed import init_dap
54+
55+
torch.distributed.init_process_group(backend='nccl', init_method='env://')
56+
init_dap(args.dap_size)
57+
```
58+
4159
## Performance Benchmark
4260

4361
We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings.
@@ -47,6 +65,13 @@ cd ./benchmark
4765
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256
4866
```
4967

68+
Benchmark Dynamic Axial Parallelism with 2 GPUs:
69+
70+
```shell
71+
cd ./benchmark
72+
torchrun --nproc_per_node=2 perf.py --msa-length 128 --res-length 256 --dap-size 2
73+
```
74+
5075
If you want to benchmark with [OpenFold](https://github.com/aqlaboratory/openfold), you need to install OpenFold first and benchmark with option `--openfold`:
5176

5277
```shell

benchmark/perf.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,34 @@
44
import torch
55
import torch.nn as nn
66

7-
from fastfold.distributed import init_shadowcore
7+
from fastfold.distributed import init_dap
88
from fastfold.model import Evoformer
99

1010

1111
def main():
1212

13-
parser = argparse.ArgumentParser(description='MSA Attention Standalone Perf Benchmark')
14-
parser.add_argument("--dap-size", default=1, type=int)
13+
parser = argparse.ArgumentParser(description='Evoformer Standalone Perf Benchmark')
14+
parser.add_argument("--dap-size", default=1, type=int, help='batch size')
1515
parser.add_argument('--batch-size', default=1, type=int, help='batch size')
16-
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of Input')
16+
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of MSA')
1717
parser.add_argument('--res-length',
1818
default=256,
1919
type=int,
20-
help='Start Range of Number of Sequences')
20+
help='Sequence Length of Residues')
2121
parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute')
2222
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
2323
parser.add_argument('--layers',
2424
default=12,
2525
type=int,
26-
help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
26+
help='Evoformer Layers to Execute')
2727
parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension')
2828
parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension')
2929
parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads')
3030
parser.add_argument('--openfold',
3131
action='store_true',
32-
help='torch.nn.MultitheadAttention Version.')
32+
help='Benchmark with Evoformer Implementation from OpenFold.')
3333
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
34-
parser.add_argument('--prof', action='store_true', help='Only execute Fwd Pass.')
34+
parser.add_argument('--prof', action='store_true', help='run with profiler.')
3535

3636
args = parser.parse_args()
3737

@@ -48,10 +48,10 @@ def main():
4848
print(
4949
'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
5050
% (args.global_rank, args.world_size))
51-
init_shadowcore(args.tensor_model_parallel_size)
51+
init_dap(args.dap_size)
5252

5353
precision = torch.bfloat16
54-
if args.tensor_model_parallel_size > 1:
54+
if args.dap_size > 1:
5555
# (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch
5656
precision = torch.float16
5757

@@ -111,13 +111,13 @@ def forward(self, node, pair, node_mask, pair_mask):
111111
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))
112112

113113
inputs_node = torch.randn(args.batch_size,
114-
args.msa_length // args.tensor_model_parallel_size,
114+
args.msa_length // args.dap_size,
115115
args.res_length,
116116
args.cm,
117117
dtype=precision,
118118
device=torch.device("cuda")).requires_grad_(True)
119119
inputs_pair = torch.randn(args.batch_size,
120-
args.res_length // args.tensor_model_parallel_size,
120+
args.res_length // args.dap_size,
121121
args.res_length,
122122
args.cz,
123123
dtype=precision,

fastfold/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VERSION = "0.1.0-beta"

fastfold/distributed/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from .core import (init_shadowcore, shadowcore_is_initialized, get_tensor_model_parallel_group,
1+
from .core import (init_dap, dap_is_initialized, get_tensor_model_parallel_group,
22
get_data_parallel_group, get_tensor_model_parallel_world_size,
33
get_tensor_model_parallel_rank, get_data_parallel_world_size,
44
get_data_parallel_rank, get_tensor_model_parallel_src_rank)
55
from .comm import (_reduce, _split, _gather, copy, scatter, reduce, gather, col_to_row, row_to_col)
66

77
__all__ = [
8-
'init_shadowcore', 'shadowcore_is_initialized', 'get_tensor_model_parallel_group',
8+
'init_dap', 'dap_is_initialized', 'get_tensor_model_parallel_group',
99
'get_data_parallel_group', 'get_tensor_model_parallel_world_size',
1010
'get_tensor_model_parallel_rank', 'get_data_parallel_world_size', 'get_data_parallel_rank',
1111
'get_tensor_model_parallel_src_rank', '_reduce', '_split', '_gather', 'copy', 'scatter',

fastfold/distributed/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def ensure_divisibility(numerator, denominator):
1515
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
1616

1717

18-
def init_shadowcore(tensor_model_parallel_size_=1):
18+
def init_dap(tensor_model_parallel_size_=1):
1919

2020
assert dist.is_initialized()
2121

@@ -51,7 +51,7 @@ def init_shadowcore(tensor_model_parallel_size_=1):
5151
print('> initialize data parallel with size {}'.format(data_parallel_size_))
5252

5353

54-
def shadowcore_is_initialized():
54+
def dap_is_initialized():
5555
"""Check if model and data parallel groups are initialized."""
5656
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
5757
_DATA_PARALLEL_GROUP is None:

fastfold/model/evoformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
class Evoformer(nn.Module):
88

9-
def __init__(self, d_node, d_pair):
9+
def __init__(self, d_node=256, d_pair=128):
1010
super(Evoformer, self).__init__()
1111

1212
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def cuda_ext_helper(name, sources, extra_cuda_flags):
141141

142142
setup(
143143
name='fastfold',
144-
version='0.0.1-beta',
144+
version='0.1.0-beta',
145145
packages=find_packages(exclude=(
146146
'assets',
147147
'benchmark',

0 commit comments

Comments
 (0)