diff --git a/RELEASENOTES.md b/RELEASENOTES.md
index 451a4b7ba..f5527428a 100644
--- a/RELEASENOTES.md
+++ b/RELEASENOTES.md
@@ -20,11 +20,13 @@ All distribution classes now implement IDisposable.
__Bug Fixes__:
-#1154 : `mu_product` was not initialized in `NAdam` optimizer
-#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error
-#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device.
-#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device.
-#1174 : Loading CUDA tensor from stream threw an error
+#1154 : `mu_product` was not initialized in `NAdam` optimizer
+#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error
+#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device.
+#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device.
+#1174 : Loading CUDA tensor from stream threw an error
+#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.
+#1148 : Calling `Module.to()` shouldn't be differentiable
## NuGet Version 0.101.2
diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs
index f6f280043..031e415ac 100644
--- a/src/TorchSharp/NN/Module.cs
+++ b/src/TorchSharp/NN/Module.cs
@@ -163,66 +163,6 @@ protected internal virtual Module _to(Device device, ScalarType dtype)
return this;
}
- protected void _toEpilog(Device device, ScalarType dtype)
- {
- foreach (var (_, sm) in named_children()) sm._to(device, dtype);
-
- var alreadyHandled = new HashSet();
-
- foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) {
-
- var fieldName = field.ComponentName();
- var value = field.GetValue(this);
-
- switch (value) {
- // This order in which these cases are arranged is significant.
- case Parameter param when dtype == param.dtype && device.type == param.device_type && device.index == param.device_index:
- alreadyHandled.Add(param.handle);
- continue;
-
- case Parameter param: {
- var t = param.to(dtype, device);
- t.retain_grad();
- var p = new Parameter(t, param.requires_grad);
- field.SetValue(this, p);
- ConditionallyRegisterParameter(fieldName, p);
- alreadyHandled.Add(p.handle);
- break;
- }
-
- case Tensor tensor when (device.type != tensor.device_type || device.index != tensor.device_index): {
- var t = tensor.to(dtype, device);
- field.SetValue(this, t);
- ConditionallyRegisterBuffer(fieldName, t);
- alreadyHandled.Add(t.handle);
- break;
- }
-
- case Tensor tensor:
- alreadyHandled.Add(tensor.handle);
- break;
- }
- }
-
- foreach (var (name, param) in named_parameters(false).ToList()) {
- if (alreadyHandled.Contains(param.handle)) continue;
- var t = param.to(dtype, device);
- ConditionallyRegisterParameter(name, t);
- }
-
- foreach (var (name, buffer) in named_buffers(false).ToList()) {
- if (alreadyHandled.Contains(buffer.handle)) continue;
- var t = buffer.to(dtype, device);
- ConditionallyRegisterBuffer(name, t);
- }
-
- _deviceType = device.type;
- _deviceIndex = device.index;
-
- Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1);
- }
-
-
///
/// Moves the parameters and buffers.
///
@@ -249,63 +189,6 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = -
return this;
}
- protected void _toEpilog(DeviceType deviceType, int deviceIndex)
- {
- foreach (var (_, sm) in named_children()) sm._to(deviceType, deviceIndex);
-
- var alreadyHandled = new HashSet();
-
- foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) {
-
- var fieldName = field.ComponentName();
- var value = field.GetValue(this);
-
- switch (value) {
- // This order in which these cases are arranged is significant.
- case Parameter param when deviceType == param.device_type && deviceIndex == param.device_index:
- alreadyHandled.Add(param.handle);
- continue;
-
- case Parameter param: {
- var t = param.to(deviceType, deviceIndex);
- t.retain_grad();
- var p = new Parameter(t, param.requires_grad);
- field.SetValue(this, p);
- ConditionallyRegisterParameter(fieldName, p);
- alreadyHandled.Add(p.handle);
- break;
- }
-
- case Tensor tensor when (deviceType != tensor.device_type || deviceIndex != tensor.device_index): {
- var t = tensor.to(deviceType, deviceIndex);
- field.SetValue(this, t);
- ConditionallyRegisterBuffer(fieldName, t);
- alreadyHandled.Add(t.handle);
- break;
- }
-
- case Tensor tensor:
- alreadyHandled.Add(tensor.handle);
- break;
- }
- }
-
- foreach (var (name, param) in named_parameters(false).ToList()) {
- if (alreadyHandled.Contains(param.handle)) continue;
- var t = param.to(deviceType, deviceIndex);
- ConditionallyRegisterParameter(name, t);
- }
-
- foreach (var (name, buffer) in named_buffers(false).ToList()) {
- if (alreadyHandled.Contains(buffer.handle)) continue;
- var t = buffer.to(deviceType, deviceIndex);
- ConditionallyRegisterBuffer(name, t);
- }
-
- _deviceType = deviceType;
- _deviceIndex = deviceIndex;
- }
-
private DeviceType _deviceType = DeviceType.CPU;
private int _deviceIndex = -1;
@@ -325,55 +208,62 @@ protected internal virtual Module _to(ScalarType dtype)
protected void _toEpilog(ScalarType dtype)
{
- foreach (var (_, sm) in named_children()) sm._to(dtype);
+ _toEpilog(dtype, null);
+ }
- var alreadyHandled = new HashSet();
+ protected void _toEpilog(Device device, ScalarType dtype)
+ {
+ _toEpilog(dtype, device);
+ }
- foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) {
+ protected void _toEpilog(DeviceType deviceType, int deviceIndex)
+ {
+ _toEpilog(null, new Device(deviceType, deviceIndex));
+ }
- var fieldName = field.ComponentName();
- var value = field.GetValue(this);
+ private void _toEpilog(ScalarType? dtype, Device device)
+ {
+ foreach (var (_, sm) in named_children()) {
+ if (device is null) sm._to(dtype.Value);
+ else if (dtype is null) sm._to(device.type, device.index);
+ else sm._to(device, dtype.Value);
+ }
- switch (value) {
- // This order in which these cases are arranged is significant.
- case Parameter param when dtype == param.dtype:
- alreadyHandled.Add(param.handle);
- continue;
-
- case Parameter param: {
- var t = param.to(dtype);
- t.retain_grad();
- var p = new Parameter(t, param.requires_grad);
- field.SetValue(this, p);
- ConditionallyRegisterParameter(fieldName, p);
- alreadyHandled.Add(p.handle);
- break;
- }
+ var fieldsByComponentName = GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)
+ .ToDictionary(field => field.ComponentName());
- case Tensor tensor when dtype == tensor.dtype:
- alreadyHandled.Add(tensor.handle);
- continue;
+ foreach (var (name, param) in named_parameters(false).ToList()) {
+ if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device)) continue;
- case Tensor tensor: {
- var t = tensor.to(dtype);
- field.SetValue(this, t);
- ConditionallyRegisterBuffer(fieldName, t);
- alreadyHandled.Add(t.handle);
- break;
- }
- }
- }
+ // Store the requires_grad flag ahead, since we dispose the parameter after moving
+ bool requiresGrad = param.requires_grad;
+ Parameter p;
+ // When moving the parameter, we don't want the autograd to track this movement on the graph.
+ // In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
+ // disable grad we would need to call .detach() on the moved tensor.
+ using (var d = torch.no_grad())
+ p = new Parameter(param.to(dtype ?? param.dtype, device ?? param.device, disposeAfter: true), requiresGrad);
+ ConditionallyRegisterParameter(name, p);
- foreach (var (name, param) in named_parameters(false).ToList()) {
- if (alreadyHandled.Contains(param.handle)) continue;
- var t = param.to(dtype);
- ConditionallyRegisterParameter(name, t);
+ // If this parameter is a field, set it
+ if (fieldsByComponentName.TryGetValue(name, out var field))
+ field.SetValue(this, p);
}
foreach (var (name, buffer) in named_buffers(false).ToList()) {
- if (alreadyHandled.Contains(buffer.handle)) continue;
- var t = buffer.to(dtype);
+ if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
+
+ // Buffers don't get grads so we don't need to detach them afterwards
+ var t = buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true);
ConditionallyRegisterBuffer(name, t);
+
+ if (fieldsByComponentName.TryGetValue(name, out var field))
+ field.SetValue(this, t);
+ }
+
+ if (device is not null) {
+ _deviceType = device.type;
+ _deviceIndex = device.index;
}
}
diff --git a/src/TorchSharp/NN/ParameterDict.cs b/src/TorchSharp/NN/ParameterDict.cs
index 64e4e3d56..810b89434 100644
--- a/src/TorchSharp/NN/ParameterDict.cs
+++ b/src/TorchSharp/NN/ParameterDict.cs
@@ -60,6 +60,37 @@ protected override void RegisterComponents()
private bool _registered = false;
+ protected internal override Module _to(DeviceType deviceType, int deviceIndex = -1)
+ {
+ base._to(deviceType, deviceIndex);
+ _toEpilog();
+ return this;
+ }
+
+ protected internal override Module _to(torch.Device device, torch.ScalarType dtype)
+ {
+ base._to(device, dtype);
+ _toEpilog();
+ return this;
+ }
+
+ protected internal override Module _to(torch.ScalarType dtype)
+ {
+ base._to(dtype);
+ _toEpilog();
+ return this;
+ }
+
+ void _toEpilog()
+ {
+ for (int i = 0; i < _list.Count; i++) {
+ string name = _list[i].Item1;
+ var param = base.get_parameter(name);
+ _list[i] = (name, param);
+ _dict[name] = param;
+ }
+ }
+
///
/// Return the ParameterDict values.
///
diff --git a/src/TorchSharp/NN/ParameterList.cs b/src/TorchSharp/NN/ParameterList.cs
index 129753d35..49dba64f3 100644
--- a/src/TorchSharp/NN/ParameterList.cs
+++ b/src/TorchSharp/NN/ParameterList.cs
@@ -33,6 +33,35 @@ protected override void RegisterComponents()
_registered = true;
}
+
+ protected internal override Module _to(DeviceType deviceType, int deviceIndex = -1)
+ {
+ base._to(deviceType, deviceIndex);
+ _toEpilog();
+ return this;
+ }
+
+ protected internal override Module _to(torch.Device device, torch.ScalarType dtype)
+ {
+ base._to(device, dtype);
+ _toEpilog();
+ return this;
+ }
+
+ protected internal override Module _to(torch.ScalarType dtype)
+ {
+ base._to(dtype);
+ _toEpilog();
+ return this;
+ }
+
+ void _toEpilog()
+ {
+ for (int i = 0; i < _list.Count; i++) {
+ _list[i] = base.get_parameter($"{i}");
+ }
+ }
+
public override IEnumerable<(string name, Parameter parameter)> named_parameters(bool recurse = true)
{
return Enumerable.Range(0, _list.Count).Select(i => ($"{i}", _list[i]));
diff --git a/src/TorchSharp/Tensor/TensorExtensionMethods.cs b/src/TorchSharp/Tensor/TensorExtensionMethods.cs
index 90dcbc20d..09d79028f 100644
--- a/src/TorchSharp/Tensor/TensorExtensionMethods.cs
+++ b/src/TorchSharp/Tensor/TensorExtensionMethods.cs
@@ -704,5 +704,66 @@ public static long TotalSize(this IEnumerable shape)
}
return result;
}
+
+ ///
+ /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
+ ///
+ /// The tensor
+ /// The device to move to
+ /// True if the tensor will be copied
+ internal static bool toWillCopy(this Tensor tensor, Device device)
+ {
+ return tensor.toWillCopy(device.type, device.index);
+ }
+
+ ///
+ /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
+ ///
+ /// The tensor
+ /// The device type to move to
+ /// The device index to move to
+ /// True if the tensor will be copied
+ internal static bool toWillCopy(this Tensor tensor, DeviceType deviceType, int deviceIndex)
+ {
+ return tensor.device_index != deviceIndex || tensor.device_type != deviceType;
+ }
+
+ ///
+ /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
+ ///
+ /// The tensor
+ /// The dtype to move to
+ /// True if the tensor will be copied
+ internal static bool toWillCopy(this Tensor tensor, ScalarType dtype)
+ {
+ return tensor.dtype != dtype;
+ }
+
+ ///
+ /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
+ ///
+ /// The tensor
+ /// The dtype to move to
+ /// The device to move to
+ /// True if the tensor will be copied
+ internal static bool toWillCopy(this Tensor tensor, ScalarType dtype, Device device)
+ {
+ return tensor.toWillCopy(dtype, device.type, device.index);
+ }
+
+ ///
+ /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
+ ///
+ /// The tensor
+ /// The dtype to move to
+ /// The device type to move to
+ /// The device index to move to
+ /// True if the tensor will be copied
+ internal static bool toWillCopy(this Tensor tensor, ScalarType dtype, DeviceType deviceType, int deviceIndex)
+ {
+ return tensor.device_index != deviceIndex || tensor.device_type != deviceType || tensor.dtype != dtype;
+ }
+
+
}
}
diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs
index 0330b0128..c07ec02e0 100644
--- a/test/TorchSharpTest/NN.cs
+++ b/test/TorchSharpTest/NN.cs
@@ -3042,6 +3042,117 @@ public void TestDeviceTo()
}
}
+ [Fact]
+ public void TestCustomModuleWithDeviceMove()
+ {
+ if (torch.cuda.is_available()) {
+ var module = new TestModule1(torch.randn(2, 2), true);
+
+ // Move the device to cuda, and make sure gradients are calculated for all the parameters
+ module.to(torch.CUDA);
+ var x = torch.randn(2, 2, device: torch.CUDA);
+ var y = torch.randn(2, device: torch.CUDA);
+ torch.nn.functional.mse_loss(module.call(x), y).backward();
+ foreach (var (pName, parm) in module.named_parameters()) {
+ var grad = parm.grad();
+ Assert.NotNull(grad);
+ }
+
+ // Reset and then try again with moving back to CPU
+ module.zero_grad();
+
+ // Try moving back to CPU
+ module.to(torch.CPU);
+ x = torch.randn(2, 2);
+ y = torch.randn(2);
+ torch.nn.functional.mse_loss(module.call(x), y).backward();
+ foreach (var (pName, parm) in module.named_parameters()) {
+ var grad = parm.grad();
+ Assert.NotNull(grad);
+ }
+ }
+ }
+
+ [Fact]
+ public void TestCustomModuleWithTypeMove()
+ {
+ var module = new TestModule1(torch.randn(2, 2), true);
+
+ // Move the module to 16-bit floats, and make sure gradients are calculated for all the parameters
+ module.@double();
+ var x = torch.randn(2, 2, float64);
+ var y = torch.randn(2, float64);
+ torch.nn.functional.mse_loss(module.call(x), y).backward();
+ foreach (var (pName, parm) in module.named_parameters()) {
+ var grad = parm.grad();
+ Assert.NotNull(grad);
+ }
+
+ // Reset and then try again with moving back to float 32
+ module.zero_grad();
+
+ // Try moving back to float 32
+ module.@float();
+ x = torch.randn(2, 2);
+ y = torch.randn(2);
+ torch.nn.functional.mse_loss(module.call(x), y).backward();
+ foreach (var (pName, parm) in module.named_parameters()) {
+ var grad = parm.grad();
+ Assert.NotNull(grad);
+ }
+ }
+ [Fact]
+ public void TestCustomModuleWithDeviceAndTypeMove()
+ {
+ if (torch.cuda.is_available()) {
+ var module = new TestModule1(torch.randn(2, 2), true);
+
+ // Move the device to cuda & float 16, and make sure gradients are calculated for all the parameters
+ module.to(torch.CUDA, float16);
+ var x = torch.randn(2, 2, float16, torch.CUDA);
+ var y = torch.randn(2, float16, torch.CUDA);
+ torch.nn.functional.mse_loss(module.call(x), y).backward();
+ foreach (var (pName, parm) in module.named_parameters()) {
+ var grad = parm.grad();
+ Assert.NotNull(grad);
+ }
+
+ // Reset and then try again with moving back to CPU & float 32
+ module.zero_grad();
+
+ // Try moving back to CPU & float 32
+ module.to(torch.CPU, float32);
+ x = torch.randn(2, 2);
+ y = torch.randn(2);
+ torch.nn.functional.mse_loss(module.call(x), y).backward();
+ foreach (var (pName, parm) in module.named_parameters()) {
+ var grad = parm.grad();
+ Assert.NotNull(grad);
+ }
+ }
+ }
+
+ [Fact]
+ public void TestCustomModuleWithMoveAndDisabledGradOnParameter()
+ {
+ var module = new TestModule1(torch.randn(2, 2), true);
+ // Disable grad on test, and make sure that it is able to move and retains the gradient state
+ module.get_parameter("test").requires_grad = false;
+
+ // Move the module to 16-bit floats
+ module.half();
+ Assert.False(module.get_parameter("test").requires_grad);
+
+ // Move to a different device
+ if (torch.cuda.is_available()) {
+ module.cuda();
+ Assert.False(module.get_parameter("test").requires_grad);
+
+ // Try a different device & type
+ module.to(torch.CPU, float32);
+ Assert.False(module.get_parameter("test").requires_grad);
+ }
+ }
[Fact]
public void TestCustomComponentName()