Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os.path
import re

from ldm_patched.modules.sd import load_lora_for_models
from ldm_patched.modules.sd import load_lora_for_models, unpatch_models_prior_to_patch
from ldm_patched.modules.utils import load_torch_file
from modules import errors, scripts, sd_models, shared

Expand Down Expand Up @@ -58,6 +58,10 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
if current_sd.current_lora_hash == compiled_lora_targets_hash:
return

if current_sd.current_lora_hash != str([]):
# patch for persistent patches only: need to unpatch prior to patching if lora hash has changed.
unpatch_models_prior_to_patch(current_sd.forge_objects.unet, current_sd.forge_objects.clip)

current_sd.current_lora_hash = compiled_lora_targets_hash
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
Expand Down
18 changes: 15 additions & 3 deletions ldm_patched/modules/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def require_patch(self) -> bool:

return self.current == 0

def require_unpatch(self) -> bool:
def require_unpatch(self, is_prepatch=False) -> bool:
if is_prepatch:
return self.require_prepatch_unpatch()

if not PERSISTENT_PATCHES:
return True

Expand All @@ -46,6 +49,15 @@ def require_unpatch(self) -> bool:

return self.current != self.updated

def require_prepatch_unpatch(self) -> bool:
if not PERSISTENT_PATCHES:
return False

if not PatchStatus.has_lora():
return False

return True

def patch(self):
if self.updated > 0:
self.current = self.updated
Expand Down Expand Up @@ -463,8 +475,8 @@ def calculate_weight(self, patches, weight, key):

return weight

def unpatch_model(self, device_to=None):
if self.backup and self.patch_status.require_unpatch():
def unpatch_model(self, device_to=None, is_prepatch=False):
if self.backup and (self.patch_status.require_unpatch(is_prepatch)):
keys = list(self.backup.keys())

if self.weight_inplace_update:
Expand Down
4 changes: 4 additions & 0 deletions ldm_patched/modules/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen

return model, clip

def unpatch_models_prior_to_patch(model, clip):
model.unpatch_model(model.offload_device, is_prepatch=True)
clip.patcher.unpatch_model(clip.patcher.offload_device, is_prepatch=True)


class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False):
Expand Down