diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index cb00a6000..6e163aecf 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,5 +1,8 @@ import argparse +import itertools +import json import os +import re import time import torch from safetensors.torch import load_file, save_file @@ -14,6 +17,106 @@ CLAMP_QUANTILE = 0.99 +ACCEPTABLE = [12, 17, 20, 26] +SDXL_LAYER_NUM = [12, 20] + +LAYER12 = { + "BASE": True, + "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False +} + +LAYER17 = { + "BASE": True, + "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +LAYER20 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, +} + +LAYER26 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +assert len([v for v in LAYER12.values() if v]) == 12 +assert len([v for v in LAYER17.values() if v]) == 17 +assert len([v for v in LAYER20.values() if v]) == 20 +assert len([v for v in LAYER26.values() if v]) == 26 + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: + # lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder + if "text_model_encoder_" in lora_name: # LoRA for text encoder + return 0 + + # lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2 + block_idx = -1 # invalid lora name + if not is_sdxl: + NUM_OF_BLOCKS = 12 # up/down blocks + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + up_down = g[0] + i = int(g[1]) + j = int(g[3]) + if up_down == "down": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + 1 + elif g[2] == "downsamplers": + idx = 3 * (i + 1) + else: + return block_idx # invalid lora name + elif up_down == "up": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers": + idx = 3 * i + 2 + else: + return block_idx # invalid lora name + + if g[0] == "down": + block_idx = 1 + idx # 1-based index, down block index + elif g[0] == "up": + block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index + + elif "mid_block_" in lora_name: + block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block + else: + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts + block_idx = 1 + elif name.startswith("input_blocks_"): # 1-8 to 2-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10 + block_idx = 10 + elif name.startswith("output_blocks_"): # 0-8 to 11-19 + block_idx = 11 + int(name.split("_")[2]) + elif name.startswith("out_"): # 20, No LoRA in sd-scripts + block_idx = 20 + + return block_idx + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -42,12 +145,34 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) -def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): +def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} - v2 = None + v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2 base_model = None - for model, ratio in zip(models, ratios): + + if lbws: + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) @@ -57,6 +182,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty if base_model is None: base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + print(dict(zip(LAYER26.keys(), lbw_weights))) + # merge logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): @@ -93,6 +224,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty # W <- W + U * D scale = alpha / network_dim + if lbw: + index = get_lbw_block_index(key, is_sdxl) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if device: # and isinstance(scale, torch.Tensor): scale = scale.to(device) @@ -170,6 +307,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty def merge(args): assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -187,7 +328,7 @@ def str_to_dtype(p): new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank state_dict, metadata, v2, base_model = merge_lora_models( - args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype + args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) logger.info(f"calculating hashes and creating metadata...") @@ -237,6 +378,7 @@ def setup_parser() -> argparse.ArgumentParser: "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") parser.add_argument( "--new_conv_rank",