Skip to content

Commit d2a99a1

Browse files
authored
Merge pull request #1056 from kohya-ss/dev
fix vram usage in LoRA training
2 parents e6b15c7 + 0395a35 commit d2a99a1

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,16 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
249249

250250
## Change History
251251

252+
### Jan 17, 2024 / 2024/1/17: v0.8.1
253+
254+
- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`).
255+
- Text Encoders were not moved to CPU.
256+
- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053)
257+
258+
- LoRA 等の学習スクリプト(`train_network.py``sdxl_train_network.py`)で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。
259+
- Text Encoder が GPU に保持されたままになっていました。
260+
- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。
261+
252262
### Jan 15, 2024 / 2024/1/15: v0.8.0
253263

254264
- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade).

sdxl_train_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def cache_text_encoder_outputs_if_needed(
9595
unet.to(org_unet_device)
9696
else:
9797
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
98-
text_encoders[0].to(accelerator.device)
99-
text_encoders[1].to(accelerator.device)
98+
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
99+
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
100100

101101
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
102102
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:

train_network.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def cache_text_encoder_outputs_if_needed(
117117
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
118118
):
119119
for t_enc in text_encoders:
120-
t_enc.to(accelerator.device)
120+
t_enc.to(accelerator.device, dtype=weight_dtype)
121121

122122
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
123123
input_ids = batch["input_ids"].to(accelerator.device)
@@ -278,6 +278,7 @@ def train(self, args):
278278
accelerator.wait_for_everyone()
279279

280280
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
281+
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
281282
self.cache_text_encoder_outputs_if_needed(
282283
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
283284
)
@@ -394,8 +395,7 @@ def train(self, args):
394395
for t_enc in text_encoders:
395396
t_enc.requires_grad_(False)
396397

397-
# acceleratorがなんかよろしくやってくれるらしい
398-
# TODO めちゃくちゃ冗長なのでコードを整理する
398+
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
399399
if train_unet:
400400
unet = accelerator.prepare(unet)
401401
else:
@@ -407,8 +407,8 @@ def train(self, args):
407407
text_encoder = accelerator.prepare(text_encoder)
408408
text_encoders = [text_encoder]
409409
else:
410-
for t_enc in text_encoders:
411-
t_enc.to(accelerator.device, dtype=weight_dtype)
410+
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
411+
412412
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
413413

414414
if args.gradient_checkpointing:
@@ -685,7 +685,7 @@ def train(self, args):
685685
if accelerator.is_main_process:
686686
init_kwargs = {}
687687
if args.wandb_run_name:
688-
init_kwargs['wandb'] = {'name': args.wandb_run_name}
688+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
689689
if args.log_tracker_config is not None:
690690
init_kwargs = toml.load(args.log_tracker_config)
691691
accelerator.init_trackers(

0 commit comments

Comments
 (0)