Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions examples/controlnet/train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
controlnet=controlnet,
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
Expand Down Expand Up @@ -249,10 +250,13 @@ def parse_args(input_args=None):
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
Expand Down Expand Up @@ -752,12 +756,15 @@ def main(args):

# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name, revision=args.revision, variant=args.variant, use_fast=False
)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
variant=args.variant,
use_fast=False,
)

Expand All @@ -767,11 +774,13 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

if args.controlnet_model_name_or_path:
Expand Down
12 changes: 11 additions & 1 deletion examples/controlnet/train_controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def parse_args():
default=None,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--from_pt",
action="store_true",
Expand Down Expand Up @@ -696,7 +702,7 @@ def main():
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, variant=args.variant
)
else:
raise NotImplementedError("No tokenizer specified!")
Expand Down Expand Up @@ -726,11 +732,13 @@ def main():
subfolder="text_encoder",
dtype=weight_dtype,
revision=args.revision,
variant=args.variant,
from_pt=args.from_pt,
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
subfolder="vae",
dtype=weight_dtype,
from_pt=args.from_pt,
Expand All @@ -740,6 +748,7 @@ def main():
subfolder="unet",
dtype=weight_dtype,
revision=args.revision,
variant=args.variant,
from_pt=args.from_pt,
)

Expand Down Expand Up @@ -787,6 +796,7 @@ def main():
safety_checker=None,
dtype=weight_dtype,
revision=args.revision,
variant=args.variant,
from_pt=args.from_pt,
)
pipeline_params = jax_utils.replicate(pipeline_params)
Expand Down
31 changes: 22 additions & 9 deletions examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
Expand Down Expand Up @@ -243,15 +244,18 @@ def parse_args(input_args=None):
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from unet.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
Expand Down Expand Up @@ -793,10 +797,18 @@ def main(args):

# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
variant=args.variant,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
variant=args.variant,
use_fast=False,
)

# import correct text encoder classes
Expand All @@ -810,10 +822,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
Expand All @@ -824,9 +836,10 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

if args.controlnet_model_name_or_path:
Expand Down
20 changes: 16 additions & 4 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
Expand Down Expand Up @@ -740,6 +746,7 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)

Expand Down Expand Up @@ -785,13 +792,15 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name,
revision=args.revision,
variant=args.variant,
use_fast=False,
)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
variant=args.variant,
use_fast=False,
)

Expand All @@ -801,11 +810,13 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

# Adding a modifier token which is optimized ####
Expand Down Expand Up @@ -1229,6 +1240,7 @@ def main(args):
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
Expand Down Expand Up @@ -1278,7 +1290,7 @@ def main(args):
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
Expand Down
29 changes: 20 additions & 9 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def log_validation(
text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
**pipeline_args,
)
Expand Down Expand Up @@ -239,10 +240,13 @@ def parse_args(input_args=None):
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
Expand Down Expand Up @@ -758,7 +762,9 @@ def model_has_vae(args):
config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
return os.path.isfile(config_file_name)
else:
files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
files_in_repo = model_info(
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant
).siblings
return any(file.rfilename == config_file_name for file in files_in_repo)


Expand Down Expand Up @@ -859,6 +865,7 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)

Expand Down Expand Up @@ -897,12 +904,15 @@ def main(args):

# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name, revision=args.revision, variant=args.variant, use_fast=False
)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
variant=args.variant,
use_fast=False,
)

Expand All @@ -912,18 +922,18 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)

if model_has_vae(args):
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
else:
vae = None

unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
Expand Down Expand Up @@ -1379,6 +1389,7 @@ def compute_text_embeddings(prompt):
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
**pipeline_args,
)

Expand Down
Loading