#!/usr/bin/env python # coding: utf-8 # # SD-TSIA204 : Linear models # ## Least squares definition: 2 variables and 3D visualization # ### *Joseph Salmon* # # This notebook reproduces the pictures for the course "LeastSquare_Def" # # In[ ]: import numpy as np import pandas as pd import statsmodels.api as sm import matplotlib.pyplot as plt from matplotlib import rc import seaborn as sns from os import mkdir, path from mpl_toolkits.mplot3d import Axes3D # interaction mode better for 3D get_ipython().run_line_magic('matplotlib', 'notebook') # # Plot initialization # In[ ]: dirname = "../srcimages/" if not path.exists(dirname): mkdir(dirname) imageformat = '.pdf' rc('font', **{'family': 'sans-serif', 'sans-serif': ['Computer Modern Roman']}) params = {'axes.labelsize': 12, 'font.size': 12, 'legend.fontsize': 12, 'xtick.labelsize': 10, 'ytick.labelsize': 10, 'text.usetex': True, 'figure.figsize': (8, 6)} plt.rcParams.update(params) plt.close("all") # sns.set_context("poster") sns.set_palette("colorblind") sns.axes_style() sns.set_style({'legend.frameon': True}) color_blind_list = sns.color_palette("colorblind", 8) my_orange = color_blind_list[2] my_green = color_blind_list[1] ############################################################################### # display function: saving = False def my_saving_display(fig, dirname, filename, imageformat): """"Saving with personal function.""" filename = filename.replace('.', 'pt') # remove "." to avoid floats issues if saving is True: dirname + filename + imageformat image_name = dirname + filename + imageformat fig.savefig(image_name) # # 3D case drawing # In[ ]: plt.close("all") # Load data url = 'http://vincentarelbundock.github.io/Rdatasets/csv/datasets/trees.csv' dat3 = pd.read_csv(url) # Fit regression model X = dat3[['Girth', 'Height']] X = sm.add_constant(X) y = dat3['Volume'] results = sm.OLS(y, X).fit().params XX = np.arange(8, 22, 0.5) YY = np.arange(64, 90, 0.5) xx, yy = np.meshgrid(XX, YY) zz = results[0] + results[1] * xx + results[2] * yy fig = plt.figure() ax = Axes3D(fig) ax.set_xlabel('Girth') ax.set_ylabel('Height') ax.set_zlabel('Volume') ax.set_zlim(5, 80) ax.plot(X['Girth'], X['Height'], y, 'o') ax.plot_wireframe(xx, yy, zz, rstride=10, cstride=10, alpha=0.3) ax.plot_surface(xx, yy, zz, alpha=0.3) my_saving_display(fig, dirname, "tree_data_plot_regression", imageformat) plt.show() # # Non trivial minima : 3D visualisation # In[ ]: sns.set_style("white") XX = np.arange(-1, 1, 0.05) YY = XX xx, yy = np.meshgrid(XX, YY) zz = (xx - yy) ** 2 # In[ ]: fig = plt.figure() ax = Axes3D(fig) ax.view_init(elev=20., azim=50) ax.set_xlabel('$x$') ax.set_ylabel('$y$') ax.set_zlabel('$z$') plt.axis('off') ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2, antialiased=False, alpha=0.5) my_saving_display(fig, dirname, "CN0_2d_non_trivial1", imageformat) ax.view_init(elev=20., azim=90) surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2, antialiased=False, alpha=0.5) my_saving_display(fig, dirname, "CN0_2d_non_trivial2", imageformat) # In[ ]: fig = plt.figure() ax = Axes3D(fig) ax.view_init(elev=20., azim=50) ax.set_xlabel('$x$') ax.set_ylabel('$y$') ax.set_zlabel('$z$') plt.axis('off') ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2, antialiased=False, alpha=0.5) ax.view_init(elev=20., azim=130) surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2, antialiased=False, alpha=0.5) my_saving_display(fig, dirname, "CN0_2d_non_trivial3", imageformat) # In[ ]: fig = plt.figure() ax = Axes3D(fig) ax.view_init(elev=20., azim=50) ax.set_xlabel('$x$') ax.set_ylabel('$y$') ax.set_zlabel('$z$') plt.axis('off') ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2, antialiased=False, alpha=0.5) ax.view_init(elev=20., azim=170) surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2, antialiased=False, alpha=0.5) my_saving_display(fig, dirname, "CN0_2d_non_trivial4", imageformat) # In[ ]: fig = plt.figure() ax = Axes3D(fig) ax.view_init(elev=20., azim=50) ax.set_xlabel('$x$') ax.set_ylabel('$y$') ax.set_zlabel('$z$') plt.axis('off') ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2, antialiased=False, alpha=0.5) ax.view_init(elev=20., azim=210) surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2, antialiased=False, alpha=0.5) my_saving_display(fig, dirname, "CN0_2d_non_trivial5", imageformat)