Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
self.model_type: Optional[str] = None
self.args = None

def assert_extra_args(
self,
Expand Down Expand Up @@ -162,25 +163,50 @@ def load_target_model(self, args, weight_dtype, accelerator):
def get_tokenize_strategy(self, args):
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
# Instead, we analyze the checkpoint state to determine if it is schnell.
if args.model_type != "chroma":
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
else:
if args.model_type == "chroma":
is_schnell = False
self.is_schnell = is_schnell
self.is_schnell = is_schnell
t5xxl_max_token_length = args.t5xxl_max_token_length or 512
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
# Chroma doesn't use CLIP-L
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=False)
else:
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
self.is_schnell = is_schnell

if args.t5xxl_max_token_length is None:
if self.is_schnell:
t5xxl_max_token_length = 256
if args.t5xxl_max_token_length is None:
if self.is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is OK with Chroma or should keep 256 ?

else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
t5xxl_max_token_length = args.t5xxl_max_token_length

logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
# FLUX models use both CLIP-L and T5
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=True)

def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
# Try to access the tokenizers through the tokenize strategy's attributes
# First, check if the attributes exist directly
if hasattr(tokenize_strategy, 'clip_l') and hasattr(tokenize_strategy, 't5xxl'):
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
# If not, try to find them with different attribute names
elif hasattr(tokenize_strategy, 'clip_l_tokenizer') and hasattr(tokenize_strategy, 't5xxl_tokenizer'):
return [tokenize_strategy.clip_l_tokenizer, tokenize_strategy.t5xxl_tokenizer]
else:
# As a last resort, create new tokenizers
logger.warning("Tokenizers not found in tokenize strategy, creating new ones")
from transformers import CLIPTokenizer, T5TokenizerFast
clip_l_tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=getattr(self.args, 'tokenizer_cache_dir', None) if hasattr(self, 'args') else None
)
t5xxl_tokenizer = T5TokenizerFast.from_pretrained(
"google/t5-v1_1-xxl",
cache_dir=getattr(self.args, 'tokenizer_cache_dir', None) if hasattr(self, 'args') else None
)
return [clip_l_tokenizer, t5xxl_tokenizer]

def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
Expand Down
29 changes: 19 additions & 10 deletions library/strategy_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@


class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None, use_clip_l: bool = True) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.use_clip_l = use_clip_l
if self.use_clip_l:
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)

def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text

l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
if self.use_clip_l:
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
l_tokens = l_tokens["input_ids"]
else:
# For Chroma, return None for CLIP-L tokens
l_tokens = None

t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]

return [l_tokens, t5_tokens, t5_attn_mask]
Expand Down Expand Up @@ -63,24 +69,27 @@ def encode_tokens(
l_tokens, t5_tokens = tokens[:2]
t5_attn_mask = tokens[2] if len(tokens) > 2 else None

# clip_l is None when using T5 only
# Handle Chroma case where CLIP-L is not used
if clip_l is not None and l_tokens is not None:
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
else:
l_pooled = None
# For Chroma, create a dummy tensor with the right shape
if t5_tokens is not None:
batch_size = t5_tokens.shape[0]
l_pooled = torch.zeros(batch_size, 768, device=t5_tokens.device, dtype=torch.float32)
else:
l_pooled = None

# t5xxl is None when using CLIP only
if t5xxl is not None and t5_tokens is not None:
# t5_out is [b, max length, 4096]
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
# if zero_pad_t5_output:
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
else:
t5_out = None
txt_ids = None
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
t5_attn_mask = None

return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer

Expand Down
43 changes: 35 additions & 8 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,19 @@ def process_batch(
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs

if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
else:
# For debugging
logger.debug(f"text_encoder_outputs_list is None, batch keys: {list(batch.keys())}")

# For Chroma, text_encoder_conds might be set up differently
# Check if we need to encode text encoders
need_to_encode = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder
# Also check if input_ids_list is None (for Chroma)
if "input_ids_list" in batch and batch["input_ids_list"] is None:
# If input_ids_list is None, we might already have the text encoder outputs cached
need_to_encode = False

if need_to_encode:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
Expand All @@ -427,12 +438,27 @@ def process_batch(
weights_list,
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
# Handle Chroma case where CLIP-L tokens might be None
# Check if input_ids_list exists and is not None
if "input_ids_list" in batch and batch["input_ids_list"] is not None:
input_ids = []
for ids in batch["input_ids_list"]:
if ids is not None: # Skip None values (CLIP-L tokens for Chroma)
input_ids.append(ids.to(accelerator.device))
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
else:
# For Chroma, we might have a different way to get the input ids
# Since input_ids_list is None, we need to handle this case
# Let's assume the text encoding strategy can handle this
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
[], # Pass empty list or handle differently
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]

Expand Down Expand Up @@ -476,6 +502,7 @@ def process_batch(
return loss.mean()

def train(self, args):
self.args = args # store args for later use
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
Expand Down