-
Notifications
You must be signed in to change notification settings - Fork 551
/
Copy pathgenerate.py
208 lines (177 loc) · 7.83 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import sys
import time
from typing import Any, Dict, List
import torch
from omegaconf import DictConfig
from torch import nn
from torchtune import config, generation, training, utils
from torchtune.data import Message, Role
from torchtune.training import FullModelTorchTuneCheckpointer
logger = utils.get_logger("DEBUG")
class InferenceRecipe:
"""
Recipe for generating tokens from a dense Transformer-based LLM.
Currently this recipe supports single-GPU generation only. Speculative
decoding is not supported.
For more details on how to use this recipe for generation, please see our
tutorial: https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#generation
For using this recipe with a quantized model, please the following section of
the above tutorial:
https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#speeding-up-generation-using-quantization
"""
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._quantizer = config.instantiate(cfg.quantizer)
self._quantization_mode = training.get_quantizer_mode(self._quantizer)
training.set_seed(
seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None)
)
def setup(self, cfg: DictConfig) -> None:
checkpointer = config.instantiate(cfg.checkpointer)
if self._quantization_mode is not None:
if not isinstance(checkpointer, FullModelTorchTuneCheckpointer):
raise ValueError(
"Quantization is only supported for models quantized and saved with the "
"FullModelTorchTuneCheckpointer - please ensure you have quantized your "
"model and are using the quantized weights!"
)
if "qat" in self._quantization_mode:
raise ValueError(
"You have specified a quantizer with 'QAT' - "
"QAT quantizers should only be used during quantization aware training "
"and when quantizing models. Please use the corresponding post-training "
"quantizer e.g. Int8DynActInt4WeightQuantizer for Int8DynActInt4WeightQATQuantizer."
)
if self._quantization_mode is None:
ckpt_dict = checkpointer.load_checkpoint()
else:
# weights_only needs to be False when loading a quantized model
ckpt_dict = checkpointer.load_checkpoint(weights_only=False)
self._model = self._setup_model(
model_cfg=cfg.model,
model_state_dict=ckpt_dict[training.MODEL_KEY],
)
self._tokenizer = config.instantiate(cfg.tokenizer)
def _setup_model(
self,
model_cfg: DictConfig,
model_state_dict: Dict[str, Any],
) -> nn.Module:
with training.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)
if self._quantization_mode is not None:
model = self._quantizer.quantize(model)
model = model.to(device=self._device, dtype=self._dtype)
for k, v in model_state_dict.items():
model_state_dict[k] = v.to(self._device)
model.load_state_dict(model_state_dict, assign=True)
else:
model.load_state_dict(model_state_dict)
# Validate model was loaded in with the expected dtype.
training.validate_expected_param_dtype(
model.named_parameters(), dtype=self._dtype
)
logger.info(f"Model is initialized with precision {self._dtype}.")
return model
def convert_prompt_to_tokens(
self,
prompt: Dict[Role, str],
) -> List[int]:
"""
Convert the prompt string to a user message with optional system messages
and tokenize using the prompt template defined on the tokenizer.
"""
messages = []
if "system" in prompt and prompt["system"] is not None:
messages.append(Message(role="system", content=prompt["system"]))
messages.extend(
[
Message(role="user", content=prompt["user"]),
# Empty assistant message to kick-start generation
Message(role="assistant", content=""),
]
)
return self._tokenizer({"messages": messages}, inference=True)["tokens"]
@torch.inference_mode()
def generate(self, cfg: DictConfig):
tokens = self.convert_prompt_to_tokens(
cfg.prompt,
)
prompt = torch.tensor(tokens, dtype=torch.int, device=self._device)
custom_generate_next_token = None
# Ensure the cache is setup on the right device, with only as many tokens as we need
if cfg.enable_kv_cache:
with self._device:
self._model.setup_caches(
batch_size=1,
dtype=self._dtype,
decoder_max_seq_len=prompt.numel() + cfg.max_new_tokens,
)
# since quantized model uses torch.compile to get speedup, it needs a warm up / prefill run
# to get the accurate performance measurement
if self._quantization_mode is not None:
logger.info("Starting compilation to improve generation performance ...")
custom_generate_next_token = torch.compile(
generation.generate_next_token, mode="max-autotune", fullgraph=True
)
t0 = time.perf_counter()
_ = generation.generate(
model=self._model,
prompt=prompt,
max_generated_tokens=2,
temperature=cfg.temperature,
top_k=cfg.top_k,
stop_tokens=self._tokenizer.stop_tokens,
custom_generate_next_token=custom_generate_next_token,
)
t = time.perf_counter() - t0
logger.info(f"Warmup run for quantized model takes: {t:.02f} sec")
self._model.reset_caches()
t0 = time.perf_counter()
generated_tokens, _ = generation.generate(
model=self._model,
prompt=prompt,
max_generated_tokens=cfg.max_new_tokens,
pad_id=self._tokenizer.pad_id,
temperature=cfg.temperature,
top_k=cfg.top_k,
stop_tokens=self._tokenizer.stop_tokens,
custom_generate_next_token=custom_generate_next_token,
)
generated_tokens = generated_tokens.tolist()
t = time.perf_counter() - t0
logger.info(self._tokenizer.decode(generated_tokens[0]))
model_size = sum(
[
p.numel() * p.dtype.itemsize
for p in itertools.chain(
self._model.parameters(), self._model.buffers()
)
]
)
tokens_generated = len(generated_tokens[0]) - prompt.size(0)
tokens_sec = tokens_generated / t
logger.info(
f"Time for inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
)
logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
if self._device.type != "cpu":
torch_device = utils.get_torch_device_namespace()
logger.info(
f"Memory used: {torch_device.max_memory_allocated() / 1e9:.02f} GB"
)
@config.parse
def main(cfg: DictConfig) -> None:
config.log_config(recipe_name="InferenceRecipe", cfg=cfg)
recipe = InferenceRecipe(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.generate(cfg=cfg)
if __name__ == "__main__":
sys.exit(main())