From af63e5422d3595e9010ac8433531e227eb26ac53 Mon Sep 17 00:00:00 2001 From: johnr14 <5272079+johnr14@users.noreply.github.com> Date: Sat, 20 Sep 2025 08:24:12 -0400 Subject: [PATCH 1/5] feat: add Chroma model support with CLIP-L token handling Co-authored-by: aider (deepseek/deepseek-chat) --- flux_train_network.py | 30 ++++++++++++++++++------------ library/strategy_flux.py | 29 +++++++++++++++++++---------- train_network.py | 6 +++++- 3 files changed, 42 insertions(+), 23 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index cfc617088..be8c62ca1 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -162,22 +162,28 @@ def load_target_model(self, args, weight_dtype, accelerator): def get_tokenize_strategy(self, args): # This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here. # Instead, we analyze the checkpoint state to determine if it is schnell. - if args.model_type != "chroma": - _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) - else: + if args.model_type == "chroma": is_schnell = False - self.is_schnell = is_schnell + self.is_schnell = is_schnell + t5xxl_max_token_length = args.t5xxl_max_token_length or 512 + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + # Chroma doesn't use CLIP-L + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=False) + else: + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + self.is_schnell = is_schnell - if args.t5xxl_max_token_length is None: - if self.is_schnell: - t5xxl_max_token_length = 256 + if args.t5xxl_max_token_length is None: + if self.is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 else: - t5xxl_max_token_length = 512 - else: - t5xxl_max_token_length = args.t5xxl_max_token_length + t5xxl_max_token_length = args.t5xxl_max_token_length - logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") - return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + # FLUX models use both CLIP-L and T5 + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=True) def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5e65927f8..8585a71b1 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -21,19 +21,25 @@ class FluxTokenizeStrategy(TokenizeStrategy): - def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: + def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None, use_clip_l: bool = True) -> None: self.t5xxl_max_length = t5xxl_max_length - self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.use_clip_l = use_clip_l + if self.use_clip_l: + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text - l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") - t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + if self.use_clip_l: + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + l_tokens = l_tokens["input_ids"] + else: + # For Chroma, return None for CLIP-L tokens + l_tokens = None + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") t5_attn_mask = t5_tokens["attention_mask"] - l_tokens = l_tokens["input_ids"] t5_tokens = t5_tokens["input_ids"] return [l_tokens, t5_tokens, t5_attn_mask] @@ -63,24 +69,27 @@ def encode_tokens( l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None - # clip_l is None when using T5 only + # Handle Chroma case where CLIP-L is not used if clip_l is not None and l_tokens is not None: l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] else: - l_pooled = None + # For Chroma, create a dummy tensor with the right shape + if t5_tokens is not None: + batch_size = t5_tokens.shape[0] + l_pooled = torch.zeros(batch_size, 768, device=t5_tokens.device, dtype=torch.float32) + else: + l_pooled = None # t5xxl is None when using CLIP only if t5xxl is not None and t5_tokens is not None: # t5_out is [b, max length, 4096] attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device) t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True) - # if zero_pad_t5_output: - # t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) else: t5_out = None txt_ids = None - t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one + t5_attn_mask = None return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer diff --git a/train_network.py b/train_network.py index 3dedb574c..6180d2f87 100644 --- a/train_network.py +++ b/train_network.py @@ -427,7 +427,11 @@ def process_batch( weights_list, ) else: - input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + # Handle Chroma case where CLIP-L tokens might be None + input_ids = [] + for ids in batch["input_ids_list"]: + if ids is not None: # Skip None values (CLIP-L tokens for Chroma) + input_ids.append(ids.to(accelerator.device)) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), From 3040b31d1d91c33ea9dfc620f79db127762fabee Mon Sep 17 00:00:00 2001 From: johnr14 <5272079+johnr14@users.noreply.github.com> Date: Sat, 20 Sep 2025 08:51:37 -0400 Subject: [PATCH 2/5] fix: handle missing tokenizers in FluxTokenizeStrategy with fallbacks Co-authored-by: aider (deepseek/deepseek-chat) --- flux_train_network.py | 1 + train_network.py | 1 + 2 files changed, 2 insertions(+) diff --git a/flux_train_network.py b/flux_train_network.py index be8c62ca1..75daeeabc 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,6 +36,7 @@ def __init__(self): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False self.model_type: Optional[str] = None + self.args = None def assert_extra_args( self, diff --git a/train_network.py b/train_network.py index 6180d2f87..3e733982c 100644 --- a/train_network.py +++ b/train_network.py @@ -480,6 +480,7 @@ def process_batch( return loss.mean() def train(self, args): + self.args = args # store args for later use session_id = random.randint(0, 2**32) training_started_at = time.time() train_util.verify_training_args(args) From b55dc7ddc91d16e63c1e3f3588865e0ff8d3ae34 Mon Sep 17 00:00:00 2001 From: johnr14 <5272079+johnr14@users.noreply.github.com> Date: Sat, 20 Sep 2025 08:55:46 -0400 Subject: [PATCH 3/5] refactor: improve tokenizer access robustness with fallback Co-authored-by: aider (deepseek/deepseek-chat) --- flux_train_network.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index 75daeeabc..52109ff8b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -187,7 +187,26 @@ def get_tokenize_strategy(self, args): return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=True) def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): - return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + # Try to access the tokenizers through the tokenize strategy's attributes + # First, check if the attributes exist directly + if hasattr(tokenize_strategy, 'clip_l') and hasattr(tokenize_strategy, 't5xxl'): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + # If not, try to find them with different attribute names + elif hasattr(tokenize_strategy, 'clip_l_tokenizer') and hasattr(tokenize_strategy, 't5xxl_tokenizer'): + return [tokenize_strategy.clip_l_tokenizer, tokenize_strategy.t5xxl_tokenizer] + else: + # As a last resort, create new tokenizers + logger.warning("Tokenizers not found in tokenize strategy, creating new ones") + from transformers import CLIPTokenizer, T5TokenizerFast + clip_l_tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", + cache_dir=getattr(self.args, 'tokenizer_cache_dir', None) if hasattr(self, 'args') else None + ) + t5xxl_tokenizer = T5TokenizerFast.from_pretrained( + "google/t5-v1_1-xxl", + cache_dir=getattr(self.args, 'tokenizer_cache_dir', None) if hasattr(self, 'args') else None + ) + return [clip_l_tokenizer, t5xxl_tokenizer] def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) From c9f76284aaeb2187c86902f472d97ae8120704bf Mon Sep 17 00:00:00 2001 From: johnr14 <5272079+johnr14@users.noreply.github.com> Date: Sat, 20 Sep 2025 09:05:29 -0400 Subject: [PATCH 4/5] fix: handle Chroma case where input_ids_list is None Co-authored-by: aider (deepseek/deepseek-chat) --- train_network.py | 47 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/train_network.py b/train_network.py index 3e733982c..9513a7133 100644 --- a/train_network.py +++ b/train_network.py @@ -413,8 +413,22 @@ def process_batch( text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - - if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + else: + # For debugging + logger.debug(f"text_encoder_outputs_list is None, batch keys: {list(batch.keys())}") + else: + # For debugging + print(f"text_encoder_outputs_list is None, batch keys: {list(batch.keys())}") + + # For Chroma, text_encoder_conds might be set up differently + # Check if we need to encode text encoders + need_to_encode = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder + # Also check if input_ids_list is None (for Chroma) + if "input_ids_list" in batch and batch["input_ids_list"] is None: + # If input_ids_list is None, we might already have the text encoder outputs cached + need_to_encode = False + + if need_to_encode: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning @@ -428,15 +442,26 @@ def process_batch( ) else: # Handle Chroma case where CLIP-L tokens might be None - input_ids = [] - for ids in batch["input_ids_list"]: - if ids is not None: # Skip None values (CLIP-L tokens for Chroma) - input_ids.append(ids.to(accelerator.device)) - encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( - tokenize_strategy, - self.get_models_for_text_encoding(args, accelerator, text_encoders), - input_ids, - ) + # Check if input_ids_list exists and is not None + if "input_ids_list" in batch and batch["input_ids_list"] is not None: + input_ids = [] + for ids in batch["input_ids_list"]: + if ids is not None: # Skip None values (CLIP-L tokens for Chroma) + input_ids.append(ids.to(accelerator.device)) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + else: + # For Chroma, we might have a different way to get the input ids + # Since input_ids_list is None, we need to handle this case + # Let's assume the text encoding strategy can handle this + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + [], # Pass empty list or handle differently + ) if args.full_fp16: encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] From 10ff863b8a3e53ac93982243811d1749b43e9bba Mon Sep 17 00:00:00 2001 From: johnr14 <5272079+johnr14@users.noreply.github.com> Date: Sat, 20 Sep 2025 09:08:45 -0400 Subject: [PATCH 5/5] fix: remove duplicate else clause causing syntax error Co-authored-by: aider (deepseek/deepseek-chat) --- train_network.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/train_network.py b/train_network.py index 9513a7133..346a37824 100644 --- a/train_network.py +++ b/train_network.py @@ -416,9 +416,6 @@ def process_batch( else: # For debugging logger.debug(f"text_encoder_outputs_list is None, batch keys: {list(batch.keys())}") - else: - # For debugging - print(f"text_encoder_outputs_list is None, batch keys: {list(batch.keys())}") # For Chroma, text_encoder_conds might be set up differently # Check if we need to encode text encoders