Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ All distribution classes now implement IDisposable.<br/>

__Bug Fixes__:

#1154 : `mu_product` was not initialized in `NAdam` optimizer
#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error
#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device.
#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device.
#1174 : Loading CUDA tensor from stream threw an error
#1154 : `mu_product` was not initialized in `NAdam` optimizer<br/>
#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error<br/>
#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device.<br/>
#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device.<br/>
#1174 : Loading CUDA tensor from stream threw an error<br/>
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.<br/>
#1148 : Calling `Module.to()` shouldn't be differentiable<br/>

## NuGet Version 0.101.2

Expand Down
200 changes: 45 additions & 155 deletions src/TorchSharp/NN/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,66 +163,6 @@ protected internal virtual Module _to(Device device, ScalarType dtype)
return this;
}

protected void _toEpilog(Device device, ScalarType dtype)
{
foreach (var (_, sm) in named_children()) sm._to(device, dtype);

var alreadyHandled = new HashSet<IntPtr>();

foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) {

var fieldName = field.ComponentName();
var value = field.GetValue(this);

switch (value) {
// This order in which these cases are arranged is significant.
case Parameter param when dtype == param.dtype && device.type == param.device_type && device.index == param.device_index:
alreadyHandled.Add(param.handle);
continue;

case Parameter param: {
var t = param.to(dtype, device);
t.retain_grad();
var p = new Parameter(t, param.requires_grad);
field.SetValue(this, p);
ConditionallyRegisterParameter(fieldName, p);
alreadyHandled.Add(p.handle);
break;
}

case Tensor tensor when (device.type != tensor.device_type || device.index != tensor.device_index): {
var t = tensor.to(dtype, device);
field.SetValue(this, t);
ConditionallyRegisterBuffer(fieldName, t);
alreadyHandled.Add(t.handle);
break;
}

case Tensor tensor:
alreadyHandled.Add(tensor.handle);
break;
}
}

foreach (var (name, param) in named_parameters(false).ToList()) {
if (alreadyHandled.Contains(param.handle)) continue;
var t = param.to(dtype, device);
ConditionallyRegisterParameter(name, t);
}

foreach (var (name, buffer) in named_buffers(false).ToList()) {
if (alreadyHandled.Contains(buffer.handle)) continue;
var t = buffer.to(dtype, device);
ConditionallyRegisterBuffer(name, t);
}

_deviceType = device.type;
_deviceIndex = device.index;

Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1);
}


/// <summary>
/// Moves the parameters and buffers.
/// </summary>
Expand All @@ -249,63 +189,6 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = -
return this;
}

protected void _toEpilog(DeviceType deviceType, int deviceIndex)
{
foreach (var (_, sm) in named_children()) sm._to(deviceType, deviceIndex);

var alreadyHandled = new HashSet<IntPtr>();

foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) {

var fieldName = field.ComponentName();
var value = field.GetValue(this);

switch (value) {
// This order in which these cases are arranged is significant.
case Parameter param when deviceType == param.device_type && deviceIndex == param.device_index:
alreadyHandled.Add(param.handle);
continue;

case Parameter param: {
var t = param.to(deviceType, deviceIndex);
t.retain_grad();
var p = new Parameter(t, param.requires_grad);
field.SetValue(this, p);
ConditionallyRegisterParameter(fieldName, p);
alreadyHandled.Add(p.handle);
break;
}

case Tensor tensor when (deviceType != tensor.device_type || deviceIndex != tensor.device_index): {
var t = tensor.to(deviceType, deviceIndex);
field.SetValue(this, t);
ConditionallyRegisterBuffer(fieldName, t);
alreadyHandled.Add(t.handle);
break;
}

case Tensor tensor:
alreadyHandled.Add(tensor.handle);
break;
}
}

foreach (var (name, param) in named_parameters(false).ToList()) {
if (alreadyHandled.Contains(param.handle)) continue;
var t = param.to(deviceType, deviceIndex);
ConditionallyRegisterParameter(name, t);
}

foreach (var (name, buffer) in named_buffers(false).ToList()) {
if (alreadyHandled.Contains(buffer.handle)) continue;
var t = buffer.to(deviceType, deviceIndex);
ConditionallyRegisterBuffer(name, t);
}

