-
Notifications
You must be signed in to change notification settings - Fork 212
WIP: Re-implement Linear in managed code. #946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
internal Linear(long inputSize, long outputSize, bool hasBias = true, Device device = null, ScalarType? dtype = null) : base(nameof(Linear))It should be |
init.kaiming_uniform_(weight, a: Math.Sqrt(5));I think we could cache the result of Note: I've checked if the compiler is smart enough to replace it with a constant; it isn't. |
Right. I took the |
|
Well, caching it probably won't make much of a difference, but it also won't hurt. |
Seems like premature optimization, that worsens readability to me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly NITs on the code itself, however I am unsure about the idea of manually reimplemented Linear. What is the plan for ensuring we actually track the upstream changes?
As far as PRs go, I think the part that overrides _to should be separate. OCD 🙃
src/TorchSharp/NN/Linear.cs
Outdated
| THSNN_Linear_set_weight(handle, value!.Handle); | ||
| torch.CheckForErrors(); | ||
| ConditionallyRegisterParameter("weight", value); | ||
| if (value is null) throw new ArgumentNullException("weight"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: nameof(weight)
src/TorchSharp/NN/Linear.cs
Outdated
| public sealed class Linear : torch.nn.Module<Tensor, Tensor> | ||
| { | ||
| internal Linear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) | ||
| internal Linear(long inputSize, long outputSize, bool hasBias = true, Device device = null, ScalarType? dtype = null) : base(nameof(Linear)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason to not have this constructor public? It does not take handles anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only because the Python-likeness we're striving for would prefer that users use the factories in torch.nn instead of the constructors.
We would have to get to all of the modules, so it's not just Linear. In some cases, it may be easier to track. There have been times when we've moved from one version to another, and the native module API didn't have some new feature, but the functional API did. |
src/TorchSharp/NN/Linear.cs
Outdated
| /// <param name="device">The desired device of the parameters and buffers in this module</param> | ||
| /// <param name="dtype">The desired floating point or complex dtype of the parameters and buffers in this module</param> | ||
| public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) | ||
| public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device device = null, ScalarType? dtype = null) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this receive in_features and out_features to be compatible with PyTorch?
| public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) | ||
| public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device device = null, ScalarType? dtype = null) | ||
| { | ||
| var res = THSNN_Linear_ctor(inputSize, outputSize, hasBias, out var boxedHandle); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is THSNN_Linear_ctor going to be deleted from https://github.com/dotnet/TorchSharp/blob/main/src/Native/LibTorchSharp/THSNN.cpp ?
|
I hope these temporary objects are gone with this change. public static Tensor relu(Tensor x, bool inplace = false)
{
using (var m = nn.ReLU(inplace)) {
return m.call(x);
}
} |
| var (fanIn, _) = init.CalculateFanInAndFanOut(weight); | ||
| init.uniform_(_bias, -bound, bound); | ||
| } | ||
| //NOTE: it's important not to call 'RegisterComponents' here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding why would be helpful for future readers. I assume it is called in base?
| } | ||
| } | ||
|
|
||
| private Parameter? _weight; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not make this readonly and non-nullable?
src/TorchSharp/NN/Dropout.cs
Outdated
| this._p = p; | ||
| this._inplace = inplace; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these be part of the state dict?
| scope.Include(this); | ||
| scope.Detach(data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs some explanation.
| var output = lin.call(input); | ||
|
|
||
| output[0, 511] = 10; // When we modify the copy, the original should be altered, too. | ||
|
|
||
| Assert.Equal(device.type, output.device_type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this how Identity behaves in PyTorch? o-O
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup:
>>> id = torch.nn.Identity()
>>> input = torch.zeros(10,10)
>>> output = id(input)
>>> output[0,0] = 13
>>> input
tensor([[13., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the PyTorch source code:
def forward(self, input: Tensor) -> Tensor:
return input
There are a couple of things in this WIP/draft PR:
Overriding
_to()implementations in modules that are known not to have any parameters or buffers. This will save a small amount of runtime overhead.An alternative implementation of Linear to more closely align with how it works in PyTorch. Rather than creating native module instances and managing their lifetimes, this approach only involves a .NET instance, and then calls into the torch.nn.functional APIs to perform the forward pass. It's simpler and gets us out of the business of managing native module instances. The downside is that we need to do a whole bunch of work without any new functionality. It's just getting rid of some technical debt and aligning better with the Python implementation, which could be enough of a reason.
@MovGP0, @lostmsu, @kaiidams, @ChengYen-Tang, @dayo05 -- your thoughts on #2?