Skip to content

Commit 5f6bf29

Browse files
shirayukohya-ss
andauthored
Replace print with logger if they are logs (#905)
* Add get_my_logger() * Use logger instead of print * Fix log level * Removed line-breaks for readability * Use setup_logging() * Add rich to requirements.txt * Make simple * Use logger instead of print --------- Co-authored-by: Kohya S <[email protected]>
1 parent 7f948db commit 5f6bf29

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1195
-961
lines changed

fine_tune.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from accelerate.utils import set_seed
1919
from diffusers import DDPMScheduler
2020

21+
from library.utils import setup_logging
22+
setup_logging()
23+
import logging
24+
logger = logging.getLogger(__name__)
2125
import library.train_util as train_util
2226
import library.config_util as config_util
2327
from library.config_util import (
@@ -49,11 +53,11 @@ def train(args):
4953
if args.dataset_class is None:
5054
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
5155
if args.dataset_config is not None:
52-
print(f"Load dataset config from {args.dataset_config}")
56+
logger.info(f"Load dataset config from {args.dataset_config}")
5357
user_config = config_util.load_user_config(args.dataset_config)
5458
ignored = ["train_data_dir", "in_json"]
5559
if any(getattr(args, attr) is not None for attr in ignored):
56-
print(
60+
logger.warning(
5761
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
5862
", ".join(ignored)
5963
)
@@ -86,7 +90,7 @@ def train(args):
8690
train_util.debug_dataset(train_dataset_group)
8791
return
8892
if len(train_dataset_group) == 0:
89-
print(
93+
logger.error(
9094
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
9195
)
9296
return
@@ -97,7 +101,7 @@ def train(args):
97101
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
98102

99103
# acceleratorを準備する
100-
print("prepare accelerator")
104+
logger.info("prepare accelerator")
101105
accelerator = train_util.prepare_accelerator(args)
102106

103107
# mixed precisionに対応した型を用意しておき適宜castする
@@ -461,7 +465,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
461465
train_util.save_sd_model_on_train_end(
462466
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
463467
)
464-
print("model saved.")
468+
logger.info("model saved.")
465469

466470

467471
def setup_parser() -> argparse.ArgumentParser:

finetune/blip/blip.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
import os
2222
from urllib.parse import urlparse
2323
from timm.models.hub import download_cached_file
24+
from library.utils import setup_logging
25+
setup_logging()
26+
import logging
27+
logger = logging.getLogger(__name__)
2428

2529
class BLIP_Base(nn.Module):
2630
def __init__(self,
@@ -235,6 +239,6 @@ def load_checkpoint(model,url_or_filename):
235239
del state_dict[key]
236240

237241
msg = model.load_state_dict(state_dict,strict=False)
238-
print('load checkpoint from %s'%url_or_filename)
242+
logger.info('load checkpoint from %s'%url_or_filename)
239243
return model,msg
240244

finetune/clean_captions_and_tags.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
import re
99

1010
from tqdm import tqdm
11+
from library.utils import setup_logging
12+
setup_logging()
13+
import logging
14+
logger = logging.getLogger(__name__)
1115

1216
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
1317
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
@@ -36,13 +40,13 @@ def clean_tags(image_key, tags):
3640
tokens = tags.split(", rating")
3741
if len(tokens) == 1:
3842
# WD14 taggerのときはこちらになるのでメッセージは出さない
39-
# print("no rating:")
40-
# print(f"{image_key} {tags}")
43+
# logger.info("no rating:")
44+
# logger.info(f"{image_key} {tags}")
4145
pass
4246
else:
4347
if len(tokens) > 2:
44-
print("multiple ratings:")
45-
print(f"{image_key} {tags}")
48+
logger.info("multiple ratings:")
49+
logger.info(f"{image_key} {tags}")
4650
tags = tokens[0]
4751

4852
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
@@ -124,43 +128,43 @@ def clean_caption(caption):
124128

125129
def main(args):
126130
if os.path.exists(args.in_json):
127-
print(f"loading existing metadata: {args.in_json}")
131+
logger.info(f"loading existing metadata: {args.in_json}")
128132
with open(args.in_json, "rt", encoding='utf-8') as f:
129133
metadata = json.load(f)
130134
else:
131-
print("no metadata / メタデータファイルがありません")
135+
logger.error("no metadata / メタデータファイルがありません")
132136
return
133137

134-
print("cleaning captions and tags.")
138+
logger.info("cleaning captions and tags.")
135139
image_keys = list(metadata.keys())
136140
for image_key in tqdm(image_keys):
137141
tags = metadata[image_key].get('tags')
138142
if tags is None:
139-
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
143+
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
140144
else:
141145
org = tags
142146
tags = clean_tags(image_key, tags)
143147
metadata[image_key]['tags'] = tags
144148
if args.debug and org != tags:
145-
print("FROM: " + org)
146-
print("TO: " + tags)
149+
logger.info("FROM: " + org)
150+
logger.info("TO: " + tags)
147151

148152
caption = metadata[image_key].get('caption')
149153
if caption is None:
150-
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
154+
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
151155
else:
152156
org = caption
153157
caption = clean_caption(caption)
154158
metadata[image_key]['caption'] = caption
155159
if args.debug and org != caption:
156-
print("FROM: " + org)
157-
print("TO: " + caption)
160+
logger.info("FROM: " + org)
161+
logger.info("TO: " + caption)
158162

159163
# metadataを書き出して終わり
160-
print(f"writing metadata: {args.out_json}")
164+
logger.info(f"writing metadata: {args.out_json}")
161165
with open(args.out_json, "wt", encoding='utf-8') as f:
162166
json.dump(metadata, f, indent=2)
163-
print("done!")
167+
logger.info("done!")
164168

165169

166170
def setup_parser() -> argparse.ArgumentParser:
@@ -178,10 +182,10 @@ def setup_parser() -> argparse.ArgumentParser:
178182

179183
args, unknown = parser.parse_known_args()
180184
if len(unknown) == 1:
181-
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
182-
print("All captions and tags in the metadata are processed.")
183-
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
184-
print("メタデータ内のすべてのキャプションとタグが処理されます。")
185+
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
186+
logger.warning("All captions and tags in the metadata are processed.")
187+
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
188+
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
185189
args.in_json = args.out_json
186190
args.out_json = unknown[0]
187191
elif len(unknown) > 0:

finetune/make_captions.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
sys.path.append(os.path.dirname(__file__))
1616
from blip.blip import blip_decoder, is_url
1717
import library.train_util as train_util
18+
from library.utils import setup_logging
19+
setup_logging()
20+
import logging
21+
logger = logging.getLogger(__name__)
1822

1923
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2024

@@ -47,7 +51,7 @@ def __getitem__(self, idx):
4751
# convert to tensor temporarily so dataloader will accept it
4852
tensor = IMAGE_TRANSFORM(image)
4953
except Exception as e:
50-
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
54+
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
5155
return None
5256

5357
return (tensor, img_path)
@@ -74,21 +78,21 @@ def main(args):
7478
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
7579

7680
cwd = os.getcwd()
77-
print("Current Working Directory is: ", cwd)
81+
logger.info(f"Current Working Directory is: {cwd}")
7882
os.chdir("finetune")
7983
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
8084
args.caption_weights = os.path.join("..", args.caption_weights)
8185

82-
print(f"load images from {args.train_data_dir}")
86+
logger.info(f"load images from {args.train_data_dir}")
8387
train_data_dir_path = Path(args.train_data_dir)
8488
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
85-
print(f"found {len(image_paths)} images.")
89+
logger.info(f"found {len(image_paths)} images.")
8690

87-
print(f"loading BLIP caption: {args.caption_weights}")
91+
logger.info(f"loading BLIP caption: {args.caption_weights}")
8892
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
8993
model.eval()
9094
model = model.to(DEVICE)
91-
print("BLIP loaded")
95+
logger.info("BLIP loaded")
9296

9397
# captioningする
9498
def run_batch(path_imgs):
@@ -108,7 +112,7 @@ def run_batch(path_imgs):
108112
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
109113
f.write(caption + "\n")
110114
if args.debug:
111-
print(image_path, caption)
115+
logger.info(f'{image_path} {caption}')
112116

113117
# 読み込みの高速化のためにDataLoaderを使うオプション
114118
if args.max_data_loader_n_workers is not None:
@@ -138,7 +142,7 @@ def run_batch(path_imgs):
138142
raw_image = raw_image.convert("RGB")
139143
img_tensor = IMAGE_TRANSFORM(raw_image)
140144
except Exception as e:
141-
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
145+
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
142146
continue
143147

144148
b_imgs.append((image_path, img_tensor))
@@ -148,7 +152,7 @@ def run_batch(path_imgs):
148152
if len(b_imgs) > 0:
149153
run_batch(b_imgs)
150154

151-
print("done!")
155+
logger.info("done!")
152156

153157

154158
def setup_parser() -> argparse.ArgumentParser:

finetune/make_captions_by_git.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from transformers.generation.utils import GenerationMixin
1111

1212
import library.train_util as train_util
13-
13+
from library.utils import setup_logging
14+
setup_logging()
15+
import logging
16+
logger = logging.getLogger(__name__)
1417

1518
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1619

@@ -35,8 +38,8 @@ def remove_words(captions, debug):
3538
for pat in PATTERN_REPLACE:
3639
cap = pat.sub("", cap)
3740
if debug and cap != caption:
38-
print(caption)
39-
print(cap)
41+
logger.info(caption)
42+
logger.info(cap)
4043
removed_caps.append(cap)
4144
return removed_caps
4245

@@ -70,16 +73,16 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
7073
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
7174
"""
7275

73-
print(f"load images from {args.train_data_dir}")
76+
logger.info(f"load images from {args.train_data_dir}")
7477
train_data_dir_path = Path(args.train_data_dir)
7578
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
76-
print(f"found {len(image_paths)} images.")
79+
logger.info(f"found {len(image_paths)} images.")
7780

7881
# できればcacheに依存せず明示的にダウンロードしたい
79-
print(f"loading GIT: {args.model_id}")
82+
logger.info(f"loading GIT: {args.model_id}")
8083
git_processor = AutoProcessor.from_pretrained(args.model_id)
8184
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
82-
print("GIT loaded")
85+
logger.info("GIT loaded")
8386

8487
# captioningする
8588
def run_batch(path_imgs):
@@ -97,7 +100,7 @@ def run_batch(path_imgs):
97100
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
98101
f.write(caption + "\n")
99102
if args.debug:
100-
print(image_path, caption)
103+
logger.info(f"{image_path} {caption}")
101104

102105
# 読み込みの高速化のためにDataLoaderを使うオプション
103106
if args.max_data_loader_n_workers is not None:
@@ -126,7 +129,7 @@ def run_batch(path_imgs):
126129
if image.mode != "RGB":
127130
image = image.convert("RGB")
128131
except Exception as e:
129-
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
132+
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
130133
continue
131134

132135
b_imgs.append((image_path, image))
@@ -137,7 +140,7 @@ def run_batch(path_imgs):
137140
if len(b_imgs) > 0:
138141
run_batch(b_imgs)
139142

140-
print("done!")
143+
logger.info("done!")
141144

142145

143146
def setup_parser() -> argparse.ArgumentParser:

finetune/merge_captions_to_metadata.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,30 @@
55
from tqdm import tqdm
66
import library.train_util as train_util
77
import os
8+
from library.utils import setup_logging
9+
setup_logging()
10+
import logging
11+
logger = logging.getLogger(__name__)
812

913
def main(args):
1014
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
1115

1216
train_data_dir_path = Path(args.train_data_dir)
1317
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
14-
print(f"found {len(image_paths)} images.")
18+
logger.info(f"found {len(image_paths)} images.")
1519

1620
if args.in_json is None and Path(args.out_json).is_file():
1721
args.in_json = args.out_json
1822

1923
if args.in_json is not None:
20-
print(f"loading existing metadata: {args.in_json}")
24+
logger.info(f"loading existing metadata: {args.in_json}")
2125
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
22-
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
26+
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
2327
else:
24-
print("new metadata will be created / 新しいメタデータファイルが作成されます")
28+
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
2529
metadata = {}
2630

27-
print("merge caption texts to metadata json.")
31+
logger.info("merge caption texts to metadata json.")
2832
for image_path in tqdm(image_paths):
2933
caption_path = image_path.with_suffix(args.caption_extension)
3034
caption = caption_path.read_text(encoding='utf-8').strip()
@@ -38,12 +42,12 @@ def main(args):
3842

3943
metadata[image_key]['caption'] = caption
4044
if args.debug:
41-
print(image_key, caption)
45+
logger.info(f"{image_key} {caption}")
4246

4347
# metadataを書き出して終わり
44-
print(f"writing metadata: {args.out_json}")
48+
logger.info(f"writing metadata: {args.out_json}")
4549
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
46-
print("done!")
50+
logger.info("done!")
4751

4852

4953
def setup_parser() -> argparse.ArgumentParser:

0 commit comments

Comments
 (0)