Clustering en Python: Parte II (Clustering jerárquico y dendograma)

Vamos a ver un ejemplo de cómo podemos aplicar clustering jerárquico y generar un dendograma.

Importamos los módulos que vamos a necesitar:

  • para medir tiempos de ejecución: time
  • para lectura y manejo de los datos: numpy y pandas
  • para disponer de implementaciones de distintos algoritmos de clustering: cluster de Scikit-learn
  • para normalizar el conjunto de datos: preprocessing de Scikit-learn
  • para las visualizaciones y los dendogramas: matplotlib, seaborn y scipy
In [1]:
%matplotlib inline
In [2]:
import time

import pandas as pd
import numpy as np

from sklearn import cluster
from sklearn import preprocessing

import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster import hierarchy

Creamos la función norm_to_zero_one para aplicar normalización min-max.

In [3]:
def norm_to_zero_one(df):
    return (df - df.min()) * 1.0 / (df.max() - df.min())

Datos

Trabajamos sobre el conjunto de datos facilitado para la segunda práctica de la asignatura que contiene información recogida por el INE.

In [4]:
censo = pd.read_csv('censo_granada.csv')

Los valores en blanco en realidad son otra categoría que vamos a nombrar con el valor 0.

In [5]:
censo = censo.replace(np.NaN,0)
In [6]:
censo.shape
Out[6]:
(83499, 142)
In [7]:
censo.columns.values
Out[7]:
array(['CPRO', 'CMUN', 'IDHUECO', 'NORDEN', 'FACTOR', 'MNAC', 'ANAC',
       'EDAD', 'SEXO', 'NACI', 'CPAISN', 'CPRON', 'CMUNN', 'ANORES',
       'ANOM', 'ANOC', 'ANOE', 'CLPAIS', 'CLPRO', 'CLMUNP',
       'RES_ANTERIOR', 'CPAISUNANO', 'CPROUNANO', 'CMUNANO', 'RES_UNANO',
       'CPAISDANO', 'CPRODANO', 'CMUNDANO', 'RES_DANO', 'SEG_VIV',
       'SEG_PAIS', 'SEG_PROV', 'SEG_MUN', 'SEG_NOCHES', 'SEG_DISP',
       'ECIVIL', 'ESCOLAR', 'ESREAL', 'TESTUD', 'TAREA1', 'TAREA2',
       'TAREA3', 'TAREA4', 'HIJOS', 'NHIJOS', 'RELA', 'JORNADA', 'CNO',
       'CNAE', 'SITU', 'CSE', 'ESCUR1', 'ESCUR2', 'ESCUR3', 'LTRABA',
       'PAISTRABA', 'PROTRABA', 'MUNTRABA', 'CODTRABA', 'NVIAJE',
       'MDESP1', 'MDESP2', 'TDESP', 'TENEN', 'CALE', 'ASEO', 'BADUCH',
       'INTERNET', 'AGUACOR', 'SUT', 'NHAB', 'PLANTAS', 'PLANTAB',
       'TIPOEDIF', 'ANOCONS', 'ESTADO', 'ASCENSOR', 'ACCESIB', 'GARAJE',
       'PLAZAS', 'GAS', 'TELEF', 'ACAL', 'RESID', 'FAMILIA', 'PAD_NORDEN',
       'MAD_NORDEN', 'CON_NORDEN', 'OPA_NORDEN', 'TIPOPER', 'NUCLEO',
       'NMIEM', 'NFAM', 'NNUC', 'NGENER', 'ESTHOG', 'TIPOHOG', 'NOCU',
       'NPARAIN', 'NPFAM', 'NPNUC', 'HM5', 'H0515', 'H1624', 'H2534',
       'H3564', 'H6584', 'H85M', 'HESPHOG', 'MESPHOG', 'HEXTHOG',
       'MEXTHOG', 'COMBINAC', 'EDADPAD', 'PAISNACPAD', 'NACIPAD',
       'ECIVPAD', 'ESTUPAD', 'SITUPAD', 'SITPPAD', 'EDADMAD',
       'PAISNACMAD', 'NACIMAD', 'ECIVMAD', 'ESTUMAD', 'SITUMAD',
       'SITPMAD', 'EDADCON', 'NACCON', 'NACICON', 'ECIVCON', 'ESTUCON',
       'SITUCON', 'SITPCON', 'TIPONUC', 'TAMNUC', 'NHIJO', 'NHIJOC',
       'FAMNUM', 'TIPOPARECIV', 'TIPOPARSEX', 'DIFEDAD'], dtype=object)

