if model == 'PCA':
if transform == 'RCDT':
from pytranskit.TBM.TBM_RCDT import RCDT_PCA
n_components = 12
rcdt_pca = RCDT_PCA(n_components)
b_hat, p_tr, p_te = rcdt_pca.rcdt_pca(x_train_hat, y_train, x_test_hat, y_test, template)
rcdt_pca.visualize(directions=9, points=7, thetas=thetas, SD_spread=2)
elif transform == 'CLOT':
from pytranskit.TBM.TBM_CLOT import VOT_PCA
n_components = 12
clot_pca = VOT_PCA(n_components,lr, alpha, max_iter)
b_hat, p_tr, p_te = clot_pca.vot_pca(x_train_hat, y_train, x_test_hat, y_test, img0)
clot_pca.visualize(directions=9, points=7, SD_spread=2)
elif transform == 'PLOT':
from pytranskit.TBM.TBM_PLOT import PLOT_PCA
n_components = 12
plot_pca = PLOT_PCA(n_components)
b_hat, p_tr, p_te = plot_pca.plot_pca(x_train_hat, y_train, x_test_hat, y_test, template)
plot_pca.visualize(mean_x_train_hat, P_tem, directions=9, points=7, SD_spread=2)
elif model == 'PLDA':
if transform == 'RCDT':
from pytranskit.TBM.TBM_RCDT import RCDT_PLDA
n_components = 12
rcdt_plda = RCDT_PLDA(n_components)
b_hat, p_tr, p_te = rcdt_plda.rcdt_plda(x_train_hat, y_train, x_test_hat, y_test, template)
rcdt_plda.visualize(directions=9, points=7, thetas=thetas, SD_spread=2)
elif transform == 'CLOT':
from pytranskit.TBM.TBM_CLOT import VOT_PLDA
n_components = 12
clot_plda = VOT_PLDA(n_components,lr, alpha, max_iter)
b_hat, p_tr, p_te = clot_plda.vot_plda(x_train_hat, y_train, x_test_hat, y_test, img0)
clot_plda.visualize(directions=9, points=7, SD_spread=2)
elif transform == 'PLOT':
from pytranskit.TBM.TBM_PLOT import PLOT_PLDA
n_components = 12
plot_plda = PLOT_PLDA(n_components)
b_hat, p_tr, p_te = plot_plda.plot_plda(x_train_hat, y_train, x_test_hat, y_test, template)
plot_plda.visualize(mean_x_train_hat, P_tem, directions=9, points=7, SD_spread=2)
#plot_plda.visualize(Pl_tem_vec, P_tem, directions=9, points=7, SD_spread=2)
elif model == 'CCA':
if transform == 'RCDT':
from pytranskit.TBM.TBM_RCDT import RCDT_CCA
n_components = 9
rcdt_cca = RCDT_CCA(n_components)
b_hat1,bhat2,p_tr1,p_tr2,p_te1,p_te2 = rcdt_cca.rcdt_cca(x_train_hat, y_train, x_test_hat, y_test, template)
rcdt_cca.visualize(directions=5, points=7, thetas=thetas, SD_spread=2)
elif transform == 'CLOT':
from pytranskit.TBM.TBM_CLOT import VOT_CCA
n_components = 9
clot_cca = VOT_CCA(n_components,lr, alpha, max_iter)
b_hat1,bhat2,p_tr1,p_tr2,p_te1,p_te2 = clot_cca.vot_cca(x_train_hat, y_train, x_test_hat, y_test, img0)
clot_cca.visualize(directions=7, points=7, SD_spread=2)
elif transform == 'PLOT':
from pytranskit.TBM.TBM_PLOT import PLOT_CCA
n_components = 9
plot_cca = PLOT_CCA(n_components)
b_hat1,bhat2,p_tr1,p_tr2,p_te1,p_te2 = plot_cca.plot_cca(x_train_hat, y_train, x_test_hat, y_test, template)
plot_cca.visualize(mean_x_train_hat, P_tem, directions=7, points=7, SD_spread=2)
elif model == 'NS':
if transform == 'RCDT':
from pytranskit.TBM.TBM_RCDT import RCDT_NS_Classifier
rcdt_ns = RCDT_NS_Classifier(train_sample=32, use_gpu=False)
y_predicted = rcdt_ns.classify_RCDT_NS(x_train_hat, y_train, x_test_hat, y_test)
elif transform == 'CLOT':
from pytranskit.TBM.TBM_CLOT import VOT_NS_Classifier
clot_ns = VOT_NS_Classifier(train_sample=32, use_gpu=False)
y_predicted = clot_ns.classify_VOT_NS(x_train_hat, y_train, x_test_hat, y_test)
elif transform == 'PLOT':
from pytranskit.TBM.TBM_PLOT import PLOT_NS_Classifier
plot_ns = PLOT_NS_Classifier(train_sample=32, use_gpu=False)
y_predicted = plot_ns.classify_PLOT_NS(x_train_hat, y_train, x_test_hat, y_test)