2
2
import os .path
3
3
4
4
import torch
5
- from transformers import AutoModelForCausalLM , AutoTokenizer , TextStreamer
5
+ from transformers import (
6
+ AutoModelForCausalLM ,
7
+ AutoTokenizer ,
8
+ BitsAndBytesConfig ,
9
+ TextStreamer ,
10
+ )
6
11
7
12
8
13
class LLama3 (object ):
@@ -24,10 +29,20 @@ def __init__(self, accelerate_engine: str = "cuda", debug: bool = False) -> None
24
29
f"{ accelerate_engine } is not a valid accelerate_engine"
25
30
)
26
31
27
- self .model = AutoModelForCausalLM .from_pretrained (self .model_id ).to (
28
- self .accelerate_engine
32
+ self .bnb_config = BitsAndBytesConfig (
33
+ load_in_4bit = True ,
34
+ bnb_4bit_quant_type = "nf4" ,
35
+ bnb_4bit_use_double_quant = True ,
36
+ bnb_4bit_compute_dtype = torch .bfloat16 ,
29
37
)
30
- self .tokenizer = AutoTokenizer .from_pretrained (self .model_id )
38
+ self .model = AutoModelForCausalLM .from_pretrained (
39
+ self .model_id , quantization_config = self .bnb_config , device_map = "auto"
40
+ )
41
+ self .tokenizer = AutoTokenizer .from_pretrained (
42
+ self .model_id , add_special_tokens = True
43
+ )
44
+ self .tokenizer .pad_token = self .tokenizer .eos_token
45
+ self .tokenizer .padding_side = "right"
31
46
self .streamer = TextStreamer (
32
47
self .tokenizer , skip_prompt = True , skip_special_tokens = True
33
48
)
@@ -48,13 +63,18 @@ def generate_instruction(self, instruction: str) -> None:
48
63
49
64
def generate_text (self , instruction : str ) -> str :
50
65
self .generate_instruction (instruction = instruction )
51
- inputs = self .tokenizer .apply_chat_template (
52
- self .prompt , tokenize = True , return_tensors = "pt"
53
- ). to ( self . accelerate_engine )
66
+ prompt = self .tokenizer .apply_chat_template (
67
+ self .prompt , tokenize = False , add_generation_prompt = True
68
+ )
54
69
outputs = self .model .generate (
55
- inputs ,
70
+ prompt ,
56
71
streamer = self .streamer ,
57
- max_new_tokens = 1024 ,
58
- pad_token_id = self .tokenizer .eos_token_id ,
72
+ do_sample = True ,
73
+ temperature = 0.4 ,
74
+ top_p = 0.9 ,
75
+ eos_token_id = [
76
+ self .tokenizer .eos_token_id ,
77
+ self .tokenizer .convert_tokens_to_ids ("<|eot_id|>" ),
78
+ ],
59
79
)
60
- print (outputs )
80
+ print (outputs [ 0 ][ "generated_text" ][ len ( prompt ) :] )
0 commit comments