-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
42 lines (35 loc) · 1.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from config import Config
from torch.utils.data import DataLoader
from torch.optim import Adam
from model.model import Model
from model.reference_model import ReferenceModel
from utils.data_load import CustomDataset
from dpo import DPO
class TrainDpo:
def __init__(self):
self.config = Config()
# 演员和评论家模型
self.model = Model(self.config)
self.tokenizer = self.model.tokenizer
# 获得策略模型优化器, 这里使用的是lora, 不优化全量数据
self.model_opt = Adam(self.model.parameters(), lr=self.config.lr)
# 参考模型
self.reference_model = ReferenceModel(self.config)
# 训练数据
dataset = CustomDataset(self.config.data_path, self.tokenizer)
self.data_loader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True,
collate_fn=dataset.collate_fn)
self.dpo = DPO(self.model, self.model_opt, self.config)
def train_dpo(self):
for epoch in range(self.config.epochs):
for batch_data in self.data_loader:
ref_logits = self.reference_model(batch_data["inputs_ids"], batch_data["inputs_masks"]) # 获得参考模型的logit
self.dpo.train(batch_data["inputs_ids"], batch_data["inputs_masks"], ref_logits,
batch_data["labels_mask"])
self.save_model()
def save_model(self):
# 保存lora参数
self.model.model.save_pretrained(self.config.save_lora_path, safe_serialization=False)
if __name__ == '__main__':
train_dpo = TrainDpo()
train_dpo.train_dpo()