In [1]:
%pylab inline
rcParams["figure.figsize"] = (16,5)

import sys
sys.path.insert(0, "..")
Populating the interactive namespace from numpy and matplotlib
In [2]:
import torch

from scipy.io import wavfile
import pyworld
import pysptk

import IPython
from IPython.display import Audio

import librosa
import librosa.display

from nnmnkwii import preprocessing as P
from glob import glob
from os.path import join

import gantts
from hparams import tts_acoustic as hp_acoustic
from hparams import tts_duration as hp_duration
In [3]:
# Depends on experimental conditions
data_dir = "../data/cmu_arctic_tts_order24"
chechpoints_dir = "../checkpoints/tts_order24/"
duration_epoch = 100
acoustic_epoch = 50
In [4]:
X_min = {}
X_max = {}
Y_mean = {}
Y_var = {}
Y_std = {}

for typ in ["acoustic", "duration"]:
    X_min[typ] = np.load(join(data_dir, "X_{}_data_min.npy".format(typ)))
    X_max[typ] = np.load(join(data_dir, "X_{}_data_max.npy".format(typ)))
    Y_mean[typ] = np.load(join(data_dir, "Y_{}_data_mean.npy".format(typ)))
    Y_var[typ] = np.load(join(data_dir, "Y_{}_data_var.npy".format(typ)))
    Y_std[typ] = np.sqrt(Y_var[typ])

Models

In [5]:
models = {"baseline": {}, "gan": {}}

for method_typ in ["baseline", "gan"]:
    models[method_typ] = {"duration":  {}, "acoustic": {}}
    for (typ, epoch) in zip(["duration", "acoustic"], [duration_epoch, acoustic_epoch]):
        # Set missing hyper params from data
        hp = hp_acoustic if typ == "acoustic" else hp_duration
        if hp.generator_params["in_dim"] is None:
            hp.generator_params["in_dim"] = X_min[typ].shape[-1]
        if hp.generator_params["out_dim"] is None:
            hp.generator_params["out_dim"] = Y_mean[typ].shape[-1]
    
        models[method_typ][typ] = getattr(gantts.models, hp.generator)(**hp.generator_params)
        print("Model for {}, {}\n".format(method_typ, typ), models[method_typ][typ])
        
        checkpoint_path = join(chechpoints_dir, "tts_{}/{}/checkpoint_epoch{}_Generator.pth".format(
            typ, method_typ, epoch))
        print("Load checkpoint from: {}".format(checkpoint_path))
        
        checkpoint = torch.load(checkpoint_path)
        models[method_typ][typ].load_state_dict(checkpoint["state_dict"])
        models[method_typ][typ].eval()
Model for baseline, duration
 LSTMRNN (
  (lstm): LSTM(416, 512, num_layers=3, batch_first=True, dropout=0.5, bidirectional=True)
  (hidden2out): Linear (1024 -> 5)
  (sigmoid): Sigmoid ()
)
Load checkpoint from: ../checkpoints/tts_order24/tts_duration/baseline/checkpoint_epoch100_Generator.pth
Model for baseline, acoustic
 MLP (
  (layers): ModuleList (
    (0): Linear (425 -> 512)
    (1): Linear (512 -> 512)
    (2): Linear (512 -> 512)
  )
  (last_linear): Linear (512 -> 82)
  (relu): LeakyReLU (0.01, inplace)
  (sigmoid): Sigmoid ()
  (dropout): Dropout (p = 0.5)
)
Load checkpoint from: ../checkpoints/tts_order24/tts_acoustic/baseline/checkpoint_epoch50_Generator.pth
Model for gan, duration
 LSTMRNN (
  (lstm): LSTM(416, 512, num_layers=3, batch_first=True, dropout=0.5, bidirectional=True)
  (hidden2out): Linear (1024 -> 5)
  (sigmoid): Sigmoid ()
)
Load checkpoint from: ../checkpoints/tts_order24/tts_duration/gan/checkpoint_epoch100_Generator.pth
Model for gan, acoustic
 MLP (
  (layers): ModuleList (
    (0): Linear (425 -> 512)
    (1): Linear (512 -> 512)
    (2): Linear (512 -> 512)
  )
  (last_linear): Linear (512 -> 82)
  (relu): LeakyReLU (0.01, inplace)
  (sigmoid): Sigmoid ()
  (dropout): Dropout (p = 0.5)
)
Load checkpoint from: ../checkpoints/tts_order24/tts_acoustic/gan/checkpoint_epoch50_Generator.pth

Compare generated audio samples

Baseline vs GAN

In [6]:
from evaluation_tts import tts_from_label, get_lab_files, get_wav_files
In [7]:
label_dir = "../nnmnkwii_gallery/data/slt_arctic_full_data/label_state_align/"
wav_dir = "../nnmnkwii_gallery/data/slt_arctic_full_data/wav/"

test_label_paths = get_lab_files(data_dir,label_dir, test=True)
test_wav_paths = get_wav_files(data_dir, wav_dir, test=True)
print(len(test_label_paths))
5

Apply acoustic model only

In [8]:
for wav_path, label_path in zip(test_wav_paths, test_label_paths):
    print("Input label:", label_path)
    
    fs, waveform = wavfile.read(wav_path)
    
    ty = "baseline"
    baseline_waveform, mgc_baseline, _, _, _ = tts_from_label(
        models[ty], label_path, X_min, X_max, Y_mean, Y_std,
        apply_duration_model=False, post_filter=False)
    ty = "gan"
    gan_waveform, mgc_gan, _, _, _ = tts_from_label(
        models[ty], label_path, X_min, X_max, Y_mean, Y_std,
        apply_duration_model=False, post_filter=False)
    
    IPython.display.display(Audio(waveform, rate=fs))
    IPython.display.display(Audio(baseline_waveform, rate=fs))
    IPython.display.display(Audio(gan_waveform, rate=fs))
Input label: ../nnmnkwii_gallery/data/slt_arctic_full_data/label_state_align/arctic_b0535.lab
Input label: ../nnmnkwii_gallery/data/slt_arctic_full_data/label_state_align/arctic_b0536.lab
Input label: ../nnmnkwii_gallery/data/slt_arctic_full_data/label_state_align/arctic_b0537.lab
Input label: ../nnmnkwii_gallery/data/slt_arctic_full_data/label_state_align/arctic_b0538.lab