Skip to content

Commit 7d94c95

Browse files
authored
feat: support multi lora adapters and TP (NVIDIA#3885)
* support multi lora, tp Signed-off-by: Shahar Mor <17088876+shaharmor98@users.noreply.github.com>
1 parent 99313af commit 7d94c95

18 files changed

+274
-175
lines changed

examples/pytorch/quickstart_lora.py

-38
This file was deleted.

tensorrt_llm/_torch/model_config.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -168,22 +168,18 @@ def from_pretrained(cls,
168168
quant_config_dict=layer_quant_config,
169169
**kwargs)
170170

171-
def get_bindings_model_config(
172-
self,
173-
tensor_parallelism: int = 1,
174-
context_parallelism: int = 1) -> "ModelConfigCpp":
171+
def get_bindings_model_config(self) -> "ModelConfigCpp":
175172
"""
176173
This method is used to construct the bindings config for the model.
177174
Currently it adheres to gptJsonConfig.cpp::createModelConfig, which assumes
178175
that an engine has been created.
179176
"""
180177
# TODO smor- this isn't robust, and currently tested for LlamaConfig only
181-
# TODO smor- currently parallelism is not supported, set default to 1
182178
# TODO smor- currently assuming no rnn layers, no MOE
183179
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
184180

185181
num_heads = self.pretrained_config.num_attention_heads // (
186-
tensor_parallelism * context_parallelism)
182+
self.mapping.tp_size * self.mapping.cp_size)
187183

188184
model_config_cpp = ModelConfigCpp(
189185
vocab_size=self.pretrained_config.vocab_size,
@@ -195,7 +191,7 @@ def get_bindings_model_config(
195191
data_type=torch_dtype_to_binding(
196192
self.pretrained_config.torch_dtype))
197193

198-
mlp_hidden_size = self.pretrained_config.intermediate_size // tensor_parallelism
194+
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
199195
if "head_size" in self.pretrained_config:
200196
head_size = self.pretrained_config.head_size
201197
else:

tensorrt_llm/_torch/models/modeling_llama.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AllReduceParams, DeepseekAllReduce)
1414
from tensorrt_llm._torch.pipeline_interface import PipelineInterface
1515
from tensorrt_llm.functional import PositionEmbeddingType
16+
from tensorrt_llm.models.convert_utils import split_matrix_tp
1617

1718
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
1819
register_input_processor)
@@ -773,13 +774,14 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
773774
self.padding_idx = config.pad_token_id
774775

775776
vocab_size = config.vocab_size
776-
# TODO smor- hack
777-
if hasattr(model_config,
778-
'lora_config') and model_config.lora_config is not None:
777+
# TODO smor- we load manually only if there is a single lora dir, need to come up with a better solution
778+
if hasattr(
779+
model_config,
780+
'lora_config') and model_config.lora_config is not None and len(
781+
model_config.lora_config.lora_dir) == 1:
779782
from tensorrt_llm.lora_manager import HfLoraLoader
780783
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
781784
weight = lora_loader.embed_tokens
782-
# TODO smor - need to split tp matrix here
783785
vocab_size = lora_loader.vocab_size
784786

785787
self.embed_tokens = Embedding(
@@ -791,9 +793,17 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
791793
gather_output=True,
792794
)
793795

794-
if hasattr(model_config,
795-
'lora_config') and model_config.lora_config is not None:
796+
if hasattr(
797+
model_config,
798+
'lora_config') and model_config.lora_config is not None and len(
799+
model_config.lora_config.lora_dir) == 1:
796800
with torch.no_grad():
801+
if model_config.mapping.tp_size > 1:
802+
weight = split_matrix_tp(
803+
weight,
804+
model_config.mapping.tp_size,
805+
model_config.mapping.tp_rank,
806+
dim=0) # split by vocabulary dimension
797807
x = weight.to(self.embed_tokens.dtype)
798808
self.embed_tokens.weight.data.copy_(x)
799809

tensorrt_llm/_torch/models/modeling_utils.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from torch.utils._pytree import tree_any_only
1212
from tqdm import tqdm
1313

