|
5 | 5 | import torch.nn.functional as F |
6 | 6 |
|
7 | 7 |
|
| 8 | +class Upsample1D(nn.Module): |
| 9 | + """ |
| 10 | + An upsampling layer with an optional convolution. |
| 11 | +
|
| 12 | + Parameters: |
| 13 | + channels: channels in the inputs and outputs. |
| 14 | + use_conv: a bool determining if a convolution is applied. |
| 15 | + use_conv_transpose: |
| 16 | + out_channels: |
| 17 | + """ |
| 18 | + |
| 19 | + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): |
| 20 | + super().__init__() |
| 21 | + self.channels = channels |
| 22 | + self.out_channels = out_channels or channels |
| 23 | + self.use_conv = use_conv |
| 24 | + self.use_conv_transpose = use_conv_transpose |
| 25 | + self.name = name |
| 26 | + |
| 27 | + self.conv = None |
| 28 | + if use_conv_transpose: |
| 29 | + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) |
| 30 | + elif use_conv: |
| 31 | + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) |
| 32 | + |
| 33 | + def forward(self, x): |
| 34 | + assert x.shape[1] == self.channels |
| 35 | + if self.use_conv_transpose: |
| 36 | + return self.conv(x) |
| 37 | + |
| 38 | + x = F.interpolate(x, scale_factor=2.0, mode="nearest") |
| 39 | + |
| 40 | + if self.use_conv: |
| 41 | + x = self.conv(x) |
| 42 | + |
| 43 | + return x |
| 44 | + |
| 45 | + |
| 46 | +class Downsample1D(nn.Module): |
| 47 | + """ |
| 48 | + A downsampling layer with an optional convolution. |
| 49 | +
|
| 50 | + Parameters: |
| 51 | + channels: channels in the inputs and outputs. |
| 52 | + use_conv: a bool determining if a convolution is applied. |
| 53 | + out_channels: |
| 54 | + padding: |
| 55 | + """ |
| 56 | + |
| 57 | + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): |
| 58 | + super().__init__() |
| 59 | + self.channels = channels |
| 60 | + self.out_channels = out_channels or channels |
| 61 | + self.use_conv = use_conv |
| 62 | + self.padding = padding |
| 63 | + stride = 2 |
| 64 | + self.name = name |
| 65 | + |
| 66 | + if use_conv: |
| 67 | + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) |
| 68 | + else: |
| 69 | + assert self.channels == self.out_channels |
| 70 | + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) |
| 71 | + |
| 72 | + def forward(self, x): |
| 73 | + assert x.shape[1] == self.channels |
| 74 | + return self.conv(x) |
| 75 | + |
| 76 | + |
8 | 77 | class Upsample2D(nn.Module): |
9 | 78 | """ |
10 | 79 | An upsampling layer with an optional convolution. |
11 | 80 |
|
12 | 81 | Parameters: |
13 | 82 | channels: channels in the inputs and outputs. |
14 | 83 | use_conv: a bool determining if a convolution is applied. |
15 | | - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. |
| 84 | + use_conv_transpose: |
| 85 | + out_channels: |
16 | 86 | """ |
17 | 87 |
|
18 | 88 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): |
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module): |
80 | 150 | Parameters: |
81 | 151 | channels: channels in the inputs and outputs. |
82 | 152 | use_conv: a bool determining if a convolution is applied. |
83 | | - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. |
| 153 | + out_channels: |
| 154 | + padding: |
84 | 155 | """ |
85 | 156 |
|
86 | 157 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): |
@@ -415,6 +486,69 @@ def forward(self, hidden_states): |
415 | 486 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) |
416 | 487 |
|
417 | 488 |
|
| 489 | +# unet_rl.py |
| 490 | +def rearrange_dims(tensor): |
| 491 | + if len(tensor.shape) == 2: |
| 492 | + return tensor[:, :, None] |
| 493 | + if len(tensor.shape) == 3: |
| 494 | + return tensor[:, :, None, :] |
| 495 | + elif len(tensor.shape) == 4: |
| 496 | + return tensor[:, :, 0, :] |
| 497 | + else: |
| 498 | + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") |
| 499 | + |
| 500 | + |
| 501 | +class Conv1dBlock(nn.Module): |
| 502 | + """ |
| 503 | + Conv1d --> GroupNorm --> Mish |
| 504 | + """ |
| 505 | + |
| 506 | + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): |
| 507 | + super().__init__() |
| 508 | + |
| 509 | + self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) |
| 510 | + self.group_norm = nn.GroupNorm(n_groups, out_channels) |
| 511 | + self.mish = nn.Mish() |
| 512 | + |
| 513 | + def forward(self, x): |
| 514 | + x = self.conv1d(x) |
| 515 | + x = rearrange_dims(x) |
| 516 | + x = self.group_norm(x) |
| 517 | + x = rearrange_dims(x) |
| 518 | + x = self.mish(x) |
| 519 | + return x |
| 520 | + |
| 521 | + |
| 522 | +# unet_rl.py |
| 523 | +class ResidualTemporalBlock1D(nn.Module): |
| 524 | + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): |
| 525 | + super().__init__() |
| 526 | + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) |
| 527 | + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) |
| 528 | + |
| 529 | + self.time_emb_act = nn.Mish() |
| 530 | + self.time_emb = nn.Linear(embed_dim, out_channels) |
| 531 | + |
| 532 | + self.residual_conv = ( |
| 533 | + nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() |
| 534 | + ) |
| 535 | + |
| 536 | + def forward(self, x, t): |
| 537 | + """ |
| 538 | + Args: |
| 539 | + x : [ batch_size x inp_channels x horizon ] |
| 540 | + t : [ batch_size x embed_dim ] |
| 541 | +
|
| 542 | + returns: |
| 543 | + out : [ batch_size x out_channels x horizon ] |
| 544 | + """ |
| 545 | + t = self.time_emb_act(t) |
| 546 | + t = self.time_emb(t) |
| 547 | + out = self.conv_in(x) + rearrange_dims(t) |
| 548 | + out = self.conv_out(out) |
| 549 | + return out + self.residual_conv(x) |
| 550 | + |
| 551 | + |
418 | 552 | def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): |
419 | 553 | r"""Upsample2D a batch of 2D images with the given filter. |
420 | 554 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given |
|
0 commit comments