Skip to content

Commit 002aefa

Browse files
Support lcm models.
Use the "lcm" sampler to sample them, you also have to use the ModelSamplingDiscrete node to set them as lcm models to use them properly.
1 parent ca71e54 commit 002aefa

File tree

3 files changed

+88
-4
lines changed

3 files changed

+88
-4
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
717717
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
718718
return mu
719719

720-
721720
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
722721
extra_args = {} if extra_args is None else extra_args
723722
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
737736
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
738737
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
739738

739+
@torch.no_grad()
740+
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
741+
extra_args = {} if extra_args is None else extra_args
742+
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
743+
s_in = x.new_ones([x.shape[0]])
744+
for i in trange(len(sigmas) - 1, disable=disable):
745+
denoised = model(x, sigmas[i] * s_in, **extra_args)
746+
if callback is not None:
747+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
748+
749+
x = denoised
750+
if sigmas[i + 1] > 0:
751+
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
752+
return x

comfy/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N
519519

520520
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
521521
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
522-
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
522+
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
523523

524524
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
525525
class KSAMPLER(Sampler):

comfy_extras/nodes_model_advanced.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,72 @@
11
import folder_paths
22
import comfy.sd
33
import comfy.model_sampling
4+
import torch
5+
6+
class LCM(comfy.model_sampling.EPS):
7+
def calculate_denoised(self, sigma, model_output, model_input):
8+
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
9+
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
10+
x0 = model_input - model_output * sigma
11+
12+
sigma_data = 0.5
13+
scaled_timestep = timestep * 10.0 #timestep_scaling
14+
15+
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
16+
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
17+
18+
return c_out * x0 + c_skip * model_input
19+
20+
class ModelSamplingDiscreteLCM(torch.nn.Module):
21+
def __init__(self):
22+
super().__init__()
23+
self.sigma_data = 1.0
24+
timesteps = 1000
25+
beta_start = 0.00085
26+
beta_end = 0.012
27+
28+
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
29+
alphas = 1.0 - betas
30+
alphas_cumprod = torch.cumprod(alphas, dim=0)
31+
32+
original_timesteps = 50
33+
self.skip_steps = timesteps // original_timesteps
34+
35+
36+
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
37+
for x in range(original_timesteps):
38+
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
39+
40+
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
41+
self.set_sigmas(sigmas)
42+
43+
def set_sigmas(self, sigmas):
44+
self.register_buffer('sigmas', sigmas)
45+
self.register_buffer('log_sigmas', sigmas.log())
46+
47+
@property
48+
def sigma_min(self):
49+
return self.sigmas[0]
50+
51+
@property
52+
def sigma_max(self):
53+
return self.sigmas[-1]
54+
55+
def timestep(self, sigma):
56+
log_sigma = sigma.log()
57+
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
58+
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
59+
60+
def sigma(self, timestep):
61+
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
62+
low_idx = t.floor().long()
63+
high_idx = t.ceil().long()
64+
w = t.frac()
65+
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
66+
return log_sigma.exp()
67+
68+
def percent_to_sigma(self, percent):
69+
return self.sigma(torch.tensor(percent * 999.0))
470

571

672
def rescale_zero_terminal_snr_sigmas(sigmas):
@@ -26,7 +92,7 @@ class ModelSamplingDiscrete:
2692
@classmethod
2793
def INPUT_TYPES(s):
2894
return {"required": { "model": ("MODEL",),
29-
"sampling": (["eps", "v_prediction"],),
95+
"sampling": (["eps", "v_prediction", "lcm"],),
3096
"zsnr": ("BOOLEAN", {"default": False}),
3197
}}
3298

@@ -38,17 +104,22 @@ def INPUT_TYPES(s):
38104
def patch(self, model, sampling, zsnr):
39105
m = model.clone()
40106

107+
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
41108
if sampling == "eps":
42109
sampling_type = comfy.model_sampling.EPS
43110
elif sampling == "v_prediction":
44111
sampling_type = comfy.model_sampling.V_PREDICTION
112+
elif sampling == "lcm":
113+
sampling_type = LCM
114+
sampling_base = ModelSamplingDiscreteLCM
45115

46-
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):
116+
class ModelSamplingAdvanced(sampling_base, sampling_type):
47117
pass
48118

49119
model_sampling = ModelSamplingAdvanced()
50120
if zsnr:
51121
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
122+
52123
m.add_object_patch("model_sampling", model_sampling)
53124
return (m, )
54125

0 commit comments

Comments
 (0)