Skip to content

Conversation

kaiidams
Copy link
Contributor

As we don't have torchaudio.pipelines to download pretrained models, we need to manually do it.

        private static void Test()
        {
            var tacotron2 = torchaudio.models.Tacotron2(
                mask_padding: false,
                n_mels: 80,
                n_frames_per_step: 1,
                symbol_embedding_dim: 512,
                encoder_embedding_dim: 512,
                encoder_n_convolution: 3,
                encoder_kernel_size: 5,
                decoder_rnn_dim: 1024,
                decoder_max_step: 2000,
                decoder_dropout: 0.1,
                decoder_early_stopping: true,
                attention_rnn_dim: 1024,
                attention_hidden_dim: 128,
                attention_location_n_filter: 32,
                attention_location_kernel_size: 31,
                attention_dropout: 0.1,
                prenet_dim: 256,
                postnet_n_convolution: 5,
                postnet_kernel_size: 5,
                postnet_embedding_dim: 512,
                gate_threshold: 0.5,
                n_symbol: 96);
            var path = Path.Combine(TORCH_HOME, "tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.bin");
            tacotron2.load(path);
            tacotron2.eval();

            var wavernn = torchaudio.models.WaveRNN(
                upsample_scales: new long[] { 5, 5, 11 },
                n_classes: 1 << 8,  // n_bits = 8
                hop_length: 275,
                n_res_block: 10,
                n_rnn: 512,
                n_fc: 512,
                kernel_size: 5,
                n_freq: 80,
                n_hidden: 128,
                n_output: 128);
            var vocoder = new WaveRNNVocoder("vocoder", wavernn);
            path = Path.Combine(TORCH_HOME, "wavernn_10k_epochs_8bits_ljspeech.bin");
            vocoder.load(path);
            vocoder.eval();

            // HH AH L OW   W ER L D !   T EH K S T   T AH   S P IY CH !
            var data = new int[] {
                54, 20, 65, 69, 11, 92, 44, 65, 38, 2, 11, 81, 40, 64, 79, 81, 11, 81,
                20, 11, 79, 77, 59, 37, 2 };
            var tokens = torch.tensor(data, dimensions: new long[] { 1, data.Length });

            var (mel_spec, mel_spec_len, _) = tacotron2.infer(tokens);

            Tensor waveform;
            (waveform, _) = vocoder.forward(mel_spec, mel_spec_len);

            var waveform_bytes = MemoryMarshal.Cast<float, byte>(
                waveform.data<float>().ToArray()).ToArray();
            File.WriteAllBytes("waveform.bin", waveform_bytes);
        }

        public class WaveRNNVocoder : nn.Module
        {
            private readonly int _sample_rate;
            private readonly Modules.WaveRNN _model;
            private readonly double? _min_level_db;

            public WaveRNNVocoder(string name, Modules.WaveRNN model, double? min_level_db = -100) : base(name)
            {
                this._sample_rate = 22050;
                this._model = model;
                this._min_level_db = min_level_db;
                this.RegisterComponents();
            }

            public int sample_rate => this._sample_rate;

            public override (Tensor, Tensor?) forward(Tensor mel_spec, Tensor? lengths = null)
            {
                mel_spec = torch.exp(mel_spec);
                mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min: 1e-5));
                if (this._min_level_db is not null)
                {
                    mel_spec = (this._min_level_db - mel_spec) / this._min_level_db;
                    mel_spec = torch.clamp(mel_spec, min: 0, max: 1);
                }
                Tensor waveform;
                (waveform, lengths) = this._model.infer(mel_spec, lengths);
                waveform = _unnormalize_waveform(waveform, this._model.n_bits);
                waveform = torchaudio.functional.mu_law_decoding(waveform, this._model.n_classes);
                waveform = waveform.squeeze(1);
                return (waveform, lengths);
            }

            static Tensor _unnormalize_waveform(Tensor waveform, int bits)
            {
                waveform = torch.clamp(waveform, -1, 1);
                waveform = (waveform + 1.0) * ((1 << bits) - 1) / 2;
                return torch.clamp(waveform, 0, (1 << bits) - 1).@int();
            }
        }

@GeorgeS2019
Copy link

@kaiidams

Where can you download?
tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.bin

@kaiidams
Copy link
Contributor Author

@NiklasGustafsson NiklasGustafsson merged commit 8fb956f into dotnet:main Aug 17, 2022
@GeorgeS2019
Copy link

@NiklasGustafsson the shared code above is only meant for here without being committed. Do you have a plan on how to commit the above code? Should the code be put in TorchSharp.Example repo?

@kaiidams Thank you for making TorchAudio increasingly complete.

@NiklasGustafsson
Copy link
Contributor

@NiklasGustafsson the shared code above is only meant for here without being committed. Do you have a plan on how to commit the above code? Should the code be put in TorchSharp.Example repo?

I don't know. @kaiidams has single-handedly made torchaudio real for us, so I will let him take the lead on anything related to audio examples and such.

@GeorgeS2019
Copy link

single-handedly

Exactly.!

@kaiidams
Copy link
Contributor Author

@NiklasGustafsson Can we have a place to put converted pretrained models for TorchSharpVision and TorchSharpAudio? Python torchvision also downloads pretrained models from the same server as torchaudio (For example MobileNet v2 from https://download.pytorch.org/models/mobilenet_v2-b0353104.pth). I assume that their models have BSD license like their code, but I'm not 100% sure.

@kaiidams
Copy link
Contributor Author

Related issue about pretrained models #588

@NiklasGustafsson
Copy link
Contributor

This is an issue we have to address for ML.NET as well as TorchSharp. Adding @luisquintanilla and @ericstj.

@GeorgeS2019
Copy link

GeorgeS2019 commented Aug 18, 2022

Theses are for me to track the discussion (Just realize *.pth not supported BUT *.pt is for TorchSharp)

A list of Torchvision pre-trained models (*.pth) saved using picking

I made mistake, this is Wrong, Only .pt

I think TorchSharp can load these pre-trained models keeping them with extension .pth

Reference1: MODELS AND PRE-TRAINED WEIGHTS

Reference2: https://github.com/saharhekmatdoust/Pre-trained-models-with-PyTorch/blob/main/resnet18_PyTorch.ipynb

Reference3: load_state_dict_from_url

@ericstj
Copy link
Member

ericstj commented Aug 18, 2022

We have an internal repo with git LFS for storing large model files that are used by ML.NET and build packages out there. ML.NET consumes those packages during the build for files it needs to redistribute in the final packages. We also use production Azure Storage for files which are not included in the packages and downloaded by ML.NET at runtime. Not sure which path makes the most sense here - let me know.

Python torchvision also downloads pretrained models from the same server as torchaudio

My understanding is that TorchSharp cannot reuse these models as-is. They need to be ported/converted - is that correct @NiklasGustafsson?

@GeorgeS2019
Copy link

GeorgeS2019 commented Aug 18, 2022

There are three file formats when discussing TorchSharp

*.pth, which TorchSharp does not support. This is for saving Pytorch modules via pickling.
*.pt, which TorchSharp will load and save, but not create from scratch. This is for saving Pytorch modules via TorchScript.
*.ts/.data/.whatever, which is for saving model weights only in the TorchSharp-specific format.

@NiklasGustafsson
Copy link
Contributor

Yes, it all depends on what format models were saved in. We do not support pickled models, but we do support TorchScript and the custom weights-only format we invented for PT/TorchSharp data exchange.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants