From 0a639d590145dcafad8b34f98f456370d1b7f024 Mon Sep 17 00:00:00 2001 From: terracottahaniwa Date: Sat, 7 Sep 2024 19:31:26 +0900 Subject: [PATCH 1/3] support lora block weight --- networks/svd_merge_lora.py | 167 ++++++++++++++++++++++++++++++++++++- 1 file changed, 164 insertions(+), 3 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index cb00a6000..f5c2be19d 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,5 +1,8 @@ import argparse +import itertools +import json import os +import re import time import torch from safetensors.torch import load_file, save_file @@ -14,6 +17,126 @@ CLAMP_QUANTILE = 0.99 +# copied from hako-mikan/sd-webui-lora-block-weight/scripts/lora_block_weight.py +BLOCKID26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"] +BLOCKID17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"] +BLOCKID12=["BASE","IN04","IN05","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05"] +BLOCKID20=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08"] +BLOCKNUMS = [12,17,20,26] +BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26] + +BLOCKS=["encoder", # BASE +"diffusion_model_input_blocks_0_", # IN00 +"diffusion_model_input_blocks_1_", # IN01 +"diffusion_model_input_blocks_2_", # IN02 +"diffusion_model_input_blocks_3_", # IN03 +"diffusion_model_input_blocks_4_", # IN04 +"diffusion_model_input_blocks_5_", # IN05 +"diffusion_model_input_blocks_6_", # IN06 +"diffusion_model_input_blocks_7_", # IN07 +"diffusion_model_input_blocks_8_", # IN08 +"diffusion_model_input_blocks_9_", # IN09 +"diffusion_model_input_blocks_10_", # IN10 +"diffusion_model_input_blocks_11_", # IN11 +"diffusion_model_middle_block_", # M00 +"diffusion_model_output_blocks_0_", # OUT00 +"diffusion_model_output_blocks_1_", # OUT01 +"diffusion_model_output_blocks_2_", # OUT02 +"diffusion_model_output_blocks_3_", # OUT03 +"diffusion_model_output_blocks_4_", # OUT04 +"diffusion_model_output_blocks_5_", # OUT05 +"diffusion_model_output_blocks_6_", # OUT06 +"diffusion_model_output_blocks_7_", # OUT07 +"diffusion_model_output_blocks_8_", # OUT08 +"diffusion_model_output_blocks_9_", # OUT09 +"diffusion_model_output_blocks_10_", # OUT10 +"diffusion_model_output_blocks_11_", # OUT11 +"embedders", +"transformer_resblocks"] + + +def convert_diffusers_name_to_compvis(key, is_sd2): + "copied from AUTOMATIC1111/stable-diffusion-webui/extensions-builtin/Lora/networks.py" + + # put original globals here + re_digits = re.compile(r"\d+") + re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") + re_compiled = {} + + suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } + } # end of original globals + + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_conv_in(.*)"): + return f'diffusion_model_input_blocks_0_0{m[0]}' + + if match(m, r"lora_unet_conv_out(.*)"): + return f'diffusion_model_out_2{m[0]}' + + if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): + return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): + if 'mlp_fc1' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return key + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -42,12 +165,28 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) -def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): +def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} v2 = None base_model = None - for model, ratio in zip(models, ratios): + + if lbws: + try: + # lbsは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all(len(lbw) in BLOCKNUMS for lbw in lbws), f"length of lbw are must be in {BLOCKNUMS} / 層別適用率の長さは{BLOCKNUMS}のいずれかにしてください" + assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + BLOCKID = BLOCKIDS[BLOCKNUMS.index(len(lbws[0]))] + conditions = [blockid in BLOCKID for blockid in BLOCKID26] + BLOCKS_ = [block for block, condition in zip(BLOCKS, conditions) if condition] + + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) @@ -63,6 +202,17 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty if "lora_down" not in key: continue + if lbw: + # keyをlora_unet_down_blocks_0_のようなdiffusers形式から、 + # diffusion_model_input_blocks_0_のようなcompvis形式に変換する + compvis_key = convert_diffusers_name_to_compvis(key, is_sd2) + + block_in_key = [block in compvis_key for block in BLOCKS_] + is_lbw_target = any(block_in_key) + if is_lbw_target: + index = [i for i, in_key in enumerate(block_in_key) if in_key][0] + lbw_weight = lbw[index] + lora_module_name = key[: key.rfind(".lora_down")] down_weight = lora_sd[key] @@ -92,6 +242,9 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty # W <- W + U * D scale = alpha / network_dim + if lbw: + if is_lbw_target: + scale *= lbw_weight # keyがlbwの対象であれば、lbwの重みを掛ける if device: # and isinstance(scale, torch.Tensor): scale = scale.to(device) @@ -170,6 +323,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty def merge(args): assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -187,7 +344,7 @@ def str_to_dtype(p): new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank state_dict, metadata, v2, base_model = merge_lora_models( - args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype + args.sd2, args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) logger.info(f"calculating hashes and creating metadata...") @@ -233,10 +390,14 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" ) + parser.add_argument( + "--sd2", action="store_true", help="set if LoRA models are for SD2 / マージするLoRAモデルがSD2用なら指定します" + ) parser.add_argument( "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") parser.add_argument( "--new_conv_rank", From b6d633d2b8e8cf6d1edc306f9d9f5bbbd074809e Mon Sep 17 00:00:00 2001 From: "terracottahaniwa@gmail.com" Date: Wed, 11 Sep 2024 06:17:21 +0900 Subject: [PATCH 2/3] solve license incompatibility --- networks/svd_merge_lora.py | 236 ++++++++++++++++--------------------- 1 file changed, 99 insertions(+), 137 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index f5c2be19d..903864934 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -17,125 +17,85 @@ CLAMP_QUANTILE = 0.99 -# copied from hako-mikan/sd-webui-lora-block-weight/scripts/lora_block_weight.py -BLOCKID26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"] -BLOCKID17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"] -BLOCKID12=["BASE","IN04","IN05","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05"] -BLOCKID20=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08"] -BLOCKNUMS = [12,17,20,26] -BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26] - -BLOCKS=["encoder", # BASE -"diffusion_model_input_blocks_0_", # IN00 -"diffusion_model_input_blocks_1_", # IN01 -"diffusion_model_input_blocks_2_", # IN02 -"diffusion_model_input_blocks_3_", # IN03 -"diffusion_model_input_blocks_4_", # IN04 -"diffusion_model_input_blocks_5_", # IN05 -"diffusion_model_input_blocks_6_", # IN06 -"diffusion_model_input_blocks_7_", # IN07 -"diffusion_model_input_blocks_8_", # IN08 -"diffusion_model_input_blocks_9_", # IN09 -"diffusion_model_input_blocks_10_", # IN10 -"diffusion_model_input_blocks_11_", # IN11 -"diffusion_model_middle_block_", # M00 -"diffusion_model_output_blocks_0_", # OUT00 -"diffusion_model_output_blocks_1_", # OUT01 -"diffusion_model_output_blocks_2_", # OUT02 -"diffusion_model_output_blocks_3_", # OUT03 -"diffusion_model_output_blocks_4_", # OUT04 -"diffusion_model_output_blocks_5_", # OUT05 -"diffusion_model_output_blocks_6_", # OUT06 -"diffusion_model_output_blocks_7_", # OUT07 -"diffusion_model_output_blocks_8_", # OUT08 -"diffusion_model_output_blocks_9_", # OUT09 -"diffusion_model_output_blocks_10_", # OUT10 -"diffusion_model_output_blocks_11_", # OUT11 -"embedders", -"transformer_resblocks"] - - -def convert_diffusers_name_to_compvis(key, is_sd2): - "copied from AUTOMATIC1111/stable-diffusion-webui/extensions-builtin/Lora/networks.py" - - # put original globals here - re_digits = re.compile(r"\d+") - re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") - re_compiled = {} - - suffix_conversion = { - "attentions": {}, - "resnets": { - "conv1": "in_layers_2", - "conv2": "out_layers_3", - "time_emb_proj": "emb_layers_1", - "conv_shortcut": "skip_connection", - } - } # end of original globals - - def match(match_list, regex_text): - regex = re_compiled.get(regex_text) - if regex is None: - regex = re.compile(regex_text) - re_compiled[regex_text] = regex - - r = re.match(regex, key) - if not r: - return False - - match_list.clear() - match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) - return True - - m = [] - - if match(m, r"lora_unet_conv_in(.*)"): - return f'diffusion_model_input_blocks_0_0{m[0]}' - - if match(m, r"lora_unet_conv_out(.*)"): - return f'diffusion_model_out_2{m[0]}' - - if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): - return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" - - if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) - return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - - if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) - return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): - return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" - - if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): - return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" - - if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): - if is_sd2: - if 'mlp_fc1' in m[1]: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" - elif 'mlp_fc2' in m[1]: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" - else: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - - return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" - - if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): - if 'mlp_fc1' in m[1]: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" - elif 'mlp_fc2' in m[1]: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" - else: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - - return key +ACCEPTABLE = [12, 17, 20, 26] +SDXL_LAYER_NUM = [12, 20] + +LAYER12 = { + "BASE": True, + "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID00": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False +} + +LAYER17 = { + "BASE": True, + "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID00": True, + "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +LAYER20 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID00": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, +} + +layer12 = LAYER12.values() +layer17 = LAYER17.values() +layer20 = LAYER20.values() +layer26 = [True] * 26 +assert len([v for v in layer12 if v]) == 12 +assert len([v for v in layer17 if v]) == 17 +assert len([v for v in layer20 if v]) == 20 +assert len([v for v in layer26 if v]) == 26 + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: + block_idx = -1 # invalid lora name + if not is_sdxl: + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 + + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + else: + # copy from sdxl_train + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA + block_idx = 0 # 0 + elif name.startswith("input_blocks_"): # 1-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10-12 + block_idx = 10 + int(name.split("_")[2]) + elif name.startswith("output_blocks_"): # 13-21 + block_idx = 13 + int(name.split("_")[2]) + elif name.startswith("out_"): # 22, out, no LoRA + block_idx = 22 + + return block_idx def load_state_dict(file_name, dtype): @@ -165,10 +125,10 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) -def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): +def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} - v2 = None + v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2 base_model = None if lbws: @@ -179,12 +139,18 @@ def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, dev raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" - assert all(len(lbw) in BLOCKNUMS for lbw in lbws), f"length of lbw are must be in {BLOCKNUMS} / 層別適用率の長さは{BLOCKNUMS}のいずれかにしてください" + assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" - BLOCKID = BLOCKIDS[BLOCKNUMS.index(len(lbws[0]))] - conditions = [blockid in BLOCKID for blockid in BLOCKID26] - BLOCKS_ = [block for block, condition in zip(BLOCKS, conditions) if condition] + layer_num = len(lbws[0]) + FLAGS = { + "12": layer12, + "17": layer17, + "20": layer20, + "26": layer26, + }[str(layer_num)] + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + TARGET = [i for i, flag in enumerate(FLAGS) if flag] for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") @@ -196,6 +162,10 @@ def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, dev if base_model is None: base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + full_lbw = [1] * 26 + for weight, index in zip(lbw, TARGET): + full_lbw[index] = weight + # merge logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): @@ -203,15 +173,10 @@ def merge_lora_models(is_sd2, models, ratios, lbws, new_rank, new_conv_rank, dev continue if lbw: - # keyをlora_unet_down_blocks_0_のようなdiffusers形式から、 - # diffusion_model_input_blocks_0_のようなcompvis形式に変換する - compvis_key = convert_diffusers_name_to_compvis(key, is_sd2) - - block_in_key = [block in compvis_key for block in BLOCKS_] - is_lbw_target = any(block_in_key) + index = 0 if "encoder" in key else get_block_index(key, is_sdxl) + is_lbw_target = index in TARGET if is_lbw_target: - index = [i for i, in_key in enumerate(block_in_key) if in_key][0] - lbw_weight = lbw[index] + lbw_weight = full_lbw[index] lora_module_name = key[: key.rfind(".lora_down")] @@ -344,7 +309,7 @@ def str_to_dtype(p): new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank state_dict, metadata, v2, base_model = merge_lora_models( - args.sd2, args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype + args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) logger.info(f"calculating hashes and creating metadata...") @@ -390,9 +355,6 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" ) - parser.add_argument( - "--sd2", action="store_true", help="set if LoRA models are for SD2 / マージするLoRAモデルがSD2用なら指定します" - ) parser.add_argument( "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" ) From 1b912775f5cfd7781d34ee010e948fa629851cf3 Mon Sep 17 00:00:00 2001 From: terracottahaniwa Date: Thu, 12 Sep 2024 19:06:39 +0900 Subject: [PATCH 3/3] Fix issue: lbw index calculation --- networks/svd_merge_lora.py | 115 +++++++++++++++++++++---------------- 1 file changed, 67 insertions(+), 48 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 903864934..6e163aecf 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -24,7 +24,7 @@ "BASE": True, "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, - "MID00": True, + "MID": True, "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False } @@ -33,7 +33,7 @@ "BASE": True, "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, - "MID00": True, + "MID": True, "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, } @@ -42,58 +42,78 @@ "BASE": True, "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, - "MID00": True, + "MID": True, "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, } -layer12 = LAYER12.values() -layer17 = LAYER17.values() -layer20 = LAYER20.values() -layer26 = [True] * 26 -assert len([v for v in layer12 if v]) == 12 -assert len([v for v in layer17 if v]) == 17 -assert len([v for v in layer20 if v]) == 20 -assert len([v for v in layer26 if v]) == 26 +LAYER26 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +assert len([v for v in LAYER12.values() if v]) == 12 +assert len([v for v in LAYER17.values() if v]) == 17 +assert len([v for v in LAYER20.values() if v]) == 20 +assert len([v for v in LAYER26.values() if v]) == 26 RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") -def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: +def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: + # lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder + if "text_model_encoder_" in lora_name: # LoRA for text encoder + return 0 + + # lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2 block_idx = -1 # invalid lora name if not is_sdxl: + NUM_OF_BLOCKS = 12 # up/down blocks m = RE_UPDOWN.search(lora_name) if m: g = m.groups() + up_down = g[0] i = int(g[1]) j = int(g[3]) - if g[2] == "resnets": - idx = 3 * i + j - elif g[2] == "attentions": - idx = 3 * i + j - elif g[2] == "upsamplers" or g[2] == "downsamplers": - idx = 3 * i + 2 + if up_down == "down": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + 1 + elif g[2] == "downsamplers": + idx = 3 * (i + 1) + else: + return block_idx # invalid lora name + elif up_down == "up": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers": + idx = 3 * i + 2 + else: + return block_idx # invalid lora name if g[0] == "down": - block_idx = 1 + idx # 0に該当するLoRAは存在しない + block_idx = 1 + idx # 1-based index, down block index elif g[0] == "up": - block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index + elif "mid_block_" in lora_name: - block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block else: - # copy from sdxl_train if lora_name.startswith("lora_unet_"): name = lora_name[len("lora_unet_") :] - if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA - block_idx = 0 # 0 - elif name.startswith("input_blocks_"): # 1-9 + if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts + block_idx = 1 + elif name.startswith("input_blocks_"): # 1-8 to 2-9 block_idx = 1 + int(name.split("_")[2]) - elif name.startswith("middle_block_"): # 10-12 - block_idx = 10 + int(name.split("_")[2]) - elif name.startswith("output_blocks_"): # 13-21 - block_idx = 13 + int(name.split("_")[2]) - elif name.startswith("out_"): # 22, out, no LoRA - block_idx = 22 + elif name.startswith("middle_block_"): # 10 + block_idx = 10 + elif name.startswith("output_blocks_"): # 0-8 to 11-19 + block_idx = 11 + int(name.split("_")[2]) + elif name.startswith("out_"): # 20, No LoRA in sd-scripts + block_idx = 20 return block_idx @@ -133,7 +153,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer if lbws: try: - # lbsは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している lbws = [json.loads(lbw) for lbw in lbws] except Exception: raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") @@ -143,14 +163,14 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False FLAGS = { - "12": layer12, - "17": layer17, - "20": layer20, - "26": layer26, + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), }[str(layer_num)] - is_sdxl = True if layer_num in SDXL_LAYER_NUM else False - TARGET = [i for i, flag in enumerate(FLAGS) if flag] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") @@ -162,9 +182,11 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer if base_model is None: base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) - full_lbw = [1] * 26 - for weight, index in zip(lbw, TARGET): - full_lbw[index] = weight + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + print(dict(zip(LAYER26.keys(), lbw_weights))) # merge logger.info(f"merging...") @@ -172,12 +194,6 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer if "lora_down" not in key: continue - if lbw: - index = 0 if "encoder" in key else get_block_index(key, is_sdxl) - is_lbw_target = index in TARGET - if is_lbw_target: - lbw_weight = full_lbw[index] - lora_module_name = key[: key.rfind(".lora_down")] down_weight = lora_sd[key] @@ -207,9 +223,12 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer # W <- W + U * D scale = alpha / network_dim + if lbw: + index = get_lbw_block_index(key, is_sdxl) + is_lbw_target = index in LBW_TARGET_IDX if is_lbw_target: - scale *= lbw_weight # keyがlbwの対象であれば、lbwの重みを掛ける + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける if device: # and isinstance(scale, torch.Tensor): scale = scale.to(device)