Skip to content

Update model.py - To show wich exactly VAE component (encoder/decoder) is loaded #9152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import comfy.ops
ops = comfy.ops.disable_weight_init

import inspect
import os
if 'vae_attention_counter' not in globals():
vae_attention_counter = {}

if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
Expand Down Expand Up @@ -298,7 +303,49 @@ def vae_attention():
logging.info("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled_vae():
logging.info("Using pytorch attention in VAE")
#Common causes for duplicate VAE loading:
#1.Different precision/dtype loading - ComfyUI might load the same VAE twice with different precisions (fp16/fp32) or for different devices (CPU/GPU)
#2.Encoder and Decoder initialization - VAEs have separate encoder and decoder components that might initialize attention separately
#3.Model switching or reinitialization - If you have workflows or settings that switch between models during startup
#4.Checkpoint with embedded VAE + separate VAE - You might have a checkpoint that includes a VAE, plus a separate VAE file
# Global counter for VAE attention initialization
def get_detailed_vae_info():
frame = inspect.currentframe()
try:
# Get model info
model_name = "Unknown model"
for f in inspect.getouterframes(frame):
local_vars = f.frame.f_locals
for var_name in ['model_path', 'vae_path', 'checkpoint_path', 'config_path']:
if var_name in local_vars and local_vars[var_name]:
model_name = os.path.basename(str(local_vars[var_name]))
break

if 'self' in local_vars:
obj = local_vars['self']
if hasattr(obj, 'model_path') and obj.model_path:
model_name = os.path.basename(obj.model_path)
break

if model_name != "Unknown model":
break

# Use the global counter
if model_name not in vae_attention_counter:
vae_attention_counter[model_name] = 0

vae_attention_counter[model_name] += 1
count = vae_attention_counter[model_name]

# Determine component based on call count
components = ["encoder", "decoder"]
component = components[min(count - 1, len(components) - 1)]

return model_name, component, count
finally:
del frame
model_name, component, count = get_detailed_vae_info()
logging.info(f"Using pytorch attention in VAE {component} ({count}/2) - Model: {model_name}")
return pytorch_attention
else:
logging.info("Using split attention in VAE")
Expand Down