Skip to content
21 changes: 16 additions & 5 deletions uberduck_ml_dev/utils/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sys
import torch
from ..models.common import STFT
from ..vocoders.istftnet import iSTFTNetGenerator, TorchSTFT


class Denoiser(torch.nn.Module):
Expand All @@ -46,11 +47,21 @@ def __init__(
raise Exception("Mode {} if not supported".format(mode))

with torch.no_grad():
bias_audio = (
hifigan.vocoder.forward(mel_input.to(hifigan.device))
.view(1, -1)
.float()
)
if isinstance(hifigan, iSTFTNetGenerator):
self.stft = TorchSTFT(filter_length=16, hop_length=4, win_length=16, device="cpu").to("cpu")
spec, phase = hifigan.vocoder(mel_input.to(hifigan.device))
y_g_hat = self.stft.inverse(spec.cpu(), phase.cpu())
bias_audio = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try using rearrange from einops instead of view - its more verbose

y_g_hat
.view(1, -1)
.float()
)
else:
bias_audio = (
hifigan.vocoder.forward(mel_input.to(hifigan.device))
.view(1, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

.float()
)
bias_spec, _ = self.stft.transform(bias_audio.cpu())

self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
Expand Down
Loading