diff --git a/src/TorchSharp/TorchAudio/Models.cs b/src/TorchSharp/TorchAudio/Models.cs index 6fc0919e1..a0f6b3e31 100644 --- a/src/TorchSharp/TorchAudio/Models.cs +++ b/src/TorchSharp/TorchAudio/Models.cs @@ -1,4 +1,4 @@ -// Copyright +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. using System; using static TorchSharp.torch; @@ -85,6 +85,46 @@ public static Modules.Tacotron2 Tacotron2( postnet_embedding_dim, gate_threshold); } + + /// + /// WaveRNN model based on the implementation from `fatchord https://github.com/fatchord/WaveRNN`. + /// + /// The list of upsample scales. + /// The number of output classes. + /// The number of samples between the starts of consecutive frames. + /// The number of ResBlock in stack. + /// The dimension of RNN layer. + /// The dimension of fully connected layer. + /// The number of kernel size in the first Conv1d layer. + /// The number of bins in a spectrogram. + /// The number of hidden dimensions of resblock. + /// The number of output dimensions of melresnet. + /// The WaveRNN model + public static Modules.WaveRNN WaveRNN( + long[] upsample_scales, + int n_classes, + int hop_length, + int n_res_block = 10, + int n_rnn = 512, + int n_fc = 512, + int kernel_size = 5, + int n_freq = 128, + int n_hidden = 128, + int n_output = 128) + { + return new Modules.WaveRNN( + "wavernn", + upsample_scales, + n_classes, + hop_length, + n_res_block, + n_rnn, + n_fc, + kernel_size, + n_freq, + n_hidden, + n_output); + } } } } \ No newline at end of file diff --git a/src/TorchSharp/TorchAudio/Modules/WaveRNN.cs b/src/TorchSharp/TorchAudio/Modules/WaveRNN.cs new file mode 100644 index 000000000..ecacebb5c --- /dev/null +++ b/src/TorchSharp/TorchAudio/Modules/WaveRNN.cs @@ -0,0 +1,347 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. + +// A number of implementation details in this file have been translated from the Python version of torchaudio, +// largely located in the files found in this folder: +// +// https://github.com/pytorch/audio/blob/c15eee23964098f88ab0afe25a8d5cd9d728af54/torchaudio/models/wavernn.py +// +// The origin has the following copyright notice and license: +// +// https://github.com/pytorch/audio/blob/main/LICENSE +// + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using static TorchSharp.torch; + +using static TorchSharp.torch.nn; +using F = TorchSharp.torch.nn.functional; + +#nullable enable +namespace TorchSharp.Modules +{ + /// + /// This class is used to represent a WaveRNN module. + /// + public class WaveRNN : nn.Module + { + private readonly int _pad; + public readonly nn.Module fc; + public readonly nn.Module fc1; + public readonly nn.Module fc2; + public readonly nn.Module fc3; + public readonly int hop_length; + public readonly int kernel_size; + public readonly int n_aux; + public readonly int n_bits; + public readonly int n_classes; + public readonly int n_rnn; + public readonly nn.Module relu1; + public readonly nn.Module relu2; + public readonly GRU rnn1; + public readonly GRU rnn2; + internal readonly UpsampleNetwork upsample; + + internal WaveRNN( + string name, + long[] upsample_scales, + int n_classes, + int hop_length, + int n_res_block = 10, + int n_rnn = 512, + int n_fc = 512, + int kernel_size = 5, + int n_freq = 128, + int n_hidden = 128, + int n_output = 128) : base(name) + { + this.kernel_size = kernel_size; + this._pad = (kernel_size % 2 == 1 ? kernel_size - 1 : kernel_size) / 2; + this.n_rnn = n_rnn; + this.n_aux = n_output / 4; + this.hop_length = hop_length; + this.n_classes = n_classes; + this.n_bits = (int)(Math.Log(this.n_classes) / Math.Log(2) + 0.5); + + long total_scale = 1; + foreach (var upsample_scale in upsample_scales) { + total_scale *= upsample_scale; + } + if (total_scale != this.hop_length) { + throw new ArgumentException($"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}"); + } + + this.upsample = new UpsampleNetwork("upsamplenetwork", upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size); + this.fc = nn.Linear(n_freq + this.n_aux + 1, n_rnn); + + this.rnn1 = nn.GRU(n_rnn, n_rnn, batchFirst: true); + this.rnn2 = nn.GRU(n_rnn + this.n_aux, n_rnn, batchFirst: true); + + this.relu1 = nn.ReLU(inPlace: true); + this.relu2 = nn.ReLU(inPlace: true); + + this.fc1 = nn.Linear(n_rnn + this.n_aux, n_fc); + this.fc2 = nn.Linear(n_fc + this.n_aux, n_fc); + this.fc3 = nn.Linear(n_fc, this.n_classes); + + this.RegisterComponents(); + } + + /// + /// Pass the input through the WaveRNN model. + /// + /// The input waveform to the WaveRNN layer + /// The input spectrogram to the WaveRNN layer + /// + /// + public override Tensor forward(Tensor waveform, Tensor specgram) + { + if (waveform.size(1) != 1) { + throw new ArgumentException("Require the input channel of waveform is 1"); + } + if (specgram.size(1) != 1) { + throw new ArgumentException("Require the input channel of specgram is 1"); + } + // remove channel dimension until the end + waveform = waveform.squeeze(1); + specgram = specgram.squeeze(1); + + var batch_size = waveform.size(0); + var h1 = torch.zeros(1, batch_size, this.n_rnn, dtype: waveform.dtype, device: waveform.device); + var h2 = torch.zeros(1, batch_size, this.n_rnn, dtype: waveform.dtype, device: waveform.device); + // output of upsample: + // specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) + // aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) + Tensor aux; + (specgram, aux) = this.upsample.forward(specgram); + specgram = specgram.transpose(1, 2); + aux = aux.transpose(1, 2); + + var aux_idx = new long[5]; + for (int i = 0; i < aux_idx.Length; i++) { + aux_idx[i] = this.n_aux * i; + } + var a1 = aux[TensorIndex.Colon, TensorIndex.Colon, TensorIndex.Slice(aux_idx[0], aux_idx[1])]; + var a2 = aux[TensorIndex.Colon, TensorIndex.Colon, TensorIndex.Slice(aux_idx[1], aux_idx[2])]; + var a3 = aux[TensorIndex.Colon, TensorIndex.Colon, TensorIndex.Slice(aux_idx[2], aux_idx[3])]; + var a4 = aux[TensorIndex.Colon, TensorIndex.Colon, TensorIndex.Slice(aux_idx[3], aux_idx[4])]; + + var x = torch.cat(new Tensor[] { waveform.unsqueeze(-1), specgram, a1 }, dimension: -1); + x = this.fc.forward(x); + var res = x; + (x, _) = this.rnn1.forward(x, h1); + + x = x + res; + res = x; + x = torch.cat(new Tensor[] { x, a2 }, dimension: -1); + (x, _) = this.rnn2.forward(x, h2); + + x = x + res; + x = torch.cat(new Tensor[] { x, a3 }, dimension: -1); + x = this.fc1.forward(x); + x = this.relu1.forward(x); + + x = torch.cat(new Tensor[] { x, a4 }, dimension: -1); + x = this.fc2.forward(x); + x = this.relu2.forward(x); + x = this.fc3.forward(x); + + // bring back channel dimension + return x.unsqueeze(1); + } + + /// + /// Inference method of WaveRNN. + /// + /// Batch of spectrograms. + /// Indicates the valid length of each audio in the batch. + /// The inferred waveform and the valid length in time axis of the output Tensor. + public virtual (Tensor, Tensor?) infer(Tensor specgram, Tensor? lengths = null) + { + var device = specgram.device; + var dtype = specgram.dtype; + + specgram = torch.nn.functional.pad(specgram, (this._pad, this._pad)); + Tensor aux; + (specgram, aux) = this.upsample.forward(specgram); + if (lengths is not null) { + lengths = lengths * this.upsample.total_scale; + } + + var output = new List(); + long b_size = specgram.size(0); + long seq_len = specgram.size(2); + + var h1 = torch.zeros(new long[] { 1, b_size, this.n_rnn }, device: device, dtype: dtype); + var h2 = torch.zeros(new long[] { 1, b_size, this.n_rnn }, device: device, dtype: dtype); + var x = torch.zeros(new long[] { b_size, 1 }, device: device, dtype: dtype); + + var aux_split = new Tensor[4]; + for (int i = 0; i < 4; i++) { + aux_split[i] = aux[TensorIndex.Colon, TensorIndex.Slice(this.n_aux * i, this.n_aux * (i + 1)), TensorIndex.Colon]; + } + + for (int i = 0; i < seq_len; i++) { + + var m_t = specgram[TensorIndex.Colon, TensorIndex.Colon, i]; + + var a1_t = aux_split[0][TensorIndex.Colon, TensorIndex.Colon, i]; + var a2_t = aux_split[1][TensorIndex.Colon, TensorIndex.Colon, i]; + var a3_t = aux_split[2][TensorIndex.Colon, TensorIndex.Colon, i]; + var a4_t = aux_split[3][TensorIndex.Colon, TensorIndex.Colon, i]; + + x = torch.cat(new Tensor[] { x, m_t, a1_t }, dimension: 1); + x = this.fc.forward(x); + (_, h1) = this.rnn1.forward(x.unsqueeze(1), h1); + + x = x + h1[0]; + var inp = torch.cat(new Tensor[] { x, a2_t }, dimension: 1); + (_, h2) = this.rnn2.forward(inp.unsqueeze(1), h2); + + x = x + h2[0]; + x = torch.cat(new Tensor[] { x, a3_t }, dimension: 1); + x = F.relu(this.fc1.forward(x)); + + x = torch.cat(new Tensor[] { x, a4_t }, dimension: 1); + x = F.relu(this.fc2.forward(x)); + + var logits = this.fc3.forward(x); + + var posterior = F.softmax(logits, dim: 1); + + x = torch.multinomial(posterior, 1).@float(); + // Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1] + + x = 2 * x / ((1 << this.n_bits) - 1.0) - 1.0; + + output.Add(x); + } + return (torch.stack(output).permute(1, 2, 0), lengths); + } + + private class ResBlock : nn.Module + { + public nn.Module resblock_model; + + public ResBlock(string name, int n_freq = 128) : base(name) + { + this.resblock_model = nn.Sequential( + nn.Conv1d(inputChannel: n_freq, outputChannel: n_freq, kernelSize: 1, bias: false), + nn.BatchNorm1d(n_freq), + nn.ReLU(inPlace: true), + nn.Conv1d(inputChannel: n_freq, outputChannel: n_freq, kernelSize: 1, bias: false), + nn.BatchNorm1d(n_freq)); + RegisterComponents(); + } + + public override Tensor forward(Tensor specgram) + { + return this.resblock_model.forward(specgram) + specgram; + } + } + + internal class MelResNet : nn.Module + { + public readonly nn.Module melresnet_model; + + public MelResNet( + string name, + int n_res_block = 10, + int n_freq = 128, + int n_hidden = 128, + int n_output = 128, + int kernel_size = 5) : base(name) + { + var modules = new List(); + modules.Add(nn.Conv1d(inputChannel: n_freq, outputChannel: n_hidden, kernelSize: kernel_size, bias: false)); + modules.Add(nn.BatchNorm1d(n_hidden)); + modules.Add(nn.ReLU(inPlace: true)); + for (int i = 0; i < n_res_block; i++) { + modules.Add(new ResBlock("resblock", n_hidden)); + } + modules.Add(nn.Conv1d(inputChannel: n_hidden, outputChannel: n_output, kernelSize: 1)); + this.melresnet_model = nn.Sequential(modules); + RegisterComponents(); + } + + public override Tensor forward(Tensor specgram) + { + return this.melresnet_model.forward(specgram); + } + } + + public class Stretch2d : nn.Module + { + public long freq_scale; + public long time_scale; + + public Stretch2d(string name, long time_scale, long freq_scale) : base(name) + { + this.freq_scale = freq_scale; + this.time_scale = time_scale; + this.RegisterComponents(); + } + + public override Tensor forward(Tensor specgram) + { + return specgram.repeat_interleave(this.freq_scale, -2).repeat_interleave(this.time_scale, -1); + } + } + + internal class UpsampleNetwork : nn.Module + { + public readonly long indent; + public readonly MelResNet resnet; + public readonly Stretch2d resnet_stretch; + public readonly long total_scale; + public readonly nn.Module upsample_layers; + + public UpsampleNetwork( + string name, + long[] upsample_scales, + int n_res_block = 10, + int n_freq = 128, + int n_hidden = 128, + int n_output = 128, + int kernel_size = 5) : base(name) + { + long total_scale = 1; + foreach (var upsample_scale in upsample_scales) { + total_scale *= upsample_scale; + } + this.total_scale = total_scale; + + this.indent = (kernel_size - 1) / 2 * total_scale; + this.resnet = new MelResNet("melresnet", n_res_block, n_freq, n_hidden, n_output, kernel_size); + this.resnet_stretch = new Stretch2d("stretch2d", total_scale, 1); + + var up_layers = new List(); + foreach (var scale in upsample_scales) { + var stretch = new Stretch2d("stretch2d", scale, 1); + var conv = nn.Conv2d(inputChannel: 1, outputChannel: 1, kernelSize: (1, scale * 2 + 1), padding: (0, scale), bias: false); + torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1)); + up_layers.Add(stretch); + up_layers.Add(conv); + } + this.upsample_layers = nn.Sequential(up_layers); + this.RegisterComponents(); + } + + public new (Tensor, Tensor) forward(Tensor specgram) + { + var resnet_output = this.resnet.forward(specgram).unsqueeze(1); + resnet_output = this.resnet_stretch.forward(resnet_output); + resnet_output = resnet_output.squeeze(1); + + specgram = specgram.unsqueeze(1); + var upsampling_output = this.upsample_layers.forward(specgram); + upsampling_output = upsampling_output.squeeze(1)[TensorIndex.Colon, TensorIndex.Colon, TensorIndex.Slice(this.indent, -this.indent)]; + + return (upsampling_output, resnet_output); + } + } + } +} diff --git a/test/TorchSharpTest/TestTorchAudioModels.cs b/test/TorchSharpTest/TestTorchAudioModels.cs index 0f845a901..014be0e4d 100644 --- a/test/TorchSharpTest/TestTorchAudioModels.cs +++ b/test/TorchSharpTest/TestTorchAudioModels.cs @@ -35,6 +35,21 @@ private Modules.Tacotron2 CreateTacotron2(int n_symbols) ); } + private Modules.WaveRNN CreateWaveRNN() + { + return torchaudio.models.WaveRNN( + upsample_scales: new long[] { 5, 5, 11 }, + n_classes: 1 << 8, // n_bits = 8 + hop_length: 275, + n_res_block: 10, + n_rnn: 512, + n_fc: 512, + kernel_size: 5, + n_freq: 80, + n_hidden: 128, + n_output: 128); + } + [Fact] public void Tacotron2ModelForward() { @@ -87,5 +102,20 @@ public void Tacotron2ModelInfer() } } } + + [Fact] + public void WaveRNNModelForward() + { + using (var scope = torch.NewDisposeScope()) { + var wavernn = CreateWaveRNN(); + long batch_size = 2; + var specgram = torch.randn(new long[] { batch_size, 1, 80, 6 }); + var waveform_len = (specgram.shape[3] - 5 + 1) * (5 * 5 * 11); + var waveform = torch.randn(new long[] { batch_size, 1, waveform_len }); + // specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) + var output = wavernn.forward(waveform, specgram); + Assert.Equal(new long[] { batch_size, 1, waveform.shape[2], 1 << 8 }, output.shape); + } + } } } \ No newline at end of file