-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgpt_model.py
214 lines (181 loc) · 6.7 KB
/
gpt_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from einops import rearrange
from torch.utils.checkpoint import checkpoint
from new_checkpoint import checkpoint as new_checkpoint
from checkpoint import save_on_cpu, get_selective_offloading_checkpoint_modes2
import json
from collections import namedtuple
from nvme_ds.partitionrd_activation_swapper import AsyncPartitionedActivationSwapper
# from flash_attn import flash_attn_qkvpacked_func
from flash_attn.flash_attention import FlashAttention
act_stream = torch.cuda.Stream()
chp_id = [0]
chp_list = []
act_swapper = None
is_swap_and_recompute = 0
def set_training(args):
global act_stream, chp_id, chp_list, act_swapper, is_swap_and_recompute
for i in range(2 * args.num_layers):
packed = torch.empty(
(args.max_seq_len, args.batch_size, args.hidden_dim),
dtype=torch.float16,
pin_memory=True)
# packed = torch.ones(
# 1,
# dtype=torch.float16,
# pin_memory=True)
chp_list.append(packed)
def json_object_hook(d):
return namedtuple('X', d.keys())(*d.values())
with open(args.sb_config) as f:
ds_config = json.load(f, object_hook=json_object_hook)
act_swapper = AsyncPartitionedActivationSwapper(ds_config, torch.float16)
is_swap_and_recompute = args.is_swap_and_recompute
@dataclass
class GPT2Config:
dim: int = 5120
hidden_dim: int = 5120
num_heads: int = 40
num_layers: int = 40
dim_head: int = 128
max_seq_len: int = 1024
attn_pdrop: float = 0.1
dropout: float = 0.1
vocab_size: int = 50257
layer_norm_epsilon: float = 1e-5
"""
{
"activation_function": "gelu_new",
"architectures": [
"GPT2LMHeadModel"
],
"attn_pdrop": 0.1,
"bos_token_id": 50256,
"embd_pdrop": 0.1,
"eos_token_id": 50256,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "gpt2",
"n_ctx": 1024,
"n_embd": 768,
"n_head": 12,
"n_layer": 12,
"n_positions": 1024,
"resid_pdrop": 0.1,
"summary_activation": null,
"summary_first_dropout": 0.1,
"summary_proj_to_labels": true,
"summary_type": "cls_index",
"summary_use_proj": true,
"task_specific_params": {
"text-generation": {
"do_sample": true,
"max_length": 50
}
},
"vocab_size": 50257
}
"""
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (
0.5 * x * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi)
* (x + 0.044715 * torch.pow(x, 3.0))
))
)
class GPT2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.hidden_dim, config.hidden_dim * 4)
self.c_proj = nn.Linear(config.hidden_dim * 4, config.hidden_dim)
# self.act = NewGELUActivation()
self.act = nn.GELU()
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.act(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class GPT2Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.c_attn = nn.Linear(config.hidden_dim, 3 * config.hidden_dim)
self.c_proj = nn.Linear(config.hidden_dim, config.hidden_dim)
self.attention = FlashAttention()
self.resid_dropout = nn.Dropout(config.dropout)
def forward(self, x):
qkv = self.c_attn(x)
QKV = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.config.num_heads)
# flash_attn_out = flash_attn_qkvpacked_func(QKV)
flash_attn_out, _ = self.attention(QKV)
out = rearrange(flash_attn_out, 'b n h d -> b n (h d)')
attn_out = self.c_proj(out)
attn_out = self.resid_dropout(attn_out)
return attn_out
class GPT2Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config)
self.ln_2 = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(config)
def forward(self, x, swap_list, act_pack):
def func1(x):
residual = x
x = self.ln_1(x)
x = self.attn(x) + residual
return x
# x = new_checkpoint(func1, x, context_fn=get_selective_offloading_checkpoint_modes2)
# x = checkpoint(func1, x)
with save_on_cpu(pin_memory=True, act_stream=act_stream, chp_id = chp_id, chp_list = chp_list, act_swapper=act_swapper):
if is_swap_and_recompute:
x = new_checkpoint(func1, x, context_fn=get_selective_offloading_checkpoint_modes2, swap_list=swap_list, act_pack=act_pack)
else:
x = checkpoint(func1, x)
def func2(x):
residual = x
x = self.ln_2(x)
x = self.mlp(x) + residual
return x
# x = new_checkpoint(func2, x, context_fn=get_selective_offloading_checkpoint_modes1)
# x = checkpoint(func2, x)
with save_on_cpu(pin_memory=True, act_stream=act_stream, chp_id = chp_id, chp_list = chp_list, act_swapper=act_swapper):
if is_swap_and_recompute:
x = new_checkpoint(func2, x, context_fn=get_selective_offloading_checkpoint_modes2, swap_list=swap_list, act_pack=act_pack)
else:
x = checkpoint(func2, x)
return x
class GPT2Model(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.token_emb = nn.Embedding(config.vocab_size, config.hidden_dim)
self.pos_emb = nn.Embedding(config.max_seq_len, config.hidden_dim)
self.drop = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([])
for _ in range(config.num_layers):
self.layers.append(GPT2Block(config))
self.ln_f = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_epsilon)
# self.fc = nn.Linear(config.hidden_dim, config.vocab_size)
def forward(self, x, swap_list, act_pack):
x = self.token_emb(x)
x = x + self.pos_emb(torch.arange(x.shape[1], device=x.device))
x = self.drop(x)
for block in self.layers:
x = block(x, swap_list, act_pack)
x = self.ln_f(x)
# def func_fc(x):
# x = self.fc(x)
# return x
# x = checkpoint(func_fc, x)
return x