Skip to content

Commit 5d784be

Browse files
committed
add test for "Adapting SDXL" guide
1 parent cd5fa97 commit 5d784be

9 files changed

+354
-14
lines changed

docs/guides/adapting_sdxl/index.md

+14-14
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ Then, define the inference parameters by setting the appropriate prompt / seed /
5858
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
5959
seed = 42
6060
sdxl.set_inference_steps(50, first_step=0)
61-
sdxl.set_self_attention_guidance(
62-
enable=True, scale=0.75
63-
) # Enable self-attention guidance to enhance the quality of the generated images
61+
62+
# Enable self-attention guidance to enhance the quality of the generated images
63+
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
6464

6565
# ... Inference process
6666

@@ -76,10 +76,10 @@ with no_grad(): # Disable gradient calculation for memory-efficient inference
7676
)
7777
time_ids = sdxl.default_time_ids
7878

79-
manual_seed(seed=seed)
79+
manual_seed(seed)
8080

81-
# Using a higher latents inner dim to improve resolution of generated images
82-
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
81+
# SDXL typically generates 1024x1024, here we use a higher resolution.
82+
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
8383

8484
# Diffusion process
8585
for step in sdxl.steps:
@@ -131,8 +131,8 @@ predicted_image.save("vanilla_sdxl.png")
131131

132132
manual_seed(seed=seed)
133133

134-
# Using a higher latents inner dim to improve resolution of generated images
135-
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
134+
# SDXL typically generates 1024x1024, here we use a higher resolution.
135+
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
136136

137137
# Diffusion process
138138
for step in sdxl.steps:
@@ -213,8 +213,8 @@ manager.add_loras("scifi-lora", tensors=scifi_lora_weights)
213213

214214
manual_seed(seed=seed)
215215

216-
# Using a higher latents inner dim to improve resolution of generated images
217-
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
216+
# SDXL typically generates 1024x1024, here we use a higher resolution.
217+
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
218218

219219
# Diffusion process
220220
for step in sdxl.steps:
@@ -304,8 +304,8 @@ manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.saf
304304

305305
manual_seed(seed=seed)
306306

307-
# Using a higher latents inner dim to improve resolution of generated images
308-
x = torch.randn(size=(1, 4, 256, 256), device=sdxl.device, dtype=sdxl.dtype)
307+
# SDXL typically generates 1024x1024, here we use a higher resolution.
308+
x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)
309309

310310
# Diffusion process
311311
for step in sdxl.steps:
@@ -440,7 +440,7 @@ with torch.no_grad():
440440
ip_adapter.set_clip_image_embedding(clip_image_embedding)
441441

442442
manual_seed(seed=seed)
443-
x = torch.randn(size=(1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype)
443+
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
444444

445445
# Diffusion process
446446
for step in sdxl.steps:
@@ -578,7 +578,7 @@ with torch.no_grad():
578578
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
579579

580580
manual_seed(seed=seed)
581-
x = torch.randn(size=(1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype)
581+
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
582582

583583
# Diffusion process
584584
for step in sdxl.steps:

scripts/prepare_test_weights.py

+14
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,20 @@ def download_loras():
253253
)
254254
download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder, expected_hash="ee170e4d")
255255

256+
dest_folder = os.path.join(test_weights_dir, "loras")
257+
download_file(
258+
"https://civitai.com/api/download/models/140624",
259+
filename="Sci-fi_Environments_sdxl.safetensors",
260+
dest_folder=dest_folder,
261+
expected_hash="6a4afda8",
262+
)
263+
download_file(
264+
"https://civitai.com/api/download/models/135931",
265+
filename="pixel-art-xl-v1.1.safetensors",
266+
dest_folder=dest_folder,
267+
expected_hash="71aaa6ca",
268+
)
269+
256270

257271
def download_preprocessors():
258272
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")

0 commit comments

Comments
 (0)