One of the most amazing advantage of using MelGAN is "it works realtime on CPU!".
Try it!
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd
import tqdm
import torch
import models
from torch.utils.data import DataLoader
from data import LJspeechDataset, collate_fn, collate_fn_synthesize
import commons
import librosa
import numpy as np
import os
import json
kwargs = {'num_workers': 0, 'pin_memory': True}
train_dataset = LJspeechDataset('./DATASETS/ljspeech/', True, 0.1)
test_dataset = LJspeechDataset('./DATASETS/ljspeech/', False, 0.1)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn,
**kwargs)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn_synthesize,
**kwargs)
model_dir = "./logs/test/"
configs = json.load(open(os.path.join(model_dir, "config.json"), "r"))
model = models.Generator(configs["data"]["n_channels"])#.to("cuda")
checkpoint_path = os.path.join(model_dir, 'G_205.pth')
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
state_dict = checkpoint_dict['model']
new_state_dict= {}
for k, v in model.state_dict().items():
try:
new_state_dict[k] = state_dict[k]
except:
print("%s is not in the checkpoint" % k)
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.remove_weight_norm()
idx_stop = 0
for i, (x, c, _) in enumerate(test_loader):
#x, c = x.to("cuda"), c.to("cuda")
if i == idx_stop:
break
ipd.Audio(x.cpu().numpy().reshape(-1), rate=22050)
with torch.no_grad():
x_hat = model(c)
ipd.Audio(x_hat.cpu().detach().numpy().reshape(-1), rate=22050)