Skip to content

Conversation

shaltielshmid
Copy link
Contributor

@shaltielshmid shaltielshmid commented Dec 9, 2023

Fixes #1148 and #1179.

I went on a whim and attempted to make a style change on the _toEpilog method. If you see a fatal error or prefer to leave it as it used to be - no problem, I'll revert that change.

Main changes:
1] Merged the three _toEpilog methods into one method.
2] Instead of re-iterating the fields every time and reassigning the parameters, I build a dictionary of the fields and as we go through the registered parameters we check if they have a field, and if so - assign it. Saves us from duplicating the code which handles the moving of the parameters.
3] When moving Parameters, we do it in a "torch.no_grad()" scope to avoid having autograd track the movement and so that the resulting tensor will be a leaf.
4] After moving the parameters and buffers, the old ones are disposed.
5] Overrode the _to() method in the ParameterList and ParameterDict classes. I didn't do it in the ModuleDict and ModuleList methods since modules themselves don't get a new reference, only their parameters and buffers, and _toEpilog() iterates through every registered module.
6] Added Tensor extension method toWillCopy() which given a set of arguments to to() will return whether the tensor will be copied.

Question about the memo properties _deviceType and _deviceIndex. I see the value in having them, but that can cause an issue if someone calls .to() on a sub module, and then tries to call it on the main module() to make sure all the submodules() are aligned, it wont work. Same thing if any of the parameters are moved separately, then calling Module.to() wont behave as people expect.

alreadyHandled.Add(p.handle);
break;
}
var fieldsByComponentName = GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have any performance implications?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the assumption that the costly operation is the reflection, then there shouldn't be.
We're creating a hash collection of the same size (alreadyHandled vs the Dictionary), so I guess the cost would be a the lookups in the dictionary, which should be very minimal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The brings up a functional concern -- the alreadyHandled set is there to make sure that we don't accidentally deal with the same tensor twice. Is that no longer a possibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.
The reason it was needed before was because the function was iterating through two different lists of parameters. One using reflection (the parameters registered through the RegisterComponents() function), and the other is using the internal list of registered parameters.
In the proposed code we only go through the list of registered parameters, so there isn't a concern of dealing with the same tensor twice.
Unless there is a case where someone can register the same tensor twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone registers the same parameter under two different names, then we have an issue - the first encounter will dispose the parameter, and then the second time it will have a null tensor error.
Should this be a use case we should handle?

foreach (var (name, buffer) in named_buffers(false).ToList()) {
if (alreadyHandled.Contains(buffer.handle)) continue;
var t = buffer.to(dtype);
if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are old buffers/parameters disposed anywhere?

Copy link
Contributor Author

@shaltielshmid shaltielshmid Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. In the ".to()" call I use the disposeAfter parameter.

@NiklasGustafsson NiklasGustafsson merged commit 6efd3e9 into dotnet:main Dec 11, 2023
@shaltielshmid shaltielshmid deleted the fix-module-to-2 branch December 11, 2023 20:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Module.to should not be differentiable

2 participants