|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import math |
7 | 8 | from typing import Tuple
|
8 | 9 |
|
9 | 10 | import torch
|
10 | 11 | import torch.nn.functional as F
|
| 12 | +from torch.nn import Parameter, init |
11 | 13 |
|
12 | 14 |
|
13 |
| -class LinearWithRepeat(torch.nn.Linear): |
| 15 | +class LinearWithRepeat(torch.nn.Module): |
14 | 16 | """
|
15 | 17 | if x has shape (..., k, n1)
|
16 | 18 | and y has shape (..., n2)
|
@@ -50,6 +52,40 @@ class LinearWithRepeat(torch.nn.Linear):
|
50 | 52 | and sent that through the Linear.
|
51 | 53 | """
|
52 | 54 |
|
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + in_features: int, |
| 58 | + out_features: int, |
| 59 | + bias: bool = True, |
| 60 | + device=None, |
| 61 | + dtype=None, |
| 62 | + ) -> None: |
| 63 | + """ |
| 64 | + Copied from torch.nn.Linear. |
| 65 | + """ |
| 66 | + factory_kwargs = {"device": device, "dtype": dtype} |
| 67 | + super().__init__() |
| 68 | + self.in_features = in_features |
| 69 | + self.out_features = out_features |
| 70 | + self.weight = Parameter( |
| 71 | + torch.empty((out_features, in_features), **factory_kwargs) |
| 72 | + ) |
| 73 | + if bias: |
| 74 | + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) |
| 75 | + else: |
| 76 | + self.register_parameter("bias", None) |
| 77 | + self.reset_parameters() |
| 78 | + |
| 79 | + def reset_parameters(self) -> None: |
| 80 | + """ |
| 81 | + Copied from torch.nn.Linear. |
| 82 | + """ |
| 83 | + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
| 84 | + if self.bias is not None: |
| 85 | + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) |
| 86 | + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| 87 | + init.uniform_(self.bias, -bound, bound) |
| 88 | + |
53 | 89 | def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
54 | 90 | n1 = input[0].shape[-1]
|
55 | 91 | output1 = F.linear(input[0], self.weight[:, :n1], self.bias)
|
|
0 commit comments