Skip to content

Commit 3bad6d7

Browse files
comfyanonymousrakki194
authored andcommitted
Remove useless code. (comfyanonymous#9059)
1 parent ec6a916 commit 3bad6d7

File tree

1 file changed

+0
-41
lines changed

1 file changed

+0
-41
lines changed

comfy/ldm/wan/vae.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -148,29 +148,6 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
148148
feat_idx[0] += 1
149149
return x
150150

151-
def init_weight(self, conv):
152-
conv_weight = conv.weight
153-
nn.init.zeros_(conv_weight)
154-
c1, c2, t, h, w = conv_weight.size()
155-
one_matrix = torch.eye(c1, c2)
156-
init_matrix = one_matrix
157-
nn.init.zeros_(conv_weight)
158-
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
159-
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
160-
conv.weight.data.copy_(conv_weight)
161-
nn.init.zeros_(conv.bias.data)
162-
163-
def init_weight2(self, conv):
164-
conv_weight = conv.weight.data
165-
nn.init.zeros_(conv_weight)
166-
c1, c2, t, h, w = conv_weight.size()
167-
init_matrix = torch.eye(c1 // 2, c2)
168-
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
169-
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
170-
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
171-
conv.weight.data.copy_(conv_weight)
172-
nn.init.zeros_(conv.bias.data)
173-
174151

175152
class ResidualBlock(nn.Module):
176153

@@ -485,12 +462,6 @@ def __init__(self,
485462
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
486463
attn_scales, self.temperal_upsample, dropout)
487464

488-
def forward(self, x):
489-
mu, log_var = self.encode(x)
490-
z = self.reparameterize(mu, log_var)
491-
x_recon = self.decode(z)
492-
return x_recon, mu, log_var
493-
494465
def encode(self, x):
495466
self.clear_cache()
496467
## cache
@@ -536,18 +507,6 @@ def decode(self, z):
536507
self.clear_cache()
537508
return out
538509

539-
def reparameterize(self, mu, log_var):
540-
std = torch.exp(0.5 * log_var)
541-
eps = torch.randn_like(std)
542-
return eps * std + mu
543-
544-
def sample(self, imgs, deterministic=False):
545-
mu, log_var = self.encode(imgs)
546-
if deterministic:
547-
return mu
548-
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
549-
return mu + std * torch.randn_like(std)
550-
551510
def clear_cache(self):
552511
self._conv_num = count_conv3d(self.decoder)
553512
self._conv_idx = [0]

0 commit comments

Comments
 (0)