Seleccionamos el caso de estudio, en este ejemplo concreto, aquellos casos para los que el campo 'EDADMAD' (grupo quinquenal de la edad de la madre) no estaba vacío.

In [8]:
subset = censo.loc[censo['EDADMAD']>0]

Seleccionamos variables de interés para clustering.

In [9]:
usadas = ['EDAD', 'ANORES', 'NPFAM', 'H6584', 'ESREAL']
X = subset[usadas]

Podemos comprobar las dimensiones (variables e instancias) del subconjunto seleccionado.

In [10]:
X.shape
Out[10]:
(27207, 5)

Para sacar el dendrograma en el jerárquico, no podemos tener muchos elementos. Hacemos un muestreo aleatorio para quedarnos solo con 1000, aunque lo ideal es elegir un caso de estudio que ya dé un tamaño pequeño.

In [11]:
X = X.sample(1000, random_state=123456)

En clustering hay que normalizar para las métricas de distancia. Normalizamos el dataframe aplicando la función 'norm_to_zero_one'.

  • 'apply' aplica una función a lo largo de un eje concreto de un dataframe. Por defecto el parámetro 'axis' tiene el valor cero lo que significa que la función se aplica a cada columna del dataframe.
In [12]:
X_normal = X.apply(norm_to_zero_one)

Podemos comprobar las dimensiones del subconjunto seleccionado con el que vamos a trabajar.

In [13]:
X_normal.shape
Out[13]:
(1000, 5)
In [14]:
list(X_normal)
Out[14]:
['EDAD', 'ANORES', 'NPFAM', 'H6584', 'ESREAL']

Clustering Jerárquico

El clustering jerárquico engloba una familia de algoritmos que construyen clusters anidadas fusionándolos o dividiéndolos sucesivamente. Esta jerarquía de clusters se representa como un árbol (o dendrograma). La raíz del árbol es el cluster único que recoge todas las muestras, siendo las hojas del árbol los cluster que incluyen un solo dato o muestra.

El objeto 'AgglomerativeClustering' realiza un clustering jerárquico usando un enfoque 'de abajo a arriba': cada dato/muestra comienza en su propio cluster, y los clusters van fusionándose sucesivamente. En cada paso fusiona los dos clusters más cercanos (la definición de cercanos va a depender de la métrica elegida). Las opciones son:

  • Ward: fusionar el par de cluster que genera un agrupamiento con mínima varianza (media de la distancia cuadrática de cada elemento al centroide)
  • Maximum or complete linkage: minimiza la distancia máxima entre elementos de dos clusters.
  • Average linkage: minimiza la distancia media entre elementos de dos clusters.
  • Single linkage: minimiza la distancia mínima entre elementos de dos clusters.

Vamos a utilizar 'AgglomerativeClustering' con 'Ward' como criterio de enlace y eligiendo quedarnos con 100 clusters (100 ramificaciones del dendograma).

In [15]:
ward = cluster.AgglomerativeClustering(n_clusters=100, linkage='ward') # n_clusters: nº de clusters a encontrar
name, algorithm = ('Ward', ward)

Aplicamos el algoritmo de clustering jerárquico sobre nuestros datos (registrando el tiempo de ejecución).

In [16]:
cluster_predict = {}
k = {}

