Skip to content

Commit 99d7fa0

Browse files
committed
fix: quantization
1 parent b2cbb20 commit 99d7fa0

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

HakaseCore/llm/llama3.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os.path
33

44
import torch
5-
from transformers import AutoModelForCausalLM, AutoTokenizer
5+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
66

77

88
class LLama3(object):
@@ -24,11 +24,14 @@ def __init__(self, accelerate_engine: str = "cuda", debug: bool = False) -> None
2424
f"{accelerate_engine} is not a valid accelerate_engine"
2525
)
2626

27-
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto")
28-
self.model = torch.quantization.quantize_dynamic(
29-
model,
30-
{torch.nn.Linear},
31-
dtype=torch.qint8,
27+
bnb_config = BitsAndBytesConfig(
28+
load_in_4bit=True,
29+
bnb_4bit_quant_type="nf4",
30+
bnb_4bit_use_double_quant=True,
31+
bnb_4bit_compute_dtype=torch.bfloat16,
32+
)
33+
self.model = AutoModelForCausalLM.from_pretrained(
34+
self.model_id, quantization_config=bnb_config, device_map="auto"
3235
)
3336
self.tokenizer = AutoTokenizer.from_pretrained(
3437
self.model_id, add_special_tokens=True

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ torch==2.2.2
33
torchvision==0.17.2
44
torchaudio==2.2.2
55
transformers==4.40.1
6-
accelerate==0.30.0
6+
accelerate==0.30.0
7+
bitsandbytes==0.43.1

0 commit comments

Comments
 (0)