From f99fe281cbb6519b7b5f1199c570d496ad4df474 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 1 Apr 2024 15:38:26 -0400 Subject: [PATCH 1/5] Add LoRA+ support --- library/train_util.py | 2 ++ networks/dylora.py | 45 ++++++++++++++++++++++++++---------- networks/lora.py | 54 ++++++++++++++++++++++++++++--------------- train_network.py | 2 +- 4 files changed, 71 insertions(+), 32 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d2b69edb5..4e5ab7370 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): diff --git a/networks/dylora.py b/networks/dylora.py index 637f33450..a73ade8bd 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -406,27 +406,48 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): logger.info(f"weights are merged") """ - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 948b30b0e..8d7619777 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,21 +1035,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1063,21 +1085,15 @@ def enumerate_params(loras): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/train_network.py b/train_network.py index e0fa69458..ba0c124d1 100644 --- a/train_network.py +++ b/train_network.py @@ -339,7 +339,7 @@ def train(self, args): # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" From c7691607ea1647864b5149c98434a27f23386c65 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 1 Apr 2024 15:43:04 -0400 Subject: [PATCH 2/5] Add LoRA-FA for LoRA+ --- networks/lora_fa.py | 58 +++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 919222ce8..fcc503e89 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,22 +1033,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras: List[LoRAModule]): - params = [] + def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - # params.extend(lora.parameters()) - params.extend(lora.get_trainable_params()) + for name, param in lora.get_trainable_named_params(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1062,21 +1083,15 @@ def enumerate_params(loras: List[LoRAModule]): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params @@ -1093,6 +1108,9 @@ def on_epoch_start(self, text_encoder, unet): def get_trainable_params(self): return self.parameters() + def get_trainable_named_params(self): + return self.named_parameters() + def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None From 1933ab4b4848b1f8b578c10f25bd050f5e246ac0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 3 Apr 2024 12:46:34 -0400 Subject: [PATCH 3/5] Fix default_lr being applied --- networks/dylora.py | 21 ++++++++++++++++++--- networks/lora.py | 30 +++++++++++++++++++++++------- networks/lora_fa.py | 30 +++++++++++++++++++++++------- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index a73ade8bd..edc3e2229 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -407,7 +407,14 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): """ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -442,11 +449,19 @@ def assemble_params(loras, lr, lora_plus_ratio): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 8d7619777..e082941e5 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,7 +1035,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1070,7 +1077,11 @@ def assemble_params(loras, lr, lora_plus_ratio): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1085,14 +1096,19 @@ def assemble_params(loras, lr, lora_plus_ratio): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora_fa.py b/networks/lora_fa.py index fcc503e89..3f6774dd8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,7 +1033,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1068,7 +1075,11 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1083,14 +1094,19 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params From 75833e84a1c7e3c2fb0a9e3ce0fe3d8c1758a012 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 8 Apr 2024 19:23:02 -0400 Subject: [PATCH 4/5] Fix default LR, Add overall LoRA+ ratio, Add log `--loraplus_ratio` added for both TE and UNet Add log for lora+ --- library/train_util.py | 1 + networks/dylora.py | 24 ++++++------- networks/lora.py | 28 ++++++++-------- networks/lora_fa.py | 30 ++++++++--------- train_network.py | 78 ++++++++++++++++++++++++++++++++----------- 5 files changed, 101 insertions(+), 60 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4e5ab7370..7c2bf6935 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") diff --git a/networks/dylora.py b/networks/dylora.py index edc3e2229..dc5c7cb35 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -412,32 +412,32 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_B" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -452,7 +452,7 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -460,7 +460,7 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( self.unet_loras, default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora.py b/networks/lora.py index e082941e5..6cb05bcb0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1040,32 +1040,32 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1080,7 +1080,7 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1099,15 +1099,15 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 3f6774dd8..2eff86d6c 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1038,32 +1038,32 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: - for name, param in lora.get_trainable_named_params(): - if lora_plus_ratio is not None and "lora_up" in name: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1078,7 +1078,7 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1097,15 +1097,15 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/train_network.py b/train_network.py index ba0c124d1..43226fc47 100644 --- a/train_network.py +++ b/train_network.py @@ -66,34 +66,69 @@ def generate_step_logs( lrs = lr_scheduler.get_last_lr() - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lrs[0]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) - else: + if len(lrs) > 4: idx = 0 if not args.network_train_unet_only: logs["lr/textencoder"] = float(lrs[0]) idx = 1 for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) + lora_plus = "" + group_id = i + + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + lora_plus = '_lora+' if i % 2 == 1 else '' + group_id = int((i / 2) + (i % 2 + 0.5)) + + logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i]) if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( + logs[f"lr/d*lr/group{group_id}{lora_plus}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + else: + if args.network_train_text_encoder_only: + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + + elif args.network_train_unet_only: + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + else: + logs["lr/unet"] = float(lrs[0]) + else: + if len(lrs) == 2: + if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None: + logs["lr/all"] = float(lrs[0]) + logs["lr/all_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) + elif len(lrs) == 4: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + logs["lr/unet"] = float(lrs[2]) + logs["lr/unet_lora+"] = float(lrs[3]) + else: + logs["lr/all"] = float(lrs[0]) + + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + ) + return logs def assert_extra_args(self, args, train_dataset_group): @@ -339,7 +374,7 @@ def train(self, args): # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" @@ -348,6 +383,11 @@ def train(self, args): optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + assert ( + (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name) + ), "LoRA+ and Prodigy/DAdaptation is not supported" + # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers From 68467bdf4d76ba2c57289209b0ffd6ba599e2080 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 11 Apr 2024 17:33:19 -0400 Subject: [PATCH 5/5] Fix unset or invalid LR from making a param_group --- networks/dylora.py | 4 ++-- networks/lora.py | 5 +++-- networks/lora_fa.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index dc5c7cb35..0546fc7ae 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -412,8 +412,8 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -441,7 +441,7 @@ def assemble_params(loras, lr, ratio): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: continue params.append(param_data) diff --git a/networks/lora.py b/networks/lora.py index 6cb05bcb0..d74608fea 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1040,8 +1040,8 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -1069,7 +1069,8 @@ def assemble_params(loras, lr, ratio): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + print("NO LR skipping!") continue params.append(param_data) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 2eff86d6c..9a608118a 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1038,8 +1038,8 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -1067,7 +1067,7 @@ def assemble_params(loras, lr, ratio): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: continue params.append(param_data)