print(name,end='')
t = time.time()
cluster_predict[name] = algorithm.fit_predict(X_normal) 
tiempo = time.time() - t
k[name] = len(set(cluster_predict[name]))
print(": k: {:3.0f}, ".format(k[name]),end='')
print("{:6.2f} segundos".format(tiempo))
Ward: k: 100,   0.04 segundos
In [17]:
cluster_predict['Ward']
Out[17]:
array([33, 61, 43, 12, 57, 37, 95,  7, 36, 30, 61, 27,  7, 36, 61, 20,  1,
       95, 82, 95, 12, 40, 61, 37, 40, 35, 98, 35, 33, 37, 20, 16, 10,  7,
       24,  2, 33,  7, 32, 95, 37,  2, 37,  1,  0, 85, 85, 37, 28, 37, 71,
       37, 87, 60, 44, 17, 18, 19,  7, 73, 36, 41, 33, 48,  0, 36, 43, 14,
       75, 28, 33, 16, 30, 65, 21,  1, 36,  8, 83, 23, 33, 30, 42, 98, 55,
       80, 28,  3, 33, 94, 75, 36,  7,  7, 92, 70, 45, 39, 16, 75,  1, 40,
       52, 21, 10, 17, 40, 23, 36, 12, 37, 12, 37, 28, 21,  2, 40, 40, 36,
       13,  7, 33, 95,  7, 33, 59, 28, 64, 28, 40, 51, 56, 12, 33, 37, 28,
       85, 61, 33, 30, 17, 28, 92, 33, 33, 43,  0, 37, 28, 52,  1, 95, 85,
       18, 24, 37, 23,  1,  0, 70, 10, 95, 43, 20, 92, 61, 15,  7, 49, 26,
        4, 61, 44, 24, 12, 61, 46, 43, 92, 28, 91,  6, 59, 64, 10, 62, 33,
       95, 33, 30,  7, 10, 17, 10, 11, 16,  1, 96, 10,  7, 35, 61,  7, 61,
       90, 21, 90, 95, 36, 36, 33,  1,  7, 42, 28, 11,  7, 22, 37, 37, 95,
       37, 17, 96, 87, 92, 68, 35, 40, 96, 36, 28, 33,  1, 61, 42, 39, 65,
        4, 37, 61, 37, 75, 34, 73, 28, 40, 61, 64, 22, 95, 55, 33, 37, 23,
        9, 61, 36, 58,  1, 18, 36, 12, 99,  7,  9,  7, 28, 46, 81, 35, 52,
        1, 30, 17, 87, 22, 24, 36, 30, 17, 24, 33,  1, 60, 39, 10, 22, 10,
       75, 37, 35, 37, 17, 44, 95, 92, 33, 28, 41,  7, 37, 62, 22,  9, 36,
        9, 18, 57, 64, 28, 61, 28,  1, 92,  8, 88,  1, 35, 75, 79, 95, 62,
       28, 67, 90,  1, 18, 19,  9, 24, 33, 37, 28, 61, 18, 19, 37,  9,  8,
       63, 86, 27, 40, 75, 63, 61, 17, 37, 37, 33, 70,  7, 96,  7, 95, 47,
        1,  9, 51, 82, 41,  0, 61, 39, 61, 33,  0,  4, 70, 77, 21, 91, 24,
       10, 53, 37, 27, 29, 27, 28, 37,  8, 28, 69, 35, 28, 87, 73, 37, 98,
        2, 37,  7, 15, 28,  9, 61, 43, 10,  7, 28, 93, 33, 17, 57, 28, 12,
       36, 35,  1, 33, 40, 51, 21, 24, 28, 52, 94, 40, 36, 50, 41, 94,  9,
       36, 46, 36, 80, 30,  2, 28, 22, 89, 36, 15,  9, 43, 17, 19, 24, 87,
       87, 37,  6, 17, 52, 33, 73,  5, 37, 75, 85, 12, 33, 99,  7,  1, 84,
       43,  4, 24, 42, 35, 96, 87, 36, 18, 43, 71, 37, 45, 37, 95, 90, 40,
       43, 90, 32, 10, 37, 58, 23, 65, 24, 33, 61, 46,  1, 18, 28, 33, 52,
        9, 94, 28, 13, 99, 85, 31, 98, 15, 60, 28, 12, 32,  7, 10, 40, 94,
        9, 33, 95,  1,  7,  0,  1, 28, 92, 20, 17, 37, 92, 32, 54, 36,  8,
       47, 32,  5, 92, 31,  5, 27,  3, 18, 96,  7, 85,  1,  7, 87, 82, 61,
        1, 17, 24, 10, 40, 85, 49, 87, 61, 95, 92, 18,  2, 28, 37, 35, 11,
       65, 90, 79, 92, 58, 30,  7, 95, 58,  9,  0, 10, 37, 26,  1, 24, 33,
       61, 43, 43, 61,  7, 37, 50, 31, 92, 18, 11, 28, 68, 75, 79,  1, 29,
        9, 61, 97, 33, 17,  0,  9, 61, 18, 36, 38, 87,  8, 33, 11, 12, 41,
       37,  7,  1, 12, 33,  4, 61, 33, 53, 48, 43, 28, 49, 10,  7, 75, 94,
        7, 11, 16, 45, 37, 32,  9, 87, 61, 37, 95,  7, 28, 73, 92, 28, 10,
       19,  9, 37, 61, 11,  0, 85, 18, 47, 85,  1,  7, 38, 61, 91, 11,  6,
       51, 46, 18, 28, 85, 28, 19, 61, 33, 33, 19, 10, 18,  7, 57, 38, 38,
       85, 71, 64, 95, 61, 79, 36, 69, 92, 43, 96, 37, 37, 52, 14, 58, 18,
        4, 12, 40, 78, 37,  3, 79, 85, 17,  7, 28, 51, 18, 29,  7, 24, 95,
       81, 40, 33, 91, 10, 80, 24, 33, 19, 73,  7,  1, 18, 37, 36, 25, 84,
       33,  7, 37, 17, 11, 85, 88, 24, 43, 31, 24, 64, 92, 33, 40, 61,  1,
       85, 41, 74, 88, 37,  2, 28, 36, 23, 33, 69, 37, 17, 56,  1,  1, 22,
       43,  1, 85, 95, 10, 80, 12,  7,  2, 18, 18, 90, 35, 32, 79, 36, 39,
       66, 29, 36,  9, 17, 37, 40, 10, 78, 65, 11, 52, 76, 33, 85, 60, 76,
        1, 52,  8, 95,  7, 37, 37, 34, 22, 34, 36, 38, 28, 10, 37, 46, 61,
       37, 65, 99,  7,  1, 95, 25, 61, 37, 64, 61, 12, 41, 58, 61, 26, 33,
       33,  5,  0, 24, 24, 12, 37, 33,  9, 37, 61, 74, 18, 45, 83, 73, 61,
       61, 85, 35, 44, 61, 91, 22, 22, 43, 52,  7,  5, 61, 38,  9, 73, 28,
       47, 27,  7,  7, 95, 10, 21,  7, 79, 95, 57,  2, 61, 37, 75, 33, 79,
       33, 33,  8, 65, 99, 28, 58, 82, 44,  7, 28, 38, 61, 22, 33, 80, 36,
       15,  7, 15, 12,  7, 93, 74, 61, 80, 24, 37, 65, 12, 61, 92, 95, 37,
       73, 18,  8,  1, 37, 13, 48,  1, 24,  7, 37, 10, 84, 10, 47, 86,  7,
       64,  1, 66, 61, 17, 72, 43, 11, 10, 85, 43, 22, 61, 87,  3, 85, 48,
       90, 22, 61, 33, 81, 33,  7, 92, 61, 10,  1, 12, 64,  7, 79, 75, 90,
        1, 13, 40, 10, 37, 37, 65, 28, 37,  1, 90, 16, 87, 99,  7, 58, 79,
       10, 45,  7, 43, 24, 15, 40, 87, 26, 43, 95,  1, 28, 58],
      dtype=int64)
