Skip to content

Commit cd5fa97

Browse files
committed
ability to get LoRA weights in SDLoraManager
1 parent fb90b00 commit cd5fa97

File tree

1 file changed

+19
-0
lines changed
  • src/refiners/foundationals/latent_diffusion

1 file changed

+19
-0
lines changed

src/refiners/foundationals/latent_diffusion/lora.py

+19
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,25 @@ def add_loras(
9292
# set the scale of the LoRA
9393
self.set_scale(name, scale)
9494

95+
def _get_lora_weights(self, base: fl.Chain, name: str, accum: dict[str, Tensor]) -> None:
96+
prev_parent: fl.Chain | None = None
97+
n = 0
98+
for lora_adapter, parent in base.walk(LoraAdapter):
99+
lora = next((l for l in lora_adapter.lora_layers if l.name == name), None)
100+
if lora is None:
101+
continue
102+
n = (parent == prev_parent) and n + 1 or 1
103+
pfx = f"{parent.get_path()}.{n}.{lora_adapter.target.__class__.__name__}"
104+
accum[f"{pfx}.down.weight"] = lora.down.weight
105+
accum[f"{pfx}.up.weight"] = lora.up.weight
106+
prev_parent = parent
107+
108+
def get_lora_weights(self, name: str) -> dict[str, Tensor]:
109+
r: dict[str, Tensor] = {}
110+
self._get_lora_weights(self.unet, name, r)
111+
self._get_lora_weights(self.clip_text_encoder, name, r)
112+
return r
113+
95114
def add_loras_to_text_encoder(
96115
self,
97116
loras: dict[str, Lora[Any]],

0 commit comments

Comments
 (0)