diff --git a/RELEASENOTES.md b/RELEASENOTES.md index ac759131b..3e93c0382 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -8,6 +8,8 @@ __Breaking Changes__: __API Changes__: +- #1291 `Tensor.grad()` and `Tensor.set_grad()` have been replaced by a new property `Tensor.grad`. + __Bug Fixes__: diff --git a/src/Examples/AdversarialExampleGeneration.cs b/src/Examples/AdversarialExampleGeneration.cs index 1e558c791..7bfc174b2 100644 --- a/src/Examples/AdversarialExampleGeneration.cs +++ b/src/Examples/AdversarialExampleGeneration.cs @@ -133,7 +133,7 @@ private static double Test( model.zero_grad(); loss.backward(); - var perturbed = Attack(data, ε, data.grad()); + var perturbed = Attack(data, ε, data.grad); using (var final = model.call(perturbed)) { diff --git a/src/FSharp.Examples/AdversarialExampleGeneration.fs b/src/FSharp.Examples/AdversarialExampleGeneration.fs index 09dd6fdf3..46ff364f3 100644 --- a/src/FSharp.Examples/AdversarialExampleGeneration.fs +++ b/src/FSharp.Examples/AdversarialExampleGeneration.fs @@ -79,7 +79,7 @@ let test (model:MNIST.Model) (eps:float) (data:Dataset) size = model.zero_grad() loss.backward() - use perturbed = attack input (eps.ToScalar()) (input.grad()) + use perturbed = attack input (eps.ToScalar()) (input.grad) use final = perturbed --> model correct <- correct + final.argmax(1L).eq(labels).sum().ToInt32() end diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 6d6cca212..f5e49616b 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -240,7 +240,7 @@ private void _toEpilog(ScalarType? dtype, Device device) foreach (var (name, param) in named_parameters(false).ToList()) { if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device) && - (param.grad() is null || !param.grad().toWillCopy(dtype ?? param.dtype, device ?? param.device))) + (param.grad is null || !param.grad.toWillCopy(dtype ?? param.dtype, device ?? param.device))) continue; Parameter p; @@ -256,11 +256,10 @@ private void _toEpilog(ScalarType? dtype, Device device) .DetachFromDisposeScope() as Parameter; // Copy the gradient over as well, if it exists - var grad = param.grad(); + var grad = param.grad; if (grad is not null) { - p.set_grad(grad.to(paramType, device ?? param.device) - .with_requires_grad(grad.requires_grad) - .MoveToOtherDisposeScope(p)); + p.grad = grad.to(paramType, device ?? param.device) + .with_requires_grad(grad.requires_grad); } // Dispose the param and gradient @@ -360,10 +359,10 @@ public virtual void zero_grad(bool set_to_none = true) CheckForErrors(); foreach (var (_, p) in named_parameters()) { - var grad = p.grad(); + var grad = p.grad; if (grad is not null) { if (set_to_none) { - p.set_grad(null); + p.grad = null; grad.DetachFromDisposeScope().Dispose(); } else { grad.zero_(); diff --git a/src/TorchSharp/Optimizers/ASGD.cs b/src/TorchSharp/Optimizers/ASGD.cs index c4ed54ead..260810aa0 100644 --- a/src/TorchSharp/Optimizers/ASGD.cs +++ b/src/TorchSharp/Optimizers/ASGD.cs @@ -145,7 +145,7 @@ public override Tensor step(Func closure = null) foreach (var param in group.Parameters) { - var grad = param.grad(); + var grad = param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/Adadelta.cs b/src/TorchSharp/Optimizers/Adadelta.cs index e262ceda9..8e2027a1a 100644 --- a/src/TorchSharp/Optimizers/Adadelta.cs +++ b/src/TorchSharp/Optimizers/Adadelta.cs @@ -136,7 +136,7 @@ public override Tensor step(Func closure = null) foreach (var param in group.Parameters) { - var grad = (maximize) ? -param.grad() : param.grad(); + var grad = (maximize) ? -param.grad : param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/Adagrad.cs b/src/TorchSharp/Optimizers/Adagrad.cs index 7abdab0eb..a4d4b70fc 100644 --- a/src/TorchSharp/Optimizers/Adagrad.cs +++ b/src/TorchSharp/Optimizers/Adagrad.cs @@ -147,7 +147,7 @@ public override Tensor step(Func closure = null) var state = (State)_state[param.handle]; - var grad = param.grad(); + var grad = param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/Adam.cs b/src/TorchSharp/Optimizers/Adam.cs index 8233e05b9..39156f6a6 100644 --- a/src/TorchSharp/Optimizers/Adam.cs +++ b/src/TorchSharp/Optimizers/Adam.cs @@ -164,7 +164,7 @@ public override Tensor step(Func closure = null) var state = (State)_state[param.handle]; - var grad = (maximize) ? -param.grad() : param.grad(); + var grad = (maximize) ? -param.grad : param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/AdamW.cs b/src/TorchSharp/Optimizers/AdamW.cs index 225c6660e..23624369b 100644 --- a/src/TorchSharp/Optimizers/AdamW.cs +++ b/src/TorchSharp/Optimizers/AdamW.cs @@ -164,7 +164,7 @@ public override Tensor step(Func closure = null) var state = (State)_state[param.handle]; - var grad = (maximize) ? -param.grad() : param.grad(); + var grad = (maximize) ? -param.grad : param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/Adamax.cs b/src/TorchSharp/Optimizers/Adamax.cs index 4e421a8ee..e09ef9170 100644 --- a/src/TorchSharp/Optimizers/Adamax.cs +++ b/src/TorchSharp/Optimizers/Adamax.cs @@ -148,7 +148,7 @@ public override Tensor step(Func closure = null) foreach (var param in group.Parameters) { - var grad = param.grad(); + var grad = param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/NAdam.cs b/src/TorchSharp/Optimizers/NAdam.cs index 251b22e7d..6118cc5d1 100644 --- a/src/TorchSharp/Optimizers/NAdam.cs +++ b/src/TorchSharp/Optimizers/NAdam.cs @@ -154,7 +154,7 @@ public override Tensor step(Func closure = null) foreach (var param in group.Parameters) { - var grad = param.grad(); + var grad = param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/Optimizer.cs b/src/TorchSharp/Optimizers/Optimizer.cs index 79b542de8..002aa7fea 100644 --- a/src/TorchSharp/Optimizers/Optimizer.cs +++ b/src/TorchSharp/Optimizers/Optimizer.cs @@ -399,7 +399,7 @@ public override void zero_grad() foreach (var p in g.Parameters) { - using var grad = p.grad(); + using var grad = p.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/RAdam.cs b/src/TorchSharp/Optimizers/RAdam.cs index 4b8690839..d64416196 100644 --- a/src/TorchSharp/Optimizers/RAdam.cs +++ b/src/TorchSharp/Optimizers/RAdam.cs @@ -147,7 +147,7 @@ public override Tensor step(Func closure = null) foreach (var param in group.Parameters) { - var grad = param.grad(); + var grad = param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/RMSprop.cs b/src/TorchSharp/Optimizers/RMSprop.cs index 2ce65d874..9bc77f95f 100644 --- a/src/TorchSharp/Optimizers/RMSprop.cs +++ b/src/TorchSharp/Optimizers/RMSprop.cs @@ -162,7 +162,7 @@ public override Tensor step(Func closure = null) var state = (State)_state[param.handle]; - var grad = param.grad(); + var grad = param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/Rprop.cs b/src/TorchSharp/Optimizers/Rprop.cs index 1bcf41f2b..47e01d982 100644 --- a/src/TorchSharp/Optimizers/Rprop.cs +++ b/src/TorchSharp/Optimizers/Rprop.cs @@ -144,7 +144,7 @@ public override Tensor step(Func closure = null) foreach (var param in group.Parameters) { - var grad = param.grad(); + var grad = param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Optimizers/SGD.cs b/src/TorchSharp/Optimizers/SGD.cs index 87cacdadd..fed1f912b 100644 --- a/src/TorchSharp/Optimizers/SGD.cs +++ b/src/TorchSharp/Optimizers/SGD.cs @@ -150,7 +150,7 @@ public override Tensor step(Func closure = null) var state = (State)_state[param.handle]; - var grad = param.grad(); + var grad = param.grad; if (grad is null) continue; diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index 8fe5fe013..4c53ff6df 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -1340,25 +1340,21 @@ public Tensor pin_memory() /// This attribute is null by default and becomes a Tensor the first time a call to backward() computes gradients for the tensor. /// The attribute will then contain the gradients computed and future calls to backward() will accumulate (add) gradients into it. /// - public Tensor? grad() - { - var res = NativeMethods.THSTensor_grad(Handle); - CheckForErrors(); - - if (res == IntPtr.Zero) - return null; + public Tensor? grad { + get { + var res = NativeMethods.THSTensor_grad(Handle); + CheckForErrors(); - return new Tensor(res); - } + if (res == IntPtr.Zero) + return null; - /// - /// This function will set the `tensor.grad()` attribute to a custom tensor. - /// - /// The new gradient tensor - public void set_grad(Tensor grad) - { - NativeMethods.THSTensor_set_grad(Handle, grad?.DetachFromDisposeScope().Handle ?? IntPtr.Zero); - CheckForErrors(); + return new Tensor(res); + } + set { + value?.DetachFromDisposeScope(); + NativeMethods.THSTensor_set_grad(Handle, value?.Handle ?? IntPtr.Zero); + CheckForErrors(); + } } internal void EncodeIndices(TensorIndex[] indices, diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index cd59ad5b9..e448d2982 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -2152,14 +2152,14 @@ public void TestBackward() output.backward(); foreach (var parm in seq.parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } seq.zero_grad(); foreach (var parm in seq.parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.True(grad is null || grad!.count_nonzero().item() == 0); } } @@ -2186,14 +2186,14 @@ public void TestGettingParameters() output.backward(); foreach (var parm in seq.parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } seq.zero_grad(); foreach (var parm in seq.parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.True(grad is null || grad!.count_nonzero().item() == 0); } } @@ -2220,14 +2220,14 @@ public void TestGrad() output.backward(); foreach (var parm in seq.parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } seq.zero_grad(); foreach (var parm in seq.parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.True(grad is null || grad!.count_nonzero().item() == 0); } } @@ -2254,9 +2254,9 @@ public void TestGrad2() output.backward(); - var scalerGrad = scaler.grad(); - var weightGrad = linear.weight.grad(); - var biasGrad = linear.bias.grad(); + var scalerGrad = scaler.grad; + var weightGrad = linear.weight.grad; + var biasGrad = linear.bias.grad; Assert.True(scalerGrad is not null && scalerGrad.shape.Length == 2); Assert.True(weightGrad is not null && weightGrad.shape.Length == 2); Assert.True(biasGrad is not null && biasGrad.shape.Length == 2); @@ -2328,7 +2328,7 @@ public void TestGradConditional() var gradCounts = 0; foreach (var (name, parm) in modT.named_parameters()) { - var grad = parm.grad(); + var grad = parm.grad; gradCounts += grad is not null ? (grad.Handle == IntPtr.Zero ? 0 : 1) : 0; } @@ -2346,7 +2346,7 @@ public void TestGradConditional() gradCounts = 0; foreach (var parm in modF.parameters()) { - var grad = parm.grad(); + var grad = parm.grad; gradCounts += grad is not null ? (grad.Handle == IntPtr.Zero ? 0 : 1) : 0; } @@ -2839,14 +2839,14 @@ public void TestCustomModule1() output.backward(); foreach (var (pName, parm) in module.named_parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } module.zero_grad(); foreach (var (pName, parm) in module.named_parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.True(grad is null || grad!.count_nonzero().item() == 0); } @@ -3016,7 +3016,7 @@ public void TestDerivedSequence1Grad() output.backward(); foreach (var parm in seq.parameters()) { - var grad = parm.grad(); + var grad = parm.grad; } } @@ -3037,7 +3037,7 @@ public void TestDerivedSequence2Grad() output.backward(); foreach (var parm in seq.parameters()) { - var grad = parm.grad(); + var grad = parm.grad; } } @@ -3121,7 +3121,7 @@ public void TestCustomModuleWithDeviceMove() var y = torch.randn(2, device: torch.CUDA); torch.nn.functional.mse_loss(module.call(x), y).backward(); foreach (var (pName, parm) in module.named_parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } @@ -3134,7 +3134,7 @@ public void TestCustomModuleWithDeviceMove() y = torch.randn(2); torch.nn.functional.mse_loss(module.call(x), y).backward(); foreach (var (pName, parm) in module.named_parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } } @@ -3151,7 +3151,7 @@ public void TestCustomModuleWithTypeMove() var y = torch.randn(2, float64); torch.nn.functional.mse_loss(module.call(x), y).backward(); foreach (var (pName, parm) in module.named_parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } @@ -3164,7 +3164,7 @@ public void TestCustomModuleWithTypeMove() y = torch.randn(2); torch.nn.functional.mse_loss(module.call(x), y).backward(); foreach (var (pName, parm) in module.named_parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } } @@ -3180,7 +3180,7 @@ public void TestCustomModuleWithDeviceAndTypeMove() var y = torch.randn(2, float16, torch.CUDA); torch.nn.functional.mse_loss(module.call(x), y).backward(); foreach (var (pName, parm) in module.named_parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } @@ -3193,7 +3193,7 @@ public void TestCustomModuleWithDeviceAndTypeMove() y = torch.randn(2); torch.nn.functional.mse_loss(module.call(x), y).backward(); foreach (var (pName, parm) in module.named_parameters()) { - var grad = parm.grad(); + var grad = parm.grad; Assert.NotNull(grad); } } diff --git a/test/TorchSharpTest/TestAutogradFunction.cs b/test/TorchSharpTest/TestAutogradFunction.cs index 2c137116f..f778d45d2 100644 --- a/test/TorchSharpTest/TestAutogradFunction.cs +++ b/test/TorchSharpTest/TestAutogradFunction.cs @@ -21,8 +21,8 @@ private void TestCustomLinearFunction(Device device, bool requires_grad) y.sum().backward(); - Assert.NotNull(x.grad()); - Assert.NotNull(weight.grad()); + Assert.NotNull(x.grad); + Assert.NotNull(weight.grad); } @@ -37,9 +37,9 @@ private void TestCustomTwoInputLinearFunction(Device device, bool requires_grad) (y[0].sum() + y[1].sum()).backward(); - Assert.NotNull(x1.grad()); - Assert.NotNull(x2.grad()); - Assert.NotNull(weight.grad()); + Assert.NotNull(x1.grad); + Assert.NotNull(x2.grad); + Assert.NotNull(weight.grad); } private void TestCustomTwoInputOneGradientLinearFunction(Device device, bool requires_grad) @@ -53,9 +53,9 @@ private void TestCustomTwoInputOneGradientLinearFunction(Device device, bool req (y[0].sum() + y[1].sum()).backward(); - Assert.NotNull(x1.grad()); - Assert.NotNull(x2.grad()); - Assert.Null(weight.grad()); + Assert.NotNull(x1.grad); + Assert.NotNull(x2.grad); + Assert.Null(weight.grad); } private float TrainXOR(Device device) @@ -203,8 +203,8 @@ private void TestCustomLinearFunctionWithGC() y.sum().backward(); - Assert.NotNull(x.grad()); - Assert.NotNull(weight.grad()); + Assert.NotNull(x.grad); + Assert.NotNull(weight.grad); } [Fact] @@ -214,7 +214,7 @@ private void TestBackwardWithPartialGradInput() var y = MulConstantFunction.apply(x, 2.0); y.sum().backward(); - Assert.NotNull(x.grad()); + Assert.NotNull(x.grad); } class MulConstantFunction : torch.autograd.SingleTensorFunction diff --git a/test/TorchSharpTest/TestNNUtils.cs b/test/TorchSharpTest/TestNNUtils.cs index ddf35bbfe..e01fe3bdf 100644 --- a/test/TorchSharpTest/TestNNUtils.cs +++ b/test/TorchSharpTest/TestNNUtils.cs @@ -80,10 +80,10 @@ public void TestAutoGradBackward1() var y = x1.pow(2) + 5 * x2; torch.autograd.backward(new[] { y }, new[] { torch.ones_like(y) }); - Assert.Equal(x1.shape, x1.grad().shape); - Assert.Equal(x2.shape, x2.grad().shape); - Assert.Equal(2.0f*x1.item(), x1.grad().item()); - Assert.Equal(5.0f, x2.grad().item()); + Assert.Equal(x1.shape, x1.grad.shape); + Assert.Equal(x2.shape, x2.grad.shape); + Assert.Equal(2.0f*x1.item(), x1.grad.item()); + Assert.Equal(5.0f, x2.grad.item()); } [Fact] @@ -96,10 +96,10 @@ public void TestAutoGradBackward2() var y = x1.pow(2) + 5 * x2; y.backward(new[] { torch.ones_like(y) }); - Assert.Equal(x1.shape, x1.grad().shape); - Assert.Equal(x2.shape, x2.grad().shape); - Assert.Equal(2.0f * x1.item(), x1.grad().item()); - Assert.Equal(5.0f, x2.grad().item()); + Assert.Equal(x1.shape, x1.grad.shape); + Assert.Equal(x2.shape, x2.grad.shape); + Assert.Equal(2.0f * x1.item(), x1.grad.item()); + Assert.Equal(5.0f, x2.grad.item()); } [Fact] diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs index 60f06d812..9bbbb2a34 100644 --- a/test/TorchSharpTest/TestTorchTensor.cs +++ b/test/TorchSharpTest/TestTorchTensor.cs @@ -4429,7 +4429,7 @@ public void AutoGradMode() Assert.True(torch.is_grad_enabled()); var sum = x.sum(); sum.backward(); - var grad = x.grad(); + var grad = x.grad; Assert.False(grad is null || grad.Handle == IntPtr.Zero); var data = grad is not null ? grad.data().ToArray() : new float[] { }; for (int i = 0; i < 2 * 3; i++) { @@ -4448,7 +4448,7 @@ public void AutoGradMode() Assert.True(torch.is_grad_enabled()); var sum = x.sum(); sum.backward(); - var grad = x.grad(); + var grad = x.grad; Assert.False(grad is not null && grad.Handle == IntPtr.Zero); var data = grad is not null ? grad.data().ToArray() : new float[] { }; for (int i = 0; i < 2 * 3; i++) { diff --git a/test/TorchSharpTest/TestTorchTensorBugs.cs b/test/TorchSharpTest/TestTorchTensorBugs.cs index e2b65ac13..7c3a44fcc 100644 --- a/test/TorchSharpTest/TestTorchTensorBugs.cs +++ b/test/TorchSharpTest/TestTorchTensorBugs.cs @@ -571,10 +571,10 @@ public void ValidateIssue516() loss.backward(); optimizer.step(); - var grad1 = optimizer.parameters().ToArray()[0].grad(); + var grad1 = optimizer.parameters().ToArray()[0].grad; Assert.NotNull(grad1); - var grad2 = model.Weight.grad(); + var grad2 = model.Weight.grad; Assert.NotNull(grad2); } } @@ -1198,7 +1198,7 @@ public void Validate1116_1() example.backward(); - var grads = x1.grad(); + var grads = x1.grad; Assert.True(x1.requires_grad); Assert.NotNull(grads); } @@ -1217,7 +1217,7 @@ public void Validate1116_2() example.backward(); - var grads = x1.grad(); + var grads = x1.grad; Assert.True(x1.requires_grad); Assert.NotNull(grads); } @@ -1235,7 +1235,7 @@ public void Validate1116_3() example.backward(); - var grads = x1.grad(); + var grads = x1.grad; Assert.True(x1.requires_grad); Assert.NotNull(grads); } @@ -1253,7 +1253,7 @@ public void Validate1116_4() example.backward(); - var grads = x1.grad(); + var grads = x1.grad; Assert.True(x1.requires_grad); Assert.NotNull(grads); } @@ -1272,7 +1272,7 @@ public void Validate1116_5() example.backward(); - var grads = x1.grad(); + var grads = x1.grad; Assert.True(x1.requires_grad); Assert.NotNull(grads); } @@ -1291,7 +1291,7 @@ public void Validate1116_6() example.backward(); - var grads = x1.grad(); + var grads = x1.grad; Assert.True(x1.requires_grad); Assert.NotNull(grads); } @@ -1517,14 +1517,14 @@ public void Validate_1191_1() // Build graph 1 on CUDA torch.nn.functional.mse_loss(module.forward(torch.rand(10).cuda()), torch.rand(10).cuda()).backward(); - Assert.Equal(DeviceType.CUDA, module.ln.weight!.grad()!.device_type); - Assert.Equal(DeviceType.CUDA, module.p.grad()!.device_type); + Assert.Equal(DeviceType.CUDA, module.ln.weight!.grad!.device_type); + Assert.Equal(DeviceType.CUDA, module.p.grad!.device_type); // Move to CPU module.to(torch.CPU); - Assert.Equal(DeviceType.CPU, module.ln.weight!.grad()!.device_type); - Assert.Equal(DeviceType.CPU, module.p.grad()!.device_type); + Assert.Equal(DeviceType.CPU, module.ln.weight!.grad!.device_type); + Assert.Equal(DeviceType.CPU, module.p.grad!.device_type); // Build graph 2 on CPU. // This should've crashed, saying something about the gradients being on the wrong device. @@ -1545,7 +1545,7 @@ public void Validate_1191_2() var resultBatch = rand(32, 1).to(aDevice); aModule.to(aDevice); foreach (var (name, p) in aModule.named_parameters()) { - Console.WriteLine($"{name} {p.device} {p.grad()}"); + Console.WriteLine($"{name} {p.device} {p.grad}"); } var aMseLoss = nn.MSELoss(); var optimizer = torch.optim.AdamW(aModule.parameters()); @@ -1570,7 +1570,7 @@ public void Validate_1191_2() aModule.to(aDevice); aModule.zero_grad(); foreach (var (name, p) in aModule.named_parameters()) { - Console.WriteLine($"{name} {p.device} {p.grad()}"); + Console.WriteLine($"{name} {p.device} {p.grad}"); } var aMseLoss = nn.MSELoss(); var optimizer = torch.optim.AdamW(aModule.parameters()); @@ -1600,18 +1600,18 @@ public void Validate_1191_3() module.zero_grad(); - Assert.Null(module.p.grad()); - Assert.Null(module.ln.weight!.grad()); - Assert.Null(module.ln.bias!.grad()); + Assert.Null(module.p.grad); + Assert.Null(module.ln.weight!.grad); + Assert.Null(module.ln.bias!.grad); // Build graph again, this time convert gradients to zero torch.nn.functional.mse_loss(module.forward(torch.rand(10)), torch.rand(10)).backward(); module.zero_grad(false); - Assert.NotNull(module.p.grad()); - Assert.NotNull(module.ln.weight!.grad()); - Assert.NotNull(module.ln.bias!.grad()); + Assert.NotNull(module.p.grad); + Assert.NotNull(module.ln.weight!.grad); + Assert.NotNull(module.ln.bias!.grad); } diff --git a/test/TorchSharpTest/TestTraining.cs b/test/TorchSharpTest/TestTraining.cs index 877a68ffd..1ff1336a2 100644 --- a/test/TorchSharpTest/TestTraining.cs +++ b/test/TorchSharpTest/TestTraining.cs @@ -55,7 +55,7 @@ public void Training1() using (torch.no_grad()) { foreach (var param in seq.parameters()) { - var grad = param.grad(); + var grad = param.grad; if (grad is not null) { var update = grad.mul(learning_rate); param.sub_(update); @@ -99,7 +99,7 @@ public void TrainingWithDropout() using (torch.no_grad()) { foreach (var param in seq.parameters()) { - var grad = param.grad(); + var grad = param.grad; if (grad is not null) { var update = grad.mul(learning_rate); param.sub_(update);