From 9c782f310a64c14283e32953cb01b4e0db33e8e8 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 28 Jun 2024 09:30:52 +0800 Subject: [PATCH 1/7] add new lr scheduler --- library/train_util.py | 43 ++++++++++++++++++++++++++++++++++++++++++- requirements.txt | 2 +- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d1405643c..63c3405c5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -37,7 +37,7 @@ from transformers import CLIPTokenizer import transformers import diffusers -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, DDPMScheduler, @@ -2082,6 +2082,12 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=0, help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", ) + parser.add_argument( + "--lr_decay_steps", + type=int, + default=0, + help="Number of steps for the decay in the lr scheduler (default is 0) ", + ) parser.add_argument( "--lr_scheduler_num_cycles", type=int, @@ -2094,6 +2100,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument( + "--lr_scheduler_timescale", + type=int, + default=None, + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`", + ) + parser.add_argument( + "--lr_scheduler_min_lr_ratio", + type=float, + default=None, + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -2921,8 +2939,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): name = args.lr_scheduler num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps + num_decay_steps: Optional[int] = args.lr_decay_steps + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power + timescale = args.lr_scheduler_timescale + min_lr_ratio = args.lr_scheduler_min_lr_ratio lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: @@ -2982,6 +3004,9 @@ def wrap_check_needless_num_warmup_steps(return_vals): if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale) + # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") @@ -2994,6 +3019,22 @@ def wrap_check_needless_num_warmup_steps(return_vals): if name == SchedulerType.POLYNOMIAL: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power) + if name == SchedulerType.COSINE_WITH_MIN_LR: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, min_lr_rate=min_lr_ratio) + + # All other schedulers require `num_decay_steps` + if num_decay_steps is None: + raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles, + min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0 + ) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) diff --git a/requirements.txt b/requirements.txt index debe2c789..3f55709de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ accelerate==0.15.0 -transformers==4.26.0 +transformers==4.41.2 ftfy==6.1.1 albumentations==1.3.0 opencv-python==4.7.0.68 From dc6767a88601ab18401e5cc17c3f31de1cd3a597 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 28 Jun 2024 12:54:04 +0800 Subject: [PATCH 2/7] fix bugs and use num_cycles / 2 --- library/train_util.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index a64bffc8b..b0fe33cb8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4359,9 +4359,6 @@ def wrap_check_needless_num_warmup_steps(return_vals): if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == SchedulerType.PIECEWISE_CONSTANT: - return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs - # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") @@ -4404,7 +4401,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): num_warmup_steps=num_warmup_steps, num_stable_steps=num_stable_steps, num_decay_steps=num_decay_steps, - num_cycles=num_cycles, + num_cycles=num_cycles / 2, min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, **lr_scheduler_kwargs ) From 5488b5133a8da0099fd26141ffce2cad2e0b54d9 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 28 Jun 2024 23:18:53 +0800 Subject: [PATCH 3/7] Update requirements.txt --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 15f5ea592..1e0a278ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -accelerate==0.25.0 +accelerate==0.30.0 transformers==4.41.2 diffusers[torch]==0.25.0 ftfy==6.1.1 @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.20.1 +huggingface-hub==0.23.3 # for Image utils imagesize==1.4.1 # for BLIP captioning From 005a232faa26564278ccba95b7a13a7373810986 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 30 Aug 2024 23:17:57 +0800 Subject: [PATCH 4/7] add num_cycles for min lr --- library/train_util.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index b0fe33cb8..fd24d29b5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4389,7 +4389,12 @@ def wrap_check_needless_num_warmup_steps(return_vals): if name == SchedulerType.COSINE_WITH_MIN_LR: return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, min_lr_rate=min_lr_ratio, **lr_scheduler_kwargs + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles / 2, + min_lr_rate=min_lr_ratio, + **lr_scheduler_kwargs ) # All other schedulers require `num_decay_steps` From 545fef2ad0b33326180d8ffde94de4ba55c8bc0a Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 2 Sep 2024 00:23:25 +0800 Subject: [PATCH 5/7] keep PIECEWISE_CONSTANT --- library/train_util.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index fd24d29b5..05cf6f26c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -42,6 +42,7 @@ from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers +from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, @@ -4353,12 +4354,15 @@ def wrap_check_needless_num_warmup_steps(return_vals): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + name = SchedulerType(name) or DiffusersSchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT: + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs + # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") From 717c379765ef0105d94ae7d6625fb91e53ee369d Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 2 Sep 2024 00:50:36 +0800 Subject: [PATCH 6/7] allow use float with warmup or decay ratio. --- library/train_util.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 05cf6f26c..195bf379e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2972,6 +2972,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): + def int_or_float(value): + if value.endswith('%'): + try: + return float(value[:-1]) / 100.0 + except ValueError: + raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") + try: + float_value = float(value) + if float_value >= 1: + return int(value) + return float(value) + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") + parser.add_argument( "--optimizer_type", type=str, @@ -3024,15 +3038,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( "--lr_warmup_steps", - type=int, + type=int_or_float, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", ) parser.add_argument( "--lr_decay_steps", - type=int, + type=int_or_float, default=0, - help="Number of steps for the decay in the lr scheduler (default is 0) ", + help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -4311,9 +4325,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): Unified API to get any scheduler from its name. """ name = args.lr_scheduler - num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps - num_decay_steps: Optional[int] = args.lr_decay_steps + num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power From 416e521ab84aaf845afee51bea56e43bc63a3a80 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 2 Sep 2024 00:58:32 +0800 Subject: [PATCH 7/7] Update train_util.py --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 195bf379e..cf5d4f332 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4412,7 +4412,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): num_training_steps=num_training_steps, num_cycles=num_cycles / 2, min_lr_rate=min_lr_ratio, - **lr_scheduler_kwargs + **lr_scheduler_kwargs, ) # All other schedulers require `num_decay_steps` @@ -4426,7 +4426,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): num_decay_steps=num_decay_steps, num_cycles=num_cycles / 2, min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, - **lr_scheduler_kwargs + **lr_scheduler_kwargs, ) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs)