import numpy as np
import sklearn.datasets as ds
from sklearn.decomposition import ProbabilisticPCA as PPCA
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import as plt_cm
import matplotlib.colors as plt_col
import matplotlib.patches as plt_patches
import seaborn as sns

%matplotlib inline
%load_ext autoreload
%autoreload 2

import vbmfa.fa as vbfa

sns.set_style('darkgrid')

def plot_scatter(x, classes, ax=None):
    ax = plt.gca() if ax is None else ax
    cmap = plt_cm.jet
    norm = plt_col.Normalize(vmin=np.min(classes), vmax=np.max(classes))
    mapper = plt_cm.ScalarMappable(cmap=cmap, norm=norm)
    colors = mapper.to_rgba(classes)
    ax.scatter(x[0, :], x[1, :], color=colors, s=20)

def plot_mse(mse):
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(mse, linewidth=2, marker='s',markersize=5, markerfacecolor='red')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('MSE')

def plot_grid(n, ncols=4, size=(5, 5)):
    nrows = int(np.ceil(n/float(ncols)))
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(size[0]*ncols, size[1]*nrows))
    ax = ax.ravel()
    return [fig, ax]

def plot_compress(q, n=30):
    np.random.seed(0)
    fa = vbfa.VbFa(data_y, q)
    y = fa.x_to_y()
    fig, ax = plot_grid(n, ncols=10)
    dim = int(np.sqrt(fa.P))
    for i in range(n):
        ax[i].matshow(y[:, i].reshape(dim, dim), cmap='binary')

def plot_images(images, n=30, size=2):
    fig, ax = plot_grid(n, ncols=10, size=(size, size))
    dim = int(np.sqrt(images.shape[0]))
    with sns.axes_style('white'):
        for i in range(n):
            ax[i].grid()
            ax[i].set_axis_off()
            ax[i].matshow(images[:, i].reshape(dim, dim), cmap='binary')

def plot_ve(ve):
    x = np.arange(len(ve))
    fig, ax = plt.subplots(figsize=(5, 4))
    ax.set_xlabel('Factor')
    ax.set_ylabel('% Variance explained'), ve)

digits = ds.load_digits()
data_y =
data_t =

pca = PCA(n_components=2)
pca_x = pca.fit_transform(data_y.transpose())
plot_scatter(pca_x.transpose(), data_t)
print('MSE: {:.3f}'.format(np.linalg.norm(data_y.transpose() - pca.inverse_transform(pca_x))))

np.random.seed(0)
fa = vbfa.VbFa(data_y, 2)
mse = [fa.mse()]
maxit = 7
fig, ax = plot_grid(maxit + 1)
plot_scatter(fa.q_x.mean, data_t, ax[0])
fa.init()
for i in range(maxit):
    fa.update()
    j = i + 1
    plot_scatter(fa.q_x.mean, data_t, ax[j])
    ax[j].set_title('Iteration {}'.format(j))
    mse.append(fa.mse())
plot_mse(mse)
print('MSE: {:f}'.format(mse[-1]))

np.random.seed(0)
fa = vbfa.VbFa(data_y, 8)
fa.order_factors()
plot_images(fa.x_to_y())
plot_ve(fa.variance_explained(sort=False))
plot_scatter(fa.q_x.mean[:2, :], data_t)

np.random.seed(0)
fa = vbfa.VbFa(data_y, 32)
plot_images(fa.x_to_y())

np.random.seed(0)
fa = vbfa.VbFa(data_y, 64)
plot_images(fa.x_to_y())

faces = ds.fetch_olivetti_faces()
data_y =
data_t =

faces = ds.fetch_olivetti_faces

pca = PCA(n_components=2)
pca_x = pca.fit_transform(data_y.transpose())
plot_scatter(pca_x.transpose(), data_t)
print('MSE: {:f}'.format(np.linalg.norm(data_y.transpose() - pca.inverse_transform(pca_x))))

np.random.seed(0)
fa = vbfa.VbFa(data_y, 2)
mse = [fa.mse()]
maxit = 3
fig, ax = plot_grid(maxit + 1)
plot_scatter(fa.q_x.mean, data_t, ax[0])
for i in range(maxit):
    fa.update()
    j = i + 1
    plot_scatter(fa.q_x.mean, data_t, ax[j])
    ax[j].set_title('Iteration {}'.format(j))
    mse.append(fa.mse())
plot_mse(mse)
print('MSE: {:f}'.format(fa.mse()))

plot_images(data_y, 50)

np.random.seed(0)
fa = vbfa.VbFa(data_y, 64)
fa.order_factors()
plot_images(fa.x_to_y(), 50)
plot_ve(fa.variance_explained())
plot_scatter(fa.q_x.mean[:2, :], data_t)

np.random.seed(0)
fa = vbfa.VbFa(data_y, 128)
plot_images(fa.x_to_y(), 50)