14+
from tensorrt_llm.mapping import Mapping
15+
from tensorrt_llm.models.convert_utils import split_matrix_tp
16+
1417
from ...logger import logger
1518
from ...mapping import Mapping
1619
from ...models.modeling_utils import QuantConfig
@@ -240,7 +243,7 @@ def forward(
240243
input_ids: torch.LongTensor = None,
241244
position_ids: Optional[torch.LongTensor] = None,
242245
inputs_embeds: Optional[torch.FloatTensor] = None,
243-
lora_params: Optional = None, # TODO smor add type hint
246+
lora_params: Optional[dict] = None,
244247
**kwargs,
245248
) -> torch.Tensor:
246249
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -357,9 +360,9 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
357360
# TODO(zhenhuanc): Currently lm_head Linear will not accept QuantConfig
358361
# will considering per layer QuantConfig in the future.
359362

360-
# TODO smor- hack
361-
if hasattr(config,
362-
'lora_config') and config.lora_config is not None:
363+
if hasattr(config, 'lora_config'
364+
) and config.lora_config is not None and len(
365+
config.lora_config.lora_dir) == 1:
363366
from tensorrt_llm.lora_manager import HfLoraLoader
364367
lora_loader = HfLoraLoader(config.lora_config.lora_dir)
365368
weight = lora_loader.lm_head
@@ -374,9 +377,16 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
374377
gather_output=True,
375378
)
376379

377-
if hasattr(config,
378-
'lora_config') and config.lora_config is not None:
380+
if hasattr(config, 'lora_config'
381+
) and config.lora_config is not None and len(
382+
config.lora_config.lora_dir) == 1:
379383
with torch.no_grad():
384+
if config.mapping.tp_size > 1:
385+
weight = split_matrix_tp(
386+
weight,
387+
config.mapping.tp_size,
388+
config.mapping.tp_rank,
389+
dim=0) # split by vocabulary dimension
380390
x = weight.to(self.lm_head.dtype).cuda()
381391
self.lm_head.weight.data.copy_(x)
382392

@@ -475,7 +485,7 @@ def forward(
475485
pipeline_interface: Optional[PipelineInterface] = None,
476486
return_context_logits: bool = False,
477487
spec_metadata: Optional[SpecMetadata] = None,
478-
lora_params: Optional = None, # TODO smor add type hint
488+
lora_params: Optional[dict] = None,
479489
**kwargs,
480490
) -> torch.Tensor:
481491
if self._supports_pp and self.pp_size > 1:
@@ -657,8 +667,10 @@ def filter_weights(prefix, weights: Dict):
657667

658668
# Skip loading weights for embedding and lm_head if LoRA is enabled
659669
if hasattr(model.model_config, 'lora_config'
660-
) and model.model_config.lora_config is not None and (
661-
name == "model.embed_tokens" or name == "lm_head"):
670+
) and model.model_config.lora_config is not None and len(
671+
model.model_config.lora_config.lora_dir) == 1 and (
672+
name == "model.embed_tokens"
673+
or name == "lm_head"):
662674
continue
663675

664676
# Skip if parameter belongs to a missing layer

tensorrt_llm/_torch/modules/attention.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def __init__(
8888
quant_config=config.get_quant_config(),
8989
skip_create_weights_in_init=config.skip_create_weights_in_init,
9090
)
91+
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
92+
[self.hidden_size])
93+
9194
self.o_proj = Linear(
9295
tp_size * self.q_size,
9396
self.hidden_size,
@@ -97,6 +100,7 @@ def __init__(
97100
tensor_parallel_mode=TensorParallelMode.ROW,
98101
quant_config=config.get_quant_config(),
99102
skip_create_weights_in_init=config.skip_create_weights_in_init,
103+
lora=self.o_lora,
100104
)
101105
self.quant_config = config.get_quant_config()
102106
self.attn_backend = config.attn_backend
@@ -229,13 +233,9 @@ def forward(
229233
mrope_config=mrope_config)
230234
hidden_states = attn_output
231235
attn_output = self.o_proj(attn_output,
232-
all_reduce_params=all_reduce_params)
233-
if bool(lora_params):
234-
attn_lora_output = self.o_lora(hidden_states, lora_params,
235-
self.layer_idx)
236-
if attn_lora_output is not None:
237-
attn_output = attn_output + attn_lora_output
238-
236+
all_reduce_params=all_reduce_params,
237+
lora_params=lora_params,
238+
layer_idx=self.layer_idx)
239239
return attn_output
240240