In [18]:
type(cluster_predict)
Out[18]:
dict
In [19]:
type(cluster_predict['Ward'])
Out[19]:
numpy.ndarray
In [20]:
cluster_predict['Ward'].shape
Out[20]:
(1000,)

Convertimos la asignación de clusters a un DataFrame con una única columna 'cluster'.

In [21]:
clusters = pd.DataFrame(cluster_predict['Ward'],index=X.index,columns=['cluster'])

Añadimos la asignación de clusters a las variables de entrada que habíamos seleccionado para el clustering.

In [22]:
X_cluster = pd.concat([X, clusters], axis=1)
In [23]:
print(list(X_cluster))
print(X_cluster.shape)
['EDAD', 'ANORES', 'NPFAM', 'H6584', 'ESREAL', 'cluster']
(1000, 6)

Filtramos outliers quitando aquellos elementos que el algoritmo ha agrupado en clusters muy pequeños:

  • usamos 'groupby' para generar grupos de instancias según el valor de la columna 'cluster', es decir, según el cluster asignado
  • 'transform' devuelve un objeto que está indexado igual (mismo tamaño) que el objeto que está siendo agrupado (en este caso 'X_cluster')

https://pandas.pydata.org/pandas-docs/stable/groupby.html

In [24]:
min_size = 3
# Agrupamos las instancias de 'X_cluster' por el cluster asignado 
X_filtrado = X_cluster[X_cluster.groupby('cluster').cluster.transform(len) > min_size]

Para entender un poco mejor lo que está pasando en esa última línea de código, podemos ver los grupos de instancias generados por 'groupby'.

