@@ -148,29 +148,6 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
148
148
feat_idx [0 ] += 1
149
149
return x
150
150
151
- def init_weight (self , conv ):
152
- conv_weight = conv .weight
153
- nn .init .zeros_ (conv_weight )
154
- c1 , c2 , t , h , w = conv_weight .size ()
155
- one_matrix = torch .eye (c1 , c2 )
156
- init_matrix = one_matrix
157
- nn .init .zeros_ (conv_weight )
158
- #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
159
- conv_weight .data [:, :, 1 , 0 , 0 ] = init_matrix #* 0.5
160
- conv .weight .data .copy_ (conv_weight )
161
- nn .init .zeros_ (conv .bias .data )
162
-
163
- def init_weight2 (self , conv ):
164
- conv_weight = conv .weight .data
165
- nn .init .zeros_ (conv_weight )
166
- c1 , c2 , t , h , w = conv_weight .size ()
167
- init_matrix = torch .eye (c1 // 2 , c2 )
168
- #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
169
- conv_weight [:c1 // 2 , :, - 1 , 0 , 0 ] = init_matrix
170
- conv_weight [c1 // 2 :, :, - 1 , 0 , 0 ] = init_matrix
171
- conv .weight .data .copy_ (conv_weight )
172
- nn .init .zeros_ (conv .bias .data )
173
-
174
151
175
152
class ResidualBlock (nn .Module ):
176
153
@@ -485,12 +462,6 @@ def __init__(self,
485
462
self .decoder = Decoder3d (dim , z_dim , dim_mult , num_res_blocks ,
486
463
attn_scales , self .temperal_upsample , dropout )
487
464
488
- def forward (self , x ):
489
- mu , log_var = self .encode (x )
490
- z = self .reparameterize (mu , log_var )
491
- x_recon = self .decode (z )
492
- return x_recon , mu , log_var
493
-
494
465
def encode (self , x ):
495
466
self .clear_cache ()
496
467
## cache
@@ -536,18 +507,6 @@ def decode(self, z):
536
507
self .clear_cache ()
537
508
return out
538
509
539
- def reparameterize (self , mu , log_var ):
540
- std = torch .exp (0.5 * log_var )
541
- eps = torch .randn_like (std )
542
- return eps * std + mu
543
-
544
- def sample (self , imgs , deterministic = False ):
545
- mu , log_var = self .encode (imgs )
546
- if deterministic :
547
- return mu
548
- std = torch .exp (0.5 * log_var .clamp (- 30.0 , 20.0 ))
549
- return mu + std * torch .randn_like (std )
550
-
551
510
def clear_cache (self ):
552
511
self ._conv_num = count_conv3d (self .decoder )
553
512
self ._conv_idx = [0 ]
0 commit comments