@@ -44,6 +44,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
4444 weight_dtype ,
4545 accelerator .device if args .lowram else "cpu" ,
4646 model_dtype ,
47+ args .disable_mmap_load_safetensors
4748 )
4849
4950 # work on low-ram device
@@ -60,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
6061
6162
6263def _load_target_model (
63- name_or_path : str , vae_path : Optional [str ], model_version : str , weight_dtype , device = "cpu" , model_dtype = None
64+ name_or_path : str , vae_path : Optional [str ], model_version : str , weight_dtype , device = "cpu" , model_dtype = None , disable_mmap = False
6465):
6566 # model_dtype only work with full fp16/bf16
6667 name_or_path = os .readlink (name_or_path ) if os .path .islink (name_or_path ) else name_or_path
@@ -75,7 +76,7 @@ def _load_target_model(
7576 unet ,
7677 logit_scale ,
7778 ckpt_info ,
78- ) = sdxl_model_util .load_models_from_sdxl_checkpoint (model_version , name_or_path , device , model_dtype )
79+ ) = sdxl_model_util .load_models_from_sdxl_checkpoint (model_version , name_or_path , device , model_dtype , disable_mmap )
7980 else :
8081 # Diffusers model is loaded to CPU
8182 from diffusers import StableDiffusionXLPipeline
@@ -332,6 +333,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
332333 action = "store_true" ,
333334 help = "cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする" ,
334335 )
336+ parser .add_argument (
337+ "--disable_mmap_load_safetensors" ,
338+ action = "store_true" ,
339+ )
335340
336341
337342def verify_sdxl_training_args (args : argparse .Namespace , supportTextEncoderCaching : bool = True ):
0 commit comments