Skip to content

Commit b2cbb20

Browse files
committed
fix: quantization
1 parent e1536be commit b2cbb20

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

HakaseCore/llm/llama3.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
import os.path
33

44
import torch
5-
from transformers import AutoModelForCausalLM
6-
from transformers import AutoTokenizer
7-
from transformers import BitsAndBytesConfig
8-
from transformers import TextStreamer
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
96

107

118
class LLama3(object):
@@ -27,14 +24,11 @@ def __init__(self, accelerate_engine: str = "cuda", debug: bool = False) -> None
2724
f"{accelerate_engine} is not a valid accelerate_engine"
2825
)
2926

30-
self.bnb_config = BitsAndBytesConfig(
31-
load_in_4bit=True,
32-
bnb_4bit_quant_type="nf4",
33-
bnb_4bit_use_double_quant=True,
34-
bnb_4bit_compute_dtype=torch.bfloat16,
35-
)
36-
self.model = AutoModelForCausalLM.from_pretrained(
37-
self.model_id, quantization_config=self.bnb_config, device_map="auto"
27+
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto")
28+
self.model = torch.quantization.quantize_dynamic(
29+
model,
30+
{torch.nn.Linear},
31+
dtype=torch.qint8,
3832
)
3933
self.tokenizer = AutoTokenizer.from_pretrained(
4034
self.model_id, add_special_tokens=True

requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ torch==2.2.2
33
torchvision==0.17.2
44
torchaudio==2.2.2
55
transformers==4.40.1
6-
bitsandbytes==0.43.1
76
accelerate==0.30.0

0 commit comments

Comments
 (0)