241241

tensorrt_llm/_torch/modules/gated_mlp.py

+19-31
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def __init__(self,
7676
reduce_output=False,
7777
skip_create_weights_in_init=config.skip_create_weights_in_init,
7878
)
79+
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
80+
[self.hidden_size])
81+
7982
self.down_proj = Linear(
8083
self.intermediate_size,
8184
self.hidden_size,
@@ -86,18 +89,20 @@ def __init__(self,
8689
quant_config=config.get_quant_config(),
8790
reduce_output=reduce_output,
8891
skip_create_weights_in_init=config.skip_create_weights_in_init,
92+
lora=self.down_lora,
8993
)
9094

9195
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
9296
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
9397
# handles them as a single fused operation.
9498
self.splitted_gate_up_lora = LoraLayer(
95-
[LoraModuleType.MLP_H_TO_4H, LoraModuleType.MLP_GATE],
96-
[self.intermediate_size, self.intermediate_size])
97-
self.fused_gate_up_lora = LoraLayer([LoraModuleType.MLP_GATE_UP],
98-
[2 * self.intermediate_size])
99-
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
100-
[self.hidden_size])
99+
[LoraModuleType.MLP_H_TO_4H, LoraModuleType.MLP_GATE], [
100+
self.intermediate_size // mapping.tp_size,
101+
self.intermediate_size // mapping.tp_size
102+
])
103+
self.fused_gate_up_lora = LoraLayer(
104+
[LoraModuleType.MLP_GATE_UP],
105+
[2 * self.intermediate_size // mapping.tp_size])
101106

102107
def forward(
103108
self,
@@ -107,33 +112,17 @@ def forward(
107112
lora_params: Optional[dict] = None,
108113
**kwargs,
109114
) -> torch.Tensor:
110-
if lora_params is not None:
115+
if bool(lora_params):
111116
return self.forward_lora(x, all_rank_num_tokens,
112117
final_all_reduce_params, lora_params)
113118

114119
if self.activation == F.silu:
115120
h1 = self.gate_up_proj(x)
116-
if bool(lora_params):
117-
assert self.layer_idx is not None, "layer_idx is required for lora"
118-
h1_lora = self.splitted_gate_up_lora(x, lora_params,
119-
self.layer_idx)
120-
if h1_lora is not None:
121-
h1 = h1 + h1_lora
122-
123-
h1_lora = self.fused_gate_up_lora(x, lora_params,
124-
self.layer_idx)
125-
126-
if h1_lora is not None:
127-
h1 = h1 + h1_lora
128121

129122
h2 = swiglu(h1)
130123
output = self.down_proj(h2,
131-
all_reduce_params=final_all_reduce_params)
132-
if bool(lora_params):
133-
output_lora = self.down_lora(h2, lora_params, self.layer_idx)
134-
if output_lora is not None:
135-
output = output + output_lora
136-
124+
all_reduce_params=final_all_reduce_params,
125+
layer_idx=self.layer_idx)
137126
return output
138127
else:
139128
raise NotImplementedError(
@@ -154,19 +143,18 @@ def forward_lora(
154143
h1 = self.gate_up_proj(x)
155144

156145
h1_lora = self.splitted_gate_up_lora(x, lora_params, self.layer_idx)
146+
157147
if h1_lora is not None:
158148
h1 = h1 + h1_lora
159149

160150
h1_lora = self.fused_gate_up_lora(x, lora_params, self.layer_idx)
161-
162151
if h1_lora is not None:
163152
h1 = h1 + h1_lora
164153

165154
h2 = swiglu(h1)
166-
output = self.down_proj(h2, all_reduce_params=final_all_reduce_params)
167-
168-
output_lora = self.down_lora(h2, lora_params, self.layer_idx)
169-
if output_lora is not None:
170-
output = output + output_lora
155+
output = self.down_proj(h2,
156+
all_reduce_params=final_all_reduce_params,
157+
lora_params=lora_params,
158+
layer_idx=self.layer_idx)
171159

172160
return output

0 commit comments

Comments
 (0)