In [25]:
print(X_cluster.groupby('cluster').groups)
{0: Int64Index([14371, 51412, 78530, 30988, 6415, 28072, 68722, 81878, 53883,
            75314, 16863],
           dtype='int64'), 1: Int64Index([23582, 38202, 24612, 69235, 73230, 51732, 17212, 13518,  6087,
            32339, 79830, 51447, 79590, 52288,  6574, 76175,  6532, 79431,
            75775, 44872, 17180, 58702, 17628, 31644, 77018, 72129, 24292,
            51720, 59230, 71775, 17021, 17324, 44843, 65705, 17168, 59716,
             9896, 31959, 52503, 35004, 55655],
           dtype='int64'), 2: Int64Index([68937, 5648, 9145, 27598, 2459, 65033, 82093, 48096, 9132], dtype='int64'), 3: Int64Index([54269, 12430, 81675, 74907], dtype='int64'), 4: Int64Index([27192, 33874, 333, 30172, 15180, 7905], dtype='int64'), 5: Int64Index([962, 36467, 20400, 12544, 4952], dtype='int64'), 6: Int64Index([9601, 75041, 79817], dtype='int64'), 7: Int64Index([69729, 67531,  3987, 28672, 39396, 67687, 56101, 80652, 49671,
            49026,  3613,  4410, 32674, 76565, 67236, 24890, 67271, 49603,
            35510, 21687, 56223, 49667, 35540, 39385, 11464, 79947, 56287,
            33547, 67073, 56205, 80944, 55869, 59516, 53776, 56271, 56130,
            60582, 67778, 35583,  3523, 19395, 77105, 47179, 70026, 42161,
            40126,  4260, 76644, 26114, 59111, 53374, 46501, 28739, 49147,
            11574],
           dtype='int64'), 8: Int64Index([79725, 68441, 14792, 74982, 66355, 61370, 62399, 6209, 63716], dtype='int64'), 9: Int64Index([58332, 16868, 11202,  9556,  2107, 73379, 55359,  3196, 17808,
            62942, 17235,  3497, 16586,  9900, 51828, 80154, 59449, 40775,
            28068, 78971],
           dtype='int64'), 10: Int64Index([76943, 76304, 45527, 31596, 11029, 80026, 76305, 46209, 79239,
            18589, 46281, 38069, 63622, 28853, 56241, 39221, 79495, 11001,
            38948, 82332, 73744, 65992, 18105, 76168,  3107, 11119, 67164,
            42850, 59146],
           dtype='int64'), 11: Int64Index([6119, 32636, 24112, 31535, 44557, 11897, 78730, 21040, 69161,
            35082, 42283],
           dtype='int64'), 12: Int64Index([58323, 68922, 71463, 24204, 34962, 30423, 44048,  9354, 44282,
            68926, 68757, 58008, 56384, 38395, 71982, 16787, 60839, 47323,
             8776],
           dtype='int64'), 13: Int64Index([81225, 51217, 65542, 51218], dtype='int64'), 14: Int64Index([4873, 1422], dtype='int64'), 15: Int64Index([61859, 25424, 42515, 72976, 79809, 53034, 54744], dtype='int64'), 16: Int64Index([63486, 22948, 26751, 23133, 82781, 22677], dtype='int64'), 17: Int64Index([72374,  6118, 64857, 70611, 17328, 62520, 55679,  5802, 48581,
            78849, 22107,  6817, 63614, 41819, 28188, 23478, 31424, 55105,
            44285, 82063],
           dtype='int64'), 18: Int64Index([37359, 31375, 25192, 20726, 72143, 16506, 74551, 72770,  2889,
             1030, 10143, 17752, 75598, 48137, 58923, 16289,  6113, 62753,
            34527, 31601,  2550, 66084],
           dtype='int64'), 19: Int64Index([32399, 55453, 23977, 78214, 69019, 65186, 44590, 12053], dtype='int64'), 20: Int64Index([7794, 68599, 68643, 8336], dtype='int64'), 21: Int64Index([12774, 21881, 19753, 83179, 54378, 33951, 7103], dtype='int64'), 22: Int64Index([35522, 63137, 42624, 14398, 28077, 55441, 14107, 14035, 76887,
            55515, 21400, 76847, 75997],
           dtype='int64'), 23: Int64Index([20137, 8318, 15183, 9737, 493, 32216], dtype='int64'), 24: Int64Index([ 9358,  3274, 48337, 51704, 38942, 48908, 16667, 20967, 46106,
            62724, 10929, 25396, 79407, 69280, 16634, 45182, 34996,  2380,
            52460, 23903,  9453, 83443],
           dtype='int64'), 25: Int64Index([51279, 12122], dtype='int64'), 26: Int64Index([12095, 54146, 71426, 74644], dtype='int64'), 27: Int64Index([64853, 24239, 41731, 80736, 20804, 27684], dtype='int64'), 28: Int64Index([66969,  3175, 38505, 55553, 19029, 38822, 56254, 39028, 80040,
            80232, 45269, 17313, 30557, 32479, 32743, 20689, 66737, 25378,
            73015, 82734, 24977, 25415, 76136, 33448, 25842, 27806, 72074,
            10328,  6234, 29116, 42575, 24624, 33443, 24762,  3368, 21149,
            45540, 72819, 46930, 45205, 39027, 82156,  6839, 17544, 31045,
            17818],
           dtype='int64'), 29: Int64Index([57281, 56069, 9604, 14961], dtype='int64'), 30: Int64Index([51796, 58488, 64943, 45651, 32939, 46343, 58839, 24730, 22175], dtype='int64'), 31: Int64Index([53849, 17734, 61263, 68451], dtype='int64'), 32: Int64Index([54529, 64377, 47604, 75183, 61739, 50796, 81787], dtype='int64'), 33: Int64Index([59950, 21488, 82738, 28442, 79911,  6461, 59433, 59771, 10757,
             6330, 24679, 39586, 10057, 24418, 52016, 59612, 24097, 42354,
            10376, 25363, 65489, 79085, 82225, 45842, 79063, 66902, 13417,
            69633,  3835, 66781, 52631, 10729, 17560, 65493, 49489, 49091,
            10488, 18355, 10717, 45669, 38765, 76124, 28209, 24357,  2984,
            13610, 60057, 28704, 45672, 53194,  2909, 45014],
           dtype='int64'), 34: Int64Index([289, 17651, 34894], dtype='int64'), 35: Int64Index([25165, 51622, 58047, 55165,  9398, 79088, 31525, 65389, 44664,
            27966, 54010, 62043, 24261],
           dtype='int64'), 36: Int64Index([80558, 22293, 27363, 18194, 17300, 25542,  1750, 11601, 32832,
            39792, 39825, 70365, 65710, 54048, 16077, 18343, 39650, 62943,
            39996, 53835,  4851, 49736, 40298, 10668, 70555, 18984,  2860,
            66101, 32258, 42737],
           dtype='int64'), 37: Int64Index([67528, 43090, 60711, 67159, 25957, 18764, 33058,  4078, 17932,
            67294,  4735, 10298, 21699, 83108, 10684,  6982, 42885, 60396,
            18147, 70499, 46257, 62998, 66699, 82987, 25655, 67834, 67033,
            60221, 46946, 83096, 47107, 83061, 46849, 26084, 59686, 35739,
            80080, 46878, 74023, 21439, 35545, 39336, 17978, 19158, 67032,
            14495, 80332, 53364, 53576, 53392, 18608, 70095, 55932, 70200,
            53356, 70147, 46256, 56195, 26062, 63370, 19537, 18391, 33033,
            11170, 60438, 60274],
           dtype='int64'), 38: Int64Index([24464, 10047, 85, 22189, 39117, 2012, 62412], dtype='int64'), 39: Int64Index([47834, 19975, 47446, 5339, 34263], dtype='int64'), 40: Int64Index([35135, 66369,  5705, 26525, 58868, 64596, 51617, 32065, 12084,
            69090, 16085, 65314, 64612, 58434, 44385,  5753, 51833, 72904,
            77158, 42336, 33680],
           dtype='int64'), 41: Int64Index([56609, 38141, 73536, 62176, 62531, 66123, 11143], dtype='int64'), 42: Int64Index([33922, 5404, 52561, 12225], dtype='int64'), 43: Int64Index([35390, 52334, 76568, 69340,  4730, 35459, 56404, 41821, 22299,
            28537, 63221, 39375, 53046, 59659, 46697, 35150,  7656, 42381,
            13623, 78822, 76530],
           dtype='int64'), 44: Int64Index([15921, 57596, 81930, 81093, 71627], dtype='int64'), 45: Int64Index([6210, 50797, 73234, 63345, 27575], dtype='int64'), 46: Int64Index([49389, 59400, 35604, 53259, 3880, 53701], dtype='int64'), 47: Int64Index([22251, 74534, 42542, 55921, 80748], dtype='int64'), 48: Int64Index([30735, 42759, 13041, 72383], dtype='int64'), 49: Int64Index([61605, 55985, 82967], dtype='int64'), 50: Int64Index([75962, 3074], dtype='int64'), 51: Int64Index([42767, 59009, 12276, 77760, 4996], dtype='int64'), 52: Int64Index([37671, 73344, 82239, 82036, 78873, 78696, 17259, 1835, 41904,
            1181],
           dtype='int64'), 53: Int64Index([81155, 75709], dtype='int64'), 54: Int64Index([56173], dtype='int64'), 55: Int64Index([40671, 19423], dtype='int64'), 56: Int64Index([7812, 33921], dtype='int64'), 57: Int64Index([18734, 44978, 76555, 74487, 75910], dtype='int64'), 58: Int64Index([53881, 77048, 69526, 6829, 34949, 21565, 28169, 82204, 25325], dtype='int64'), 59: Int64Index([14841, 27773], dtype='int64'), 60: Int64Index([27321, 56955, 57405, 41037], dtype='int64'), 61: Int64Index([32492, 60359, 18962, 11216, 40017, 39914, 46384, 73693, 69852,
            11227, 14162, 39817, 46455,  7207, 13827, 80563, 69695, 79839,
            32689, 18869,  2882, 35785, 73372, 39726, 17795, 39708, 53333,
            53319, 80414, 80555, 33151, 25799, 45920, 53113, 62870, 32006,
            62737,  3187, 35213, 73856,  4294, 46673, 46250, 49283, 25088,
            33156, 70421, 53177, 83164, 61055, 49064],
           dtype='int64'), 62: Int64Index([81313, 81172, 1498], dtype='int64'), 63: Int64Index([76815, 49362], dtype='int64'), 64: Int64Index([58157, 44153, 30317, 49195, 48399, 34549, 8972, 2608, 23461], dtype='int64'), 65: Int64Index([40217, 59739, 74210, 48189, 49870, 46810, 32923, 47252, 66089], dtype='int64'), 66: Int64Index([29459, 64367], dtype='int64'), 67: Int64Index([47534], dtype='int64'), 68: Int64Index([6191, 15630], dtype='int64'), 69: Int64Index([54439, 698, 26918], dtype='int64'), 70: Int64Index([51659, 48317, 79428, 82654], dtype='int64'), 71: Int64Index([81345, 12191, 74791], dtype='int64'), 72: Int64Index([43751], dtype='int64'), 73: Int64Index([69220, 9053, 79055, 28213, 78891, 51261, 24391, 1196, 75485], dtype='int64'), 74: Int64Index([47595, 5213, 19645], dtype='int64'), 75: Int64Index([32222, 66641, 17609, 69310, 76208, 24703, 54021, 31516, 46320,
            69118,  5692, 28646],
           dtype='int64'), 76: Int64Index([12579, 12113], dtype='int64'), 77: Int64Index([40676], dtype='int64'), 78: Int64Index([12338, 27144], dtype='int64'), 79: Int64Index([37934, 23557, 28101, 37981, 2939, 2271, 2018, 48736, 2536, 72392], dtype='int64'), 80: Int64Index([44384, 43074, 29262, 69229, 13431, 16672], dtype='int64'), 81: Int64Index([12208, 12771, 47723], dtype='int64'), 82: Int64Index([78321, 65249, 4812, 82042], dtype='int64'), 83: Int64Index([47471, 75562], dtype='int64'), 84: Int64Index([20899, 57310, 54689], dtype='int64'), 85: Int64Index([42071, 51844, 23986, 75673, 45638, 62253, 65886, 31565, 34880,
             2279, 75848, 48600, 82558, 24905, 51871, 78815, 17657,  6219,
            71843,  2693],
           dtype='int64'), 86: Int64Index([27412, 48697], dtype='int64'), 87: Int64Index([21047, 63646, 48801, 82269, 78112, 55785, 42416, 42297, 49032,
            25749, 59568, 20872, 28328, 40139],
           dtype='int64'), 88: Int64Index([52317, 74619, 81129], dtype='int64'), 89: Int64Index([40188], dtype='int64'), 90: Int64Index([82312, 28250, 24353, 15876, 78726, 64795, 44738, 65554, 58216,
            37720],
           dtype='int64'), 91: Int64Index([50935, 5342, 29725, 56786, 75994], dtype='int64'), 92: Int64Index([ 6565, 28181, 52690,  5921, 73348, 66220, 73003, 69546, 78262,
            23764, 44917, 16542, 73325, 78594, 21451, 58278,  6208, 44545],
           dtype='int64'), 93: Int64Index([78258, 37216], dtype='int64'), 94: Int64Index([45343, 36162, 16812, 47243, 63560, 46238], dtype='int64'), 95: Int64Index([28466, 52461, 13977, 18092, 25242, 46915, 25482, 59853, 32727,
            82719, 32487, 73141, 30786, 80115, 66876, 42052, 27647, 18659,
            48958, 66418, 26601, 49003, 25113, 17992, 66852, 80224, 14161,
             7173],
           dtype='int64'), 96: Int64Index([2924, 38067, 38068, 72622, 73389, 23933, 65091], dtype='int64'), 97: Int64Index([47469], dtype='int64'), 98: Int64Index([71607, 81981, 65513, 68810], dtype='int64'), 99: Int64Index([8201, 41086, 50497, 478, 318, 54482], dtype='int64')}