_deviceType = deviceType;
_deviceIndex = deviceIndex;
}

private DeviceType _deviceType = DeviceType.CPU;
private int _deviceIndex = -1;

Expand All @@ -325,55 +208,62 @@ protected internal virtual Module _to(ScalarType dtype)

protected void _toEpilog(ScalarType dtype)
{
foreach (var (_, sm) in named_children()) sm._to(dtype);
_toEpilog(dtype, null);
}

var alreadyHandled = new HashSet<IntPtr>();
protected void _toEpilog(Device device, ScalarType dtype)
{
_toEpilog(dtype, device);
}

foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) {
protected void _toEpilog(DeviceType deviceType, int deviceIndex)
{
_toEpilog(null, new Device(deviceType, deviceIndex));
}

var fieldName = field.ComponentName();
var value = field.GetValue(this);
private void _toEpilog(ScalarType? dtype, Device device)
{
foreach (var (_, sm) in named_children()) {
if (device is null) sm._to(dtype.Value);
else if (dtype is null) sm._to(device.type, device.index);
else sm._to(device, dtype.Value);
}

switch (value) {
// This order in which these cases are arranged is significant.
case Parameter param when dtype == param.dtype:
alreadyHandled.Add(param.handle);
continue;

case Parameter param: {
var t = param.to(dtype);
t.retain_grad();
var p = new Parameter(t, param.requires_grad);
field.SetValue(this, p);
ConditionallyRegisterParameter(fieldName, p);
alreadyHandled.Add(p.handle);
break;
}
var fieldsByComponentName = GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have any performance implications?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the assumption that the costly operation is the reflection, then there shouldn't be.
We're creating a hash collection of the same size (alreadyHandled vs the Dictionary), so I guess the cost would be a the lookups in the dictionary, which should be very minimal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The brings up a functional concern -- the alreadyHandled set is there to make sure that we don't accidentally deal with the same tensor twice. Is that no longer a possibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.
The reason it was needed before was because the function was iterating through two different lists of parameters. One using reflection (the parameters registered through the RegisterComponents() function), and the other is using the internal list of registered parameters.
In the proposed code we only go through the list of registered parameters, so there isn't a concern of dealing with the same tensor twice.
Unless there is a case where someone can register the same tensor twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone registers the same parameter under two different names, then we have an issue - the first encounter will dispose the parameter, and then the second time it will have a null tensor error.
Should this be a use case we should handle?

.ToDictionary(field => field.ComponentName());

case Tensor tensor when dtype == tensor.dtype:
alreadyHandled.Add(tensor.handle);
continue;
foreach (var (name, param) in named_parameters(false).ToList()) {
if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device)) continue;

case Tensor tensor: {
var t = tensor.to(dtype);
field.SetValue(this, t);
ConditionallyRegisterBuffer(fieldName, t);
alreadyHandled.Add(t.handle);
break;
}
}
}
// Store the requires_grad flag ahead, since we dispose the parameter after moving
bool requiresGrad = param.requires_grad;
Parameter p;
// When moving the parameter, we don't want the autograd to track this movement on the graph.
// In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
// disable grad we would need to call .detach() on the moved tensor.
using (var d = torch.no_grad())
p = new Parameter(param.to(dtype ?? param.dtype, device ?? param.device, disposeAfter: true), requiresGrad);
ConditionallyRegisterParameter(name, p);

foreach (var (name, param) in named_parameters(false).ToList()) {
if (alreadyHandled.Contains(param.handle)) continue;
var t = param.to(dtype);
ConditionallyRegisterParameter(name, t);
// If this parameter is a field, set it
if (fieldsByComponentName.TryGetValue(name, out var field))
field.SetValue(this, p);
}

foreach (var (name, buffer) in named_buffers(false).ToList()) {
if (alreadyHandled.Contains(buffer.handle)) continue;
var t = buffer.to(dtype);
if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are old buffers/parameters disposed anywhere?

Copy link
Contributor Author

@shaltielshmid shaltielshmid Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. In the ".to()" call I use the disposeAfter parameter.


// Buffers don't get grads so we don't need to detach them afterwards
var t = buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true);
ConditionallyRegisterBuffer(name, t);

