Skip to content

Commit 9755601

Browse files
committed
improve docstrings
1 parent 5d784be commit 9755601

File tree

2 files changed

+28
-7
lines changed
  • src/refiners

2 files changed

+28
-7
lines changed

src/refiners/fluxion/adapters/lora.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -486,10 +486,11 @@ def auto_attach_loras(
486486
"""Auto-attach several LoRA layers to a Chain.
487487
488488
Args:
489-
loras: A dictionary of LoRA layers associated to their respective key.
489+
loras: A dictionary of LoRA layers associated to their respective key. The keys are typically
490+
derived from the state dict and only used for `debug_map` and the return value.
490491
target: The target Chain.
491-
include: A list of layer names, only layers with such a layer in its parents will be considered.
492-
exclude: A list of layer names, layers with such a layer in its parents will not be considered.
492+
include: A list of layer names, only layers with such a layer in their ancestors will be considered.
493+
exclude: A list of layer names, layers with such a layer in their ancestors will not be considered.
493494
sanity_check: Check that LoRAs passed are correctly attached.
494495
debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
495496
Returns:
@@ -507,7 +508,7 @@ def auto_attach_loras(
507508
f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
508509
)
509510

510-
# Sanity check: if we re-run the attach, all layers should fail.
511+
# Extra sanity check: if we re-run the attach, all layers should fail.
511512
debug_map_2: list[tuple[str, str]] = []
512513
failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
513514
if debug_map_2 or len(failed_keys_2) != len(loras):

src/refiners/foundationals/latent_diffusion/lora.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@ def add_loras(
6161
name: The name of the LoRA.
6262
tensors: The `state_dict` of the LoRA to load.
6363
scale: The scale to use for the LoRA.
64+
unet_inclusions: A list of layer names, only layers with such a layer
65+
in their ancestors will be considered when patching the UNet.
66+
unet_exclusions: A list of layer names, layers with such a layer in
67+
their ancestors will not be considered when patching the UNet.
68+
If this is `None` then it defaults to `["TimestepEncoder"]`.
69+
unet_preprocess: A map between parts of state dict keys and layer names.
70+
This is used to attach some keys to specific parts of the UNet.
71+
You should leave it set to `None` (it has a default value),
72+
otherwise read the source code to understand how it works.
73+
text_encoder_inclusions: A list of layer names, only layers with such a layer
74+
in their ancestors will be considered when patching the text encoder.
75+
text_encoder_exclusions: A list of layer names, layers with such a layer in
76+
their ancestors will not be considered when patching the text encoder.
6477
6578
Raises:
6679
AssertionError: If the Manager already has a LoRA with the same name.
@@ -117,15 +130,22 @@ def add_loras_to_text_encoder(
117130
/,
118131
include: list[str] | None = None,
119132
exclude: list[str] | None = None,
133+
debug_map: list[tuple[str, str]] | None = None,
120134
) -> None:
121-
"""Add multiple LoRAs to the text encoder.
135+
"""Add multiple LoRAs to the text encoder. See `add_loras` for details about arguments.
122136
123137
Args:
124138
loras: The dictionary of LoRAs to add to the text encoder.
125139
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
126140
"""
127141
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
128-
auto_attach_loras(text_encoder_loras, self.clip_text_encoder, exclude=exclude, include=include)
142+
auto_attach_loras(
143+
text_encoder_loras,
144+
self.clip_text_encoder,
145+
exclude=exclude,
146+
include=include,
147+
debug_map=debug_map,
148+
)
129149

130150
def add_loras_to_unet(
131151
self,
@@ -136,7 +156,7 @@ def add_loras_to_unet(
136156
preprocess: dict[str, str] | None = None,
137157
debug_map: list[tuple[str, str]] | None = None,
138158
) -> None:
139-
"""Add multiple LoRAs to the U-Net.
159+
"""Add multiple LoRAs to the U-Net. See `add_loras` for details about arguments.
140160
141161
Args:
142162
loras: The dictionary of LoRAs to add to the U-Net.

0 commit comments

Comments
 (0)