@@ -95,9 +95,9 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=
9595 assert self .channels == self .out_channels
9696 self .conv = nn .AvgPool1d (kernel_size = stride , stride = stride )
9797
98- def forward (self , x ):
99- assert x .shape [1 ] == self .channels
100- return self .conv (x )
98+ def forward (self , inputs ):
99+ assert inputs .shape [1 ] == self .channels
100+ return self .conv (inputs )
101101
102102
103103class Upsample2D (nn .Module ):
@@ -431,13 +431,13 @@ def __init__(self, pad_mode="reflect"):
431431 self .pad = kernel_1d .shape [1 ] // 2 - 1
432432 self .register_buffer ("kernel" , kernel_1d .T @ kernel_1d , persistent = False )
433433
434- def forward (self , x ):
435- x = F .pad (x , (self .pad ,) * 4 , self .pad_mode )
436- weight = x .new_zeros ([x .shape [1 ], x .shape [1 ], self .kernel .shape [0 ], self .kernel .shape [1 ]])
437- indices = torch .arange (x .shape [1 ], device = x .device )
438- kernel = self .kernel .to (weight )[None , :].expand (x .shape [1 ], - 1 , - 1 )
434+ def forward (self , inputs ):
435+ inputs = F .pad (inputs , (self .pad ,) * 4 , self .pad_mode )
436+ weight = inputs .new_zeros ([inputs .shape [1 ], inputs .shape [1 ], self .kernel .shape [0 ], self .kernel .shape [1 ]])
437+ indices = torch .arange (inputs .shape [1 ], device = inputs .device )
438+ kernel = self .kernel .to (weight )[None , :].expand (inputs .shape [1 ], - 1 , - 1 )
439439 weight [indices , indices ] = kernel
440- return F .conv2d (x , weight , stride = 2 )
440+ return F .conv2d (inputs , weight , stride = 2 )
441441
442442
443443class KUpsample2D (nn .Module ):
@@ -448,13 +448,13 @@ def __init__(self, pad_mode="reflect"):
448448 self .pad = kernel_1d .shape [1 ] // 2 - 1
449449 self .register_buffer ("kernel" , kernel_1d .T @ kernel_1d , persistent = False )
450450
451- def forward (self , x ):
452- x = F .pad (x , ((self .pad + 1 ) // 2 ,) * 4 , self .pad_mode )
453- weight = x .new_zeros ([x .shape [1 ], x .shape [1 ], self .kernel .shape [0 ], self .kernel .shape [1 ]])
454- indices = torch .arange (x .shape [1 ], device = x .device )
455- kernel = self .kernel .to (weight )[None , :].expand (x .shape [1 ], - 1 , - 1 )
451+ def forward (self , inputs ):
452+ inputs = F .pad (inputs , ((self .pad + 1 ) // 2 ,) * 4 , self .pad_mode )
453+ weight = inputs .new_zeros ([inputs .shape [1 ], inputs .shape [1 ], self .kernel .shape [0 ], self .kernel .shape [1 ]])
454+ indices = torch .arange (inputs .shape [1 ], device = inputs .device )
455+ kernel = self .kernel .to (weight )[None , :].expand (inputs .shape [1 ], - 1 , - 1 )
456456 weight [indices , indices ] = kernel
457- return F .conv_transpose2d (x , weight , stride = 2 , padding = self .pad * 2 + 1 )
457+ return F .conv_transpose2d (inputs , weight , stride = 2 , padding = self .pad * 2 + 1 )
458458
459459
460460class ResnetBlock2D (nn .Module ):
@@ -664,13 +664,13 @@ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
664664 self .group_norm = nn .GroupNorm (n_groups , out_channels )
665665 self .mish = nn .Mish ()
666666
667- def forward (self , x ):
668- x = self .conv1d (x )
669- x = rearrange_dims (x )
670- x = self .group_norm (x )
671- x = rearrange_dims (x )
672- x = self .mish (x )
673- return x
667+ def forward (self , inputs ):
668+ intermediate_repr = self .conv1d (inputs )
669+ intermediate_repr = rearrange_dims (intermediate_repr )
670+ intermediate_repr = self .group_norm (intermediate_repr )
671+ intermediate_repr = rearrange_dims (intermediate_repr )
672+ output = self .mish (intermediate_repr )
673+ return output
674674
675675
676676# unet_rl.py
@@ -687,20 +687,20 @@ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
687687 nn .Conv1d (inp_channels , out_channels , 1 ) if inp_channels != out_channels else nn .Identity ()
688688 )
689689
690- def forward (self , x , t ):
690+ def forward (self , inputs , t ):
691691 """
692692 Args:
693- x : [ batch_size x inp_channels x horizon ]
693+ inputs : [ batch_size x inp_channels x horizon ]
694694 t : [ batch_size x embed_dim ]
695695
696696 returns:
697697 out : [ batch_size x out_channels x horizon ]
698698 """
699699 t = self .time_emb_act (t )
700700 t = self .time_emb (t )
701- out = self .conv_in (x ) + rearrange_dims (t )
701+ out = self .conv_in (inputs ) + rearrange_dims (t )
702702 out = self .conv_out (out )
703- return out + self .residual_conv (x )
703+ return out + self .residual_conv (inputs )
704704
705705
706706def upsample_2d (hidden_states , kernel = None , factor = 2 , gain = 1 ):
0 commit comments