-
Notifications
You must be signed in to change notification settings - Fork 213
Closed
Description
TorchSharp/src/TorchSharp/NN/Losses.cs
Line 422 in ba2fa75
if ((variance < 0).any().item<bool>()) throw new ArgumentException("variance has negative entry/entries"); |
I had to change
if ((variance < 0).any().item<bool>())
throw new ArgumentException("variance has negative entry/entries");
into
if ((variance < 0).any().to(DeviceType.CPU).item<bool>())
throw new ArgumentException("variance has negative entry/entries");
in order to make it work on a GPU, because one can't extract an item()
from tensor unless the tensor resides on CPU.
However when looking at this code, I think this whole line is unnecessary (if variance
is less than zero it gets clamped to eps
in the next line) as well as slow, as it requires synchronizing GPU state and copying data back to CPU which is inefficient and causes extra I/O latency.
Also
variance = variance.clone().maximum(torch.tensor(eps))
causes an unnecessary cloning of tensor (maximum
is supposed to return a new tensor anyway) and tensor allocation for eps
.
I suggest to replace that with:
variance = variance.clamp_min(eps);
Metadata
Metadata
Assignees
Labels
No labels