Skip to content

Commit e83dd4c

Browse files
committed
fix: quantization
1 parent 99d7fa0 commit e83dd4c

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

HakaseCore/llm/llama3.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import os.path
33

44
import torch
5-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5+
from transformers import (
6+
AutoModelForCausalLM,
7+
AutoTokenizer,
8+
BitsAndBytesConfig,
9+
pipeline,
10+
)
611

712

813
class LLama3(object):
@@ -30,14 +35,17 @@ def __init__(self, accelerate_engine: str = "cuda", debug: bool = False) -> None
3035
bnb_4bit_use_double_quant=True,
3136
bnb_4bit_compute_dtype=torch.bfloat16,
3237
)
33-
self.model = AutoModelForCausalLM.from_pretrained(
38+
self.model_4bit = AutoModelForCausalLM.from_pretrained(
3439
self.model_id, quantization_config=bnb_config, device_map="auto"
3540
)
3641
self.tokenizer = AutoTokenizer.from_pretrained(
3742
self.model_id, add_special_tokens=True
3843
)
3944
self.tokenizer.pad_token = self.tokenizer.eos_token
4045
self.tokenizer.padding_side = "right"
46+
self.pipe = pipeline(
47+
"text-generation", model=self.model_4bit, tokenizer=self.tokenizer
48+
)
4149

4250
def load_prompt(self) -> list[dict[str, str]]:
4351
# Get Hakase Project Path
@@ -55,10 +63,10 @@ def generate_instruction(self, instruction: str) -> None:
5563

5664
def generate_text(self, instruction: str) -> str:
5765
self.generate_instruction(instruction=instruction)
58-
prompt = self.tokenizer.apply_chat_template(
66+
prompt = self.pipe.tokenizer.apply_chat_template(
5967
self.prompt, tokenize=False, add_generation_prompt=True
6068
)
61-
outputs = self.model.generate(
69+
outputs = self.pipe(
6270
prompt,
6371
do_sample=True,
6472
temperature=0.4,

0 commit comments

Comments
 (0)