diff --git a/fine_tune.py b/fine_tune.py index 875a91951..c5e97d267 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -108,6 +108,7 @@ def train(args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) @@ -158,7 +159,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習を準備する if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): @@ -191,7 +192,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if not cache_latents: vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) for m in training_models: m.requires_grad_(True) @@ -218,7 +219,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): batch_size=1, shuffle=True, collate_fn=collator, - num_workers=n_workers, + num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. persistent_workers=args.persistent_data_loader_workers, ) @@ -230,7 +231,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print( f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" ) - + # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -246,13 +247,28 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): unet.to(weight_dtype) text_encoder.to(weight_dtype) - # acceleratorがなんかよろしくやってくれるらしい - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.deepspeed: + training_models_dict = {} + training_models_dict["unet"] = unet + if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder + + ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) + + training_models = [] + unet = ds_model.models["unet"] + training_models.append(unet) + if args.train_text_encoder: + text_encoder = ds_model.models["text_encoder"] + training_models.append(text_encoder) + + else: # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 1932bf881..a29013e34 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -24,7 +24,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): - # load models for each process model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: diff --git a/library/train_util.py b/library/train_util.py index b71e4edc6..3781dcde8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -20,7 +20,8 @@ Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs +from accelerate import DeepSpeedPlugin import glob import math import os @@ -3242,6 +3243,52 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) + # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed + parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training") + parser.add_argument( + "--zero_stage", + type=int, default=2, + choices=[0, 1, 2, 3], + help="Possible options are 0,1,2,3." + ) + parser.add_argument( + "--offload_optimizer_device", + type=str, default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3." + ) + parser.add_argument( + "--offload_optimizer_nvme_path", + type=str, default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3." + ) + parser.add_argument( + "--offload_param_device", + type=str, default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3." + ) + parser.add_argument( + "--offload_param_nvme_path", + type=str, default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3." + ) + parser.add_argument( + "--zero3_init_flag", + action="store_true", + help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." + "Only applicable with ZeRO Stage-3." + ) + parser.add_argument( + "--zero3_save_16bit_model", + action="store_true", + help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3." + ) + parser.add_argument( + "--fp16_master_weights_and_gradients", + action="store_true", + help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32." + ) def verify_training_args(args: argparse.Namespace): r""" @@ -4088,6 +4135,8 @@ def prepare_accelerator(args: argparse.Namespace): ), ) kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) + deepspeed_plugin = prepare_deepspeed_plugin(args) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -4095,10 +4144,67 @@ def prepare_accelerator(args: argparse.Namespace): project_dir=logging_dir, kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, + deepspeed_plugin=deepspeed_plugin, ) print("accelerator device:", accelerator.device) return accelerator +def prepare_deepspeed_plugin(args: argparse.Namespace): + if args.deepspeed is None: return None + try: + import deepspeed + except ImportError as e: + print("deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed") + exit(1) + + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=args.zero_stage, + gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm, + offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, + offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path, + zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model, + ) + deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size + deepspeed_plugin.deepspeed_config['train_batch_size'] = \ + args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE']) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) + if args.mixed_precision.lower() == "fp16": + deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0 # preventing overflow. + if args.full_fp16 or args.fp16_master_weights_and_gradients: + if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: + deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True + print("[DeepSpeed] full fp16 enable.") + else: + print("[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage.") + + if args.offload_optimizer_device is not None: + print('[DeepSpeed] start to manually build cpu_adam.') + deepspeed.ops.op_builder.CPUAdamBuilder().load() + print('[DeepSpeed] building cpu_adam done.') + + return deepspeed_plugin + +def prepare_deepspeed_model(args: argparse.Namespace, **models): + class DeepSpeedWrapper(torch.nn.Module): + def __init__(self, **kw_models) -> None: + super().__init__() + self.models = torch.nn.ModuleDict() + + for key, model in kw_models.items(): + if isinstance(model, list): + model = torch.nn.ModuleList(model) + assert isinstance(model, torch.nn.Module), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update( + torch.nn.ModuleDict( + {key: model} + ) + ) + + def get_models(self): + return self.models + + ds_model = DeepSpeedWrapper(**models) + return ds_model def prepare_dtype(args: argparse.Namespace): weight_dtype = torch.float32 @@ -4165,7 +4271,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): - # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") @@ -4176,7 +4281,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2, ) - # work on low-ram device if args.lowram: text_encoder.to(accelerator.device) @@ -4185,7 +4289,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() - return text_encoder, vae, unet, load_stable_diffusion_format diff --git a/sdxl_train.py b/sdxl_train.py index e0df263d6..5e5e9f291 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -361,7 +361,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): batch_size=1, shuffle=True, collate_fn=collator, - num_workers=n_workers, + num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. persistent_workers=args.persistent_data_loader_workers, ) @@ -398,18 +398,41 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) - # acceleratorがなんかよろしくやってくれるらしい - if train_unet: - unet = accelerator.prepare(unet) - if train_text_encoder1: - # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer - text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) - text_encoder1.text_model.final_layer_norm.requires_grad_(False) - text_encoder1 = accelerator.prepare(text_encoder1) - if train_text_encoder2: - text_encoder2 = accelerator.prepare(text_encoder2) - - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.deepspeed: + training_models_dict = {} + if train_unet: + training_models_dict["unet"] = unet + if train_text_encoder1: + text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + text_encoder1.text_model.final_layer_norm.requires_grad_(False) + training_models_dict["text_encoder1"] = text_encoder1 + if train_text_encoder2: + training_models_dict["text_encoder2"] = text_encoder2 + ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) + + training_models = [] # override training_models + if train_unet: + unet = ds_model.models["unet"] + training_models.append(unet) + if train_text_encoder1: + text_encoder1 = ds_model.models["text_encoder1"] + training_models.append(text_encoder1) + if train_text_encoder2: + text_encoder2 = ds_model.models["text_encoder2"] + training_models.append(text_encoder2) + + else: # acceleratorがなんかよろしくやってくれるらしい + if train_unet: + unet = accelerator.prepare(unet) + if train_text_encoder1: + # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + text_encoder1.text_model.final_layer_norm.requires_grad_(False) + text_encoder1 = accelerator.prepare(text_encoder1) + if train_text_encoder2: + text_encoder2 = accelerator.prepare(text_encoder2) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -423,7 +446,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2.to(accelerator.device) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: + if args.full_fp16 and not args.deepspeed: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする @@ -484,10 +508,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(*training_models): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - with torch.no_grad(): + with torch.no_grad(): # why this block differ within train_network.py? + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: # latentに変換 latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype) @@ -495,7 +519,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: input_ids1 = batch["input_ids"] diff --git a/train_db.py b/train_db.py index 8d36097a5..66a83d1df 100644 --- a/train_db.py +++ b/train_db.py @@ -187,7 +187,7 @@ def train(args): batch_size=1, shuffle=True, collate_fn=collator, - num_workers=n_workers, + num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. persistent_workers=args.persistent_data_loader_workers, ) @@ -219,15 +219,31 @@ def train(args): text_encoder.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.deepspeed: + training_models_dict = {} + training_models_dict["unet"] = unet + if train_text_encoder: training_models_dict["text_encoder"] = text_encoder + + ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) + + training_models = [] + unet = ds_model.models["unet"] + training_models.append(unet) + if train_text_encoder: + text_encoder = ds_model.models["text_encoder"] + training_models.append(text_encoder) + else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - if not train_text_encoder: - text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: diff --git a/train_network.py b/train_network.py index e5b26d8a2..af1b7f635 100644 --- a/train_network.py +++ b/train_network.py @@ -357,7 +357,7 @@ def train(self, args): batch_size=1, shuffle=True, collate_fn=collator, - num_workers=n_workers, + num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. persistent_workers=args.persistent_data_loader_workers, ) @@ -413,20 +413,38 @@ def train(self, args): t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good - if train_unet: - unet = accelerator.prepare(unet) + if args.deepspeed: + training_models_dict = {} + if train_unet: training_models_dict["unet"] = unet + if train_text_encoder: training_models_dict["text_encoder"] = text_encoders + training_models_dict["network"] = network + + ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) + + if train_unet: unet = ds_model.models["unet"] + if train_text_encoder: + text_encoder = ds_model.models["text_encoder"] + if len(ds_model.models["text_encoder"]) > 1: + text_encoders = text_encoder + else: + text_encoders = [text_encoder] + else: - unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator - if train_text_encoder: - if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + if train_unet: + unet = accelerator.prepare(unet) else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] - else: - pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator + if train_text_encoder: + if len(text_encoders) > 1: + text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + else: + text_encoder = accelerator.prepare(text_encoder) + text_encoders = [text_encoder] + else: + pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required @@ -453,7 +471,8 @@ def train(self, args): vae.to(accelerator.device, dtype=vae_dtype) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: + if args.full_fp16 and not args.deepspeed: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする