9
9
import torch .distributed as dist
10
10
11
11
import deepspeed
12
- from transformers import AutoConfig , AutoTokenizer
12
+ from huggingface_hub import try_to_load_from_cache
13
+ from transformers import AutoConfig
13
14
14
15
from ..utils import print_rank_n , run_rank_n
15
- from .model import Model , get_downloaded_model_path , get_hf_model_class , load_tokenizer
16
+ from .model import Model , get_hf_model_class
16
17
17
18
18
19
# basic DeepSpeed inference model class for benchmarking
@@ -24,26 +25,23 @@ def __init__(self, args: Namespace) -> None:
24
25
25
26
world_size = int (os .getenv ("WORLD_SIZE" , "1" ))
26
27
27
- downloaded_model_path = get_downloaded_model_path (args .model_name )
28
-
29
- self .tokenizer = load_tokenizer (downloaded_model_path )
30
- self .pad = self .tokenizer .pad_token_id
31
-
32
28
# create dummy tensors for allocating space which will be filled with
33
29
# the actual weights while calling deepspeed.init_inference in the
34
30
# following code
35
31
with deepspeed .OnDevice (dtype = torch .float16 , device = "meta" ):
36
32
self .model = get_hf_model_class (args .model_class ).from_config (
37
- AutoConfig .from_pretrained (downloaded_model_path ), torch_dtype = torch .bfloat16
33
+ AutoConfig .from_pretrained (args . model_name ), torch_dtype = torch .bfloat16
38
34
)
39
35
self .model = self .model .eval ()
40
36
37
+ downloaded_model_path = get_model_path (args .model_name )
38
+
41
39
if args .dtype in [torch .float16 , torch .int8 ]:
42
40
# We currently support the weights provided by microsoft (which are
43
41
# pre-sharded)
44
- if args .use_pre_sharded_checkpoints :
45
- checkpoints_json = os .path .join (downloaded_model_path , "ds_inference_config.json" )
42
+ checkpoints_json = os .path .join (downloaded_model_path , "ds_inference_config.json" )
46
43
44
+ if os .path .isfile (checkpoints_json ):
47
45
self .model = deepspeed .init_inference (
48
46
self .model ,
49
47
mp_size = world_size ,
@@ -60,6 +58,7 @@ def __init__(self, args: Namespace) -> None:
60
58
self .model = deepspeed .init_inference (
61
59
self .model ,
62
60
mp_size = world_size ,
61
+ base_dir = downloaded_model_path ,
63
62
dtype = args .dtype ,
64
63
checkpoint = checkpoints_json ,
65
64
replace_with_kernel_inject = True ,
@@ -74,6 +73,8 @@ def __init__(self, args: Namespace) -> None:
74
73
print_rank_n ("Model loaded" )
75
74
dist .barrier ()
76
75
76
+ self .post_init (args .model_name )
77
+
77
78
78
79
class TemporaryCheckpointsJSON :
79
80
def __init__ (self , model_path : str ):
@@ -93,3 +94,16 @@ def __enter__(self):
93
94
94
95
def __exit__ (self , type , value , traceback ):
95
96
return
97
+
98
+
99
+ def get_model_path (model_name : str ):
100
+ config_file = "config.json"
101
+
102
+ # will fall back to HUGGINGFACE_HUB_CACHE
103
+ config_path = try_to_load_from_cache (model_name , config_file , cache_dir = os .getenv ("TRANSFORMERS_CACHE" ))
104
+
105
+ if config_path is not None :
106
+ return os .path .dirname (config_path )
107
+ # treat the model name as an explicit model path
108
+ elif os .path .isfile (os .path .join (model_name , config_file )):
109
+ return model_name
0 commit comments