diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 20d05f526..0e4a9a4b5 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -10,6 +10,8 @@ __API Changes__: Add torch.utils.rnn
Add torchvision.io
+Add Tensor.trace() and torch.trace() (unrelated to torch.jit.trace)
+Add ability to load and save TorchScript modules created using Pytorch
## NuGet Version 0.96.8 diff --git a/docfx/articles/torchscript.md b/docfx/articles/torchscript.md new file mode 100644 index 000000000..d86378577 --- /dev/null +++ b/docfx/articles/torchscript.md @@ -0,0 +1,25 @@ +# Loading TorchScript Modules + +Starting with release 0.96.9, you can load TorchScript modules and functions that have been either traced or scripted in Pytorch. It is, however, not yet possible to create a TorchScript module from scratch using TorchSharp. Refer to the [Pytorch JIT](https://pytorch.org/docs/stable/jit.html) docs for information on how to create such a file. + +TorchScript is very powerful, because it allows you to save the logic and the weights of a model together, and it furthermore allows the module to be loaded into another program, __without any dependencies on the Python runtime.__ Thus, you can load a model that has been serialized using TorchScript and have it behave as any TorchScript module -- you can use it for training, or you can use it for inference. + +Once you have a TorchScript file, you can load it into TorchSharp using: + +```C# +var m = torch.jit.load("file-name"); +``` + +It returns a ScriptModule, which behaves just like any other TorchSharp module. Whether the original script came from a module or a function, it is deserialized as a module. You can use it for training of inference by calling either `train()` or `eval()`. ScriptModules always start out on the CPU, so you have to call `cuda()` in order to move it to a GPU. + +Note that if you used __tracing__ to create the TorchScript file in Pytorch, submodules that behave differently in training and eval modes will behave according to the mode they were traced in. + +If you use the script module to train, you may want / need to save it afterwards. + +That is easily done using `save()`: + +```C# +torch.jit.save(m, "file-name"); +``` + +While it is possible to save a modified ScriptModule from TorchSharp, it is not (yet) possible to create one _from scratch_ using either tracing or scripting. Another limitation is that the TorchSharp code assumes that the `forward()` function takes only tensors as its arguments and returns a single tensor, a limitation it shares with other TorchSharp modules. diff --git a/docfx/index.md b/docfx/index.md index 30701ffd2..e7dda7beb 100644 --- a/docfx/index.md +++ b/docfx/index.md @@ -1,10 +1,9 @@ -TorchSharp are .NET bindings to the Torch library published -here: +TorchSharp are .NET bindings to the Torch library published here: https://pytorch.org/get-started/locally/ -This surfaces the C API as a strongly-typed C# API. +This surfaces the C++ library as a strongly-typed .NET API. ## Getting Started @@ -18,7 +17,4 @@ Then, start by reading up on [creating your own modules](articles/modules.md). An intruction on how to [share model](articles/saveload.md) weights between applications, whether in Python or .NET. - -## API documentation - -The [API Documentation](api/TorchSharp.html) \ No newline at end of file +Loading existing TorchScript files is now supported and described in [Loading TorchScript](articles/torchscript.md). diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index 11531aef3..8fb622b56 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -3,9 +3,48 @@ JITModule THSJIT_load(const char* filename) { - auto res = torch::jit::load(filename); - auto copy = new torch::jit::Module(res); - return new std::shared_ptr(copy); + CATCH( + auto res = torch::jit::load(filename); + auto copy = new torch::jit::Module(res); + return new std::shared_ptr(copy); + ); + + return nullptr; +} + +void THSJIT_save(JITModule module, const char* filename) +{ + CATCH( + (*module)->save(filename); + ); +} + +int THSJIT_Module_is_training(JITModule module) +{ + return (*module)->is_training(); +} + +void THSJIT_Module_train(JITModule module, bool on) +{ + (*module)->train(on); +} + +void THSJIT_Module_eval(JITModule module) +{ + (*module)->eval(); +} + +void THSJIT_Module_to_device(JITModule module, int64_t device, int64_t index) +{ + c10::DeviceType dev = c10::kCPU; + if (device == 1) + dev = c10::kCUDA; + (*module)->to(torch::Device(dev, index)); +} + +void THSJIT_Module_to_dtype(JITModule module, int8_t dtype) +{ + (*module)->to((at::ScalarType)dtype); } void THSJIT_Module_modules(const JITModule module, JITModule* (*allocator)(size_t length)) @@ -35,6 +74,22 @@ void THSJIT_Module_named_modules(const JITModule module, } } +void THSJIT_Module_named_children(const JITModule module, + JITModule* (*allocator)(size_t length), + const char** (*allocator2)(size_t length)) +{ + auto modules = (*module)->named_children(); + JITModule* result = allocator(modules.size()); + const char** names = allocator2(modules.size()); + int i = 0; + for (const auto& child : modules) { + auto copy = new torch::jit::Module(child.value); + result[i] = new std::shared_ptr(copy); + names[i] = make_sharable_string(child.name); + i++; + } +} + void THSJIT_Module_parameters(const JITModule module, Tensor* (*allocator)(size_t length)) { auto parameters = (*module)->parameters(); @@ -60,6 +115,21 @@ void THSJIT_Module_named_parameters(const JITModule module, } } +void THSJIT_Module_named_buffers(const JITModule module, + Tensor* (*allocator)(size_t length), + const char** (*allocator2)(size_t length)) +{ + auto parameters = (*module)->named_buffers(); + Tensor* result = allocator(parameters.size()); + const char** names = allocator2(parameters.size()); + int i = 0; + for (const auto& child : parameters) { + result[i] = new torch::Tensor(child.value); + names[i] = make_sharable_string(child.name); + i++; + } +} + JITMethod THSJIT_Module_get_method(const JITModule module, const char* name) { auto method = (*module)->get_method(name); @@ -69,7 +139,7 @@ JITMethod THSJIT_Module_get_method(const JITModule module, const char* name) Tensor THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length) { - return new torch::Tensor((*module)->forward(toTensors((torch::Tensor**)tensorPtrs, length)).toTensor()); + CATCH_TENSOR((*module)->forward(toTensors((torch::Tensor**)tensorPtrs, length)).toTensor()); } void THSJIT_Module_dispose(const JITModule module) @@ -87,6 +157,16 @@ int THSJIT_Method_num_inputs(const JITMethod method) return (int)(*method)->num_inputs(); } +int THSJIT_Module_num_inputs(const JITModule module) +{ + return (int)(*module)->get_method("forward").num_inputs() - 1; // Don't count the 'self' argument. +} + +int THSJIT_Module_num_outputs(const JITModule module) +{ + return (int)(*module)->get_method("forward").function().getSchema().returns().size(); +} + JITFunction THSJIT_Method_function(const JITMethod method) { return new std::shared_ptr(&(*method)->function()); @@ -113,32 +193,77 @@ void THSJIT_Function_dispose(const JITFunction function) delete function; } -//void* THSJIT_typeCast(const JITType type) -//{ -// switch ((*type)->kind()) -// { -// case c10::TypeKind::TensorType: -// return new std::shared_ptr((*type)->cast()); -// case c10::TypeKind::DimensionedTensorType: -// return new std::shared_ptr((*type)->cast()); -// default: -// return NULL; -// } -//} -// -//int8_t THSJIT_typeKind(const JITType type) -//{ -// switch ((*type)->kind()) -// { -// case c10::TypeKind::TensorType: -// return (int8_t)TypeKind::TensorType; -// case c10::TypeKind::DimensionedTensorType: -// return (int8_t)TypeKind::DimensionedTensorType; -// default: -// return -1; -// } -//} -// +void THSJIT_Type_dispose(const JITType type) +{ + delete type; +} + +void THSJIT_TensorType_dispose(const JITTensorType type) +{ + delete type; +} + +void* THSJIT_Type_cast(const JITType type) +{ + switch ((*type)->kind()) + { + case c10::TypeKind::TensorType: + return new std::shared_ptr((*type)->cast()); + //case c10::TypeKind::DimensionedTensorType: + // return new std::shared_ptr((*type)->cast()); + default: + return NULL; + } +} + +int8_t THSJIT_TensorType_dtype(const JITTensorType type) +{ + auto scT = (*type)->scalarType(); + if (scT.has_value()) { + return (int8_t)scT.value(); + } + else { + return -1; + } +} + +void THSJIT_TensorType_sizes(const JITTensorType type, int64_t* (*allocator)(int64_t length)) +{ + //CATCH( + auto& t = *type; + auto dim = t->dim(); + auto res = (*type)->sizes().concrete_sizes(); + if (res.has_value()) { + const size_t sz = res.value().size(); + auto& vec = res.value(); + int64_t* result = allocator(sz); + for (size_t i = 0; i < sz; i++) + result[i] = vec[i]; + } + //); +} + +int8_t THSJIT_Type_kind(const JITType type) +{ + switch ((*type)->kind()) + { + case c10::TypeKind::TensorType: + return (int8_t)TypeKind::TensorType; + //case c10::TypeKind::DimensionedTensorType: + // return (int8_t)TypeKind::DimensionedTensorType; + default: + return -1; + } +} + +JITType THSJIT_Module_getInputType(JITModule module, int8_t index) +{ + auto typ = (*module)->type(); + c10::TypeKind kind = typ->kind(); + auto& schema = typ->getMethod("forward").getSchema(); + return new std::shared_ptr(schema.arguments()[1 + index].type()->cast()); +} + //int8_t THSJIT_getScalarFromDimensionedTensorType(const JITDimensionedTensorType type) //{ // return (int8_t)(*type)->scalarType(); @@ -159,10 +284,10 @@ void THSJIT_Function_dispose(const JITFunction function) // // return make_sharable_string(device_type); //} -// -// -//void THSJIT_typeDispose(const JITType type) -//{ -// delete type; -//} \ No newline at end of file + + +void THSJIT_typeDispose(const JITType type) +{ + delete type; +} \ No newline at end of file diff --git a/src/Native/LibTorchSharp/THSJIT.h b/src/Native/LibTorchSharp/THSJIT.h index 0f1926f41..a1331cc73 100644 --- a/src/Native/LibTorchSharp/THSJIT.h +++ b/src/Native/LibTorchSharp/THSJIT.h @@ -7,38 +7,66 @@ #include "Utils.h" -//// Copied from libtorch to share the type as an int8_t. -//enum TypeKind : int8_t { -//#define DEFINE_TYPE(T) T, -// C10_FORALL_TYPES(DEFINE_TYPE) -//#undef DEFINE_TYPE -//}; -// -//// API. +// Copied from libtorch to share the type as an int8_t. +enum TypeKind : int8_t { +#define DEFINE_TYPE(T) T, + C10_FORALL_TYPES(DEFINE_TYPE) +#undef DEFINE_TYPE +}; + +// API. EXPORT_API(JITModule) THSJIT_load(const char* filename); +EXPORT_API(void) THSJIT_save(JITModule module, const char* filename); -EXPORT_API(void) THSJIT_Module_modules(const JITModule module, JITModule* (*allocator)(size_t length)); +EXPORT_API(void) THSJIT_Module_dispose(const JITModule module); + +EXPORT_API(int) THSJIT_Module_num_inputs(const JITModule method); +EXPORT_API(int) THSJIT_Module_num_outputs(const JITModule method); + +EXPORT_API(Tensor) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length); + +EXPORT_API(int) THSJIT_Module_is_training(JITModule module); +EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on); +EXPORT_API(void) THSJIT_Module_eval(JITModule module); +EXPORT_API(void) THSJIT_Module_to_device(JITModule module, int64_t device, int64_t index); +EXPORT_API(void) THSJIT_Module_to_dtype(JITModule module, int8_t dtype); + +EXPORT_API(JITType) THSJIT_Module_getInputType(JITModule module, int8_t dtype); + +EXPORT_API(int8_t) THSJIT_Type_kind(JITType handle); +EXPORT_API(void*) THSJIT_Type_cast(const JITType type); + +EXPORT_API(int8_t) THSJIT_TensorType_dtype(const JITTensorType type); +EXPORT_API(void) THSJIT_TensorType_sizes(const JITTensorType type, int64_t* (*allocator)(int64_t length)); + +EXPORT_API(void) THSJIT_Type_dispose(const JITType type); +EXPORT_API(void) THSJIT_TensorType_dispose(const JITTensorType type); + +EXPORT_API(void) THSJIT_Module_modules(const JITModule module, JITModule* (*allocator)(size_t length)); EXPORT_API(void) THSJIT_Module_named_modules(const JITModule module, JITModule* (*allocator)(size_t length), const char** (*allocator2)(size_t length)); +EXPORT_API(void) THSJIT_Module_named_children(const JITModule module, + JITModule* (*allocator)(size_t length), + const char** (*allocator2)(size_t length)); + EXPORT_API(JITMethod) THSJIT_Module_get_method(const JITModule module, const char* name); EXPORT_API(void) THSJIT_Module_parameters(const JITModule module, Tensor* (*allocator)(size_t length)); - EXPORT_API(void) THSJIT_Module_named_parameters(const JITModule module, Tensor* (*allocator)(size_t length), const char** (*allocator2)(size_t length)); -EXPORT_API(Tensor) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length); - -EXPORT_API(void) THSJIT_Module_dispose(const JITModule module); - -EXPORT_API(const char*) THSJIT_Method_name(const JITMethod method); +EXPORT_API(void) THSJIT_Module_named_buffers(const JITModule module, + Tensor* (*allocator)(size_t length), + const char** (*allocator2)(size_t length)); EXPORT_API(int) THSJIT_Method_num_inputs(const JITMethod method); EXPORT_API(void) THSJIT_Method_dispose(const JITMethod method); + +EXPORT_API(const char*) THSJIT_Method_name(const JITMethod method); diff --git a/src/Native/LibTorchSharp/THSLinearAlgebra.cpp b/src/Native/LibTorchSharp/THSLinearAlgebra.cpp index d9539033b..861cbedcb 100644 --- a/src/Native/LibTorchSharp/THSLinearAlgebra.cpp +++ b/src/Native/LibTorchSharp/THSLinearAlgebra.cpp @@ -293,6 +293,11 @@ Tensor THSTensor_diag(const Tensor tensor, const int64_t diagonal) CATCH_TENSOR(tensor->diag(diagonal)); } +Tensor THSTensor_trace(const Tensor tensor) +{ + CATCH_TENSOR(tensor->trace()); +} + Tensor THSTensor_diagflat(const Tensor tensor, const int64_t offset) { CATCH_TENSOR(tensor->diagflat(offset)); diff --git a/src/Native/LibTorchSharp/THSModule.cpp b/src/Native/LibTorchSharp/THSModule.cpp index a0dc52ec6..7f81ec25c 100644 --- a/src/Native/LibTorchSharp/THSModule.cpp +++ b/src/Native/LibTorchSharp/THSModule.cpp @@ -10,14 +10,9 @@ int THSNN_Module_is_training(NNModule module) return (*module)->is_training(); } -void THSNN_Module_train(NNModule module) +void THSNN_Module_train(NNModule module, bool on) { - (*module)->train(); -} - -void THSNN_Module_eval(NNModule module) -{ - (*module)->eval(); + (*module)->train(on); } const char* THSNN_Module_name(const NNModule module) diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h index 1932c1a9a..269f5cfd5 100644 --- a/src/Native/LibTorchSharp/THSNN.h +++ b/src/Native/LibTorchSharp/THSNN.h @@ -17,8 +17,7 @@ EXPORT_API(void) THSNN_Module_get_named_children(const NNModule module, N EXPORT_API(void) THSNN_Module_get_named_modules(const NNModule module, NNModule* (*allocator1)(size_t length), const char** (*allocator2)(size_t length)); EXPORT_API(void) THSNN_Module_get_parameters(const NNModule module, Tensor* (*allocator1)(size_t length), bool recurse); EXPORT_API(int) THSNN_Module_is_training(NNModule module); -EXPORT_API(void) THSNN_Module_train(NNModule module); -EXPORT_API(void) THSNN_Module_eval(NNModule module); +EXPORT_API(void) THSNN_Module_train(NNModule module, bool on); EXPORT_API(long) THSNN_Module_children_size(const NNModule module); EXPORT_API(NNModule) THSNN_Module_child(const NNModule module, const int index); EXPORT_API(const char*) THSNN_Module_name(const NNModule module); diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h index 7b034fac2..1e00105db 100644 --- a/src/Native/LibTorchSharp/THSTensor.h +++ b/src/Native/LibTorchSharp/THSTensor.h @@ -358,6 +358,8 @@ EXPORT_API(int) THSTensor_device_index(const Tensor tensor); EXPORT_API(Tensor) THSTensor_diag(const Tensor tensor, const int64_t diagonal); +EXPORT_API(Tensor) THSTensor_trace(const Tensor tensor); + EXPORT_API(Tensor) THSTensor_diagflat(const Tensor tensor, const int64_t offset); EXPORT_API(Tensor) THSTensor_diagonal(const Tensor tensor, const int64_t offset, const int64_t dim1, const int64_t dim2); diff --git a/src/Native/LibTorchSharp/Utils.h b/src/Native/LibTorchSharp/Utils.h index b399f4640..bbad85e41 100644 --- a/src/Native/LibTorchSharp/Utils.h +++ b/src/Native/LibTorchSharp/Utils.h @@ -20,6 +20,8 @@ typedef std::shared_ptr * JITModule; typedef std::shared_ptr* JITMethod; typedef std::shared_ptr * JITFunction; typedef std::shared_ptr * JITType; +typedef std::shared_ptr* JITTensorType; + //typedef std::shared_ptr* JITDimensionedTensorType; #define THS_API TH_API diff --git a/src/TorchSharp/JIT/Module.cs b/src/TorchSharp/JIT/Module.cs deleted file mode 100644 index d9e63b46f..000000000 --- a/src/TorchSharp/JIT/Module.cs +++ /dev/null @@ -1,154 +0,0 @@ -//// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. -//using System; -//using System.Linq; -//using System.Runtime.InteropServices; -//using static TorchSharp.torch; - -//namespace TorchSharp.JIT -//{ -// public class Module : IDisposable -// { -// /// -// /// Class wrapping PyTorch's module object reference. -// /// -// internal sealed class HType : SafeHandle -// { -// public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle) -// { -// SetHandle(preexistingHandle); -// } - -// public override bool IsInvalid => handle == IntPtr.Zero; - -// // This is just for marshalling -// internal HType() : base(IntPtr.Zero, true) -// { -// } - -// [DllImport("LibTorchSharp")] -// private static extern void THSJIT_moduleDispose(HType handle); - -// protected override bool ReleaseHandle() -// { -// THSJIT_moduleDispose(this); -// return true; -// } - -// protected override void Dispose(bool disposing) -// { -// if (disposing) -// { -// ReleaseHandle(); -// } -// } -// } - -// internal HType handle; - -// internal Module(IntPtr handle) -// { -// this.handle = new HType(handle, true); -// } - -// ~Module() -// { -// Dispose(false); -// } - -// /// -// /// Releases the storage. -// /// -// public void Dispose() -// { -// Dispose(true); -// GC.SuppressFinalize(this); -// } - -// /// -// /// Implements the .NET Dispose pattern. -// /// -// protected void Dispose(bool disposing) -// { -// if (disposing) -// { -// handle.Dispose(); -// handle.SetHandleAsInvalid(); -// } -// } - -// [DllImport("LibTorchSharp")] -// private static extern IntPtr THSJIT_loadModule(string filename); - -// static public Module Load(string filename) -// { -// return new Module(THSJIT_loadModule(filename)); -// } - -// [DllImport("LibTorchSharp")] -// private static extern long THSJIT_getNumModules(HType module); - -// [DllImport("LibTorchSharp")] -// private static extern int THSJIT_getNumberOfInputs(HType module); - -// public int GetNumberOfInputs() -// { -// return THSJIT_getNumberOfInputs(handle); -// } - -// [DllImport("LibTorchSharp")] -// private static extern int THSJIT_getNumberOfOutputs(HType module); - -// public int GetNumberOfOutputs() -// { -// return THSJIT_getNumberOfOutputs(handle); -// } - -// [DllImport("LibTorchSharp")] -// private static extern IntPtr THSJIT_getInputType(HType module, int index); - -// public Type GetInputType(int index) -// { -// var type = new Type(THSJIT_getInputType(handle, index)); - -// return GetType(type); -// } - -// [DllImport("LibTorchSharp")] -// private static extern IntPtr THSJIT_getOutputType(HType module, int index); - -// public Type GetOutputType(int index) -// { -// var type = new Type(THSJIT_getOutputType(handle, index)); - -// return GetType(type); -// } - -// private Type GetType(Type type) -// { -// switch (type.Kind) -// { -// case Type.TypeKind.TensorType: -// var dynamic = type.AsDynamicType(); -// type.Dispose(); -// return dynamic; -// case Type.TypeKind.DimensionedTensorType: -// var tensor = type.AsTensorType(); -// type.Dispose(); -// return tensor; -// default: -// return type; -// } -// } - -// [DllImport("LibTorchSharp")] -// private static extern IntPtr THSJIT_forward(torch.nn.Module.HType module, IntPtr tensors, int length); - -// public Tensor forward(params Tensor[] tensors) -// { -// var parray = new PinnedArray(); -// IntPtr tensorRefs = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); - -// return new Tensor(THSJIT_forward(handle, tensorRefs, parray.Array.Length)); -// } -// } -//} diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs new file mode 100644 index 000000000..e91aa153f --- /dev/null +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -0,0 +1,428 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Linq; +using System.Collections.Generic; +using System.Diagnostics; +using System.Reflection; +using System.Runtime.InteropServices; +using static TorchSharp.torch; + +namespace TorchSharp +{ + public static partial class torch + { + public static partial class jit + { + public class ScriptModule : torch.nn.Module + { + internal ScriptModule(IntPtr handle) : base(new HType(handle, true, THSJIT_Module_dispose), null) + { + } + + ~ScriptModule() + { + Dispose(false); + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Module_dispose(HType handle); + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Module_named_parameters(HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); + + protected override (string name, TorchSharp.Modules.Parameter parameter)[] _named_parameters() + { + using var pa = new PinnedArray(); + using var sa = new PinnedArray(); + THSJIT_Module_named_parameters(handle, pa.CreateArray, sa.CreateArray); + CheckForErrors(); + var ptrArray = pa.Array; + var strArray = sa.Array; + + return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new TorchSharp.Modules.Parameter(x))).ToArray(); + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Module_named_buffers(HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); + + protected override (string name, Tensor buffer)[] _named_buffers() + { + using var pa = new PinnedArray(); + using var sa = new PinnedArray(); + THSJIT_Module_named_buffers(handle, pa.CreateArray, sa.CreateArray); + CheckForErrors(); + var ptrArray = pa.Array; + var strArray = sa.Array; + + return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new Tensor(x))).ToArray(); + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Module_named_modules(HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); + + /// + /// Returns an enumerable of all modules in the network, yielding both the name of the module as well as the module itself. + /// + /// (string, Module) – Tuple of name and module + public override IEnumerable<(string name, nn.Module module)> named_modules() + { + using var pa = new PinnedArray(); + using var sa = new PinnedArray(); + THSJIT_Module_named_modules(handle, pa.CreateArray, sa.CreateArray); + CheckForErrors(); + var ptrArray = pa.Array; + var strArray = sa.Array; + + return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new ScriptModule(x) as nn.Module)).Where(m => !String.IsNullOrEmpty(m.Item1)); + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Module_named_children(HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); + + /// + /// Returns an enumerable of immediate children modules, yielding both the name of the module as well as the module itself. + /// + /// (string, Module) – Tuple containing a name and child module + public override IEnumerable<(string name, nn.Module module)> named_children() + { + using var pa = new PinnedArray(); + using var sa = new PinnedArray(); + THSJIT_Module_named_children(handle, pa.CreateArray, sa.CreateArray); + CheckForErrors(); + var ptrArray = pa.Array; + var strArray = sa.Array; + + return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new ScriptModule(x) as nn.Module)); + } + + [DllImport("LibTorchSharp")] + private static extern long THSJIT_getNumModules(HType module); + + [DllImport("LibTorchSharp")] + private static extern int THSJIT_Module_num_inputs(HType module); + + public int GetNumberOfInputs() + { + return THSJIT_Module_num_inputs(handle); + } + + [DllImport("LibTorchSharp")] + private static extern int THSJIT_Module_num_outputs(HType module); + + public int GetNumberOfOutputs() + { + return THSJIT_Module_num_outputs(handle); + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Module_train(HType module, bool on); + + /// + /// Sets the module in evaluation mode. + /// + /// + /// Any script module that was created using torch.jit.trace() will be unaffected. The behavior of such + /// modules will be captured when traced. + /// + public override void train(bool on = true) + { + THSJIT_Module_train(handle, on); + CheckForErrors(); + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Module_eval(HType module); + + /// + /// Sets the module in evaluation mode. + /// + /// + /// Any script module that was created using torch.jit.trace() will be unaffected. The behavior of such + /// modules will be captured when traced. + /// + public override void eval() + { + THSJIT_Module_eval(handle); + CheckForErrors(); + } + + [DllImport("LibTorchSharp")] + private static extern bool THSJIT_Module_is_training(HType module); + + /// + /// Check whether the module is set to training or evaluation mode. + /// + public override bool training { + get { + var res = THSJIT_Module_is_training(handle); + CheckForErrors(); + return res; + } + } + + [DllImport("LibTorchSharp")] + static extern void THSJIT_Module_to_device(HType module, long deviceType, long deviceIndex); + + [DllImport("LibTorchSharp")] + static extern void THSJIT_Module_to_dtype(HType module, sbyte dtype); + + /// + /// Moves the parameters and buffers. + /// + /// The device type, e.g. 'CPU' or 'CUDA'. + /// The optional device index. + /// + public override nn.Module to(DeviceType deviceType, int deviceIndex = -1) + { + if (deviceType != DeviceType.CUDA) deviceIndex = -1; + + if (deviceType == DeviceType.CUDA && !torch.cuda.is_available()) throw new InvalidOperationException("CUDA is not available."); + + if (deviceType != _deviceType || deviceIndex != _deviceIndex) { + + InitializeDeviceType(deviceType); + THSJIT_Module_to_device(handle, (int)deviceType, deviceIndex); + CheckForErrors(); + + foreach (var (_, sm) in named_children()) sm.to(deviceType, deviceIndex); + + foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) { + + var fieldName = field.Name; + var value = field.GetValue(this); + + switch (value) { + // This test must come before the Tensor test + case Modules.Parameter param when deviceType == param.device_type && deviceIndex == param.device_index: + continue; + + case Modules.Parameter param: { + var t = param.to(deviceType, deviceIndex); + t.retain_grad(); + var p = new Modules.Parameter(t, param.requires_grad); + field.SetValue(this, p); + ConditionallyRegisterParameter(fieldName, p); + break; + } + + case Tensor tensor when (deviceType != tensor.device_type || deviceIndex != tensor.device_index): { + var t = tensor.to(deviceType, deviceIndex); + field.SetValue(this, t); + ConditionallyRegisterBuffer(fieldName, t); + break; + } + } + } + + _deviceType = deviceType; + _deviceIndex = deviceIndex; + } + + Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1); + + return this; + } + + private DeviceType _deviceType = DeviceType.CPU; + private int _deviceIndex = -1; + + /// + /// Convert the parameters and buffers. + /// + /// + public override nn.Module to(ScalarType dtype) + { + THSJIT_Module_to_dtype(handle, (sbyte)dtype); + CheckForErrors(); + + foreach (var (_, sm) in named_children()) sm.to(dtype); + foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) { + + var fieldName = field.Name; + var value = field.GetValue(this); + + switch (value) { + // This test must come before the Tensor test + case Modules.Parameter param when dtype == param.dtype: + continue; + + case Modules.Parameter param: { + var t = param.to(dtype); + t.retain_grad(); + var p = new Modules.Parameter(t, param.requires_grad); + field.SetValue(this, p); + ConditionallyRegisterParameter(fieldName, p); + break; + } + + case Tensor tensor when dtype == tensor.dtype: + continue; + + case Tensor tensor: { + var t = tensor.to(dtype); + field.SetValue(this, t); + ConditionallyRegisterBuffer(fieldName, t); + break; + } + } + } + + return this; + } + +#if false // These functions "work," but the native code doesn't seem to find any interesting information. + + [DllImport("LibTorchSharp")] + private static extern IntPtr THSJIT_Module_getInputType(HType module, int index); + + public Type GetInputType(int index) + { + var type = new Type(THSJIT_Module_getInputType(handle, index), Type.TypeKind.AnyType); + + return GetType(type); + } + + [DllImport("LibTorchSharp")] + private static extern IntPtr THSJIT_getOutputType(HType module, int index); + + public Type GetOutputType(int index) + { + var type = new Type(THSJIT_getOutputType(handle, index), Type.TypeKind.AnyType); + + return GetType(type); + } + + private Type GetType(Type type) + { + switch (type.Kind) { + case Type.TypeKind.TensorType: + var dynamic = type.AsTensorType(); + type.Dispose(); + return dynamic; + //case Type.TypeKind.DimensionedTensorType: + // var tensor = type.AsTensorType(); + // type.Dispose(); + // return tensor; + default: + return type; + } + } +#endif + + [DllImport("LibTorchSharp")] + private static extern IntPtr THSJIT_Module_forward(HType module, IntPtr tensors, int length); + + /// + /// Invoke the 'forward' function of the script with one tensor as its argument + /// + /// The input tensor + /// + public unsafe override Tensor forward(Tensor tensor) + { + using (var parray = new PinnedArray()) { + IntPtr tensorRefs = parray.CreateArray(new[] { tensor.Handle }); + var res = THSJIT_Module_forward(handle, tensorRefs, 1); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + } + + /// + /// Invoke the 'forward' function of the script with two tensors as its argument + /// + /// The first input tensor + /// The second input tensor + /// + public unsafe override Tensor forward(Tensor x, Tensor y) + { + using (var parray = new PinnedArray()) { + IntPtr tensorRefs = parray.CreateArray(new[] { x.Handle, y.Handle }); + var res = THSJIT_Module_forward(handle, tensorRefs, 2); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + } + + /// + /// Invoke the 'forward' function of the script with three tensors as its argument + /// + /// The first input tensor + /// The second input tensor + /// The third input tensor + /// + public unsafe override Tensor forward(Tensor x, Tensor y, Tensor z) + { + using (var parray = new PinnedArray()) { + IntPtr tensorRefs = parray.CreateArray(new[] { x.Handle, y.Handle, z.Handle }); + var res = THSJIT_Module_forward(handle, tensorRefs, 3); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + } + + /// + /// Invoke the 'forward' function of the script with four or more tensors as its argument + /// + /// The first input tensor + /// The second input tensor + /// The third input tensor + /// The remaining tensors. + /// + public unsafe Tensor forward(Tensor x, Tensor y, Tensor z, params Tensor[] tensors) + { + using (var parray = new PinnedArray()) { + IntPtr tensorRefs = parray.CreateArray(new[] { x, y, z }.Concat(tensors).Select(t => t.Handle).ToArray()); + var res = THSJIT_Module_forward(handle, tensorRefs, parray.Array.Length); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + } + } + + [DllImport("LibTorchSharp")] + private static extern IntPtr THSJIT_load(string filename); + + /// + /// Load a ScriptModule or ScriptFunction previously saved with torch.jit.save + /// + /// + /// A ScriptModule instance, whether the script originated as a module or function. + /// + /// All previously saved modules, no matter their device, are first loaded onto CPU, and then are moved to the devices they were saved from.If this fails (e.g.because the run time system doesn’t have certain devices), an exception is raised. + /// + /// Raised if the file is not found. + public static ScriptModule load(string filename) + { + if (!System.IO.File.Exists(filename)) + throw new System.IO.FileNotFoundException(filename); + + var result = THSJIT_load(filename); + if (result == IntPtr.Zero) + CheckForErrors(); + return new ScriptModule(result); + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_save(nn.Module.HType handle, string filename); + + /// + /// Save an offline version of a previously loaded script module. + /// + /// The saved module serializes all of the methods, submodules, parameters, and attributes of this module. + /// It can be loaded into the C++ API using torch::jit::load(filename) or into the .NET API with torch.jit.load(). + /// + /// + /// + public static void save(ScriptModule module, string filename) + { + THSJIT_save(module.handle, filename); + CheckForErrors(); + } + + } + } +} diff --git a/src/TorchSharp/JIT/Type/DynamicType .cs b/src/TorchSharp/JIT/Type/DynamicType .cs index 440d3532c..e65a649f3 100644 --- a/src/TorchSharp/JIT/Type/DynamicType .cs +++ b/src/TorchSharp/JIT/Type/DynamicType .cs @@ -1,13 +1,20 @@ -//// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. -//using System; +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; -//namespace TorchSharp.JIT -//{ -// public sealed class DynamicType : Type -// { -// internal DynamicType(IntPtr handle) : base(handle) -// { -// this.handle = new HType(handle, true); -// } -// } -//} +namespace TorchSharp +{ + public static partial class torch + { + + public static partial class jit + { + public sealed class DynamicType : Type + { + internal DynamicType(IntPtr handle) : base(handle, Type.TypeKind.AnyType) + { + this.handle = new HType(handle, true, Type.TypeKind.AnyType); + } + } + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/JIT/Type/TensorType.cs b/src/TorchSharp/JIT/Type/TensorType.cs index 37d476c3e..296a680a7 100644 --- a/src/TorchSharp/JIT/Type/TensorType.cs +++ b/src/TorchSharp/JIT/Type/TensorType.cs @@ -1,45 +1,71 @@ -//// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. -//using System; -//using System.Runtime.InteropServices; - -//namespace TorchSharp.JIT -//{ -// public sealed class TensorType : Type -// { -// internal TensorType(IntPtr handle) : base(handle) -// { -// this.handle = new HType(handle, true); -// } - -// internal TensorType(Type type) : base() -// { -// handle = type.handle; -// type.handle = new HType(IntPtr.Zero, true); -// type.Dispose(); -// } - -// [DllImport("LibTorchSharp")] -// private static extern short THSJIT_getScalarFromDimensionedTensorType(HType handle); - -// public Tensor.ScalarType GetScalarType() -// { -// return (Tensor.ScalarType)THSJIT_getScalarFromDimensionedTensorType(handle); -// } - -// [DllImport("LibTorchSharp")] -// private static extern int THSJIT_getDimensionedTensorTypeDimensions(HType handle); - -// public int GetDimensions() -// { -// return THSJIT_getDimensionedTensorTypeDimensions(handle); -// } - -// [DllImport("LibTorchSharp")] -// private static extern string THSJIT_getDimensionedTensorDevice(HType handle); - -// public string GetDevice() -// { -// return THSJIT_getDimensionedTensorDevice(handle); -// } -// } -//} +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Runtime.InteropServices; + +namespace TorchSharp +{ + public static partial class torch + { + public static partial class jit + { + public sealed class TensorType : Type + { + internal TensorType(IntPtr handle) : base(handle, TypeKind.TensorType) + { + this.handle = new HType(handle, true, TypeKind.TensorType); + } + + internal TensorType(Type type) : base() + { + handle = type.handle; + type.handle = new HType(IntPtr.Zero, true, TypeKind.TensorType); + type.Dispose(); + } + + [DllImport("LibTorchSharp")] + private static extern sbyte THSJIT_TensorType_dtype(HType handle); + + public torch.ScalarType GetScalarType() + { + return (torch.ScalarType)THSJIT_TensorType_dtype(handle); + } + + + [DllImport("LibTorchSharp")] + static extern long THSJIT_TensorType_sizes(HType handle, AllocatePinnedArray allocator); + + /// + /// Retrieves the sizes of all dimensions of the tensor. + /// + public long[] size() + { + long[] ptrArray; + + using (var pa = new PinnedArray()) { + THSJIT_TensorType_sizes(handle, pa.CreateArray); + torch.CheckForErrors(); + ptrArray = pa.Array; + } + + return ptrArray; + } + + [DllImport("LibTorchSharp")] + private static extern int THSJIT_getDimensionedTensorTypeDimensions(HType handle); + + public int GetDimensions() + { + return THSJIT_getDimensionedTensorTypeDimensions(handle); + } + + [DllImport("LibTorchSharp")] + private static extern string THSJIT_getDimensionedTensorDevice(HType handle); + + public string GetDevice() + { + return THSJIT_getDimensionedTensorDevice(handle); + } + } + } + } +} diff --git a/src/TorchSharp/JIT/Type/Type.cs b/src/TorchSharp/JIT/Type/Type.cs index e06d40ccc..33725a156 100644 --- a/src/TorchSharp/JIT/Type/Type.cs +++ b/src/TorchSharp/JIT/Type/Type.cs @@ -1,108 +1,119 @@ -//// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. -//using System; -//using System.Runtime.InteropServices; - -//namespace TorchSharp.JIT -//{ -// public class Type : IDisposable -// { -// /// -// /// Class wrapping PyTorch's type object reference. -// /// -// internal sealed class HType : SafeHandle -// { -// public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle) -// { -// SetHandle(preexistingHandle); -// } - -// public override bool IsInvalid => handle == IntPtr.Zero; - -// // This is just for marshalling -// internal HType() : base(IntPtr.Zero, true) -// { -// } - -// [DllImport("LibTorchSharp")] -// private static extern void THSJIT_typeDispose(HType handle); - -// protected override bool ReleaseHandle() -// { -// THSJIT_typeDispose(this); -// return true; -// } - -// protected override void Dispose(bool disposing) -// { -// if (disposing) -// { -// ReleaseHandle(); -// } -// } -// } - -// internal HType handle; - -// internal Type(IntPtr handle) -// { -// this.handle = new HType(handle, true); -// } - -// protected Type() -// { -// } - -// ~Type() -// { -// Dispose(false); -// } - -// /// -// /// Releases the storage. -// /// -// public void Dispose() -// { -// Dispose(true); -// GC.SuppressFinalize(this); -// } - -// /// -// /// Implements the .NET Dispose pattern. -// /// -// protected void Dispose(bool disposing) -// { -// if (disposing) -// { -// handle.Dispose(); -// handle.SetHandleAsInvalid(); -// } -// } - -// [DllImport("LibTorchSharp")] -// private static extern sbyte THSJIT_typeKind(HType handle); - -// internal TypeKind Kind -// { -// get { return (TypeKind)THSJIT_typeKind(handle); } -// } - -// [DllImport("LibTorchSharp")] -// private static extern IntPtr THSJIT_typeCast(HType module); - -// internal TensorType AsTensorType() -// { -// return new TensorType(THSJIT_typeCast(handle)); -// } - -// internal DynamicType AsDynamicType() -// { -// return new DynamicType(THSJIT_typeCast(handle)); -// } - -// internal enum TypeKind : sbyte -// { -// TensorType = 0, -// DimensionedTensorType = 1 -// } -// } -//} +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections; +using System.Runtime.InteropServices; + +namespace TorchSharp +{ + public static partial class torch + { + public static partial class jit + { + + public class Type : IDisposable + { + /// + /// Class wrapping PyTorch's type object reference. + /// + internal sealed class HType : SafeHandle + { + public HType(IntPtr preexistingHandle, bool ownsHandle, TypeKind kind) : base(IntPtr.Zero, ownsHandle) + { + SetHandle(preexistingHandle); + this.kind = kind; + } + + public override bool IsInvalid => handle == IntPtr.Zero; + + // This is just for marshalling + internal HType() : base(IntPtr.Zero, true) + { + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Type_dispose(HType handle); + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_TensorType_dispose(HType handle); + + protected override bool ReleaseHandle() + { + switch (kind) { + case TypeKind.TensorType: + THSJIT_TensorType_dispose(this); + break; + default: + THSJIT_Type_dispose(this); + break; + } + return true; + } + + private TypeKind kind; + } + + internal HType handle; + + internal Type(IntPtr handle, TypeKind kind) + { + this.handle = new HType(handle, true, kind); + } + + protected Type() + { + } + + ~Type() + { + Dispose(false); + } + + /// + /// Releases the storage. + /// + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Implements the .NET Dispose pattern. + /// + protected void Dispose(bool disposing) + { + if (disposing) { + handle.Dispose(); + handle.SetHandleAsInvalid(); + } + } + + [DllImport("LibTorchSharp")] + private static extern sbyte THSJIT_Type_kind(HType handle); + + internal TypeKind Kind { + get { return (TypeKind)THSJIT_Type_kind(handle); } + } + + [DllImport("LibTorchSharp")] + private static extern IntPtr THSJIT_Type_cast(HType module); + + internal TensorType AsTensorType() + { + return new TensorType(THSJIT_Type_cast(handle)); + } + + internal DynamicType AsDynamicType() + { + return new DynamicType(THSJIT_Type_cast(handle)); + } + + internal enum TypeKind : sbyte + { + AnyType = 0, + TensorType = 3, + } + } + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/NN/EmbeddingBag.cs b/src/TorchSharp/NN/EmbeddingBag.cs index 122efd956..e6c0ff42d 100644 --- a/src/TorchSharp/NN/EmbeddingBag.cs +++ b/src/TorchSharp/NN/EmbeddingBag.cs @@ -32,7 +32,7 @@ internal EmbeddingBag(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHan /// If specified, per_sample_weights must have exactly the same shape as input and is treated as having the same offsets, if those are not None. /// Only supported for mode='sum'. /// - public Tensor forward(Tensor input, Tensor offsets, Tensor perSampleWeights) + public override Tensor forward(Tensor input, Tensor offsets, Tensor perSampleWeights) { if (!input.IsIntegral()) throw new ArgumentException("Embedding input must be an integral tensor."); if (!(offsets is null) && input.dtype != offsets.dtype) throw new ArgumentException("input and offsets must have the same element type."); diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 53ddc465b..9f3ca9521 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -31,9 +31,10 @@ public class Module : IDisposable /// internal sealed class HType : SafeHandle { - public HType(IntPtr preexistingHandle, bool ownsHandle) + public HType(IntPtr preexistingHandle, bool ownsHandle, Action dispose = null) : base(IntPtr.Zero, ownsHandle) { + _dispose = dispose??THSNN_Module_dispose; SetHandle(preexistingHandle); } @@ -49,7 +50,9 @@ internal HType() : base(IntPtr.Zero, true) protected override bool ReleaseHandle() { - if (!IsInvalid) THSNN_Module_dispose(this); + if (!IsInvalid) { + _dispose(this); + } SetHandle(IntPtr.Zero); return true; } @@ -60,6 +63,8 @@ protected override void Dispose(bool disposing) ReleaseHandle(); } } + + private Action _dispose; } internal HType handle; @@ -75,6 +80,19 @@ internal BoxedModule BoxedModule { } } + internal Module(HType handle, IntPtr? boxedHandle) + { + this.handle = handle; + boxedModule = boxedHandle.HasValue ? new BoxedModule(boxedHandle.Value) : null; + + foreach (var (parameterName, parameter) in _named_parameters()) { + register_parameter(parameterName, parameter); + } + foreach (var (bufferName, buffer) in _named_buffers()) { + register_buffer(bufferName, buffer); + } + } + internal Module(IntPtr handle, IntPtr? boxedHandle, bool ownsHandle = true) { this.handle = new HType(handle, ownsHandle); @@ -256,8 +274,8 @@ public virtual Module to(ScalarType dtype) /// public Module to(Tensor other) { - to(other.device_type, other.device_index); - return to(other.dtype); + to(other.dtype); + return to(other.device_type, other.device_index); } /// @@ -288,9 +306,12 @@ public virtual Module apply(Action fn) [DllImport("LibTorchSharp")] static extern IntPtr THSNN_Module_load([MarshalAs(UnmanagedType.LPStr)] string location); - public static Module Load(string location) + public static Module Load(string filename) { - var handle = THSNN_Module_load(location); + if (!System.IO.File.Exists(filename)) + throw new System.IO.FileNotFoundException(filename); + + var handle = THSNN_Module_load(filename); if (handle == IntPtr.Zero) { CheckForErrors(); } return new Module(handle, IntPtr.Zero); } @@ -304,29 +325,42 @@ public virtual void Save(string modelPath) => THSNN_Module_save(handle, modelPath); [DllImport("LibTorchSharp")] - private static extern void THSNN_Module_train(HType module); + private static extern void THSNN_Module_train(HType module, bool on); - public virtual void train() + /// + /// Sets the module in training mode. + /// + /// + /// This has any effect only on certain modules.See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.Dropout, BatchNorm, etc. + /// + public virtual void train(bool train = true) { - THSNN_Module_train(handle); + THSNN_Module_train(handle, train); CheckForErrors(); - foreach (var (_, m) in named_children()) { m.train(); } + foreach (var (_, m) in named_children()) { m.train(train); } } [DllImport("LibTorchSharp")] private static extern void THSNN_Module_eval(HType module); + /// + /// Sets the module in evaluation mode. + /// + /// + /// This has any effect only on certain modules.See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.Dropout, BatchNorm, etc. + /// public virtual void eval() { - THSNN_Module_eval(handle); - CheckForErrors(); - foreach (var (_, m) in named_children()) { m.eval(); } + train(false); } [DllImport("LibTorchSharp")] private static extern bool THSNN_Module_is_training(HType module); - public bool training { + /// + /// Check whether the module is set to training or evaluation mode. + /// + public virtual bool training { get { var res = THSNN_Module_is_training(handle); CheckForErrors(); @@ -464,7 +498,7 @@ public virtual (IList missing_keys, IList unexpected_keyes) load [DllImport("LibTorchSharp")] private static extern void THSNN_Module_get_named_parameters(HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); - protected (string name, Parameter parameter)[] _named_parameters() + protected virtual (string name, Parameter parameter)[] _named_parameters() { using var pa = new PinnedArray(); using var sa = new PinnedArray(); @@ -479,7 +513,7 @@ public virtual (IList missing_keys, IList unexpected_keyes) load [DllImport("LibTorchSharp")] private static extern void THSNN_Module_get_named_buffers(HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); - protected (string name, Tensor buffer)[] _named_buffers() + protected virtual (string name, Tensor buffer)[] _named_buffers() { using var pa = new PinnedArray(); using var sa = new PinnedArray(); @@ -727,6 +761,9 @@ public virtual Tensor forward(Tensor t) public virtual Tensor forward(Tensor x, Tensor y) => throw new NotImplementedException("forward(x,y)"); + public virtual Tensor forward(Tensor x, Tensor y, Tensor z) + => throw new NotImplementedException("forward(x,y,z)"); + /// /// Save the parameters and buffers of the module to a disk location. /// @@ -799,8 +836,11 @@ public static void save_state_dict(System.IO.BinaryWriter writer, Dictionary - public Module load(string location, bool strict = true, IList skip = null) + public virtual Module load(string location, bool strict = true, IList skip = null) { + if (!System.IO.File.Exists(location)) + throw new System.IO.FileNotFoundException(location); + var dt = _deviceType; var di = _deviceIndex; diff --git a/src/TorchSharp/NN/Sequential.cs b/src/TorchSharp/NN/Sequential.cs index a1720c82a..096d204ef 100644 --- a/src/TorchSharp/NN/Sequential.cs +++ b/src/TorchSharp/NN/Sequential.cs @@ -130,11 +130,23 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } - public override void train() + /// + /// Sets the module in training mode. + /// + /// + /// This has any effect only on certain modules.See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.Dropout, BatchNorm, etc. + /// + public override void train(bool on = true) { - foreach (var m in _modules) { m.train(); } + foreach (var m in _modules) { m.train(on); } } + /// + /// Sets the module in evaluation mode. + /// + /// + /// This has any effect only on certain modules.See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.Dropout, BatchNorm, etc. + /// public override void eval() { foreach (var m in _modules) { m.eval(); } diff --git a/src/TorchSharp/NN/TransformerEncoder.cs b/src/TorchSharp/NN/TransformerEncoder.cs index 46858a39a..84b1f05a7 100644 --- a/src/TorchSharp/NN/TransformerEncoder.cs +++ b/src/TorchSharp/NN/TransformerEncoder.cs @@ -29,7 +29,7 @@ internal TransformerEncoder(IntPtr handle, IntPtr boxedHandle) : base(handle, bo /// The additive mask for the src sequence (optional). /// The ByteTensor mask for src keys per batch (optional). /// - public Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_padding_mask) + public override Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_padding_mask) { var res = THSNN_TransformerEncoder_forward(handle, src.Handle, diff --git a/src/TorchSharp/NN/TransformerEncoderLayer.cs b/src/TorchSharp/NN/TransformerEncoderLayer.cs index 4567599da..ae40d83d4 100644 --- a/src/TorchSharp/NN/TransformerEncoderLayer.cs +++ b/src/TorchSharp/NN/TransformerEncoderLayer.cs @@ -23,7 +23,7 @@ internal TransformerEncoderLayer(IntPtr handle, IntPtr boxedHandle) : base(handl /// The additive mask for the src sequence (optional). /// The ByteTensor mask for src keys per batch (optional). /// - public Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_padding_mask) + public override Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_padding_mask) { var res = THSNN_TransformerEncoderLayer_forward(handle, src.Handle, diff --git a/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs b/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs index 002a1471d..e5b6be8cf 100644 --- a/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs +++ b/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs @@ -265,6 +265,15 @@ public static Tensor det(Tensor input) return torch.linalg.det(input); } + public static Tensor diag(Tensor input, long dimension = 0) => input.diag(dimension); + + /// + /// Returns the sum of the elements of the diagonal of the input 2-D matrix. + /// + /// The input tensor + /// + public static Tensor trace(Tensor input) => input.trace(); + /// /// Returns a partial view of input with the its diagonal elements with respect to dim1 and dim2 appended as a dimension at the end of the shape. /// The argument offset controls which diagonal to consider: diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index 63e4d9c54..420dbaae0 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -2990,6 +2990,22 @@ public Tensor diag(long dimension = 0) return new Tensor(res); } + [DllImport("LibTorchSharp")] + static extern IntPtr THSTensor_trace(IntPtr tensor); + + /// + /// Returns the sum of the elements of the diagonal of the input 2-D matrix. + /// + /// + public Tensor trace() + { + if (ndim != 2) + throw new ArgumentException($"Expected a matrix, but got tensor with ndim == {ndim}"); + var res = THSTensor_trace(Handle); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); + } + [DllImport("LibTorchSharp")] static extern IntPtr THSTensor_diagflat(IntPtr tensor, long offset); diff --git a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj index 5acc6b627..eb145cbb2 100644 --- a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj +++ b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj @@ -42,6 +42,24 @@ + + + PreserveNewest + + + PreserveNewest + + + + + + PreserveNewest + + + PreserveNewest + + + diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index 8b29ea011..8090754a2 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -1096,6 +1096,8 @@ public void TestGradConditional() modT.train(); + Assert.True(modT.training); + var eval = modT.forward(x); var loss = mse_loss(Reduction.Sum); var output = loss(eval, y); @@ -3154,7 +3156,10 @@ public void TestMultiheadAttention() using (var K = torch.tensor(k_data, src_seq_len, batch_size, kembed_dim)) using (var V = torch.tensor(v_data, src_seq_len, batch_size, vembed_dim)) using (var Attn = torch.tensor(attn_data, batch_size, src_seq_len, src_seq_len)) { + mha.eval(); + Assert.False(mha.training); + var (att_out, att_wts) = mha.forward(Q, K, V); var t = att_wts.allclose(Attn, rtol: 0.5, atol: 0.5); Assert.True(t); diff --git a/test/TorchSharpTest/TestLoadSave.cs b/test/TorchSharpTest/TestLoadSave.cs index 6242e4a5d..f615240ac 100644 --- a/test/TorchSharpTest/TestLoadSave.cs +++ b/test/TorchSharpTest/TestLoadSave.cs @@ -60,7 +60,7 @@ public void TestSaveLoadLinear2() public void TestSaveLoadLinear3() { if (File.Exists(".model.ts")) File.Delete(".model.ts"); - var linear = Linear(100, 10, true); + using var linear = Linear(100, 10, true); var params0 = linear.parameters(); linear.save(".model.ts"); @@ -76,14 +76,126 @@ public void TestSaveLoadLinear3() } + + [Fact] + public void TestLoadJIT_Func() + { + // One linear layer followed by ReLU. + using var m = torch.jit.load(@"func.script.dat"); + + var sms = m.named_modules().ToArray(); + Assert.Empty(sms); + + var kids = m.named_children().ToArray(); + Assert.Empty(kids); + + var t = m.forward(torch.ones(10), torch.ones(10)); + + Assert.Equal(new long[] { 10 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 }).allclose(t)); + } + + [Fact] + public void TestLoadJIT_1() + { + // One linear layer followed by ReLU. + using var m = torch.jit.load(@"linrelu.script.dat"); + var t = m.forward(torch.ones(10)); + + Assert.Equal(new long[] { 6 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 }).allclose(t)); + } + + [Fact] + public void TestSaveJIT() + { + if (File.Exists(".model.ts")) File.Delete(".model.ts"); + + // One linear layer followed by ReLU. + using var m1 = torch.jit.load(@"linrelu.script.dat"); + + torch.jit.save(m1, ".model.ts"); + using var m2 = torch.jit.load(@".model.ts"); + + var t = m2.forward(torch.ones(10)); + + Assert.Equal(new long[] { 6 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 }).allclose(t)); + + if (File.Exists(".model.ts")) File.Delete(".model.ts"); + } + + [Fact] + public void TestLoadJIT_2() + { + // One linear layer followed by ReLU. + using var m = torch.jit.load(@"scripted.script.dat"); + var t = m.forward(torch.ones(6)); + + Assert.Equal(new long[] { 6 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 1.554085f, 1.01024628f, -1.35086036f, -1.84021854f, 0.0127189457f, 0.5994258f }).allclose(t)); + } + + [Fact] + public void TestLoadJIT_3() + { + // Two linear layers, nested Sequential, ReLU in between. + using var m = torch.jit.load(@"l1000_100_10.script.dat"); + + var sms = m.named_modules().ToArray(); + Assert.Equal(4, sms.Length); + + var kids = m.named_children().ToArray(); + Assert.Equal(2, kids.Length); + + var t = m.forward(torch.ones(1000)); + + Assert.Equal(new long[] { 10 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 0.564213157f, -0.04519982f, -0.005117342f, 0.395530462f, -0.3780813f, -0.004734449f, -0.3221216f, -0.289159119f, 0.268511474f, 0.180702567f }).allclose(t)); + + Assert.Throws(() => m.forward(torch.ones(100))); + } + + [Fact] + public void TestLoadJIT_4() + { + // Definitely not a TorchScript file. Let's see what the runtime does with it. + Assert.Throws(() => torch.jit.load(@"bug510.dat")); + } + + [Fact] + public void TestSaveLoadJITCUDA() + { + if (torch.cuda.is_available()) { + + using var m = torch.jit.load(@"linrelu.script.dat"); + + m.to(DeviceType.CUDA); + var params0 = m.parameters().ToArray(); + foreach (var p in params0) + Assert.Equal(DeviceType.CUDA, p.device_type); + + var t = m.forward(torch.ones(10).cuda()).cpu(); + + Assert.Equal(new long[] { 6 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.Equal(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 }, t.data().ToArray()); + } + } + [Fact] public void TestSaveLoadConv2D() { if (File.Exists(".model.ts")) File.Delete(".model.ts"); - var conv = Conv2d(100, 10, 5); + using var conv = Conv2d(100, 10, 5); var params0 = conv.parameters(); conv.save(".model.ts"); - var loaded = Conv2d(100, 10, 5); + using var loaded = Conv2d(100, 10, 5); loaded.load(".model.ts"); var params1 = loaded.parameters(); File.Delete(".model.ts"); @@ -94,12 +206,12 @@ public void TestSaveLoadConv2D() [Fact] public void TestSaveLoadConv2D_sd() { - var conv = Conv2d(100, 10, 5); + using var conv = Conv2d(100, 10, 5); var params0 = conv.parameters(); var sd = conv.state_dict(); - var loaded = Conv2d(100, 10, 5); + using var loaded = Conv2d(100, 10, 5); Assert.NotEqual(params0, loaded.parameters()); loaded.load_state_dict(sd); @@ -110,10 +222,10 @@ public void TestSaveLoadConv2D_sd() public void TestSaveLoadSequential() { if (File.Exists(".model.ts")) File.Delete(".model.ts"); - var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); var params0 = conv.parameters(); conv.save(".model.ts"); - var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); loaded.load(".model.ts"); var params1 = loaded.parameters(); File.Delete(".model.ts"); @@ -124,12 +236,12 @@ public void TestSaveLoadSequential() [Fact] public void TestSaveLoadSequential_sd() { - var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); var params0 = conv.parameters(); var sd = conv.state_dict(); - var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); Assert.NotEqual(params0, loaded.parameters()); loaded.load_state_dict(sd); @@ -139,13 +251,13 @@ public void TestSaveLoadSequential_sd() [Fact] public void TestSaveLoadSequential_error1() { - var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); var params0 = conv.parameters(); var sd = conv.state_dict(); sd.Remove("0.bias"); - var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); Assert.NotEqual(params0, loaded.parameters()); Assert.Throws(() => loaded.load_state_dict(sd)); @@ -154,7 +266,7 @@ public void TestSaveLoadSequential_error1() [Fact] public void TestSaveLoadSequential_error2() { - var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); var params0 = conv.parameters(); var sd = conv.state_dict(); @@ -162,7 +274,7 @@ public void TestSaveLoadSequential_error2() sd.Add("2.bias", t); - var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); Assert.NotEqual(params0, loaded.parameters()); Assert.Throws(() => loaded.load_state_dict(sd)); @@ -171,7 +283,7 @@ public void TestSaveLoadSequential_error2() [Fact] public void TestSaveLoadSequential_lax() { - var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var conv = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); var params0 = conv.parameters(); var sd = conv.state_dict(); @@ -180,7 +292,7 @@ public void TestSaveLoadSequential_lax() sd.Add("2.bias", t); - var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); + using var loaded = Sequential(Conv2d(100, 10, 5), Linear(100, 10, true)); Assert.NotEqual(params0, loaded.parameters()); var (m,u) = loaded.load_state_dict(sd, false); @@ -195,14 +307,14 @@ public void TestSaveLoadCustomWithParameters() { if (File.Exists(".model.ts")) File.Delete(".model.ts"); - var original = new TestModule1(); + using var original = new TestModule1(); Assert.True(original.has_parameter("test")); var params0 = original.parameters(); Assert.True(params0.ToArray().ToArray()[0].requires_grad); original.save(".model.ts"); - var loaded = new TestModule1(); + using var loaded = new TestModule1(); Assert.True(loaded.has_parameter("test")); var params1 = loaded.parameters(); diff --git a/test/TorchSharpTest/TestTraining.cs b/test/TorchSharpTest/TestTraining.cs index a8561d409..fbdee3492 100644 --- a/test/TorchSharpTest/TestTraining.cs +++ b/test/TorchSharpTest/TestTraining.cs @@ -186,7 +186,7 @@ private static void ReInitializeLinear(Generator gen, Linear linear) } } - private static float TrainLoop(Sequential seq, Tensor x, Tensor y, optim.Optimizer optimizer) + private static float TrainLoop(Module seq, Tensor x, Tensor y, optim.Optimizer optimizer) { var loss = mse_loss(Reduction.Sum); @@ -213,7 +213,7 @@ private static float TrainLoop(Sequential seq, Tensor x, Tensor y, optim.Optimiz return finalLoss; } - private static float TrainLoop(Sequential seq, Tensor x, Tensor y, optim.Optimizer optimizer, optim.lr_scheduler.LRScheduler scheduler, bool check_lr = true) + private static float TrainLoop(Module seq, Tensor x, Tensor y, optim.Optimizer optimizer, optim.lr_scheduler.LRScheduler scheduler, bool check_lr = true) { var loss = mse_loss(Reduction.Sum); @@ -1665,6 +1665,24 @@ public void TestTrainingLBFGS_ME() //Assert.True(finalLoss < initialLoss); } + [Fact] + public void TestTrainingLoadedTorchScript() + { + var gen = new Generator(4711); + CreateDataAndLabels(gen, out var x, out var y); + + var seq = torch.jit.load(@"l1000_100_10.script.dat"); + + double learning_rate = 0.00004f; + var optimizer = torch.optim.SGD(seq.parameters(), learning_rate); + + var loss = TrainLoop(seq, x, y, optimizer); + + LossIsClose(53.81697f, loss); + + seq.eval(); + } + [Fact] public void TestTrainingConv2d() { diff --git a/test/TorchSharpTest/TorchSharpTest.csproj b/test/TorchSharpTest/TorchSharpTest.csproj index cc9971188..ab3fb075b 100644 --- a/test/TorchSharpTest/TorchSharpTest.csproj +++ b/test/TorchSharpTest/TorchSharpTest.csproj @@ -26,6 +26,21 @@ + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + PreserveNewest diff --git a/test/TorchSharpTest/func.script.dat b/test/TorchSharpTest/func.script.dat new file mode 100644 index 000000000..659837edd Binary files /dev/null and b/test/TorchSharpTest/func.script.dat differ diff --git a/test/TorchSharpTest/l1000_100_10.script.dat b/test/TorchSharpTest/l1000_100_10.script.dat new file mode 100644 index 000000000..399c66b0a Binary files /dev/null and b/test/TorchSharpTest/l1000_100_10.script.dat differ diff --git a/test/TorchSharpTest/linrelu.script.dat b/test/TorchSharpTest/linrelu.script.dat new file mode 100644 index 000000000..7af081e96 Binary files /dev/null and b/test/TorchSharpTest/linrelu.script.dat differ diff --git a/test/TorchSharpTest/scripted.script.dat b/test/TorchSharpTest/scripted.script.dat new file mode 100644 index 000000000..1a769e213 Binary files /dev/null and b/test/TorchSharpTest/scripted.script.dat differ