11import  folder_paths 
22import  comfy .sd 
33import  comfy .model_sampling 
4+ import  torch 
5+ 
6+ class  LCM (comfy .model_sampling .EPS ):
7+     def  calculate_denoised (self , sigma , model_output , model_input ):
8+         timestep  =  self .timestep (sigma ).view (sigma .shape [:1 ] +  (1 ,) *  (model_output .ndim  -  1 ))
9+         sigma  =  sigma .view (sigma .shape [:1 ] +  (1 ,) *  (model_output .ndim  -  1 ))
10+         x0  =  model_input  -  model_output  *  sigma 
11+ 
12+         sigma_data  =  0.5 
13+         scaled_timestep  =  timestep  *  10.0  #timestep_scaling 
14+ 
15+         c_skip  =  sigma_data ** 2  /  (scaled_timestep ** 2  +  sigma_data ** 2 )
16+         c_out  =  scaled_timestep  /  (scaled_timestep ** 2  +  sigma_data ** 2 ) **  0.5 
17+ 
18+         return  c_out  *  x0  +  c_skip  *  model_input 
19+ 
20+ class  ModelSamplingDiscreteLCM (torch .nn .Module ):
21+     def  __init__ (self ):
22+         super ().__init__ ()
23+         self .sigma_data  =  1.0 
24+         timesteps  =  1000 
25+         beta_start  =  0.00085 
26+         beta_end  =  0.012 
27+ 
28+         betas  =  torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , timesteps , dtype = torch .float32 ) **  2 
29+         alphas  =  1.0  -  betas 
30+         alphas_cumprod  =  torch .cumprod (alphas , dim = 0 )
31+ 
32+         original_timesteps  =  50 
33+         self .skip_steps  =  timesteps  //  original_timesteps 
34+ 
35+ 
36+         alphas_cumprod_valid  =  torch .zeros ((original_timesteps ), dtype = torch .float32 )
37+         for  x  in  range (original_timesteps ):
38+             alphas_cumprod_valid [original_timesteps  -  1  -  x ] =  alphas_cumprod [timesteps  -  1  -  x  *  self .skip_steps ]
39+ 
40+         sigmas  =  ((1  -  alphas_cumprod_valid ) /  alphas_cumprod_valid ) **  0.5 
41+         self .set_sigmas (sigmas )
42+ 
43+     def  set_sigmas (self , sigmas ):
44+         self .register_buffer ('sigmas' , sigmas )
45+         self .register_buffer ('log_sigmas' , sigmas .log ())
46+ 
47+     @property  
48+     def  sigma_min (self ):
49+         return  self .sigmas [0 ]
50+ 
51+     @property  
52+     def  sigma_max (self ):
53+         return  self .sigmas [- 1 ]
54+ 
55+     def  timestep (self , sigma ):
56+         log_sigma  =  sigma .log ()
57+         dists  =  log_sigma .to (self .log_sigmas .device ) -  self .log_sigmas [:, None ]
58+         return  dists .abs ().argmin (dim = 0 ).view (sigma .shape ) *  self .skip_steps  +  (self .skip_steps  -  1 )
59+ 
60+     def  sigma (self , timestep ):
61+         t  =  torch .clamp (((timestep  -  (self .skip_steps  -  1 )) /  self .skip_steps ).float (), min = 0 , max = (len (self .sigmas ) -  1 ))
62+         low_idx  =  t .floor ().long ()
63+         high_idx  =  t .ceil ().long ()
64+         w  =  t .frac ()
65+         log_sigma  =  (1  -  w ) *  self .log_sigmas [low_idx ] +  w  *  self .log_sigmas [high_idx ]
66+         return  log_sigma .exp ()
67+ 
68+     def  percent_to_sigma (self , percent ):
69+         return  self .sigma (torch .tensor (percent  *  999.0 ))
470
571
672def  rescale_zero_terminal_snr_sigmas (sigmas ):
@@ -26,7 +92,7 @@ class ModelSamplingDiscrete:
2692    @classmethod  
2793    def  INPUT_TYPES (s ):
2894        return  {"required" : { "model" : ("MODEL" ,),
29-                               "sampling" : (["eps" , "v_prediction" ],),
95+                               "sampling" : (["eps" , "v_prediction" ,  "lcm" ],),
3096                              "zsnr" : ("BOOLEAN" , {"default" : False }),
3197                              }}
3298
@@ -38,17 +104,22 @@ def INPUT_TYPES(s):
38104    def  patch (self , model , sampling , zsnr ):
39105        m  =  model .clone ()
40106
107+         sampling_base  =  comfy .model_sampling .ModelSamplingDiscrete 
41108        if  sampling  ==  "eps" :
42109            sampling_type  =  comfy .model_sampling .EPS 
43110        elif  sampling  ==  "v_prediction" :
44111            sampling_type  =  comfy .model_sampling .V_PREDICTION 
112+         elif  sampling  ==  "lcm" :
113+             sampling_type  =  LCM 
114+             sampling_base  =  ModelSamplingDiscreteLCM 
45115
46-         class  ModelSamplingAdvanced (comfy . model_sampling . ModelSamplingDiscrete , sampling_type ):
116+         class  ModelSamplingAdvanced (sampling_base , sampling_type ):
47117            pass 
48118
49119        model_sampling  =  ModelSamplingAdvanced ()
50120        if  zsnr :
51121            model_sampling .set_sigmas (rescale_zero_terminal_snr_sigmas (model_sampling .sigmas ))
122+ 
52123        m .add_object_patch ("model_sampling" , model_sampling )
53124        return  (m , )
54125
0 commit comments