forked from hpcaitech/PaLM-colossalai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
103 lines (74 loc) · 2.73 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import torch.distributed as dist
import torch.nn.functional as F
from colossalai.constants import NUM_PARTITIONS
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from torch import nn
def calc_local_model_size(model: torch.nn.Module):
numel_per_device = 0
for p in model.parameters():
numel_per_device += p.numel()
return numel_per_device
def calc_mem(x):
res = 0
if isinstance(x, dict):
for v in x.values():
res += calc_mem(v)
elif isinstance(x, (tuple, list)):
for v in x:
res += calc_mem(v)
elif isinstance(x, torch.Tensor):
res += x.element_size() * x.numel()
return res
"""Autoregressive wrapper adapted from
https://github.com/lucidrains/PaLM-pytorch/blob/main/palm_pytorch/autoregressive_wrapper.py
"""
# helper function
def exists(val):
return val is not None
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# top k filtering
def top_k(logits, thres=0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float("-inf"))
probs.scatter_(1, ind, val)
return probs
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, max_seq_len=2048, pad_value=0):
super().__init__()
self.max_seq_len = max_seq_len
self.pad_value = pad_value
self.net = net
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token=None, temperature=1.0, filter_thres=0.9, **kwargs):
b, t, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
logits = self.net(out, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
is_eos_token = out == eos_token
if is_eos_token.any(dim=-1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_token, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
out = out.masked_fill(mask, self.pad_value)
break
out = out[:, t:]
return out
def forward(self, x, **kwargs):
return self.net(x, **kwargs)