diff --git a/src/Native/LibTorchSharp/THSConvolution.cpp b/src/Native/LibTorchSharp/THSConvolution.cpp index e1500d939..fb51e4c91 100644 --- a/src/Native/LibTorchSharp/THSConvolution.cpp +++ b/src/Native/LibTorchSharp/THSConvolution.cpp @@ -408,218 +408,6 @@ Tensor THSNN_FractionalMaxPool3d_forward_with_indices(const NNModule module, con return ResultTensor(std::get<0>(res)); } -NNModule THSNN_ZeroPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ZeroPad2dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ZeroPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ZeroPad2dOptions({ padding_left, padding_right, padding_top, padding_bottom }); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ZeroPad2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ConstantPad1d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad1dOptions(padding, value); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ConstantPad1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ConstantPad2d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad2dOptions(padding, value); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ConstantPad2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ConstantPad3d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad3dOptions(padding, value); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ConstantPad3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ConstantPad1d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad1dOptions({ padding_left, padding_right }, value); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ConstantPad2d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad2dOptions({ padding_left, padding_right, padding_top, padding_bottom }, value); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ConstantPad3d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ConstantPad3dOptions({ padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back }, value); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReplicationPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad1dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReplicationPad1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReplicationPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad2dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReplicationPad2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReplicationPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad3dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - - -Tensor THSNN_ReplicationPad3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReplicationPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad1dOptions({ padding_left, padding_right }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReplicationPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad2dOptions({ padding_left, padding_right, padding_top, padding_bottom }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReplicationPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReplicationPad3dOptions({ padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReflectionPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad1dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReflectionPad1d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReflectionPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad2dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReflectionPad2d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReflectionPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad3dOptions(padding); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_ReflectionPad3d_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -NNModule THSNN_ReflectionPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad1dOptions({ padding_left, padding_right }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReflectionPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad2dOptions({ padding_left, padding_right, padding_top, padding_bottom }); - res = create_module(opts, outAsAnyModule); - ); -} - -NNModule THSNN_ReflectionPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::ReflectionPad3dOptions({ padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back }); - res = create_module(opts, outAsAnyModule); - ); -} - - template void ApplyPaddingMode(T& opts, const int64_t padding) { diff --git a/src/Native/LibTorchSharp/THSNN.cpp b/src/Native/LibTorchSharp/THSNN.cpp index 6f9ca1742..b522c6a74 100644 --- a/src/Native/LibTorchSharp/THSNN.cpp +++ b/src/Native/LibTorchSharp/THSNN.cpp @@ -16,40 +16,6 @@ Tensor THSNN_Identity_forward(const NNModule module, const Tensor tensor) CATCH_TENSOR((*module)->as()->forward(*tensor)); } -NNModule THSNN_Linear_ctor(const int64_t input_size, const int64_t output_size, const bool bias, - NNAnyModule* outAsAnyModule) -{ - CATCH_RETURN_NNModule( - auto opts = torch::nn::LinearOptions(input_size, output_size).bias(bias); - res = create_module(opts, outAsAnyModule); - ); -} - -Tensor THSNN_Linear_forward(const NNModule module, const Tensor tensor) -{ - CATCH_TENSOR((*module)->as()->forward(*tensor)); -} - -Tensor THSNN_Linear_bias(const NNModule module) -{ - return get_bias(module); -} - -void THSNN_Linear_set_bias(const NNModule module, const Tensor bias) -{ - set_bias(module, bias); -} - -Tensor THSNN_Linear_weight(const NNModule module) -{ - return get_weight(module); -} - -void THSNN_Linear_set_weight(const NNModule module, const Tensor weight) -{ - set_weight(module, weight); -} - Tensor THSNN_functional_linear(const Tensor input, const Tensor weights, const Tensor bias) { CATCH_TENSOR(bias == nullptr ? diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h index 330ceb1d2..7d7b3fc0f 100644 --- a/src/Native/LibTorchSharp/THSNN.h +++ b/src/Native/LibTorchSharp/THSNN.h @@ -94,42 +94,6 @@ EXPORT_API(Tensor) THSNN_LPPool1d_forward(const NNModule module, const Tensor EXPORT_API(NNModule) THSNN_LPPool2d_ctor(double norm_type, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, bool ceil_mode, NNAnyModule* outAsAnyModule); EXPORT_API(Tensor) THSNN_LPPool2d_forward(const NNModule module, const Tensor tensor); -// Padding - -EXPORT_API(NNModule) THSNN_ZeroPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ZeroPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ZeroPad2d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_ConstantPad1d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ConstantPad1d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ConstantPad1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ConstantPad2d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ConstantPad2d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ConstantPad2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ConstantPad3d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ConstantPad3d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ConstantPad3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_ReplicationPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReplicationPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReplicationPad1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ReplicationPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReplicationPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReplicationPad2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ReplicationPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReplicationPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReplicationPad3d_forward(const NNModule module, const Tensor tensor); - -EXPORT_API(NNModule) THSNN_ReflectionPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReflectionPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReflectionPad1d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ReflectionPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReflectionPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReflectionPad2d_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_ReflectionPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule); -EXPORT_API(NNModule) THSNN_ReflectionPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_ReflectionPad3d_forward(const NNModule module, const Tensor tensor); - // Convolution EXPORT_API(NNModule) THSNN_Conv1d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule); @@ -325,12 +289,6 @@ EXPORT_API(Tensor) THSNN_unfold(const Tensor input, const int64_t kernel1, const EXPORT_API(NNModule) THSNN_Identity_ctor(NNAnyModule* outAsAnyModule); EXPORT_API(Tensor) THSNN_Identity_forward(const NNModule module, const Tensor tensor); -EXPORT_API(NNModule) THSNN_Linear_ctor(const int64_t input_size, const int64_t output_size, const bool with_bias, NNAnyModule* outAsAnyModule); -EXPORT_API(Tensor) THSNN_Linear_forward(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_Linear_bias(const NNModule module); -EXPORT_API(void) THSNN_Linear_set_bias(const NNModule module, const Tensor tensor); -EXPORT_API(Tensor) THSNN_Linear_weight(const NNModule module); -EXPORT_API(void) THSNN_Linear_set_weight(const NNModule module, const Tensor tensor); EXPORT_API(Tensor) THSNN_functional_linear(const Tensor input, const Tensor weights, const Tensor bias); EXPORT_API(Tensor) THSNN_functional_bilinear(const Tensor input1, const Tensor input2, const Tensor weights, const Tensor bias); diff --git a/src/TorchSharp/NN/Activation/CELU.cs b/src/TorchSharp/NN/Activation/CELU.cs index 0af37f6b6..85ba92195 100644 --- a/src/TorchSharp/NN/Activation/CELU.cs +++ b/src/TorchSharp/NN/Activation/CELU.cs @@ -31,7 +31,9 @@ public override string GetName() // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Activation/ELU.cs b/src/TorchSharp/NN/Activation/ELU.cs index 365280f87..4dc818b0a 100644 --- a/src/TorchSharp/NN/Activation/ELU.cs +++ b/src/TorchSharp/NN/Activation/ELU.cs @@ -31,7 +31,9 @@ public override string GetName() // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Activation/GELU.cs b/src/TorchSharp/NN/Activation/GELU.cs index 71dda326e..c069ca639 100644 --- a/src/TorchSharp/NN/Activation/GELU.cs +++ b/src/TorchSharp/NN/Activation/GELU.cs @@ -31,7 +31,9 @@ public override string GetName() // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Activation/GLU.cs b/src/TorchSharp/NN/Activation/GLU.cs index 15e717272..95ca2c25d 100644 --- a/src/TorchSharp/NN/Activation/GLU.cs +++ b/src/TorchSharp/NN/Activation/GLU.cs @@ -31,7 +31,9 @@ public override string GetName() // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Activation/Hardshrink.cs b/src/TorchSharp/NN/Activation/Hardshrink.cs index b3f2684d3..c4f880b72 100644 --- a/src/TorchSharp/NN/Activation/Hardshrink.cs +++ b/src/TorchSharp/NN/Activation/Hardshrink.cs @@ -31,7 +31,9 @@ public override string GetName() // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Activation/Hardsigmoid.cs b/src/TorchSharp/NN/Activation/Hardsigmoid.cs index c4f354cf2..66e0aa816 100644 --- a/src/TorchSharp/NN/Activation/Hardsigmoid.cs +++ b/src/TorchSharp/NN/Activation/Hardsigmoid.cs @@ -33,7 +33,9 @@ public override string GetName() // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Activation/Hardswish.cs b/src/TorchSharp/NN/Activation/Hardswish.cs index 6f8b401da..d1065be3b 100644 --- a/src/TorchSharp/NN/Activation/Hardswish.cs +++ b/src/TorchSharp/NN/Activation/Hardswish.cs @@ -33,7 +33,9 @@ public override string GetName() // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Activation/Hardtanh.cs b/src/TorchSharp/NN/Activation/Hardtanh.cs index f7236a72a..9451c2132 100644 --- a/src/TorchSharp/NN/Activation/Hardtanh.cs +++ b/src/TorchSharp/NN/Activation/Hardtanh.cs @@ -31,7 +31,9 @@ public override string GetName() // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Bilinear.cs b/src/TorchSharp/NN/Bilinear.cs index 316f81ea0..17a9dc903 100644 --- a/src/TorchSharp/NN/Bilinear.cs +++ b/src/TorchSharp/NN/Bilinear.cs @@ -13,46 +13,57 @@ namespace Modules { public sealed class Bilinear : Module { - internal Bilinear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - public new static Bilinear Load(string modelPath) + internal Bilinear(long in1_features, long in2_features, long out_features, bool hasBias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Bilinear)) { - var res = Module.Load(modelPath); - return new Bilinear(res.handle.DangerousGetHandle(), IntPtr.Zero); + weight = torch.empty(out_features, in1_features, in2_features, device: device, dtype: dtype).AsParameter(); + var bound = 1 / Math.Sqrt(weight!.shape[1]); + + init.uniform_(_weight, -bound, bound); + + if (hasBias) { + bias = torch.empty(out_features, device: device, dtype: dtype).AsParameter(); + var (fanIn, _) = init.CalculateFanInAndFanOut(weight); + init.uniform_(_bias, -bound, bound); + } + //NOTE: it's important not to call 'RegisterComponents' here. } public override Tensor forward(Tensor input1, Tensor input2) { - var res = THSNN_Bilinear_forward(handle, input1.Handle, input2.Handle); - if (res == IntPtr.Zero) { CheckForErrors(); } - return new Tensor(res); + return torch.nn.functional.bilinear(input1, input2, _weight!, _bias); } - public Parameter? bias { - get { - var res = THSNN_Bilinear_bias(handle); - if (res == IntPtr.Zero) { CheckForErrors(); } - return ((res == IntPtr.Zero) ? null : new Parameter(res)); + protected override void Dispose(bool disposing) + { + if (disposing) { + _weight?.Dispose(); + _bias?.Dispose(); } + } + + public Parameter? bias { + get => _bias; set { - THSNN_Bilinear_set_bias(handle, value?.Handle ?? IntPtr.Zero); - CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + _bias?.Dispose(); + _bias = value?.DetachFromDisposeScope() as Parameter; + ConditionallyRegisterParameter(nameof(bias), _bias); } } + private Parameter? _bias; - public Parameter? weight { - get { - var res = THSNN_Bilinear_weight(handle); - if (res == IntPtr.Zero) { CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } + public Parameter weight { + get => _weight!; set { - THSNN_Bilinear_set_weight(handle, value?.Handle ?? IntPtr.Zero); - CheckForErrors(); - ConditionallyRegisterParameter("weight", value); + if (value is null) throw new ArgumentNullException(nameof(weight)); + if (value.Handle != _weight?.Handle) { + _weight?.Dispose(); + _weight = (value.DetachFromDisposeScope() as Parameter)!; + ConditionallyRegisterParameter(nameof(weight), _weight); + } } } + + private Parameter? _weight; } } @@ -64,19 +75,16 @@ public static partial class nn /// /// Applies a bilinear transformation to the incoming data /// - /// size of each first input sample - /// size of each second input sample - /// size of each output sample + /// size of each first input sample + /// size of each second input sample + /// size of each output sample /// If set to false, the layer will not learn an additive bias /// The desired device of the parameters and buffers in this module /// The desired floating point or complex dtype of the parameters and buffers in this module /// - public static Bilinear Bilinear(long in1Features, long in2Features, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) + public static Bilinear Bilinear(long in1_features, long in2_features, long out_features, bool hasBias = true, Device? device = null, ScalarType? dtype = null) { - var res = THSNN_Bilinear_ctor(in1Features, in2Features, outputSize, hasBias, out var boxedHandle); - if (res == IntPtr.Zero) { CheckForErrors(); } - - return new Bilinear(res, boxedHandle).MoveModule(device, dtype); + return new Bilinear(in1_features, in2_features, out_features, hasBias, device, dtype); } public static partial class functional @@ -94,7 +102,7 @@ public static Tensor bilinear(Tensor input1, Tensor input2, Tensor weight, Tenso { IntPtr bPtr = bias?.Handle ?? IntPtr.Zero; var res = THSNN_functional_bilinear(input1.Handle, input2.Handle, weight.Handle, bPtr); - if (res == IntPtr.Zero) { CheckForErrors(); } + if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } } diff --git a/src/TorchSharp/NN/CosineSimilarity.cs b/src/TorchSharp/NN/CosineSimilarity.cs index 261565e03..31323384d 100644 --- a/src/TorchSharp/NN/CosineSimilarity.cs +++ b/src/TorchSharp/NN/CosineSimilarity.cs @@ -24,6 +24,14 @@ public override Tensor forward(Tensor input1, Tensor input2) if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } + + // Rather than spending cycles only to discover that this module has neither + // parameters nor buffers, just shortcut the move completely. + protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Dropout.cs b/src/TorchSharp/NN/Dropout.cs index b2d31dbae..f1c3a04b0 100644 --- a/src/TorchSharp/NN/Dropout.cs +++ b/src/TorchSharp/NN/Dropout.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a dropout module. /// - public sealed class Dropout : torch.nn.Module + public sealed class Dropout : ParamLessModule { internal Dropout(double p = 0.5, bool inplace = false) : base(nameof(Dropout)) { diff --git a/src/TorchSharp/NN/Flatten.cs b/src/TorchSharp/NN/Flatten.cs index b1568938c..e6bf6ccc9 100644 --- a/src/TorchSharp/NN/Flatten.cs +++ b/src/TorchSharp/NN/Flatten.cs @@ -10,21 +10,24 @@ namespace TorchSharp namespace Modules { /// - /// This class is used to represent a dropout module for 2d/3d convolutational layers. + /// This class is used to represent a flattening of the input tensors. /// - public sealed class Flatten : torch.nn.Module + public sealed class Flatten : ParamLessModule { - internal Flatten(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal Flatten(long startDim = 1, long endDim = -1) : base(nameof(Flatten)) { + _startDim = startDim; + _endDim = endDim; } public override Tensor forward(Tensor tensor) { - var res = THSNN_Flatten_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return tensor.flatten(_startDim, _endDim); } + private long _startDim; + private long _endDim; + // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; @@ -40,14 +43,12 @@ public static partial class nn /// /// Flattens a contiguous range of dims into a tensor. For use with Sequential. /// - /// First dim to flatten (default = 1). - /// Last dim to flatten (default = -1). + /// First dim to flatten (default = 1). + /// Last dim to flatten (default = -1). /// - public static Flatten Flatten(long startDim = 1, long endDim = -1) + public static Flatten Flatten(long start_dim = 1, long end_dim = -1) { - var handle = THSNN_Flatten_ctor(startDim, endDim, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new Flatten(handle, boxedHandle); + return new Flatten(start_dim, end_dim); } } } diff --git a/src/TorchSharp/NN/Identity.cs b/src/TorchSharp/NN/Identity.cs index 7296a52a1..97bd68098 100644 --- a/src/TorchSharp/NN/Identity.cs +++ b/src/TorchSharp/NN/Identity.cs @@ -10,15 +10,13 @@ namespace TorchSharp namespace Modules { - public sealed class Identity : torch.nn.Module + public sealed class Identity : ParamLessModule { - internal Identity(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Identity() : base(nameof(Identity)) { } public override Tensor forward(Tensor tensor) { - var res = THSNN_Identity_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return tensor.alias(); } // Rather than spending cycles only to discover that this module has neither @@ -39,9 +37,7 @@ public static partial class nn /// The same tensor as is input. public static Identity Identity() { - var res = THSNN_Identity_ctor(out var boxedHandle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Identity(res, boxedHandle); + return new Identity(); } } } diff --git a/src/TorchSharp/NN/Linear.cs b/src/TorchSharp/NN/Linear.cs index b5aa0c653..1b87c4a3a 100644 --- a/src/TorchSharp/NN/Linear.cs +++ b/src/TorchSharp/NN/Linear.cs @@ -13,48 +13,57 @@ namespace Modules { public sealed class Linear : torch.nn.Module { - internal Linear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Linear)) { + weight = torch.empty(outputSize, inputSize, device: device, dtype: dtype).AsParameter(); + init.kaiming_uniform_(weight, a: _sqrt5); + + if (hasBias) { + bias = torch.empty(outputSize, device: device, dtype: dtype).AsParameter(); + var (fanIn, _) = init.CalculateFanInAndFanOut(weight); + var bound = fanIn > 0 ? 1 / Math.Sqrt(fanIn) : 0; + init.uniform_(_bias, -bound, bound); + } + //NOTE: it's important not to call 'RegisterComponents' here. } - public new static Linear Load(string modelPath) + public override Tensor forward(Tensor tensor) { - var res = Module.Load(modelPath); - return new Linear(res.handle.DangerousGetHandle(), IntPtr.Zero); + return torch.nn.functional.linear(tensor, _weight!, _bias); } - public override Tensor forward(Tensor tensor) + protected override void Dispose(bool disposing) { - var res = THSNN_Linear_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + if (disposing) { + _weight?.Dispose(); + _bias?.Dispose(); + } } public Parameter? bias { - get { - var res = THSNN_Linear_bias(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return ((res == IntPtr.Zero) ? null : new Parameter(res)); - } + get => _bias; set { - THSNN_Linear_set_bias(handle, value?.Handle ?? IntPtr.Zero); - torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + _bias?.Dispose(); + _bias = value?.DetachFromDisposeScope() as Parameter; + ConditionallyRegisterParameter(nameof(bias), _bias); } } + private Parameter? _bias; - public Parameter? weight { - get { - var res = THSNN_Linear_weight(handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return (res == IntPtr.Zero) ? null : new Parameter(res); - } + public Parameter weight { + get => _weight!; set { - THSNN_Linear_set_weight(handle, value!.Handle); - torch.CheckForErrors(); - ConditionallyRegisterParameter("weight", value); + if (value is null) throw new ArgumentNullException(nameof(weight)); + if (value.Handle != _weight?.Handle) { + _weight?.Dispose(); + _weight = (value.DetachFromDisposeScope() as Parameter)!; + ConditionallyRegisterParameter(nameof(weight), _weight); + } } } + + private Parameter? _weight; + private static readonly double _sqrt5 = Math.Sqrt(5); } } @@ -72,10 +81,7 @@ public static partial class nn /// The desired floating point or complex dtype of the parameters and buffers in this module public static Linear Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) { - var res = THSNN_Linear_ctor(inputSize, outputSize, hasBias, out var boxedHandle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - - return new Linear(res, boxedHandle).MoveModule(device, dtype); + return new Linear(inputSize, outputSize, hasBias, device, dtype); } public static partial class functional diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index d423f05c2..2cf86411b 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -126,10 +126,10 @@ protected virtual void Dispose(bool disposing) if (disposing && !handle.IsInvalid) { foreach (var (_, p) in named_buffers(false)) { - p.Dispose(); + p.DetachFromDisposeScope().Dispose(); } foreach (var (_, b) in named_parameters(false)) { - b.Dispose(); + b.DetachFromDisposeScope().Dispose(); } foreach (var (_, m) in named_modules()) { diff --git a/src/TorchSharp/NN/Padding/ConstantPad1d.cs b/src/TorchSharp/NN/Padding/ConstantPad1d.cs index ad6771e7b..ec905b4b7 100644 --- a/src/TorchSharp/NN/Padding/ConstantPad1d.cs +++ b/src/TorchSharp/NN/Padding/ConstantPad1d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ConstantPad1d module. /// - public sealed class ConstantPad1d : torch.nn.Module + public sealed class ConstantPad1d : PadBase { - internal ConstantPad1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ConstantPad1d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ConstantPad1d(double value, params long[] padding) : base(nameof(ConstantPad1d), PaddingModes.Constant, value, padding) { } } } @@ -48,9 +30,7 @@ public static partial class nn /// public static ConstantPad1d ConstantPad1d(long padding, double value) { - var handle = THSNN_ConstantPad1d_ctor(value, padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad1d(handle, boxedHandle); + return new ConstantPad1d(value, padding, padding); } /// @@ -61,9 +41,7 @@ public static ConstantPad1d ConstantPad1d(long padding, double value) /// public static ConstantPad1d ConstantPad1d((long, long) padding, double value) { - var handle = THSNN_ConstantPad1d_ctor_tuple(value, padding.Item1, padding.Item2, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad1d(handle, boxedHandle); + return new ConstantPad1d(value, padding.Item1, padding.Item2); } } } diff --git a/src/TorchSharp/NN/Padding/ConstantPad2d.cs b/src/TorchSharp/NN/Padding/ConstantPad2d.cs index 7d54b7bc6..9bc47b2be 100644 --- a/src/TorchSharp/NN/Padding/ConstantPad2d.cs +++ b/src/TorchSharp/NN/Padding/ConstantPad2d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ConstantPad2d module. /// - public sealed class ConstantPad2d : torch.nn.Module + public sealed class ConstantPad2d : PadBase { - internal ConstantPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ConstantPad2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ConstantPad2d(double value, params long[] padding) : base(nameof(ConstantPad2d), PaddingModes.Constant, value, padding) { } } } @@ -48,9 +30,7 @@ public static partial class nn /// public static ConstantPad2d ConstantPad2d(long padding, double value) { - var handle = THSNN_ConstantPad2d_ctor(value, padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad2d(handle, boxedHandle); + return new ConstantPad2d(value, padding, padding, padding, padding); } /// @@ -61,9 +41,7 @@ public static ConstantPad2d ConstantPad2d(long padding, double value) /// public static ConstantPad2d ConstantPad2d((long, long, long, long) padding, double value) { - var handle = THSNN_ConstantPad2d_ctor_tuple(value, padding.Item1, padding.Item2, padding.Item3, padding.Item4, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad2d(handle, boxedHandle); + return new ConstantPad2d(value, padding.Item1, padding.Item2, padding.Item3, padding.Item4); } } } diff --git a/src/TorchSharp/NN/Padding/ConstantPad3d.cs b/src/TorchSharp/NN/Padding/ConstantPad3d.cs index 4ab2c55fb..4da9344e0 100644 --- a/src/TorchSharp/NN/Padding/ConstantPad3d.cs +++ b/src/TorchSharp/NN/Padding/ConstantPad3d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ConstantPad3d module. /// - public sealed class ConstantPad3d : torch.nn.Module + public sealed class ConstantPad3d : PadBase { - internal ConstantPad3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ConstantPad3d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ConstantPad3d(double value, params long[] padding) : base(nameof(ConstantPad3d), PaddingModes.Constant, value, padding) { } } } @@ -48,9 +30,7 @@ public static partial class nn /// public static ConstantPad3d ConstantPad3d(long padding, double value) { - var handle = THSNN_ConstantPad3d_ctor(value, padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad3d(handle, boxedHandle); + return new ConstantPad3d(value, padding, padding, padding, padding, padding, padding); } /// @@ -61,9 +41,7 @@ public static ConstantPad3d ConstantPad3d(long padding, double value) /// public static ConstantPad3d ConstantPad3d((long, long, long, long, long, long) padding, double value) { - var handle = THSNN_ConstantPad3d_ctor_tuple(value, padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ConstantPad3d(handle, boxedHandle); + return new ConstantPad3d(value, padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6); } } } diff --git a/src/TorchSharp/NN/Padding/PadBase.cs b/src/TorchSharp/NN/Padding/PadBase.cs new file mode 100644 index 000000000..20c0dde47 --- /dev/null +++ b/src/TorchSharp/NN/Padding/PadBase.cs @@ -0,0 +1,39 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using static TorchSharp.torch; +using static TorchSharp.PInvoke.NativeMethods; + +namespace TorchSharp +{ + using Modules; + + namespace Modules + { + /// + /// This class is used to represent the base of all padding-related modules. + /// + public class PadBase : ParamLessModule + { + protected PadBase(string name, PaddingModes mode, double value, params long[] padding) : base(name) + { + _value = value; + _padding = padding; + _paddingMode = mode; + } + + /// + /// Forward pass. + /// + /// Input tensor + /// + public override Tensor forward(Tensor input) + { + return nn.functional.pad(input, _padding, _paddingMode, _value); + } + + private PaddingModes _paddingMode; + private long[] _padding; + private double _value = 0.0; + } + } +} diff --git a/src/TorchSharp/NN/Padding/ReflectionPad1d.cs b/src/TorchSharp/NN/Padding/ReflectionPad1d.cs index 1a975dd7d..780f77550 100644 --- a/src/TorchSharp/NN/Padding/ReflectionPad1d.cs +++ b/src/TorchSharp/NN/Padding/ReflectionPad1d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReflectionPad1d module. /// - public sealed class ReflectionPad1d : torch.nn.Module + public sealed class ReflectionPad1d : PadBase { - internal ReflectionPad1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReflectionPad1d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReflectionPad1d(params long[] padding) : base(nameof(ReflectionPad1d), PaddingModes.Reflect, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReflectionPad1d ReflectionPad1d(long padding) { - var handle = THSNN_ReflectionPad1d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad1d(handle, boxedHandle); + return new ReflectionPad1d(padding, padding); } /// @@ -59,9 +39,7 @@ public static ReflectionPad1d ReflectionPad1d(long padding) /// public static ReflectionPad1d ReflectionPad1d((long, long) padding) { - var handle = THSNN_ReflectionPad1d_ctor_tuple(padding.Item1, padding.Item2, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad1d(handle, boxedHandle); + return new ReflectionPad1d(padding.Item1, padding.Item2); } } } diff --git a/src/TorchSharp/NN/Padding/ReflectionPad2d.cs b/src/TorchSharp/NN/Padding/ReflectionPad2d.cs index 418e971c3..1aa15f2e8 100644 --- a/src/TorchSharp/NN/Padding/ReflectionPad2d.cs +++ b/src/TorchSharp/NN/Padding/ReflectionPad2d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReflectionPad2d module. /// - public sealed class ReflectionPad2d : torch.nn.Module + public sealed class ReflectionPad2d : PadBase { - internal ReflectionPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReflectionPad2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReflectionPad2d(params long[] padding) : base(nameof(ReflectionPad1d), PaddingModes.Reflect, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReflectionPad2d ReflectionPad2d(long padding) { - var handle = THSNN_ReflectionPad2d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad2d(handle, boxedHandle); + return new ReflectionPad2d(padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ReflectionPad2d ReflectionPad2d(long padding) /// public static ReflectionPad2d ReflectionPad2d((long, long, long, long) padding) { - var handle = THSNN_ReflectionPad2d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad2d(handle, boxedHandle); + return new ReflectionPad2d(padding.Item1, padding.Item2, padding.Item3, padding.Item4); } } } diff --git a/src/TorchSharp/NN/Padding/ReflectionPad3d.cs b/src/TorchSharp/NN/Padding/ReflectionPad3d.cs index 18db464be..cf26874ff 100644 --- a/src/TorchSharp/NN/Padding/ReflectionPad3d.cs +++ b/src/TorchSharp/NN/Padding/ReflectionPad3d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReflectionPad3d module. /// - public sealed class ReflectionPad3d : torch.nn.Module + public sealed class ReflectionPad3d : PadBase { - internal ReflectionPad3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReflectionPad3d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReflectionPad3d(params long[] padding) : base(nameof(ReflectionPad1d), PaddingModes.Reflect, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReflectionPad3d ReflectionPad3d(long padding) { - var handle = THSNN_ReflectionPad3d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad3d(handle, boxedHandle); + return new ReflectionPad3d(padding, padding, padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ReflectionPad3d ReflectionPad3d(long padding) /// public static ReflectionPad3d ReflectionPad3d((long, long, long, long, long, long) padding) { - var handle = THSNN_ReflectionPad3d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReflectionPad3d(handle, boxedHandle); + return new ReflectionPad3d(padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6); } } } diff --git a/src/TorchSharp/NN/Padding/ReplicationPad1d.cs b/src/TorchSharp/NN/Padding/ReplicationPad1d.cs index 55f572ee8..fb3744f5b 100644 --- a/src/TorchSharp/NN/Padding/ReplicationPad1d.cs +++ b/src/TorchSharp/NN/Padding/ReplicationPad1d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReplicationPad1d module. /// - public sealed class ReplicationPad1d : torch.nn.Module + public sealed class ReplicationPad1d : PadBase { - internal ReplicationPad1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReplicationPad1d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReplicationPad1d(params long[] padding) : base(nameof(ReplicationPad1d), PaddingModes.Replicate, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReplicationPad1d ReplicationPad1d(long padding) { - var handle = THSNN_ReplicationPad1d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad1d(handle, boxedHandle); + return new ReplicationPad1d(padding, padding); } /// @@ -59,9 +39,7 @@ public static ReplicationPad1d ReplicationPad1d(long padding) /// public static ReplicationPad1d ReplicationPad1d((long, long) padding) { - var handle = THSNN_ReplicationPad1d_ctor_tuple(padding.Item1, padding.Item2, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad1d(handle, boxedHandle); + return new ReplicationPad1d(padding.Item1, padding.Item2); } } } diff --git a/src/TorchSharp/NN/Padding/ReplicationPad2d.cs b/src/TorchSharp/NN/Padding/ReplicationPad2d.cs index 205ac9e59..0bd779b0e 100644 --- a/src/TorchSharp/NN/Padding/ReplicationPad2d.cs +++ b/src/TorchSharp/NN/Padding/ReplicationPad2d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReplicationPad2d module. /// - public sealed class ReplicationPad2d : torch.nn.Module + public sealed class ReplicationPad2d : PadBase { - internal ReplicationPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReplicationPad2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReplicationPad2d(params long[] padding) : base(nameof(ReplicationPad1d), PaddingModes.Replicate, 0, padding) { } } } @@ -41,15 +23,13 @@ public static partial class torch public static partial class nn { /// - /// Pads the input tensor using replication of the input boundary. + /// Pads the input tensor using the replication of the input boundary. /// /// The size of the padding. /// public static ReplicationPad2d ReplicationPad2d(long padding) { - var handle = THSNN_ReplicationPad2d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad2d(handle, boxedHandle); + return new ReplicationPad2d(padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ReplicationPad2d ReplicationPad2d(long padding) /// public static ReplicationPad2d ReplicationPad2d((long, long, long, long) padding) { - var handle = THSNN_ReplicationPad2d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad2d(handle, boxedHandle); + return new ReplicationPad2d(padding.Item1, padding.Item2, padding.Item3, padding.Item4); } } } diff --git a/src/TorchSharp/NN/Padding/ReplicationPad3d.cs b/src/TorchSharp/NN/Padding/ReplicationPad3d.cs index 6b92f2972..5a243489e 100644 --- a/src/TorchSharp/NN/Padding/ReplicationPad3d.cs +++ b/src/TorchSharp/NN/Padding/ReplicationPad3d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ReplicationPad3d module. /// - public sealed class ReplicationPad3d : torch.nn.Module + public sealed class ReplicationPad3d : PadBase { - internal ReplicationPad3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ReplicationPad3d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ReplicationPad3d(params long[] padding) : base(nameof(ReplicationPad1d), PaddingModes.Replicate, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ReplicationPad3d ReplicationPad3d(long padding) { - var handle = THSNN_ReplicationPad3d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad3d(handle, boxedHandle); + return new ReplicationPad3d(padding, padding, padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ReplicationPad3d ReplicationPad3d(long padding) /// public static ReplicationPad3d ReplicationPad3d((long, long, long, long, long, long) padding) { - var handle = THSNN_ReplicationPad3d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ReplicationPad3d(handle, boxedHandle); + return new ReplicationPad3d(padding.Item1, padding.Item2, padding.Item3, padding.Item4, padding.Item5, padding.Item6); } } } diff --git a/src/TorchSharp/NN/Padding/ZeroPad2d.cs b/src/TorchSharp/NN/Padding/ZeroPad2d.cs index 82a075d86..679e96e4d 100644 --- a/src/TorchSharp/NN/Padding/ZeroPad2d.cs +++ b/src/TorchSharp/NN/Padding/ZeroPad2d.cs @@ -12,27 +12,9 @@ namespace Modules /// /// This class is used to represent a ZeroPad2d module. /// - public sealed class ZeroPad2d : torch.nn.Module + public sealed class ZeroPad2d : PadBase { - internal ZeroPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } - - /// - /// Forward pass. - /// - /// Input tensor - /// - public override Tensor forward(Tensor tensor) - { - var res = THSNN_ZeroPad2d_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; + internal ZeroPad2d(params long[] padding) : base(nameof(ZeroPad2d), PaddingModes.Zeros, 0, padding) { } } } @@ -47,9 +29,7 @@ public static partial class nn /// public static ZeroPad2d ZeroPad2d(long padding) { - var handle = THSNN_ZeroPad2d_ctor(padding, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ZeroPad2d(handle, boxedHandle); + return new ZeroPad2d(padding, padding, padding, padding); } /// @@ -59,9 +39,7 @@ public static ZeroPad2d ZeroPad2d(long padding) /// public static ZeroPad2d ZeroPad2d((long, long, long, long) padding) { - var handle = THSNN_ZeroPad2d_ctor_tuple(padding.Item1, padding.Item2, padding.Item3, padding.Item4, out var boxedHandle); - if (handle == IntPtr.Zero) { torch.CheckForErrors(); } - return new ZeroPad2d(handle, boxedHandle); + return new ZeroPad2d(padding.Item1, padding.Item2, padding.Item3, padding.Item4); } } } diff --git a/src/TorchSharp/NN/ParamLessModule.cs b/src/TorchSharp/NN/ParamLessModule.cs new file mode 100644 index 000000000..c2a274a52 --- /dev/null +++ b/src/TorchSharp/NN/ParamLessModule.cs @@ -0,0 +1,46 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using static TorchSharp.torch; +using static TorchSharp.PInvoke.NativeMethods; + +namespace TorchSharp +{ + using Modules; + + namespace Modules + { + /// + /// Base class for all modules that do not have any parameters or buffers, and + /// for which the `_to()` implementation can therefore be simplified. + /// + public abstract class ParamLessModule : nn.Module + { + protected ParamLessModule(string name) : base(name) { } + + // Rather than spending cycles only to discover that this module has neither + // parameters nor buffers, just shortcut the move completely. + protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + + protected internal override nn.Module _to(ScalarType dtype) => this; + } + + /// + /// Base class for all modules that do not have any parameters or buffers, and + /// for which the `_to()` implementation can therefore be simplified. + /// + public abstract class ParamLessModule : nn.Module + { + protected ParamLessModule(string name) : base(name) { } + + // Rather than spending cycles only to discover that this module has neither + // parameters nor buffers, just shortcut the move completely. + protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + + protected internal override nn.Module _to(ScalarType dtype) => this; + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/NN/Parameter.cs b/src/TorchSharp/NN/Parameter.cs index 81e9051d8..4c1faa01e 100644 --- a/src/TorchSharp/NN/Parameter.cs +++ b/src/TorchSharp/NN/Parameter.cs @@ -26,6 +26,12 @@ public class Parameter : Tensor public Parameter(Tensor data, bool requires_grad = true) : base(data.with_requires_grad(requires_grad).MoveHandle()) { + var scope = data.OwningDisposeScope; + if (scope is not null) { + this.OwningDisposeScope = scope; + scope.Include(this); + scope.Detach(data); + } } /// @@ -35,7 +41,6 @@ public Parameter(Tensor data, bool requires_grad = true) : internal Parameter(System.IntPtr handle) : base(handle) { } - }; } diff --git a/src/TorchSharp/NN/PixelShuffle.cs b/src/TorchSharp/NN/PixelShuffle.cs index fe1d94bd5..e7054be2b 100644 --- a/src/TorchSharp/NN/PixelShuffle.cs +++ b/src/TorchSharp/NN/PixelShuffle.cs @@ -27,6 +27,14 @@ public override Tensor forward(Tensor tensor) if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } + + // Rather than spending cycles only to discover that this module has neither + // parameters nor buffers, just shortcut the move completely. + protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/PixelUnshuffle.cs b/src/TorchSharp/NN/PixelUnshuffle.cs index e6d3f120a..5467ab59b 100644 --- a/src/TorchSharp/NN/PixelUnshuffle.cs +++ b/src/TorchSharp/NN/PixelUnshuffle.cs @@ -27,6 +27,14 @@ public override Tensor forward(Tensor tensor) if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } + + // Rather than spending cycles only to discover that this module has neither + // parameters nor buffers, just shortcut the move completely. + protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs b/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs index 7fbfa371a..169215cde 100644 --- a/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs +++ b/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs @@ -35,7 +35,9 @@ public override Tensor forward(Tensor tensor) // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Pooling/LPPool1d.cs b/src/TorchSharp/NN/Pooling/LPPool1d.cs index 424da18d5..bd0c7d9bd 100644 --- a/src/TorchSharp/NN/Pooling/LPPool1d.cs +++ b/src/TorchSharp/NN/Pooling/LPPool1d.cs @@ -24,12 +24,6 @@ public override Tensor forward(Tensor tensor) if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } - - // Rather than spending cycles only to discover that this module has neither - // parameters nor buffers, just shortcut the move completely. - protected internal override nn.Module _to(Device device, ScalarType dtype) => this; - protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; - protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Pooling/LPPool2d.cs b/src/TorchSharp/NN/Pooling/LPPool2d.cs index 67c06b58b..811d97598 100644 --- a/src/TorchSharp/NN/Pooling/LPPool2d.cs +++ b/src/TorchSharp/NN/Pooling/LPPool2d.cs @@ -28,7 +28,9 @@ public override Tensor forward(Tensor tensor) // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Pooling/MaxPool1D.cs b/src/TorchSharp/NN/Pooling/MaxPool1D.cs index 79a521f59..6484c149e 100644 --- a/src/TorchSharp/NN/Pooling/MaxPool1D.cs +++ b/src/TorchSharp/NN/Pooling/MaxPool1D.cs @@ -36,7 +36,9 @@ public override Tensor forward(Tensor tensor) // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Pooling/MaxPool2D.cs b/src/TorchSharp/NN/Pooling/MaxPool2D.cs index 55808c454..b8657ccb0 100644 --- a/src/TorchSharp/NN/Pooling/MaxPool2D.cs +++ b/src/TorchSharp/NN/Pooling/MaxPool2D.cs @@ -35,7 +35,9 @@ public override Tensor forward(Tensor tensor) // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Pooling/MaxPool3D.cs b/src/TorchSharp/NN/Pooling/MaxPool3D.cs index 1ab30d15d..165a0c847 100644 --- a/src/TorchSharp/NN/Pooling/MaxPool3D.cs +++ b/src/TorchSharp/NN/Pooling/MaxPool3D.cs @@ -36,7 +36,9 @@ public override Tensor forward(Tensor tensor) // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Pooling/MaxUnpool1d.cs b/src/TorchSharp/NN/Pooling/MaxUnpool1d.cs index 2d8d7e908..0f2fe9bd9 100644 --- a/src/TorchSharp/NN/Pooling/MaxUnpool1d.cs +++ b/src/TorchSharp/NN/Pooling/MaxUnpool1d.cs @@ -37,7 +37,9 @@ public override Tensor forward(Tensor tensor, Tensor indices, long[] output_size // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Pooling/MaxUnpool2d.cs b/src/TorchSharp/NN/Pooling/MaxUnpool2d.cs index 84e8c6cb3..ded7b81ee 100644 --- a/src/TorchSharp/NN/Pooling/MaxUnpool2d.cs +++ b/src/TorchSharp/NN/Pooling/MaxUnpool2d.cs @@ -37,7 +37,9 @@ public override Tensor forward(Tensor tensor, Tensor indices, long[] output_size // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Pooling/MaxUnpool3d.cs b/src/TorchSharp/NN/Pooling/MaxUnpool3d.cs index a5473d8d6..32b85ce5a 100644 --- a/src/TorchSharp/NN/Pooling/MaxUnpool3d.cs +++ b/src/TorchSharp/NN/Pooling/MaxUnpool3d.cs @@ -37,7 +37,9 @@ public override Tensor forward(Tensor tensor, Tensor indices, long[] output_size // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs b/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs index f820cae57..65b8e4fdd 100644 --- a/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs +++ b/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs @@ -32,7 +32,9 @@ public override string GetName() // Rather than spending cycles only to discover that this module has neither // parameters nor buffers, just shortcut the move completely. protected internal override nn.Module _to(Device device, ScalarType dtype) => this; + protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; + protected internal override nn.Module _to(ScalarType dtype) => this; } } diff --git a/src/TorchSharp/NN/Unflatten.cs b/src/TorchSharp/NN/Unflatten.cs index 71c7b6a23..fa5954f6b 100644 --- a/src/TorchSharp/NN/Unflatten.cs +++ b/src/TorchSharp/NN/Unflatten.cs @@ -12,17 +12,17 @@ namespace Modules /// /// This class is used to represent an unflattening operation. /// - public sealed class Unflatten : torch.nn.Module + public sealed class Unflatten : ParamLessModule { - internal Unflatten(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) + internal Unflatten(long dim, long[] unflattenedSize) : base(nameof(Unflatten)) { + this._dim = dim; + this._unflattenedSize = unflattenedSize; } public override Tensor forward(Tensor tensor) { - var res = THSNN_Unflatten_forward(handle, tensor.Handle); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + return tensor.unflatten(_dim, _unflattenedSize); } // Rather than spending cycles only to discover that this module has neither @@ -30,6 +30,9 @@ public override Tensor forward(Tensor tensor) protected internal override nn.Module _to(Device device, ScalarType dtype) => this; protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this; protected internal override nn.Module _to(ScalarType dtype) => this; + + long _dim; + long[] _unflattenedSize; } } @@ -45,13 +48,7 @@ public static partial class nn /// public static Unflatten Unflatten(long dim, long[] unflattenedSize) { - unsafe { - fixed (long* pUnflattenedSize = unflattenedSize) { - var handle = THSNN_Unflatten_ctor(dim, (IntPtr)pUnflattenedSize, unflattenedSize.Length, out var boxedHandle); - if (handle == IntPtr.Zero) { CheckForErrors(); } - return new Unflatten(handle, boxedHandle); - } - } + return new Unflatten(dim, unflattenedSize); } } } diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs index efeaec1b3..3fd598823 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs @@ -1115,96 +1115,6 @@ internal static extern IntPtr THSNN_custom_module( [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_LocalResponseNorm_ctor(long size, double alpha, double beta, double k, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad1d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad1d_ctor(double value, long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad1d_ctor_tuple(double value, long padding_left, long padding_right, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad2d_ctor(double value, long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad2d_ctor_tuple(double value, long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad3d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad3d_ctor(double value, long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ConstantPad3d_ctor_tuple(double value, long padding_left, long padding_right, long padding_top, long padding_bottom, long padding_front, long padding_back, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad1d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad1d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad1d_ctor_tuple(long padding_left, long padding_right, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad2d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad2d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad3d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad3d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReflectionPad3d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, long padding_front, long padding_back, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad1d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad1d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad1d_ctor_tuple(long padding_left, long padding_right, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad2d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad2d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad3d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad3d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ReplicationPad3d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, long padding_front, long padding_back, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ZeroPad2d_forward(torch.nn.Module.HType module, IntPtr tensor); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ZeroPad2d_ctor(long padding, out IntPtr pBoxedModule); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_ZeroPad2d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule); - [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_AdaptiveAvgPool1d_forward(IntPtr module, IntPtr tensor); diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index 25943888f..455734fc2 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -323,6 +323,8 @@ public void TestIdentity() var input = torch.randn(new long[] { 1, 1000 }, device: device); var output = lin.call(input); + output[0, 511] = 10; // When we modify the copy, the original should be altered, too. + Assert.Equal(device.type, output.device_type); Assert.Equal(input.data(), output.data()); } @@ -2124,12 +2126,12 @@ public void TestConv1d() var shape = new long[] { 16, 3, 28 }; foreach (var device in TestUtils.AvailableDevices(false)) { Tensor t = torch.rand(shape, device: device); - var conv = Conv1d(3, 64, 3, device: device); + var conv = Conv1d(3, 64, 5, device: device); var output = conv.call(t); Assert.Equal(device.type, output.device_type); Assert.Equal(16, output.shape[0]); Assert.Equal(64, output.shape[1]); - Assert.Equal(26, output.shape[2]); + Assert.Equal(24, output.shape[2]); } } @@ -4700,6 +4702,7 @@ public void TestMultiheadAttention() using (var V = torch.tensor(v_data, src_seq_len, batch_size, vembed_dim)) using (var Attn = torch.tensor(attn_data, batch_size, src_seq_len, src_seq_len)) { + var children = mha.children().ToList(); mha.eval(); Assert.False(mha.training); @@ -4892,13 +4895,13 @@ public void TestFlatten() Assert.Equal(new long[] { 32, 360 }, output.shape); } - using (var flat = Flatten(startDim: 2)) { + using (var flat = Flatten(start_dim: 2)) { var output = flat.call(data); Assert.Equal(device.type, output.device_type); Assert.Equal(new long[] { 32, 3, 120 }, output.shape); } - using (var flat = Flatten(startDim: 0)) { + using (var flat = Flatten(start_dim: 0)) { var output = flat.call(data); Assert.Equal(device.type, output.device_type); Assert.Equal(new long[] { 32 * 360 }, output.shape);