Skip to content

Commit 0bf6aeb

Browse files
feat: rename single-letter vars in resnet.py (#3868)
feat: rename single-letter vars
1 parent 9a45d7f commit 0bf6aeb

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

src/diffusers/models/resnet.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=
9595
assert self.channels == self.out_channels
9696
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
9797

98-
def forward(self, x):
99-
assert x.shape[1] == self.channels
100-
return self.conv(x)
98+
def forward(self, inputs):
99+
assert inputs.shape[1] == self.channels
100+
return self.conv(inputs)
101101

102102

103103
class Upsample2D(nn.Module):
@@ -431,13 +431,13 @@ def __init__(self, pad_mode="reflect"):
431431
self.pad = kernel_1d.shape[1] // 2 - 1
432432
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
433433

434-
def forward(self, x):
435-
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
436-
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
437-
indices = torch.arange(x.shape[1], device=x.device)
438-
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
434+
def forward(self, inputs):
435+
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
436+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
437+
indices = torch.arange(inputs.shape[1], device=inputs.device)
438+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
439439
weight[indices, indices] = kernel
440-
return F.conv2d(x, weight, stride=2)
440+
return F.conv2d(inputs, weight, stride=2)
441441

442442

443443
class KUpsample2D(nn.Module):
@@ -448,13 +448,13 @@ def __init__(self, pad_mode="reflect"):
448448
self.pad = kernel_1d.shape[1] // 2 - 1
449449
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
450450

451-
def forward(self, x):
452-
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
453-
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
454-
indices = torch.arange(x.shape[1], device=x.device)
455-
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
451+
def forward(self, inputs):
452+
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
453+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
454+
indices = torch.arange(inputs.shape[1], device=inputs.device)
455+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
456456
weight[indices, indices] = kernel
457-
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
457+
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
458458

459459

460460
class ResnetBlock2D(nn.Module):
@@ -664,13 +664,13 @@ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
664664
self.group_norm = nn.GroupNorm(n_groups, out_channels)
665665
self.mish = nn.Mish()
666666

667-
def forward(self, x):
668-
x = self.conv1d(x)
669-
x = rearrange_dims(x)
670-
x = self.group_norm(x)
671-
x = rearrange_dims(x)
672-
x = self.mish(x)
673-
return x
667+
def forward(self, inputs):
668+
intermediate_repr = self.conv1d(inputs)
669+
intermediate_repr = rearrange_dims(intermediate_repr)
670+
intermediate_repr = self.group_norm(intermediate_repr)
671+
intermediate_repr = rearrange_dims(intermediate_repr)
672+
output = self.mish(intermediate_repr)
673+
return output
674674

675675

676676
# unet_rl.py
@@ -687,20 +687,20 @@ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
687687
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
688688
)
689689

690-
def forward(self, x, t):
690+
def forward(self, inputs, t):
691691
"""
692692
Args:
693-
x : [ batch_size x inp_channels x horizon ]
693+
inputs : [ batch_size x inp_channels x horizon ]
694694
t : [ batch_size x embed_dim ]
695695
696696
returns:
697697
out : [ batch_size x out_channels x horizon ]
698698
"""
699699
t = self.time_emb_act(t)
700700
t = self.time_emb(t)
701-
out = self.conv_in(x) + rearrange_dims(t)
701+
out = self.conv_in(inputs) + rearrange_dims(t)
702702
out = self.conv_out(out)
703-
return out + self.residual_conv(x)
703+
return out + self.residual_conv(inputs)
704704

705705

706706
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):

0 commit comments

Comments
 (0)