if (fieldsByComponentName.TryGetValue(name, out var field))
field.SetValue(this, t);
}

if (device is not null) {
_deviceType = device.type;
_deviceIndex = device.index;
}
}

Expand Down
31 changes: 31 additions & 0 deletions src/TorchSharp/NN/ParameterDict.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,37 @@ protected override void RegisterComponents()

private bool _registered = false;

protected internal override Module _to(DeviceType deviceType, int deviceIndex = -1)
{
base._to(deviceType, deviceIndex);
_toEpilog();
return this;
}

protected internal override Module _to(torch.Device device, torch.ScalarType dtype)
{
base._to(device, dtype);
_toEpilog();
return this;
}

protected internal override Module _to(torch.ScalarType dtype)
{
base._to(dtype);
_toEpilog();
return this;
}

void _toEpilog()
{
for (int i = 0; i < _list.Count; i++) {
string name = _list[i].Item1;
var param = base.get_parameter(name);
_list[i] = (name, param);
_dict[name] = param;
}
}

/// <summary>
/// Return the ParameterDict values.
/// </summary>
Expand Down
29 changes: 29 additions & 0 deletions src/TorchSharp/NN/ParameterList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,35 @@ protected override void RegisterComponents()
_registered = true;
}


protected internal override Module _to(DeviceType deviceType, int deviceIndex = -1)
{
base._to(deviceType, deviceIndex);
_toEpilog();
return this;
}

protected internal override Module _to(torch.Device device, torch.ScalarType dtype)
{
base._to(device, dtype);
_toEpilog();
return this;
}

protected internal override Module _to(torch.ScalarType dtype)
{
base._to(dtype);
_toEpilog();
return this;
}

void _toEpilog()
{
for (int i = 0; i < _list.Count; i++) {
_list[i] = base.get_parameter($"{i}");
}
}

public override IEnumerable<(string name, Parameter parameter)> named_parameters(bool recurse = true)
{
return Enumerable.Range(0, _list.Count).Select(i => ($"{i}", _list[i]));
Expand Down
61 changes: 61 additions & 0 deletions src/TorchSharp/Tensor/TensorExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -704,5 +704,66 @@ public static long TotalSize(this IEnumerable<long> shape)
}
return result;
}

/// <summary>
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
/// </summary>
/// <param name="tensor">The tensor</param>
/// <param name="device">The device to move to</param>
/// <returns>True if the tensor will be copied</returns>
internal static bool toWillCopy(this Tensor tensor, Device device)
{
return tensor.toWillCopy(device.type, device.index);
}

/// <summary>
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
/// </summary>
/// <param name="tensor">The tensor</param>
/// <param name="deviceType">The device type to move to</param>
/// <param name="deviceIndex">The device index to move to</param>
/// <returns>True if the tensor will be copied</returns>
internal static bool toWillCopy(this Tensor tensor, DeviceType deviceType, int deviceIndex)
{
return tensor.device_index != deviceIndex || tensor.device_type != deviceType;
}

/// <summary>
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
/// </summary>
/// <param name="tensor">The tensor</param>
/// <param name="dtype">The dtype to move to</param>
/// <returns>True if the tensor will be copied</returns>
internal static bool toWillCopy(this Tensor tensor, ScalarType dtype)
{
return tensor.dtype != dtype;
}

/// <summary>
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
/// </summary>
/// <param name="tensor">The tensor</param>
/// <param name="dtype">The dtype to move to</param>
/// <param name="device">The device to move to</param>
/// <returns>True if the tensor will be copied</returns>
internal static bool toWillCopy(this Tensor tensor, ScalarType dtype, Device device)
{
return tensor.toWillCopy(dtype, device.type, device.index);
}

/// <summary>
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
/// </summary>
/// <param name="tensor">The tensor</param>
/// <param name="dtype">The dtype to move to</param>
/// <param name="deviceType">The device type to move to</param>
/// <param name="deviceIndex">The device index to move to</param>
/// <returns>True if the tensor will be copied</returns>
internal static bool toWillCopy(this Tensor tensor, ScalarType dtype, DeviceType deviceType, int deviceIndex)
{
return tensor.device_index != deviceIndex || tensor.device_type != deviceType || tensor.dtype != dtype;
}


}
}
Loading