forked from e-p-armstrong/amadeus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·220 lines (178 loc) · 8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from datasets import Dataset
import os
import torch
import torch.nn as nn
import datasets
from datasets import Dataset
import bitsandbytes as bb
from transformers import AutoTokenizer, LlamaForCausalLM, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import json
from peft import LoraConfig, get_peft_model
import transformers
from make_card_evanchat import make_card_evanchat_daru, make_card_evanchat_faris, make_card_evanchat_kurisu, make_card_evanchat_luka, make_card_evanchat_mayuri, make_card_evanchat_okabe, make_card_evanchat_suzuha # Evanchat is my own take on character cards, where instead of PLists we use normal English, and we also list the character archetypes at the top of the card.
import json
import random
from determine_perspective import determine_perspective
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_json_file(json_file_path):
# Initialize an empty list to store the reformatted dictionaries
reformatted_list = []
# Read the JSON file
with open(json_file_path, 'r') as f:
data = json.load(f)
# Iterate through each dictionary in the list
for entry in data:
# Initialize an empty list to store the parsed 'history' field
parsed_history = []
# Split the 'history' string by the newline character to get each line
lines = entry['history'].split('\n')
# Iterate through each line to parse the speaker and the line
for line in lines:
if ': ' in line: # Checking if the line actually contains dialogue
speaker, dialogue = line.split(': ', 1) # Split at the first occurrence of ': '
parsed_history.append((speaker, dialogue))
# Create a new dictionary with the parsed 'history' field
new_entry = {
'history': parsed_history,
'completion': entry['completion'],
'scenario': entry['scenario'],
'speaker': entry['speaker']
}
# Append the new dictionary to the reformatted list
reformatted_list.append(new_entry)
return reformatted_list
json_file_path = "final_dataset.json"
reformatted_data = parse_json_file(json_file_path)
print(reformatted_data[7])
# New dataset code:
def format_chat_history(chat_history, speaker):
return '\n'.join([f'### Response:\n#### {speaker}: {line}' if s == speaker else f'### Instruction:\n#### {s}: {line}' for s, line in chat_history])
card_dataset = []
for ex in reformatted_data: # make a version of the dataset with the first card
fp = "first person" == determine_perspective(ex["speaker"],ex["completion"])
if ex["speaker"] == "Kurisu":
card_dataset.append(make_card_evanchat_kurisu(ex["scenario"], format_chat_history(ex["history"],ex["speaker"]), ex["completion"], fp))
elif ex["speaker"] == "Itaru":
card_dataset.append(make_card_evanchat_daru(ex["scenario"], format_chat_history(ex["history"],ex["speaker"]), ex["completion"], fp))
elif ex["speaker"] == "Mayuri":
card_dataset.append(make_card_evanchat_mayuri(ex["scenario"], format_chat_history(ex["history"],ex["speaker"]), ex["completion"], fp))
elif ex["speaker"] == "Faris":
card_dataset.append(make_card_evanchat_faris(ex["scenario"], format_chat_history(ex["history"],ex["speaker"]), ex["completion"], fp))
elif ex["speaker"] == "Okabe":
card_dataset.append(make_card_evanchat_okabe(ex["scenario"], format_chat_history(ex["history"],ex["speaker"]), ex["completion"], fp))
elif ex["speaker"] == "Luka":
card_dataset.append(make_card_evanchat_luka(ex["scenario"], format_chat_history(ex["history"],ex["speaker"]), ex["completion"], fp))
elif ex["speaker"] == "Suzuha":
card_dataset.append(make_card_evanchat_suzuha(ex["scenario"], format_chat_history(ex["history"],ex["speaker"]), ex["completion"], fp))
else:
print("\n\n\nERROR unrecognized char: " + ex["speaker"] + "\nFIX THIS\n\n\n")
# for ex in reformatted_data: # make a version with the second, so that the model doesn't learn to predict correctly when given only a very specific type of card (experiment)
# card_dataset.append(make_card_bullets(ex["scenario"], format_chat_history(ex["history"]), ex["completion"]))
# Load dataset and convert to Huggingface Dataset Dict
dataset = Dataset.from_list(card_dataset)
print(dataset,"\n\n\n")
# Sort datasets by length so that if longer examples cause memory issues, it'll happen first, and we can fix it without wasting time
# dataset = dataset.map(lambda example: {"text": example["text"], "length": len(example["text"])})
# dataset = dataset.sort("length", reverse=True)
tokenizer = AutoTokenizer.from_pretrained("Gryphe/MythoMax-L2-13b", max_length=4000, padding_side="right")
# tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # Note, do not do this, it will break the embedding and cause a hard-to-fix error
tokenizer.pad_token_id = tokenizer.eos_token_id
# add eos token to training data
dataset = dataset.map(lambda example: {"text": example["text"] + tokenizer.eos_token})
dataset = dataset.train_test_split(test_size=0.05)
print(dataset)
print(dataset["train"][0]["text"])
# don't forget pip install -U git+https://github.com/lvwerra/trl
# Model time!
# Sillytavern response template: "### Response (2 paragraphs, engaging, natural, authentic, descriptive, creative):
####"
response_template = [2277,
29937,
13291,
313,
29906,
14880,
29879,
29892,
3033,
6751,
29892,
5613,
29892,
15585,
29892,
29037,
573,
29892,
907,
1230,
1125,
13,
4136]
# print("\n\n\n====================\n\n\n")
# print(type(response_template), response_template)
# print("\n\n\n====================\n\n\n")
# uncoment this and the thing in the sfttrainer to do completion only
# This is the only problem besides OOM, which will be solved by using vast.ai
# No prompt dropout this time, because I want to vary only one thing at a time
collator = DataCollatorForCompletionOnlyLM(
# instruction_template="You are an expert roleplaying model", # If I have a response template I don't think I *need* this part. Probably.
response_template=response_template,
tokenizer=tokenizer,
mlm=False
)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
llm_int8_enable_fp32_cpu_offloat=True,
bnb_4bit_compute_dtype=torch.float16,
)
base_model = LlamaForCausalLM.from_pretrained(
"Gryphe/MythoMax-L2-13b",
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
force_download=True,
resume_download=False
)
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=["q_proj","k_proj","v_proj","o_proj", "gate_proj", "up_proj", "down_proj"
# "rotary_emb" # idk what this even is, so I'm hesitant to LoRA it. Try it later?
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",# the weird index issue was solved by correctly specifying the task type in CAPS
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
model.enable_input_require_grads() # sometimes prevents an error for some reason
# model.gradient_checkpointing_enable()
training_args = TrainingArguments(
per_device_eval_batch_size=2,
gradient_accumulation_steps=32,
gradient_checkpointing=True,
learning_rate=1e-4,
num_train_epochs=2,
save_strategy="epoch",
# save_steps=len(reformatted_data), # save every time we go through the dataset once, not through the dataset 2x
logging_steps=1,
fp16=True,
output_dir="outputs",
per_device_train_batch_size=2,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
# data_collator=transformers.DataCollatorForLanguageModeling(tokenizer,mlm=False),#
data_collator=collator,
max_seq_length=4000,
dataset_text_field ="text",
)
trainer.train()
trainer.save_model("Kakkokari-13b-mythomax")