@@ -67,16 +67,16 @@ def randn_tensor(self, shape: ov.Shape):
67
67
# Main Class for Flux.1 Schnell
68
68
# -------------------------------------------------------------------------
69
69
class FluxSchnell :
70
- def __init__ (self , quantize = False ):
71
- self .quantize = quantize
70
+ def __init__ (self ):
71
+ self .weight_format = os .getenv ("WEIGHT_FORMAT" , "int8" )
72
+ self .model_dir = Path (os .getenv ("MODEL_DIR" , "openvino-flux-schnell" ))
72
73
self .model_name = "black-forest-labs/FLUX.1-schnell"
73
- self .model_dir = Path ("openvino-flux-schnell" )
74
74
self .random_generator = Generator (42 )
75
75
76
76
# Automatically convert models during initialization
77
77
print ("Converting models during initialization..." )
78
78
try :
79
- self .convert_models ()
79
+ self .convert_models (self . weight_format )
80
80
print ("Model conversion completed successfully." )
81
81
except Exception as e :
82
82
print (f"Model conversion failed: { e } " )
@@ -91,14 +91,17 @@ def __init__(self, quantize=False):
91
91
raise
92
92
93
93
@log_elapsed_time
94
- def convert_models (self ):
94
+ def convert_models (self , weight_format = "int8" ):
95
95
"""Convert PyTorch models to OpenVINO IR (if not already done)."""
96
96
if not self .model_dir .exists ():
97
- print (f"Downloading model: { self .model_name } to { self .model_dir } ..." )
97
+ print (f"Downloading model: { self .model_name } ( { weight_format } ) to { self .model_dir } ..." )
98
98
additional_args = {}
99
- additional_args .update ({"weight-format" : "int8" , "group-size" : "64" , "ratio" : "1.0" })
99
+ additional_args .update ({
100
+ "weight-format" : weight_format ,
101
+ "group-size" : "64" ,
102
+ "ratio" : "1.0"
103
+ })
100
104
optimum_cli (self .model_name , self .model_dir , additional_args = additional_args )
101
- # optimum_cli(self.model_name, self.model_dir)
102
105
print ("Model conversion completed." )
103
106
104
107
@staticmethod
@@ -212,7 +215,7 @@ class PromptRequest(BaseModel):
212
215
213
216
class Sdv3API :
214
217
def __init__ (self ):
215
- self .installer = FluxSchnell (quantize = True )
218
+ self .installer = FluxSchnell ()
216
219
self .image_path = "tmp_output_image.png"
217
220
self .pipeline_status = {"running" : False , "completed" : False }
218
221
0 commit comments