-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmerge_lora_weights_and_save_hf_model.py
executable file
·147 lines (127 loc) · 5.22 KB
/
merge_lora_weights_and_save_hf_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse
import os
import sys
import torch
import transformers
from peft import LoraConfig, get_peft_model
from model.VideoLISA import VideoLISAForCausalLM
from utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN
def parse_args(args):
parser = argparse.ArgumentParser(description="merge lora weights and save model with hf format")
parser.add_argument("--version", default="MBZUAI/LLaVA-Phi-3-mini-4k-instruct")
parser.add_argument(
"--precision",
default="bf16",
type=str,
choices=["fp32", "bf16", "fp16"],
help="precision for inference",
)
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
parser.add_argument("--out_dim", default=256, type=int)
parser.add_argument("--image_size", default=1024, type=int, help="image size")
parser.add_argument("--model_max_length", default=2048, type=int)
parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14-336", type=str)
parser.add_argument("--lora_r", default=8, type=int)
parser.add_argument("--lora_alpha", default=16, type=int)
parser.add_argument("--lora_dropout", default=0.05, type=float)
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
parser.add_argument("--train_mask_decoder", action="store_true", default=True)
parser.add_argument("--use_mm_start_end", action="store_true", default=False)
parser.add_argument(
"--conv_type",
default="phi3_instruct",
type=str
)
parser.add_argument("--weight", default="", type=str, required=True)
parser.add_argument("--save_path", default="./lisa_model", type=str, required=True)
return parser.parse_args(args)
def main(args):
args = parse_args(args)
# Create model
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
num_added_tokens = tokenizer.add_tokens("[SEG]")
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[-1]
if args.use_mm_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
model_args = {
"train_mask_decoder": args.train_mask_decoder,
"out_dim": args.out_dim,
"seg_token_idx": args.seg_token_idx,
"vision_tower": args.vision_tower,
"use_mm_start_end": args.use_mm_start_end,
}
torch_dtype = torch.float32
if args.precision == "bf16":
torch_dtype = torch.bfloat16
elif args.precision == "fp16":
torch_dtype = torch.half
model = VideoLISAForCausalLM.from_pretrained(
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch_dtype)
model.get_model().initialize_lisa_modules(model.get_model().config)
lora_r = args.lora_r
if lora_r > 0:
def find_linear_layers(model, lora_target_modules):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
and all(
[
x not in name
for x in [
"visual_model",
"vision_tower",
"mm_projector",
"text_hidden_fcs",
]
]
)
and any([x in name for x in lora_target_modules])
):
lora_module_names.add(name)
return sorted(list(lora_module_names))
lora_alpha = args.lora_alpha
lora_dropout = args.lora_dropout
lora_target_modules = find_linear_layers(
model, args.lora_target_modules.split(",")
)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.resize_token_embeddings(len(tokenizer))
state_dict = torch.load(args.weight, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model = model.merge_and_unload()
state_dict = {}
for k, v in model.state_dict().items():
if "vision_tower" not in k:
state_dict[k] = v
model.save_pretrained(args.save_path, state_dict=state_dict)
tokenizer.save_pretrained(args.save_path)
if __name__ == "__main__":
main(sys.argv[1:])