Lo que queremos es saber la longitud (el número de instancias) de cada uno de los grupos. Si llamamos a la función 'len' con la salida de groupby como argumento ...

In [26]:
print(len(X_cluster.groupby('cluster')))
100

Usando 'transform' 'len' nos devuelve para cada instancia de 'X_cluster' la longitud (nº de instancias) del grupo en el que se ha incluido a esa instancia.

In [27]:
print(X_cluster.groupby('cluster').cluster.transform(len))
59950    52
32492    51
35390    21
58323    19
18734     5
67528    66
28466    28
69729    55
80558    30
51796     9
60359    51
64853     6
67531    55
22293    30
18962    51
7794      4
23582    41
52461    28
78321     4
13977    28
68922    19
35135    21
11216    51
43090    66
66369    21
25165    13
71607     4
51622    13
21488    52
60711    66
         ..
51218     4
42336    21
42850    29
11170    66
60438    66
66089     9
31045    46
60274    66
35004    41
37720    10
22677     6
28328    14
54482     6
49147    55
82204     9
72392    10
59146    29
27575     5
11574    55
78822    21
83443    22
54744     7
33680    21
40139    14
74644     4
76530    21
7173     28
55655    41
17818    46
25325     9
Name: cluster, Length: 1000, dtype: int64

