diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index b56d921a3..f68b4ae93 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -6,7 +6,7 @@ import cv2 import numpy as np import torch -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, snapshot_download from PIL import Image from tqdm import tqdm @@ -84,20 +84,7 @@ def main(args): # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") - files = FILES - if args.onnx: - files += FILES_ONNX - for file in files: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download( - args.repo_id, - file, - subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), - force_download=True, - force_filename=file, - ) + snapshot_download(args.repo_id, cache_dir=args.model_dir, force_download=True) else: logger.info("using existing wd14 tagger model")