Skip to content

Commit f7e13b4

Browse files
committed
feat: apply 4bit quantization
1 parent 994cece commit f7e13b4

File tree

1 file changed

+31
-11
lines changed

1 file changed

+31
-11
lines changed

HakaseCore/llm/llama3.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import os.path
33

44
import torch
5-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
5+
from transformers import (
6+
AutoModelForCausalLM,
7+
AutoTokenizer,
8+
BitsAndBytesConfig,
9+
TextStreamer,
10+
)
611

712

813
class LLama3(object):
@@ -24,10 +29,20 @@ def __init__(self, accelerate_engine: str = "cuda", debug: bool = False) -> None
2429
f"{accelerate_engine} is not a valid accelerate_engine"
2530
)
2631

27-
self.model = AutoModelForCausalLM.from_pretrained(self.model_id).to(
28-
self.accelerate_engine
32+
self.bnb_config = BitsAndBytesConfig(
33+
load_in_4bit=True,
34+
bnb_4bit_quant_type="nf4",
35+
bnb_4bit_use_double_quant=True,
36+
bnb_4bit_compute_dtype=torch.bfloat16,
2937
)
30-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
38+
self.model = AutoModelForCausalLM.from_pretrained(
39+
self.model_id, quantization_config=self.bnb_config, device_map="auto"
40+
)
41+
self.tokenizer = AutoTokenizer.from_pretrained(
42+
self.model_id, add_special_tokens=True
43+
)
44+
self.tokenizer.pad_token = self.tokenizer.eos_token
45+
self.tokenizer.padding_side = "right"
3146
self.streamer = TextStreamer(
3247
self.tokenizer, skip_prompt=True, skip_special_tokens=True
3348
)
@@ -48,13 +63,18 @@ def generate_instruction(self, instruction: str) -> None:
4863

4964
def generate_text(self, instruction: str) -> str:
5065
self.generate_instruction(instruction=instruction)
51-
inputs = self.tokenizer.apply_chat_template(
52-
self.prompt, tokenize=True, return_tensors="pt"
53-
).to(self.accelerate_engine)
66+
prompt = self.tokenizer.apply_chat_template(
67+
self.prompt, tokenize=False, add_generation_prompt=True
68+
)
5469
outputs = self.model.generate(
55-
inputs,
70+
prompt,
5671
streamer=self.streamer,
57-
max_new_tokens=1024,
58-
pad_token_id=self.tokenizer.eos_token_id,
72+
do_sample=True,
73+
temperature=0.4,
74+
top_p=0.9,
75+
eos_token_id=[
76+
self.tokenizer.eos_token_id,
77+
self.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
78+
],
5979
)
60-
print(outputs)
80+
print(outputs[0]["generated_text"][len(prompt) :])

0 commit comments

Comments
 (0)