diff --git a/flux_train_network.py b/flux_train_network.py index cfc617088..1e38cefc1 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -232,21 +232,21 @@ def cache_text_encoder_outputs_if_needed( logger.info("move vae and unet to cpu to save memory") org_vae_device = vae.device org_unet_device = unet.device - vae.to("cpu") - unet.to("cpu") + vae = vae.to("cpu") + unet = unet.to("cpu") clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 - text_encoders[1].to(accelerator.device) + text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True) # always not fp8 + text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True) if text_encoders[1].dtype == torch.float8_e4m3fn: # if we load fp8 weights, the model is already fp8, so we use it as is self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) else: # otherwise, we need to convert it to target dtype - text_encoders[1].to(weight_dtype) + text_encoders[1] = text_encoders[1].to(weight_dtype, non_blocking=True) with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) @@ -276,19 +276,19 @@ def cache_text_encoder_outputs_if_needed( # move back to cpu if not self.is_train_text_encoder(args): logger.info("move CLIP-L back to cpu") - text_encoders[0].to("cpu") + text_encoders[0] = text_encoders[0].to("cpu", non_blocking=True) logger.info("move t5XXL back to cpu") - text_encoders[1].to("cpu") + text_encoders[1] = text_encoders[1].to("cpu", non_blocking=True) clean_memory_on_device(accelerator.device) if not args.lowram: logger.info("move vae and unet back to original device") - vae.to(org_vae_device) - unet.to(org_unet_device) + vae = vae.to(org_vae_device, non_blocking=True) + unet = unet.to(org_unet_device, non_blocking=True) else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device) + text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True) + text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True) def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): text_encoders = text_encoder # for compatibility @@ -429,7 +429,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t noisy_model_input[diff_output_pr_indices], sigmas[diff_output_pr_indices] if sigmas is not None else None, ) - target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype, non_blocking=True) return model_pred, target, timesteps, weighting @@ -468,8 +468,8 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): if index == 0: # CLIP-L logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") - text_encoder.to(te_weight_dtype) # fp8 - text_encoder.text_model.embeddings.to(dtype=weight_dtype) + text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8 + text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype) else: # T5XXL def prepare_fp8(text_encoder, target_dtype): @@ -488,7 +488,7 @@ def forward(hidden_states): for module in text_encoder.modules(): if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) + module = module.to(target_dtype, non_blocking=True) if module.__class__.__name__ in ["T5DenseGatedActDense"]: # print("set", module.__class__.__name__, "hooks") module.forward = forward_hook(module) @@ -497,7 +497,7 @@ def forward(hidden_states): logger.info(f"T5XXL already prepared for fp8") else: logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") - text_encoder.to(te_weight_dtype) # fp8 + text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8 prepare_fp8(text_encoder, weight_dtype) def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 0681dcdcb..48faca277 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -53,7 +53,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye # print( # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" # ) - module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device, non_blocking=True) torch.cuda.current_stream().synchronize() # this prevents the illegal loss value diff --git a/library/flux_models.py b/library/flux_models.py index d2d7e06c7..84b5aa358 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -307,7 +307,7 @@ def forward(self, z: Tensor) -> Tensor: mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) if self.sample: std = torch.exp(0.5 * logvar) - return mean + std * torch.randn_like(mean) + return mean + std * torch.randn_like(mean, pin_memory=True) else: return mean @@ -532,7 +532,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 """ t = time_factor * t half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, pin_memory=True) / half).to(t.device, non_blocking=True) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) @@ -600,7 +600,7 @@ def __init__(self, dim: int): def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: q = self.query_norm(q) k = self.key_norm(k) - return q.to(v), k.to(v) + return q.to(v, non_blocking=True), k.to(v, non_blocking=True) class SelfAttention(nn.Module): @@ -997,7 +997,7 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = None self.single_blocks = None - self.to(device) + self = self.to(device, non_blocking=True) if self.blocks_to_swap: self.double_blocks = save_double_blocks @@ -1081,8 +1081,8 @@ def forward( img = img[:, txt.shape[1] :, ...] if self.training and self.cpu_offload_checkpointing: - img = img.to(self.device) - vec = vec.to(self.device) + img = img.to(self.device, non_blocking=True) + vec = vec.to(self.device, non_blocking=True) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) @@ -1243,7 +1243,7 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = nn.ModuleList() self.single_blocks = nn.ModuleList() - self.to(device) + self = self.to(device, non_blocking=True) if self.blocks_to_swap: self.double_blocks = save_double_blocks diff --git a/library/strategy_sd.py b/library/strategy_sd.py index a44fc4092..d0a3a68bf 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,7 +40,7 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] - def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: text = [text] if isinstance(text, str) else text tokens_list = [] weights_list = [] diff --git a/library/train_util.py b/library/train_util.py index 756d88b1c..719611f61 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4,6 +4,7 @@ import ast import asyncio from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import nullcontext import datetime import importlib import json @@ -26,6 +27,7 @@ # from concurrent.futures import ThreadPoolExecutor, as_completed +from torch.cuda import Stream from tqdm import tqdm from packaging.version import Version @@ -1415,10 +1417,11 @@ def cache_text_encoder_outputs_common( return # prepare tokenizers and text encoders - for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): - text_encoder.to(device) + for i, (text_encoder, device, te_dtype) in enumerate(zip(text_encoders, devices, te_dtypes)): + te_kwargs = {} if te_dtype is not None: - text_encoder.to(dtype=te_dtype) + te_kwargs['dtype'] = te_dtype + text_encoders[i] = text_encoder.to(device, non_blocking=True, **te_dtype) # create batch is_sd3 = len(tokenizers) == 1 @@ -1440,6 +1443,8 @@ def cache_text_encoder_outputs_common( if len(batch) > 0: batches.append(batch) + torch.cuda.synchronize() + # iterate batches: call text encoder and cache outputs for memory or disk logger.info("caching text encoder outputs...") if not is_sd3: @@ -3120,7 +3125,10 @@ def cache_batch_latents( images.append(image) img_tensors = torch.stack(images, dim=0) - img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) + + s = Stream() + + img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype, non_blocking=True) with torch.no_grad(): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") @@ -3156,12 +3164,13 @@ def cache_batch_latents( if not HIGH_VRAM: clean_memory_on_device(vae.device) + torch.cuda.synchronize() def cache_batch_text_encoder_outputs( image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype ): - input_ids1 = input_ids1.to(text_encoders[0].device) - input_ids2 = input_ids2.to(text_encoders[1].device) + input_ids1 = input_ids1.to(text_encoders[0].device, non_blocking=True) + input_ids2 = input_ids2.to(text_encoders[1].device, non_blocking=True) with torch.no_grad(): b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( @@ -5619,9 +5628,9 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio ) # work on low-ram device if args.lowram: - text_encoder.to(accelerator.device) - unet.to(accelerator.device) - vae.to(accelerator.device) + text_encoder = text_encoder.to(accelerator.device, non_blocking=True) + unet = unet.to(accelerator.device, non_blocking=True) + vae = vae.to(accelerator.device, non_blocking=True) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -6435,7 +6444,7 @@ def sample_images_common( distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here org_vae_device = vae.device # CPUにいるはず - vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device + vae = vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet_wrapped) @@ -6470,7 +6479,7 @@ def sample_images_common( requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.to(distributed_state.device) + pipeline = pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) @@ -6521,7 +6530,7 @@ def sample_images_common( torch.set_rng_state(rng_state) if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) + vae = vae.to(org_vae_device) clean_memory_on_device(accelerator.device) diff --git a/library/utils.py b/library/utils.py index 296fc4151..7ae03f812 100644 --- a/library/utils.py +++ b/library/utils.py @@ -110,7 +110,7 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): # cuda to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.record_stream(stream) - module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu") stream.synchronize() diff --git a/networks/oft.py b/networks/oft.py index 0c3a5393f..cbadf9b70 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -49,11 +49,11 @@ def __init__( if type(alpha) == torch.Tensor: alpha = alpha.detach().numpy() - + # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha - self.constraint = alpha * out_dim - + self.constraint = alpha * out_dim + self.register_buffer("alpha", torch.tensor(alpha)) self.block_size = out_dim // self.num_blocks diff --git a/train_db.py b/train_db.py index 4bf3b31ce..689d6c970 100644 --- a/train_db.py +++ b/train_db.py @@ -239,8 +239,8 @@ def train(args): args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) + unet = unet.to(weight_dtype) + text_encoder = text_encoder.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい if args.deepspeed: @@ -335,6 +335,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + optimizer.train() current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: diff --git a/train_network.py b/train_network.py index 6cebf5fc7..4c6aa1ecd 100644 --- a/train_network.py +++ b/train_network.py @@ -222,8 +222,8 @@ def is_train_text_encoder(self, args): return not args.network_train_unet_only def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): - for t_enc in text_encoders: - t_enc.to(accelerator.device, dtype=weight_dtype) + for i, t_enc in enumerate(text_encoders): + text_encoders[i] = t_enc.to(accelerator.device, dtype=weight_dtype) def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample @@ -323,7 +323,7 @@ def get_noise_pred_and_target( indices=diff_output_pr_indices, ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step - target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype, non_blocking=True) return noise_pred, target, timesteps, None @@ -352,7 +352,7 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): text_encoder.text_model.embeddings.requires_grad_(True) def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): - text_encoder.text_model.embeddings.to(dtype=weight_dtype) + text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype) def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module @@ -390,11 +390,11 @@ def process_batch( """ with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) + latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device, non_blocking=True)) else: # latentに変換 if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: - latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype, non_blocking=True)) else: chunks = [ batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size) @@ -402,7 +402,7 @@ def process_batch( list_latents = [] for chunk in chunks: with torch.no_grad(): - chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype)) + chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype, non_blocking=True)) list_latents.append(chunk) latents = torch.cat(list_latents, dim=0) @@ -431,14 +431,14 @@ def process_batch( weights_list, ) else: - input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + input_ids = [ids.to(accelerator.device, non_blocking=True) for ids in batch["input_ids_list"]] encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) if args.full_fp16: - encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + encoded_text_encoder_conds = [c.to(weight_dtype, non_blocking=True) for c in encoded_text_encoder_conds] # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: @@ -449,6 +449,8 @@ def process_batch( if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] + torch.cuda.synchronize() + # sample noise, call unet, get target noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, @@ -816,13 +818,13 @@ def train(self, args): args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") - network.to(weight_dtype) + network = network.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") - network.to(weight_dtype) + network = network.to(weight_dtype) unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram @@ -844,7 +846,7 @@ def train(self, args): # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") - unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator + unet = unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.requires_grad_(False) if self.cast_unet(args): @@ -858,7 +860,7 @@ def train(self, args): # nn.Embedding not support FP8 if te_weight_dtype != weight_dtype: - self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) + self.prepare_text_encoder_fp8(i, text_encoders[i], te_weight_dtype, weight_dtype) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: @@ -920,7 +922,7 @@ def train(self, args): if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) + vae = vae.to(accelerator.device, dtype=vae_dtype) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -1398,6 +1400,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen torch.cuda.set_rng_state(gpu_rng_state) random.setstate(python_rng_state) + torch.cuda.empty_cache() + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1454,6 +1458,12 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen if hasattr(network, "update_norms"): network.update_norms() + torch.cuda.synchronize() # Ensure GPU ops complete before next batch + + # Periodic cleanup + if step % 50 == 0: + torch.cuda.empty_cache() + optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True)