Skip to content

Commit fd2f698

Browse files
committed
update convert script for fp16
1 parent 661a80b commit fd2f698

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

scripts/convert_fp32_to_fp16.bin

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
This script provides an exmaple to wrap TencentPretrain for generation.
3+
Given the beginning of a text, language model generates the rest.
4+
"""
5+
import sys
6+
import os
7+
import argparse
8+
import torch
9+
import torch.nn.functional as F
10+
11+
tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
12+
sys.path.append(tencentpretrain_dir)
13+
14+
from tencentpretrain.embeddings import *
15+
from tencentpretrain.encoders import *
16+
from tencentpretrain.targets import *
17+
from tencentpretrain.utils.constants import *
18+
from tencentpretrain.utils import *
19+
from tencentpretrain.utils.config import load_hyperparam
20+
from tencentpretrain.model_loader import *
21+
from tencentpretrain.opts import model_opts, tokenizer_opts
22+
23+
24+
class GenerateLm(torch.nn.Module):
25+
def __init__(self, args):
26+
super(GenerateLm, self).__init__()
27+
self.embedding = Embedding(args)
28+
for embedding_name in args.embedding:
29+
tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
30+
self.embedding.update(tmp_emb, embedding_name)
31+
self.encoder = str2encoder[args.encoder](args)
32+
self.target = Target()
33+
self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm")
34+
35+
def forward(self, src, seg):
36+
emb = self.embedding(src, seg)
37+
output = self.encoder(emb, seg)
38+
output = self.target.lm.output_layer(output)
39+
return output
40+
41+
42+
if __name__ == '__main__':
43+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
44+
45+
model_opts(parser)
46+
47+
parser.add_argument("--load_model_path", default=None, type=str,
48+
help="Path of the input model.")
49+
parser.add_argument("--config_path", type=str, required=True,
50+
help="Path of the config file.")
51+
parser.add_argument("--output_model_path", type=str)
52+
53+
tokenizer_opts(parser)
54+
55+
args = parser.parse_args()
56+
57+
args.target = "lm"
58+
args.batch_size = 1
59+
60+
args = load_hyperparam(args)
61+
62+
args.tokenizer = str2tokenizer[args.tokenizer](args)
63+
64+
model = GenerateLm(args)
65+
model = load_model(model, args.load_model_path)
66+
67+
model.half()
68+
69+
torch.save(model.state_dict(), args.output_model_path)
70+

0 commit comments

Comments
 (0)