Skip to content

Conversation

@NiklasGustafsson
Copy link
Contributor

There are a couple of things in this WIP/draft PR:

  1. Overriding _to() implementations in modules that are known not to have any parameters or buffers. This will save a small amount of runtime overhead.

  2. An alternative implementation of Linear to more closely align with how it works in PyTorch. Rather than creating native module instances and managing their lifetimes, this approach only involves a .NET instance, and then calls into the torch.nn.functional APIs to perform the forward pass. It's simpler and gets us out of the business of managing native module instances. The downside is that we need to do a whole bunch of work without any new functionality. It's just getting rid of some technical debt and aligning better with the Python implementation, which could be enough of a reason.

@MovGP0, @lostmsu, @kaiidams, @ChengYen-Tang, @dayo05 -- your thoughts on #2?

@MovGP0
Copy link
Contributor

MovGP0 commented Mar 6, 2023

internal Linear(long inputSize, long outputSize, bool hasBias = true, Device device = null, ScalarType? dtype = null) : base(nameof(Linear))

It should be Device? device = null. Same problem with the other constructors.

@MovGP0
Copy link
Contributor

MovGP0 commented Mar 6, 2023

init.kaiming_uniform_(weight, a: Math.Sqrt(5));

I think we could cache the result of Math.Sqrt(5) in a private static readonly variable. But probably won't matter in the grand scheme of things.

Note: I've checked if the compiler is smart enough to replace it with a constant; it isn't.

@NiklasGustafsson
Copy link
Contributor Author

It should be Device? device = null. Same problem with the other constructors.

Right. I took the #nullable enable out. Putting it back.

@NiklasGustafsson
Copy link
Contributor Author

Well, caching it probably won't make much of a difference, but it also won't hurt.

@lostmsu
Copy link
Contributor

lostmsu commented Mar 6, 2023

Well, caching it probably won't make much of a difference, but it also won't hurt.

Seems like premature optimization, that worsens readability to me.

Copy link
Contributor

@lostmsu lostmsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly NITs on the code itself, however I am unsure about the idea of manually reimplemented Linear. What is the plan for ensuring we actually track the upstream changes?

As far as PRs go, I think the part that overrides _to should be separate. OCD 🙃

THSNN_Linear_set_weight(handle, value!.Handle);
torch.CheckForErrors();
ConditionallyRegisterParameter("weight", value);
if (value is null) throw new ArgumentNullException("weight");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: nameof(weight)

public sealed class Linear : torch.nn.Module<Tensor, Tensor>
{
internal Linear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle)
internal Linear(long inputSize, long outputSize, bool hasBias = true, Device device = null, ScalarType? dtype = null) : base(nameof(Linear))
Copy link
Contributor

@lostmsu lostmsu Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to not have this constructor public? It does not take handles anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only because the Python-likeness we're striving for would prefer that users use the factories in torch.nn instead of the constructors.

@NiklasGustafsson
Copy link
Contributor Author

Mostly NITs on the code itself, however I am unsure about the idea of manually reimplemented Linear. What is the plan for ensuring we actually track the upstream changes?

We would have to get to all of the modules, so it's not just Linear. In some cases, it may be easier to track. There have been times when we've moved from one version to another, and the native module API didn't have some new feature, but the functional API did.

/// <param name="device">The desired device of the parameters and buffers in this module</param>
/// <param name="dtype">The desired floating point or complex dtype of the parameters and buffers in this module</param>
public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null)
public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device device = null, ScalarType? dtype = null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this receive in_features and out_features to be compatible with PyTorch?

public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null)
public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device device = null, ScalarType? dtype = null)
{
var res = THSNN_Linear_ctor(inputSize, outputSize, hasBias, out var boxedHandle);
Copy link
Contributor

@kaiidams kaiidams Mar 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaiidams
Copy link
Contributor

I hope these temporary objects are gone with this change.

                public static Tensor relu(Tensor x, bool inplace = false)
                {
                    using (var m = nn.ReLU(inplace)) {
                        return m.call(x);
                    }
                }

var (fanIn, _) = init.CalculateFanInAndFanOut(weight);
init.uniform_(_bias, -bound, bound);
}
//NOTE: it's important not to call 'RegisterComponents' here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding why would be helpful for future readers. I assume it is called in base?

}
}

private Parameter? _weight;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make this readonly and non-nullable?

Comment on lines 19 to 20
this._p = p;
this._inplace = inplace;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be part of the state dict?

Comment on lines +32 to +33
scope.Include(this);
scope.Detach(data);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs some explanation.

Comment on lines 326 to 330
var output = lin.call(input);

output[0, 511] = 10; // When we modify the copy, the original should be altered, too.

Assert.Equal(device.type, output.device_type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this how Identity behaves in PyTorch? o-O

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup:

>>> id = torch.nn.Identity()
>>> input = torch.zeros(10,10)
>>> output = id(input)
>>> output[0,0] = 13
>>> input
tensor([[13.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the PyTorch source code:

    def forward(self, input: Tensor) -> Tensor:
        return input

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants