Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/TorchSharp/NN/Utils/RNNUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public static partial class rnn
public static PackedSequence pack_padded_sequence(torch.Tensor input, torch.Tensor lengths, bool batch_first = false, bool enforce_sorted = true)
{
var res = THSNN_pack_padded_sequence(input.Handle, lengths.Handle, batch_first, enforce_sorted);
if (res.IsInvalid) { torch.CheckForErrors(); }
return new PackedSequence(res);
}

Expand All @@ -54,6 +55,7 @@ public static (torch.Tensor, torch.Tensor) pad_packed_sequence(PackedSequence se
IntPtr res1, res2;
long total_length_arg = total_length.HasValue ? total_length.Value : -1;
THSNN_pad_packed_sequence(sequence.Handle, batch_first, padding_value, total_length_arg, out res1, out res2);
if (res1 == IntPtr.Zero || res2 == IntPtr.Zero) { torch.CheckForErrors(); }
return (new torch.Tensor(res1), new torch.Tensor(res2));
}

Expand All @@ -68,6 +70,7 @@ public static torch.Tensor pad_sequence(IEnumerable<torch.Tensor> sequences, boo
{
var sequences_arg = sequences.Select(p => p.Handle).ToArray();
var res = THSNN_pad_sequence(sequences_arg, sequences_arg.Length, batch_first, padding_value);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new torch.Tensor(res);
}

Expand All @@ -81,6 +84,7 @@ public static PackedSequence pack_sequence(IEnumerable<torch.Tensor> sequences,
{
var sequences_arg = sequences.Select(p => p.Handle).ToArray();
var res = THSNN_pack_sequence(sequences_arg, sequences_arg.Length, enforce_sorted);
if (res.IsInvalid) { torch.CheckForErrors(); }
return new PackedSequence(res);
}
}
Expand Down
90 changes: 90 additions & 0 deletions src/TorchSharp/TorchAudio/Models.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright</param>

using System;
using static TorchSharp.torch;

namespace TorchSharp
{
public static partial class torchaudio
{
public static partial class models
{
/// <summary>
/// Tacotron2 model based on the implementation from
/// Nvidia https://github.com/NVIDIA/DeepLearningExamples/.
/// </summary>
/// <param name="mask_padding">Use mask padding</param>
/// <param name="n_mels">Number of mel bins</param>
/// <param name="n_symbol">Number of symbols for the input text</param>
/// <param name="n_frames_per_step">Number of frames processed per step, only 1 is supported</param>
/// <param name="symbol_embedding_dim">Input embedding dimension</param>
/// <param name="encoder_n_convolution">Number of encoder convolutions</param>
/// <param name="encoder_kernel_size">Encoder kernel size</param>
/// <param name="encoder_embedding_dim">Encoder embedding dimension</param>
/// <param name="decoder_rnn_dim">Number of units in decoder LSTM</param>
/// <param name="decoder_max_step">Maximum number of output mel spectrograms</param>
/// <param name="decoder_dropout">Dropout probability for decoder LSTM</param>
/// <param name="decoder_early_stopping">Continue decoding after all samples are finished</param>
/// <param name="attention_rnn_dim">Number of units in attention LSTM</param>
/// <param name="attention_hidden_dim">Dimension of attention hidden representation</param>
/// <param name="attention_location_n_filter">Number of filters for attention model</param>
/// <param name="attention_location_kernel_size">Kernel size for attention model</param>
/// <param name="attention_dropout">Dropout probability for attention LSTM</param>
/// <param name="prenet_dim">Number of ReLU units in prenet layers</param>
/// <param name="postnet_n_convolution">Number of postnet convolutions</param>
/// <param name="postnet_kernel_size">Postnet kernel size</param>
/// <param name="postnet_embedding_dim">Postnet embedding dimension</param>
/// <param name="gate_threshold">Probability threshold for stop token</param>
/// <returns>Tacotron2 model</returns>
public static Modules.Tacotron2 Tacotron2(
bool mask_padding = false,
int n_mels = 80,
int n_symbol = 148,
int n_frames_per_step = 1,
int symbol_embedding_dim = 512,
int encoder_embedding_dim = 512,
int encoder_n_convolution = 3,
int encoder_kernel_size = 5,
int decoder_rnn_dim = 1024,
int decoder_max_step = 2000,
double decoder_dropout = 0.1,
bool decoder_early_stopping = true,
int attention_rnn_dim = 1024,
int attention_hidden_dim = 128,
int attention_location_n_filter = 32,
int attention_location_kernel_size = 31,
double attention_dropout = 0.1,
int prenet_dim = 256,
int postnet_n_convolution = 5,
int postnet_kernel_size = 5,
int postnet_embedding_dim = 512,
double gate_threshold = 0.5)
{
return new Modules.Tacotron2(
"tacotron2",
mask_padding,
n_mels,
n_symbol,
n_frames_per_step,
symbol_embedding_dim,
encoder_embedding_dim,
encoder_n_convolution,
encoder_kernel_size,
decoder_rnn_dim,
decoder_max_step,
decoder_dropout,
decoder_early_stopping,
attention_rnn_dim,
attention_hidden_dim,
attention_location_n_filter,
attention_location_kernel_size,
attention_dropout,
prenet_dim,
postnet_n_convolution,
postnet_kernel_size,
postnet_embedding_dim,
gate_threshold);
}
}
}
}
Loading