Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/Native/LibTorchSharp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,3 @@ const char * make_sharable_string(const std::string str)
return result;
}

Tensor ResultTensor(const at::Tensor& res)
{
if (res.defined())
return new torch::Tensor(res);
else
return NULL;
}
8 changes: 7 additions & 1 deletion src/Native/LibTorchSharp/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ typedef std::shared_ptr<torch::optim::Optimizer> * Optimizer;
#define CATCH_RETURN_Tensor(stmt) CATCH_RETURN_RES(Tensor, NULL, stmt)

// Return undefined tensors as NULL to C#
Tensor ResultTensor(const at::Tensor & res);
inline Tensor ResultTensor(const at::Tensor & res)
{
if (res.defined())
return new torch::Tensor(res);
else
return NULL;
}

#define CATCH_TENSOR(expr) \
at::Tensor res = at::Tensor(); \
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/AdaptiveAvgPool2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ internal AdaptiveAvgPool2D (IntPtr handle, IntPtr boxedHandle) : base (handle, b
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_AdaptiveAvgPool2d_forward (handle.DangerousGetHandle (), tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}
}
Expand All @@ -34,7 +34,7 @@ static public AdaptiveAvgPool2D AdaptiveAvgPool2D (long[] kernelSize)
unsafe {
fixed (long* pkernelSize = kernelSize) {
var handle = THSNN_AdaptiveAvgPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, out var boxedHandle);
Torch.CheckForErrors ();
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
return new AdaptiveAvgPool2D (handle, boxedHandle);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/AvgPool2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ internal AvgPool2D (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHand
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_AvgPool2d_forward (handle.DangerousGetHandle (), tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}
}
Expand All @@ -34,7 +34,7 @@ static public AvgPool2D AvgPool2D (long[] kernelSize, long[] strides = null)
unsafe {
fixed (long* pkernelSize = kernelSize, pstrides = strides) {
var handle = THSNN_AvgPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), out var boxedHandle);
Torch.CheckForErrors ();
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
return new AvgPool2D (handle, boxedHandle);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/Conv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal Conv2D (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_Conv2d_forward (handle, tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}
}
Expand All @@ -27,7 +27,7 @@ public static partial class Modules
static public Conv2D Conv2D (long inputChannel, long outputChannel, long kernelSize, long stride = 1, long padding = 0)
{
var res = THSNN_Conv2d_ctor (inputChannel, outputChannel, kernelSize, stride, padding, out var boxedHandle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new Conv2D (res, boxedHandle);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/Dropout.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal Dropout (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_Dropout_forward (handle, tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}
}
Expand All @@ -30,7 +30,7 @@ public static partial class Modules
static public Dropout Dropout (double probability = 0.5)
{
var handle = THSNN_Dropout_ctor (probability, out var boxedHandle);
Torch.CheckForErrors ();
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
return new Dropout (handle, boxedHandle);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/FeatureDropout.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ internal FeatureAlphaDropout (IntPtr handle, IntPtr boxedHandle) : base (handle,
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_FeatureAlphaDropout_forward (handle, tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}
}
Expand All @@ -32,7 +32,7 @@ public static partial class Modules
static public FeatureAlphaDropout FeatureAlphaDropout (double probability = 0.5)
{
var handle = THSNN_FeatureAlphaDropout_ctor (probability, out var boxedHandle);
Torch.CheckForErrors ();
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
return new FeatureAlphaDropout (handle, boxedHandle);
}
}
Expand Down
9 changes: 4 additions & 5 deletions src/TorchSharp/NN/Linear.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ internal Linear (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
public new static Linear Load (String modelPath)
{
var res = Module.Load (modelPath);
Torch.CheckForErrors ();
return new Linear (res.handle.DangerousGetHandle(), IntPtr.Zero);
}

Expand All @@ -25,7 +24,7 @@ internal Linear (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_Linear_forward (handle, tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}
[DllImport ("LibTorchSharp")]
Expand All @@ -36,7 +35,7 @@ public TorchTensor Forward (TorchTensor tensor)
public TorchTensor? Bias {
get {
var res = THSNN_Linear_bias (handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return ((res == IntPtr.Zero) ? null : new TorchTensor (res));
}
set {
Expand All @@ -52,7 +51,7 @@ public TorchTensor? Bias {
public TorchTensor Weight {
get {
var res = THSNN_Linear_weight (handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}
set {
Expand All @@ -69,7 +68,7 @@ public static partial class Modules
static public Linear Linear (long inputSize, long outputSize, bool hasBias = true)
{
var res = THSNN_Linear_ctor (inputSize, outputSize, hasBias, out var boxedHandle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new Linear (res, boxedHandle);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/LogSoftMax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ internal LogSoftMax (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHan
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_LogSoftMax_forward (handle, tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}
}
Expand All @@ -32,7 +32,7 @@ public static partial class Modules
static public LogSoftMax LogSoftMax (long dimension)
{
var handle = THSNN_LogSoftMax_ctor (dimension, out var boxedHandle);
Torch.CheckForErrors ();
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
return new LogSoftMax (handle, boxedHandle);
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/TorchSharp/NN/Losses.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public static Loss BCE (TorchTensor? weigths = null, Reduction reduction = Reduc
{
return (TorchTensor src, TorchTensor target) => {
var res = THSNN_binary_cross_entropy (src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
};
}
Expand All @@ -34,7 +34,7 @@ public static Loss MSE (Reduction reduction = Reduction.Mean)
{
return (TorchTensor src, TorchTensor target) => {
var res = THSNN_mse_loss (src.Handle, target.Handle, (long)reduction);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
};
}
Expand All @@ -46,7 +46,7 @@ public static Loss NLL (TorchTensor? weigths = null, Reduction reduction = Reduc
{
return (TorchTensor src, TorchTensor target) => {
var res = THSNN_nll_loss (src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
};
}
Expand All @@ -58,7 +58,7 @@ public static Loss PoissonNLL (bool logInput = true, bool full = false, float ep
{
return (TorchTensor src, TorchTensor target) => {
var res = THSNN_poisson_loss (src.Handle, target.Handle, logInput, full, eps, (long)reduction);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
};
}
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/MaxPool2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ internal MaxPool2D (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHand
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_MaxPool2d_forward (handle, tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}
}
Expand All @@ -34,7 +34,7 @@ static public MaxPool2D MaxPool2D (long[] kernelSize, long[] strides = null)
unsafe {
fixed (long* pkernelSize = kernelSize, pstrides = strides) {
var handle = THSNN_MaxPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), out var boxedHandle);
Torch.CheckForErrors ();
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
return new MaxPool2D (handle, boxedHandle);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/NN/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ protected void Dispose (bool disposing)
public static Module Load(String location)
{
var handle = THSNN_Module_load (location);
Torch.CheckForErrors ();
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
return new Module (handle, IntPtr.Zero);
}

Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/Optimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public static Optimizer Adam (IEnumerable<TorchTensor> parameters, double learni
IntPtr paramsRef = parray.CreateArray (parameters.Select (p => p.Handle).ToArray ());

var res = THSNN_Adam_ctor (paramsRef, parray.Array.Length, learningRate);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new Optimizer (res);
}

Expand All @@ -100,7 +100,7 @@ public static Optimizer SGD (IEnumerable<TorchTensor> parameters, double learnin
IntPtr paramsRef = parray.CreateArray (parameters.Select (p => p.Handle).ToArray ());

var res = THSNN_SGD_ctor (paramsRef, parray.Array.Length, learningRate, momentum);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new Optimizer (res);
}

Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/ReLu.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal ReLU (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle) {
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_ReLU_forward (handle, tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}

Expand All @@ -36,7 +36,7 @@ public static partial class Modules
static public ReLU Relu (bool inPlace = false)
{
var handle = THSNN_ReLU_ctor (inPlace, out var boxedHandle);
Torch.CheckForErrors ();
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
return new ReLU (handle, boxedHandle);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/Sequential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ internal Sequential (IntPtr handle) : base (handle, IntPtr.Zero)
public TorchTensor Forward (TorchTensor tensor)
{
var res = THSNN_Sequential_forward (handle, tensor.Handle);
Torch.CheckForErrors ();
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
return new TorchTensor (res);
}

Expand All @@ -52,7 +52,7 @@ public static partial class Modules
static public Sequential Sequential (params (string name, Module submodule)[] modules)
{
var handle = THSNN_Sequential_ctor ();
Torch.CheckForErrors ();
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
var res = new Sequential (handle);
foreach (var module in modules)
res.Add(module.name, module.submodule);
Expand Down
Loading