Skip to content

Commit 8eca8c7

Browse files
authored
Merge pull request #189 from intel/update-branch
refactor: update FluxSchnell initialization and model conversion process (#503)
2 parents 5a91600 + cb4e267 commit 8eca8c7

File tree

1 file changed

+12
-9
lines changed
  • usecases/ai/microservices/text-to-image/flux-schnell/backend

1 file changed

+12
-9
lines changed

usecases/ai/microservices/text-to-image/flux-schnell/backend/server.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,16 @@ def randn_tensor(self, shape: ov.Shape):
6767
# Main Class for Flux.1 Schnell
6868
# -------------------------------------------------------------------------
6969
class FluxSchnell:
70-
def __init__(self, quantize=False):
71-
self.quantize = quantize
70+
def __init__(self):
71+
self.weight_format = os.getenv("WEIGHT_FORMAT", "int8")
72+
self.model_dir = Path(os.getenv("MODEL_DIR", "openvino-flux-schnell"))
7273
self.model_name = "black-forest-labs/FLUX.1-schnell"
73-
self.model_dir = Path("openvino-flux-schnell")
7474
self.random_generator = Generator(42)
7575

7676
# Automatically convert models during initialization
7777
print("Converting models during initialization...")
7878
try:
79-
self.convert_models()
79+
self.convert_models(self.weight_format)
8080
print("Model conversion completed successfully.")
8181
except Exception as e:
8282
print(f"Model conversion failed: {e}")
@@ -91,14 +91,17 @@ def __init__(self, quantize=False):
9191
raise
9292

9393
@log_elapsed_time
94-
def convert_models(self):
94+
def convert_models(self, weight_format="int8"):
9595
"""Convert PyTorch models to OpenVINO IR (if not already done)."""
9696
if not self.model_dir.exists():
97-
print(f"Downloading model: {self.model_name} to {self.model_dir}...")
97+
print(f"Downloading model: {self.model_name} ({weight_format}) to {self.model_dir}...")
9898
additional_args = {}
99-
additional_args.update({"weight-format": "int8", "group-size": "64", "ratio": "1.0"})
99+
additional_args.update({
100+
"weight-format": weight_format,
101+
"group-size": "64",
102+
"ratio": "1.0"
103+
})
100104
optimum_cli(self.model_name, self.model_dir, additional_args=additional_args)
101-
# optimum_cli(self.model_name, self.model_dir)
102105
print("Model conversion completed.")
103106

104107
@staticmethod
@@ -212,7 +215,7 @@ class PromptRequest(BaseModel):
212215

213216
class Sdv3API:
214217
def __init__(self):
215-
self.installer = FluxSchnell(quantize=True)
218+
self.installer = FluxSchnell()
216219
self.image_path = "tmp_output_image.png"
217220
self.pipeline_status = {"running": False, "completed": False}
218221

0 commit comments

Comments
 (0)