diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 451a4b7ba..f5527428a 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -20,11 +20,13 @@ All distribution classes now implement IDisposable.
__Bug Fixes__: -#1154 : `mu_product` was not initialized in `NAdam` optimizer -#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error -#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device. -#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device. -#1174 : Loading CUDA tensor from stream threw an error +#1154 : `mu_product` was not initialized in `NAdam` optimizer
+#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error
+#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device.
+#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device.
+#1174 : Loading CUDA tensor from stream threw an error
+#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.
+#1148 : Calling `Module.to()` shouldn't be differentiable
## NuGet Version 0.101.2 diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index f6f280043..031e415ac 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -163,66 +163,6 @@ protected internal virtual Module _to(Device device, ScalarType dtype) return this; } - protected void _toEpilog(Device device, ScalarType dtype) - { - foreach (var (_, sm) in named_children()) sm._to(device, dtype); - - var alreadyHandled = new HashSet(); - - foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) { - - var fieldName = field.ComponentName(); - var value = field.GetValue(this); - - switch (value) { - // This order in which these cases are arranged is significant. - case Parameter param when dtype == param.dtype && device.type == param.device_type && device.index == param.device_index: - alreadyHandled.Add(param.handle); - continue; - - case Parameter param: { - var t = param.to(dtype, device); - t.retain_grad(); - var p = new Parameter(t, param.requires_grad); - field.SetValue(this, p); - ConditionallyRegisterParameter(fieldName, p); - alreadyHandled.Add(p.handle); - break; - } - - case Tensor tensor when (device.type != tensor.device_type || device.index != tensor.device_index): { - var t = tensor.to(dtype, device); - field.SetValue(this, t); - ConditionallyRegisterBuffer(fieldName, t); - alreadyHandled.Add(t.handle); - break; - } - - case Tensor tensor: - alreadyHandled.Add(tensor.handle); - break; - } - } - - foreach (var (name, param) in named_parameters(false).ToList()) { - if (alreadyHandled.Contains(param.handle)) continue; - var t = param.to(dtype, device); - ConditionallyRegisterParameter(name, t); - } - - foreach (var (name, buffer) in named_buffers(false).ToList()) { - if (alreadyHandled.Contains(buffer.handle)) continue; - var t = buffer.to(dtype, device); - ConditionallyRegisterBuffer(name, t); - } - - _deviceType = device.type; - _deviceIndex = device.index; - - Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1); - } - - /// /// Moves the parameters and buffers. /// @@ -249,63 +189,6 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = - return this; } - protected void _toEpilog(DeviceType deviceType, int deviceIndex) - { - foreach (var (_, sm) in named_children()) sm._to(deviceType, deviceIndex); - - var alreadyHandled = new HashSet(); - - foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) { - - var fieldName = field.ComponentName(); - var value = field.GetValue(this); - - switch (value) { - // This order in which these cases are arranged is significant. - case Parameter param when deviceType == param.device_type && deviceIndex == param.device_index: - alreadyHandled.Add(param.handle); - continue; - - case Parameter param: { - var t = param.to(deviceType, deviceIndex); - t.retain_grad(); - var p = new Parameter(t, param.requires_grad); - field.SetValue(this, p); - ConditionallyRegisterParameter(fieldName, p); - alreadyHandled.Add(p.handle); - break; - } - - case Tensor tensor when (deviceType != tensor.device_type || deviceIndex != tensor.device_index): { - var t = tensor.to(deviceType, deviceIndex); - field.SetValue(this, t); - ConditionallyRegisterBuffer(fieldName, t); - alreadyHandled.Add(t.handle); - break; - } - - case Tensor tensor: - alreadyHandled.Add(tensor.handle); - break; - } - } - - foreach (var (name, param) in named_parameters(false).ToList()) { - if (alreadyHandled.Contains(param.handle)) continue; - var t = param.to(deviceType, deviceIndex); - ConditionallyRegisterParameter(name, t); - } - - foreach (var (name, buffer) in named_buffers(false).ToList()) { - if (alreadyHandled.Contains(buffer.handle)) continue; - var t = buffer.to(deviceType, deviceIndex); - ConditionallyRegisterBuffer(name, t); - } - - _deviceType = deviceType; - _deviceIndex = deviceIndex; - } - private DeviceType _deviceType = DeviceType.CPU; private int _deviceIndex = -1; @@ -325,55 +208,62 @@ protected internal virtual Module _to(ScalarType dtype) protected void _toEpilog(ScalarType dtype) { - foreach (var (_, sm) in named_children()) sm._to(dtype); + _toEpilog(dtype, null); + } - var alreadyHandled = new HashSet(); + protected void _toEpilog(Device device, ScalarType dtype) + { + _toEpilog(dtype, device); + } - foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) { + protected void _toEpilog(DeviceType deviceType, int deviceIndex) + { + _toEpilog(null, new Device(deviceType, deviceIndex)); + } - var fieldName = field.ComponentName(); - var value = field.GetValue(this); + private void _toEpilog(ScalarType? dtype, Device device) + { + foreach (var (_, sm) in named_children()) { + if (device is null) sm._to(dtype.Value); + else if (dtype is null) sm._to(device.type, device.index); + else sm._to(device, dtype.Value); + } - switch (value) { - // This order in which these cases are arranged is significant. - case Parameter param when dtype == param.dtype: - alreadyHandled.Add(param.handle); - continue; - - case Parameter param: { - var t = param.to(dtype); - t.retain_grad(); - var p = new Parameter(t, param.requires_grad); - field.SetValue(this, p); - ConditionallyRegisterParameter(fieldName, p); - alreadyHandled.Add(p.handle); - break; - } + var fieldsByComponentName = GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance) + .ToDictionary(field => field.ComponentName()); - case Tensor tensor when dtype == tensor.dtype: - alreadyHandled.Add(tensor.handle); - continue; + foreach (var (name, param) in named_parameters(false).ToList()) { + if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device)) continue; - case Tensor tensor: { - var t = tensor.to(dtype); - field.SetValue(this, t); - ConditionallyRegisterBuffer(fieldName, t); - alreadyHandled.Add(t.handle); - break; - } - } - } + // Store the requires_grad flag ahead, since we dispose the parameter after moving + bool requiresGrad = param.requires_grad; + Parameter p; + // When moving the parameter, we don't want the autograd to track this movement on the graph. + // In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't + // disable grad we would need to call .detach() on the moved tensor. + using (var d = torch.no_grad()) + p = new Parameter(param.to(dtype ?? param.dtype, device ?? param.device, disposeAfter: true), requiresGrad); + ConditionallyRegisterParameter(name, p); - foreach (var (name, param) in named_parameters(false).ToList()) { - if (alreadyHandled.Contains(param.handle)) continue; - var t = param.to(dtype); - ConditionallyRegisterParameter(name, t); + // If this parameter is a field, set it + if (fieldsByComponentName.TryGetValue(name, out var field)) + field.SetValue(this, p); } foreach (var (name, buffer) in named_buffers(false).ToList()) { - if (alreadyHandled.Contains(buffer.handle)) continue; - var t = buffer.to(dtype); + if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue; + + // Buffers don't get grads so we don't need to detach them afterwards + var t = buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true); ConditionallyRegisterBuffer(name, t); + + if (fieldsByComponentName.TryGetValue(name, out var field)) + field.SetValue(this, t); + } + + if (device is not null) { + _deviceType = device.type; + _deviceIndex = device.index; } } diff --git a/src/TorchSharp/NN/ParameterDict.cs b/src/TorchSharp/NN/ParameterDict.cs index 64e4e3d56..810b89434 100644 --- a/src/TorchSharp/NN/ParameterDict.cs +++ b/src/TorchSharp/NN/ParameterDict.cs @@ -60,6 +60,37 @@ protected override void RegisterComponents() private bool _registered = false; + protected internal override Module _to(DeviceType deviceType, int deviceIndex = -1) + { + base._to(deviceType, deviceIndex); + _toEpilog(); + return this; + } + + protected internal override Module _to(torch.Device device, torch.ScalarType dtype) + { + base._to(device, dtype); + _toEpilog(); + return this; + } + + protected internal override Module _to(torch.ScalarType dtype) + { + base._to(dtype); + _toEpilog(); + return this; + } + + void _toEpilog() + { + for (int i = 0; i < _list.Count; i++) { + string name = _list[i].Item1; + var param = base.get_parameter(name); + _list[i] = (name, param); + _dict[name] = param; + } + } + /// /// Return the ParameterDict values. /// diff --git a/src/TorchSharp/NN/ParameterList.cs b/src/TorchSharp/NN/ParameterList.cs index 129753d35..49dba64f3 100644 --- a/src/TorchSharp/NN/ParameterList.cs +++ b/src/TorchSharp/NN/ParameterList.cs @@ -33,6 +33,35 @@ protected override void RegisterComponents() _registered = true; } + + protected internal override Module _to(DeviceType deviceType, int deviceIndex = -1) + { + base._to(deviceType, deviceIndex); + _toEpilog(); + return this; + } + + protected internal override Module _to(torch.Device device, torch.ScalarType dtype) + { + base._to(device, dtype); + _toEpilog(); + return this; + } + + protected internal override Module _to(torch.ScalarType dtype) + { + base._to(dtype); + _toEpilog(); + return this; + } + + void _toEpilog() + { + for (int i = 0; i < _list.Count; i++) { + _list[i] = base.get_parameter($"{i}"); + } + } + public override IEnumerable<(string name, Parameter parameter)> named_parameters(bool recurse = true) { return Enumerable.Range(0, _list.Count).Select(i => ($"{i}", _list[i])); diff --git a/src/TorchSharp/Tensor/TensorExtensionMethods.cs b/src/TorchSharp/Tensor/TensorExtensionMethods.cs index 90dcbc20d..09d79028f 100644 --- a/src/TorchSharp/Tensor/TensorExtensionMethods.cs +++ b/src/TorchSharp/Tensor/TensorExtensionMethods.cs @@ -704,5 +704,66 @@ public static long TotalSize(this IEnumerable shape) } return result; } + + /// + /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor. + /// + /// The tensor + /// The device to move to + /// True if the tensor will be copied + internal static bool toWillCopy(this Tensor tensor, Device device) + { + return tensor.toWillCopy(device.type, device.index); + } + + /// + /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor. + /// + /// The tensor + /// The device type to move to + /// The device index to move to + /// True if the tensor will be copied + internal static bool toWillCopy(this Tensor tensor, DeviceType deviceType, int deviceIndex) + { + return tensor.device_index != deviceIndex || tensor.device_type != deviceType; + } + + /// + /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor. + /// + /// The tensor + /// The dtype to move to + /// True if the tensor will be copied + internal static bool toWillCopy(this Tensor tensor, ScalarType dtype) + { + return tensor.dtype != dtype; + } + + /// + /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor. + /// + /// The tensor + /// The dtype to move to + /// The device to move to + /// True if the tensor will be copied + internal static bool toWillCopy(this Tensor tensor, ScalarType dtype, Device device) + { + return tensor.toWillCopy(dtype, device.type, device.index); + } + + /// + /// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor. + /// + /// The tensor + /// The dtype to move to + /// The device type to move to + /// The device index to move to + /// True if the tensor will be copied + internal static bool toWillCopy(this Tensor tensor, ScalarType dtype, DeviceType deviceType, int deviceIndex) + { + return tensor.device_index != deviceIndex || tensor.device_type != deviceType || tensor.dtype != dtype; + } + + } } diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index 0330b0128..c07ec02e0 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -3042,6 +3042,117 @@ public void TestDeviceTo() } } + [Fact] + public void TestCustomModuleWithDeviceMove() + { + if (torch.cuda.is_available()) { + var module = new TestModule1(torch.randn(2, 2), true); + + // Move the device to cuda, and make sure gradients are calculated for all the parameters + module.to(torch.CUDA); + var x = torch.randn(2, 2, device: torch.CUDA); + var y = torch.randn(2, device: torch.CUDA); + torch.nn.functional.mse_loss(module.call(x), y).backward(); + foreach (var (pName, parm) in module.named_parameters()) { + var grad = parm.grad(); + Assert.NotNull(grad); + } + + // Reset and then try again with moving back to CPU + module.zero_grad(); + + // Try moving back to CPU + module.to(torch.CPU); + x = torch.randn(2, 2); + y = torch.randn(2); + torch.nn.functional.mse_loss(module.call(x), y).backward(); + foreach (var (pName, parm) in module.named_parameters()) { + var grad = parm.grad(); + Assert.NotNull(grad); + } + } + } + + [Fact] + public void TestCustomModuleWithTypeMove() + { + var module = new TestModule1(torch.randn(2, 2), true); + + // Move the module to 16-bit floats, and make sure gradients are calculated for all the parameters + module.@double(); + var x = torch.randn(2, 2, float64); + var y = torch.randn(2, float64); + torch.nn.functional.mse_loss(module.call(x), y).backward(); + foreach (var (pName, parm) in module.named_parameters()) { + var grad = parm.grad(); + Assert.NotNull(grad); + } + + // Reset and then try again with moving back to float 32 + module.zero_grad(); + + // Try moving back to float 32 + module.@float(); + x = torch.randn(2, 2); + y = torch.randn(2); + torch.nn.functional.mse_loss(module.call(x), y).backward(); + foreach (var (pName, parm) in module.named_parameters()) { + var grad = parm.grad(); + Assert.NotNull(grad); + } + } + [Fact] + public void TestCustomModuleWithDeviceAndTypeMove() + { + if (torch.cuda.is_available()) { + var module = new TestModule1(torch.randn(2, 2), true); + + // Move the device to cuda & float 16, and make sure gradients are calculated for all the parameters + module.to(torch.CUDA, float16); + var x = torch.randn(2, 2, float16, torch.CUDA); + var y = torch.randn(2, float16, torch.CUDA); + torch.nn.functional.mse_loss(module.call(x), y).backward(); + foreach (var (pName, parm) in module.named_parameters()) { + var grad = parm.grad(); + Assert.NotNull(grad); + } + + // Reset and then try again with moving back to CPU & float 32 + module.zero_grad(); + + // Try moving back to CPU & float 32 + module.to(torch.CPU, float32); + x = torch.randn(2, 2); + y = torch.randn(2); + torch.nn.functional.mse_loss(module.call(x), y).backward(); + foreach (var (pName, parm) in module.named_parameters()) { + var grad = parm.grad(); + Assert.NotNull(grad); + } + } + } + + [Fact] + public void TestCustomModuleWithMoveAndDisabledGradOnParameter() + { + var module = new TestModule1(torch.randn(2, 2), true); + // Disable grad on test, and make sure that it is able to move and retains the gradient state + module.get_parameter("test").requires_grad = false; + + // Move the module to 16-bit floats + module.half(); + Assert.False(module.get_parameter("test").requires_grad); + + // Move to a different device + if (torch.cuda.is_available()) { + module.cuda(); + Assert.False(module.get_parameter("test").requires_grad); + + // Try a different device & type + module.to(torch.CPU, float32); + Assert.False(module.get_parameter("test").requires_grad); + } + } [Fact] public void TestCustomComponentName()