Comprobamos cuántos clusters quedan después de haber eliminado aquellos que no llegaban al tamaño mínimo.

In [28]:
k_filtrado = len(set(X_filtrado['cluster']))
print("De los {:.0f} clusters hay {:.0f} con más de {:.0f} elementos. Del total de {:.0f} elementos, se seleccionan {:.0f}".format(k['Ward'],k_filtrado,min_size,len(X),len(X_filtrado)))
De los 100 clusters hay 69 con más de 3 elementos. Del total de 1000 elementos, se seleccionan 934

Eliminamos la columna con la asignación de cluster.

  • drop permite eliminar una fila o columna de un DataFrame
In [29]:
X_filtrado = X_filtrado.drop('cluster', axis=1) # axis=1 para indicar que lo que queremos eliminar es una columna y no una fila
print(list(X_filtrado))
['EDAD', 'ANORES', 'NPFAM', 'H6584', 'ESREAL']

Normalizamos el conjunto filtrado. Volvemos a normalizar porque al eliminar outliers puede que los valores mínimos/máximos de algunas variables hayan cambiado.

In [30]:
X_filtrado_normal = X_filtrado.apply(norm_to_zero_one)

Obtenemos el dendograma usando scipy (que realmente va a volver a ejecutar el clustering jerárquico).

In [31]:
linkage_array = hierarchy.ward(X_filtrado_normal)
plt.figure(1)
plt.clf()
h_dict = hierarchy.dendrogram(linkage_array,orientation='left') #lo ponemos en horizontal para compararlo con el generado por seaborn

'h_dict' es un diccionario con información para representar el dendograma.

Generamos el dendograma usando seaborn (que a su vez usa scipy) para incluir un heatmap.

In [32]:
#Ahora lo saco usando seaborn (que a su vez usa scipy) para incluir un heatmap
sns.clustermap(X_filtrado_normal, method='ward', col_cluster=False, figsize=(20,10), cmap="YlGnBu", yticklabels=False)
Out[32]:
<seaborn.matrix.ClusterGrid at 0x1a2de692a20>