Vamos a ver un ejemplo de cómo podemos aplicar clustering jerárquico y generar un dendograma.
Importamos los módulos que vamos a necesitar:
%matplotlib inline
import time
import pandas as pd
import numpy as np
from sklearn import cluster
from sklearn import preprocessing
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster import hierarchy
Creamos la función norm_to_zero_one para aplicar normalización min-max.
def norm_to_zero_one(df):
return (df - df.min()) * 1.0 / (df.max() - df.min())
Trabajamos sobre el conjunto de datos facilitado para la segunda práctica de la asignatura que contiene información recogida por el INE.
censo = pd.read_csv('censo_granada.csv')
Los valores en blanco en realidad son otra categoría que vamos a nombrar con el valor 0.
censo = censo.replace(np.NaN,0)
censo.shape
(83499, 142)
censo.columns.values
array(['CPRO', 'CMUN', 'IDHUECO', 'NORDEN', 'FACTOR', 'MNAC', 'ANAC', 'EDAD', 'SEXO', 'NACI', 'CPAISN', 'CPRON', 'CMUNN', 'ANORES', 'ANOM', 'ANOC', 'ANOE', 'CLPAIS', 'CLPRO', 'CLMUNP', 'RES_ANTERIOR', 'CPAISUNANO', 'CPROUNANO', 'CMUNANO', 'RES_UNANO', 'CPAISDANO', 'CPRODANO', 'CMUNDANO', 'RES_DANO', 'SEG_VIV', 'SEG_PAIS', 'SEG_PROV', 'SEG_MUN', 'SEG_NOCHES', 'SEG_DISP', 'ECIVIL', 'ESCOLAR', 'ESREAL', 'TESTUD', 'TAREA1', 'TAREA2', 'TAREA3', 'TAREA4', 'HIJOS', 'NHIJOS', 'RELA', 'JORNADA', 'CNO', 'CNAE', 'SITU', 'CSE', 'ESCUR1', 'ESCUR2', 'ESCUR3', 'LTRABA', 'PAISTRABA', 'PROTRABA', 'MUNTRABA', 'CODTRABA', 'NVIAJE', 'MDESP1', 'MDESP2', 'TDESP', 'TENEN', 'CALE', 'ASEO', 'BADUCH', 'INTERNET', 'AGUACOR', 'SUT', 'NHAB', 'PLANTAS', 'PLANTAB', 'TIPOEDIF', 'ANOCONS', 'ESTADO', 'ASCENSOR', 'ACCESIB', 'GARAJE', 'PLAZAS', 'GAS', 'TELEF', 'ACAL', 'RESID', 'FAMILIA', 'PAD_NORDEN', 'MAD_NORDEN', 'CON_NORDEN', 'OPA_NORDEN', 'TIPOPER', 'NUCLEO', 'NMIEM', 'NFAM', 'NNUC', 'NGENER', 'ESTHOG', 'TIPOHOG', 'NOCU', 'NPARAIN', 'NPFAM', 'NPNUC', 'HM5', 'H0515', 'H1624', 'H2534', 'H3564', 'H6584', 'H85M', 'HESPHOG', 'MESPHOG', 'HEXTHOG', 'MEXTHOG', 'COMBINAC', 'EDADPAD', 'PAISNACPAD', 'NACIPAD', 'ECIVPAD', 'ESTUPAD', 'SITUPAD', 'SITPPAD', 'EDADMAD', 'PAISNACMAD', 'NACIMAD', 'ECIVMAD', 'ESTUMAD', 'SITUMAD', 'SITPMAD', 'EDADCON', 'NACCON', 'NACICON', 'ECIVCON', 'ESTUCON', 'SITUCON', 'SITPCON', 'TIPONUC', 'TAMNUC', 'NHIJO', 'NHIJOC', 'FAMNUM', 'TIPOPARECIV', 'TIPOPARSEX', 'DIFEDAD'], dtype=object)
Seleccionamos el caso de estudio, en este ejemplo concreto, aquellos casos para los que el campo 'EDADMAD' (grupo quinquenal de la edad de la madre) no estaba vacío.
subset = censo.loc[censo['EDADMAD']>0]
Seleccionamos variables de interés para clustering.
usadas = ['EDAD', 'ANORES', 'NPFAM', 'H6584', 'ESREAL']
X = subset[usadas]
Podemos comprobar las dimensiones (variables e instancias) del subconjunto seleccionado.
X.shape
(27207, 5)
Para sacar el dendrograma en el jerárquico, no podemos tener muchos elementos. Hacemos un muestreo aleatorio para quedarnos solo con 1000, aunque lo ideal es elegir un caso de estudio que ya dé un tamaño pequeño.
X = X.sample(1000, random_state=123456)
En clustering hay que normalizar para las métricas de distancia. Normalizamos el dataframe aplicando la función 'norm_to_zero_one'.
X_normal = X.apply(norm_to_zero_one)
Podemos comprobar las dimensiones del subconjunto seleccionado con el que vamos a trabajar.
X_normal.shape
(1000, 5)
list(X_normal)
['EDAD', 'ANORES', 'NPFAM', 'H6584', 'ESREAL']
El clustering jerárquico engloba una familia de algoritmos que construyen clusters anidadas fusionándolos o dividiéndolos sucesivamente. Esta jerarquía de clusters se representa como un árbol (o dendrograma). La raíz del árbol es el cluster único que recoge todas las muestras, siendo las hojas del árbol los cluster que incluyen un solo dato o muestra.
El objeto 'AgglomerativeClustering' realiza un clustering jerárquico usando un enfoque 'de abajo a arriba': cada dato/muestra comienza en su propio cluster, y los clusters van fusionándose sucesivamente. En cada paso fusiona los dos clusters más cercanos (la definición de cercanos va a depender de la métrica elegida). Las opciones son:
Vamos a utilizar 'AgglomerativeClustering' con 'Ward' como criterio de enlace y eligiendo quedarnos con 100 clusters (100 ramificaciones del dendograma).
ward = cluster.AgglomerativeClustering(n_clusters=100, linkage='ward') # n_clusters: nº de clusters a encontrar
name, algorithm = ('Ward', ward)
Aplicamos el algoritmo de clustering jerárquico sobre nuestros datos (registrando el tiempo de ejecución).
cluster_predict = {}
k = {}
print(name,end='')
t = time.time()
cluster_predict[name] = algorithm.fit_predict(X_normal)
tiempo = time.time() - t
k[name] = len(set(cluster_predict[name]))
print(": k: {:3.0f}, ".format(k[name]),end='')
print("{:6.2f} segundos".format(tiempo))
Ward: k: 100, 0.04 segundos
cluster_predict['Ward']
array([33, 61, 43, 12, 57, 37, 95, 7, 36, 30, 61, 27, 7, 36, 61, 20, 1, 95, 82, 95, 12, 40, 61, 37, 40, 35, 98, 35, 33, 37, 20, 16, 10, 7, 24, 2, 33, 7, 32, 95, 37, 2, 37, 1, 0, 85, 85, 37, 28, 37, 71, 37, 87, 60, 44, 17, 18, 19, 7, 73, 36, 41, 33, 48, 0, 36, 43, 14, 75, 28, 33, 16, 30, 65, 21, 1, 36, 8, 83, 23, 33, 30, 42, 98, 55, 80, 28, 3, 33, 94, 75, 36, 7, 7, 92, 70, 45, 39, 16, 75, 1, 40, 52, 21, 10, 17, 40, 23, 36, 12, 37, 12, 37, 28, 21, 2, 40, 40, 36, 13, 7, 33, 95, 7, 33, 59, 28, 64, 28, 40, 51, 56, 12, 33, 37, 28, 85, 61, 33, 30, 17, 28, 92, 33, 33, 43, 0, 37, 28, 52, 1, 95, 85, 18, 24, 37, 23, 1, 0, 70, 10, 95, 43, 20, 92, 61, 15, 7, 49, 26, 4, 61, 44, 24, 12, 61, 46, 43, 92, 28, 91, 6, 59, 64, 10, 62, 33, 95, 33, 30, 7, 10, 17, 10, 11, 16, 1, 96, 10, 7, 35, 61, 7, 61, 90, 21, 90, 95, 36, 36, 33, 1, 7, 42, 28, 11, 7, 22, 37, 37, 95, 37, 17, 96, 87, 92, 68, 35, 40, 96, 36, 28, 33, 1, 61, 42, 39, 65, 4, 37, 61, 37, 75, 34, 73, 28, 40, 61, 64, 22, 95, 55, 33, 37, 23, 9, 61, 36, 58, 1, 18, 36, 12, 99, 7, 9, 7, 28, 46, 81, 35, 52, 1, 30, 17, 87, 22, 24, 36, 30, 17, 24, 33, 1, 60, 39, 10, 22, 10, 75, 37, 35, 37, 17, 44, 95, 92, 33, 28, 41, 7, 37, 62, 22, 9, 36, 9, 18, 57, 64, 28, 61, 28, 1, 92, 8, 88, 1, 35, 75, 79, 95, 62, 28, 67, 90, 1, 18, 19, 9, 24, 33, 37, 28, 61, 18, 19, 37, 9, 8, 63, 86, 27, 40, 75, 63, 61, 17, 37, 37, 33, 70, 7, 96, 7, 95, 47, 1, 9, 51, 82, 41, 0, 61, 39, 61, 33, 0, 4, 70, 77, 21, 91, 24, 10, 53, 37, 27, 29, 27, 28, 37, 8, 28, 69, 35, 28, 87, 73, 37, 98, 2, 37, 7, 15, 28, 9, 61, 43, 10, 7, 28, 93, 33, 17, 57, 28, 12, 36, 35, 1, 33, 40, 51, 21, 24, 28, 52, 94, 40, 36, 50, 41, 94, 9, 36, 46, 36, 80, 30, 2, 28, 22, 89, 36, 15, 9, 43, 17, 19, 24, 87, 87, 37, 6, 17, 52, 33, 73, 5, 37, 75, 85, 12, 33, 99, 7, 1, 84, 43, 4, 24, 42, 35, 96, 87, 36, 18, 43, 71, 37, 45, 37, 95, 90, 40, 43, 90, 32, 10, 37, 58, 23, 65, 24, 33, 61, 46, 1, 18, 28, 33, 52, 9, 94, 28, 13, 99, 85, 31, 98, 15, 60, 28, 12, 32, 7, 10, 40, 94, 9, 33, 95, 1, 7, 0, 1, 28, 92, 20, 17, 37, 92, 32, 54, 36, 8, 47, 32, 5, 92, 31, 5, 27, 3, 18, 96, 7, 85, 1, 7, 87, 82, 61, 1, 17, 24, 10, 40, 85, 49, 87, 61, 95, 92, 18, 2, 28, 37, 35, 11, 65, 90, 79, 92, 58, 30, 7, 95, 58, 9, 0, 10, 37, 26, 1, 24, 33, 61, 43, 43, 61, 7, 37, 50, 31, 92, 18, 11, 28, 68, 75, 79, 1, 29, 9, 61, 97, 33, 17, 0, 9, 61, 18, 36, 38, 87, 8, 33, 11, 12, 41, 37, 7, 1, 12, 33, 4, 61, 33, 53, 48, 43, 28, 49, 10, 7, 75, 94, 7, 11, 16, 45, 37, 32, 9, 87, 61, 37, 95, 7, 28, 73, 92, 28, 10, 19, 9, 37, 61, 11, 0, 85, 18, 47, 85, 1, 7, 38, 61, 91, 11, 6, 51, 46, 18, 28, 85, 28, 19, 61, 33, 33, 19, 10, 18, 7, 57, 38, 38, 85, 71, 64, 95, 61, 79, 36, 69, 92, 43, 96, 37, 37, 52, 14, 58, 18, 4, 12, 40, 78, 37, 3, 79, 85, 17, 7, 28, 51, 18, 29, 7, 24, 95, 81, 40, 33, 91, 10, 80, 24, 33, 19, 73, 7, 1, 18, 37, 36, 25, 84, 33, 7, 37, 17, 11, 85, 88, 24, 43, 31, 24, 64, 92, 33, 40, 61, 1, 85, 41, 74, 88, 37, 2, 28, 36, 23, 33, 69, 37, 17, 56, 1, 1, 22, 43, 1, 85, 95, 10, 80, 12, 7, 2, 18, 18, 90, 35, 32, 79, 36, 39, 66, 29, 36, 9, 17, 37, 40, 10, 78, 65, 11, 52, 76, 33, 85, 60, 76, 1, 52, 8, 95, 7, 37, 37, 34, 22, 34, 36, 38, 28, 10, 37, 46, 61, 37, 65, 99, 7, 1, 95, 25, 61, 37, 64, 61, 12, 41, 58, 61, 26, 33, 33, 5, 0, 24, 24, 12, 37, 33, 9, 37, 61, 74, 18, 45, 83, 73, 61, 61, 85, 35, 44, 61, 91, 22, 22, 43, 52, 7, 5, 61, 38, 9, 73, 28, 47, 27, 7, 7, 95, 10, 21, 7, 79, 95, 57, 2, 61, 37, 75, 33, 79, 33, 33, 8, 65, 99, 28, 58, 82, 44, 7, 28, 38, 61, 22, 33, 80, 36, 15, 7, 15, 12, 7, 93, 74, 61, 80, 24, 37, 65, 12, 61, 92, 95, 37, 73, 18, 8, 1, 37, 13, 48, 1, 24, 7, 37, 10, 84, 10, 47, 86, 7, 64, 1, 66, 61, 17, 72, 43, 11, 10, 85, 43, 22, 61, 87, 3, 85, 48, 90, 22, 61, 33, 81, 33, 7, 92, 61, 10, 1, 12, 64, 7, 79, 75, 90, 1, 13, 40, 10, 37, 37, 65, 28, 37, 1, 90, 16, 87, 99, 7, 58, 79, 10, 45, 7, 43, 24, 15, 40, 87, 26, 43, 95, 1, 28, 58], dtype=int64)
type(cluster_predict)
dict
type(cluster_predict['Ward'])
numpy.ndarray
cluster_predict['Ward'].shape
(1000,)
Convertimos la asignación de clusters a un DataFrame con una única columna 'cluster'.
clusters = pd.DataFrame(cluster_predict['Ward'],index=X.index,columns=['cluster'])
Añadimos la asignación de clusters a las variables de entrada que habíamos seleccionado para el clustering.
X_cluster = pd.concat([X, clusters], axis=1)
print(list(X_cluster))
print(X_cluster.shape)
['EDAD', 'ANORES', 'NPFAM', 'H6584', 'ESREAL', 'cluster'] (1000, 6)
Filtramos outliers quitando aquellos elementos que el algoritmo ha agrupado en clusters muy pequeños:
min_size = 3
# Agrupamos las instancias de 'X_cluster' por el cluster asignado
X_filtrado = X_cluster[X_cluster.groupby('cluster').cluster.transform(len) > min_size]
Para entender un poco mejor lo que está pasando en esa última línea de código, podemos ver los grupos de instancias generados por 'groupby'.
print(X_cluster.groupby('cluster').groups)
{0: Int64Index([14371, 51412, 78530, 30988, 6415, 28072, 68722, 81878, 53883, 75314, 16863], dtype='int64'), 1: Int64Index([23582, 38202, 24612, 69235, 73230, 51732, 17212, 13518, 6087, 32339, 79830, 51447, 79590, 52288, 6574, 76175, 6532, 79431, 75775, 44872, 17180, 58702, 17628, 31644, 77018, 72129, 24292, 51720, 59230, 71775, 17021, 17324, 44843, 65705, 17168, 59716, 9896, 31959, 52503, 35004, 55655], dtype='int64'), 2: Int64Index([68937, 5648, 9145, 27598, 2459, 65033, 82093, 48096, 9132], dtype='int64'), 3: Int64Index([54269, 12430, 81675, 74907], dtype='int64'), 4: Int64Index([27192, 33874, 333, 30172, 15180, 7905], dtype='int64'), 5: Int64Index([962, 36467, 20400, 12544, 4952], dtype='int64'), 6: Int64Index([9601, 75041, 79817], dtype='int64'), 7: Int64Index([69729, 67531, 3987, 28672, 39396, 67687, 56101, 80652, 49671, 49026, 3613, 4410, 32674, 76565, 67236, 24890, 67271, 49603, 35510, 21687, 56223, 49667, 35540, 39385, 11464, 79947, 56287, 33547, 67073, 56205, 80944, 55869, 59516, 53776, 56271, 56130, 60582, 67778, 35583, 3523, 19395, 77105, 47179, 70026, 42161, 40126, 4260, 76644, 26114, 59111, 53374, 46501, 28739, 49147, 11574], dtype='int64'), 8: Int64Index([79725, 68441, 14792, 74982, 66355, 61370, 62399, 6209, 63716], dtype='int64'), 9: Int64Index([58332, 16868, 11202, 9556, 2107, 73379, 55359, 3196, 17808, 62942, 17235, 3497, 16586, 9900, 51828, 80154, 59449, 40775, 28068, 78971], dtype='int64'), 10: Int64Index([76943, 76304, 45527, 31596, 11029, 80026, 76305, 46209, 79239, 18589, 46281, 38069, 63622, 28853, 56241, 39221, 79495, 11001, 38948, 82332, 73744, 65992, 18105, 76168, 3107, 11119, 67164, 42850, 59146], dtype='int64'), 11: Int64Index([6119, 32636, 24112, 31535, 44557, 11897, 78730, 21040, 69161, 35082, 42283], dtype='int64'), 12: Int64Index([58323, 68922, 71463, 24204, 34962, 30423, 44048, 9354, 44282, 68926, 68757, 58008, 56384, 38395, 71982, 16787, 60839, 47323, 8776], dtype='int64'), 13: Int64Index([81225, 51217, 65542, 51218], dtype='int64'), 14: Int64Index([4873, 1422], dtype='int64'), 15: Int64Index([61859, 25424, 42515, 72976, 79809, 53034, 54744], dtype='int64'), 16: Int64Index([63486, 22948, 26751, 23133, 82781, 22677], dtype='int64'), 17: Int64Index([72374, 6118, 64857, 70611, 17328, 62520, 55679, 5802, 48581, 78849, 22107, 6817, 63614, 41819, 28188, 23478, 31424, 55105, 44285, 82063], dtype='int64'), 18: Int64Index([37359, 31375, 25192, 20726, 72143, 16506, 74551, 72770, 2889, 1030, 10143, 17752, 75598, 48137, 58923, 16289, 6113, 62753, 34527, 31601, 2550, 66084], dtype='int64'), 19: Int64Index([32399, 55453, 23977, 78214, 69019, 65186, 44590, 12053], dtype='int64'), 20: Int64Index([7794, 68599, 68643, 8336], dtype='int64'), 21: Int64Index([12774, 21881, 19753, 83179, 54378, 33951, 7103], dtype='int64'), 22: Int64Index([35522, 63137, 42624, 14398, 28077, 55441, 14107, 14035, 76887, 55515, 21400, 76847, 75997], dtype='int64'), 23: Int64Index([20137, 8318, 15183, 9737, 493, 32216], dtype='int64'), 24: Int64Index([ 9358, 3274, 48337, 51704, 38942, 48908, 16667, 20967, 46106, 62724, 10929, 25396, 79407, 69280, 16634, 45182, 34996, 2380, 52460, 23903, 9453, 83443], dtype='int64'), 25: Int64Index([51279, 12122], dtype='int64'), 26: Int64Index([12095, 54146, 71426, 74644], dtype='int64'), 27: Int64Index([64853, 24239, 41731, 80736, 20804, 27684], dtype='int64'), 28: Int64Index([66969, 3175, 38505, 55553, 19029, 38822, 56254, 39028, 80040, 80232, 45269, 17313, 30557, 32479, 32743, 20689, 66737, 25378, 73015, 82734, 24977, 25415, 76136, 33448, 25842, 27806, 72074, 10328, 6234, 29116, 42575, 24624, 33443, 24762, 3368, 21149, 45540, 72819, 46930, 45205, 39027, 82156, 6839, 17544, 31045, 17818], dtype='int64'), 29: Int64Index([57281, 56069, 9604, 14961], dtype='int64'), 30: Int64Index([51796, 58488, 64943, 45651, 32939, 46343, 58839, 24730, 22175], dtype='int64'), 31: Int64Index([53849, 17734, 61263, 68451], dtype='int64'), 32: Int64Index([54529, 64377, 47604, 75183, 61739, 50796, 81787], dtype='int64'), 33: Int64Index([59950, 21488, 82738, 28442, 79911, 6461, 59433, 59771, 10757, 6330, 24679, 39586, 10057, 24418, 52016, 59612, 24097, 42354, 10376, 25363, 65489, 79085, 82225, 45842, 79063, 66902, 13417, 69633, 3835, 66781, 52631, 10729, 17560, 65493, 49489, 49091, 10488, 18355, 10717, 45669, 38765, 76124, 28209, 24357, 2984, 13610, 60057, 28704, 45672, 53194, 2909, 45014], dtype='int64'), 34: Int64Index([289, 17651, 34894], dtype='int64'), 35: Int64Index([25165, 51622, 58047, 55165, 9398, 79088, 31525, 65389, 44664, 27966, 54010, 62043, 24261], dtype='int64'), 36: Int64Index([80558, 22293, 27363, 18194, 17300, 25542, 1750, 11601, 32832, 39792, 39825, 70365, 65710, 54048, 16077, 18343, 39650, 62943, 39996, 53835, 4851, 49736, 40298, 10668, 70555, 18984, 2860, 66101, 32258, 42737], dtype='int64'), 37: Int64Index([67528, 43090, 60711, 67159, 25957, 18764, 33058, 4078, 17932, 67294, 4735, 10298, 21699, 83108, 10684, 6982, 42885, 60396, 18147, 70499, 46257, 62998, 66699, 82987, 25655, 67834, 67033, 60221, 46946, 83096, 47107, 83061, 46849, 26084, 59686, 35739, 80080, 46878, 74023, 21439, 35545, 39336, 17978, 19158, 67032, 14495, 80332, 53364, 53576, 53392, 18608, 70095, 55932, 70200, 53356, 70147, 46256, 56195, 26062, 63370, 19537, 18391, 33033, 11170, 60438, 60274], dtype='int64'), 38: Int64Index([24464, 10047, 85, 22189, 39117, 2012, 62412], dtype='int64'), 39: Int64Index([47834, 19975, 47446, 5339, 34263], dtype='int64'), 40: Int64Index([35135, 66369, 5705, 26525, 58868, 64596, 51617, 32065, 12084, 69090, 16085, 65314, 64612, 58434, 44385, 5753, 51833, 72904, 77158, 42336, 33680], dtype='int64'), 41: Int64Index([56609, 38141, 73536, 62176, 62531, 66123, 11143], dtype='int64'), 42: Int64Index([33922, 5404, 52561, 12225], dtype='int64'), 43: Int64Index([35390, 52334, 76568, 69340, 4730, 35459, 56404, 41821, 22299, 28537, 63221, 39375, 53046, 59659, 46697, 35150, 7656, 42381, 13623, 78822, 76530], dtype='int64'), 44: Int64Index([15921, 57596, 81930, 81093, 71627], dtype='int64'), 45: Int64Index([6210, 50797, 73234, 63345, 27575], dtype='int64'), 46: Int64Index([49389, 59400, 35604, 53259, 3880, 53701], dtype='int64'), 47: Int64Index([22251, 74534, 42542, 55921, 80748], dtype='int64'), 48: Int64Index([30735, 42759, 13041, 72383], dtype='int64'), 49: Int64Index([61605, 55985, 82967], dtype='int64'), 50: Int64Index([75962, 3074], dtype='int64'), 51: Int64Index([42767, 59009, 12276, 77760, 4996], dtype='int64'), 52: Int64Index([37671, 73344, 82239, 82036, 78873, 78696, 17259, 1835, 41904, 1181], dtype='int64'), 53: Int64Index([81155, 75709], dtype='int64'), 54: Int64Index([56173], dtype='int64'), 55: Int64Index([40671, 19423], dtype='int64'), 56: Int64Index([7812, 33921], dtype='int64'), 57: Int64Index([18734, 44978, 76555, 74487, 75910], dtype='int64'), 58: Int64Index([53881, 77048, 69526, 6829, 34949, 21565, 28169, 82204, 25325], dtype='int64'), 59: Int64Index([14841, 27773], dtype='int64'), 60: Int64Index([27321, 56955, 57405, 41037], dtype='int64'), 61: Int64Index([32492, 60359, 18962, 11216, 40017, 39914, 46384, 73693, 69852, 11227, 14162, 39817, 46455, 7207, 13827, 80563, 69695, 79839, 32689, 18869, 2882, 35785, 73372, 39726, 17795, 39708, 53333, 53319, 80414, 80555, 33151, 25799, 45920, 53113, 62870, 32006, 62737, 3187, 35213, 73856, 4294, 46673, 46250, 49283, 25088, 33156, 70421, 53177, 83164, 61055, 49064], dtype='int64'), 62: Int64Index([81313, 81172, 1498], dtype='int64'), 63: Int64Index([76815, 49362], dtype='int64'), 64: Int64Index([58157, 44153, 30317, 49195, 48399, 34549, 8972, 2608, 23461], dtype='int64'), 65: Int64Index([40217, 59739, 74210, 48189, 49870, 46810, 32923, 47252, 66089], dtype='int64'), 66: Int64Index([29459, 64367], dtype='int64'), 67: Int64Index([47534], dtype='int64'), 68: Int64Index([6191, 15630], dtype='int64'), 69: Int64Index([54439, 698, 26918], dtype='int64'), 70: Int64Index([51659, 48317, 79428, 82654], dtype='int64'), 71: Int64Index([81345, 12191, 74791], dtype='int64'), 72: Int64Index([43751], dtype='int64'), 73: Int64Index([69220, 9053, 79055, 28213, 78891, 51261, 24391, 1196, 75485], dtype='int64'), 74: Int64Index([47595, 5213, 19645], dtype='int64'), 75: Int64Index([32222, 66641, 17609, 69310, 76208, 24703, 54021, 31516, 46320, 69118, 5692, 28646], dtype='int64'), 76: Int64Index([12579, 12113], dtype='int64'), 77: Int64Index([40676], dtype='int64'), 78: Int64Index([12338, 27144], dtype='int64'), 79: Int64Index([37934, 23557, 28101, 37981, 2939, 2271, 2018, 48736, 2536, 72392], dtype='int64'), 80: Int64Index([44384, 43074, 29262, 69229, 13431, 16672], dtype='int64'), 81: Int64Index([12208, 12771, 47723], dtype='int64'), 82: Int64Index([78321, 65249, 4812, 82042], dtype='int64'), 83: Int64Index([47471, 75562], dtype='int64'), 84: Int64Index([20899, 57310, 54689], dtype='int64'), 85: Int64Index([42071, 51844, 23986, 75673, 45638, 62253, 65886, 31565, 34880, 2279, 75848, 48600, 82558, 24905, 51871, 78815, 17657, 6219, 71843, 2693], dtype='int64'), 86: Int64Index([27412, 48697], dtype='int64'), 87: Int64Index([21047, 63646, 48801, 82269, 78112, 55785, 42416, 42297, 49032, 25749, 59568, 20872, 28328, 40139], dtype='int64'), 88: Int64Index([52317, 74619, 81129], dtype='int64'), 89: Int64Index([40188], dtype='int64'), 90: Int64Index([82312, 28250, 24353, 15876, 78726, 64795, 44738, 65554, 58216, 37720], dtype='int64'), 91: Int64Index([50935, 5342, 29725, 56786, 75994], dtype='int64'), 92: Int64Index([ 6565, 28181, 52690, 5921, 73348, 66220, 73003, 69546, 78262, 23764, 44917, 16542, 73325, 78594, 21451, 58278, 6208, 44545], dtype='int64'), 93: Int64Index([78258, 37216], dtype='int64'), 94: Int64Index([45343, 36162, 16812, 47243, 63560, 46238], dtype='int64'), 95: Int64Index([28466, 52461, 13977, 18092, 25242, 46915, 25482, 59853, 32727, 82719, 32487, 73141, 30786, 80115, 66876, 42052, 27647, 18659, 48958, 66418, 26601, 49003, 25113, 17992, 66852, 80224, 14161, 7173], dtype='int64'), 96: Int64Index([2924, 38067, 38068, 72622, 73389, 23933, 65091], dtype='int64'), 97: Int64Index([47469], dtype='int64'), 98: Int64Index([71607, 81981, 65513, 68810], dtype='int64'), 99: Int64Index([8201, 41086, 50497, 478, 318, 54482], dtype='int64')}
Lo que queremos es saber la longitud (el número de instancias) de cada uno de los grupos. Si llamamos a la función 'len' con la salida de groupby como argumento ...
print(len(X_cluster.groupby('cluster')))
100
Usando 'transform' 'len' nos devuelve para cada instancia de 'X_cluster' la longitud (nº de instancias) del grupo en el que se ha incluido a esa instancia.
print(X_cluster.groupby('cluster').cluster.transform(len))
59950 52 32492 51 35390 21 58323 19 18734 5 67528 66 28466 28 69729 55 80558 30 51796 9 60359 51 64853 6 67531 55 22293 30 18962 51 7794 4 23582 41 52461 28 78321 4 13977 28 68922 19 35135 21 11216 51 43090 66 66369 21 25165 13 71607 4 51622 13 21488 52 60711 66 .. 51218 4 42336 21 42850 29 11170 66 60438 66 66089 9 31045 46 60274 66 35004 41 37720 10 22677 6 28328 14 54482 6 49147 55 82204 9 72392 10 59146 29 27575 5 11574 55 78822 21 83443 22 54744 7 33680 21 40139 14 74644 4 76530 21 7173 28 55655 41 17818 46 25325 9 Name: cluster, Length: 1000, dtype: int64
Comprobamos cuántos clusters quedan después de haber eliminado aquellos que no llegaban al tamaño mínimo.
k_filtrado = len(set(X_filtrado['cluster']))
print("De los {:.0f} clusters hay {:.0f} con más de {:.0f} elementos. Del total de {:.0f} elementos, se seleccionan {:.0f}".format(k['Ward'],k_filtrado,min_size,len(X),len(X_filtrado)))
De los 100 clusters hay 69 con más de 3 elementos. Del total de 1000 elementos, se seleccionan 934
Eliminamos la columna con la asignación de cluster.
X_filtrado = X_filtrado.drop('cluster', axis=1) # axis=1 para indicar que lo que queremos eliminar es una columna y no una fila
print(list(X_filtrado))
['EDAD', 'ANORES', 'NPFAM', 'H6584', 'ESREAL']
Normalizamos el conjunto filtrado. Volvemos a normalizar porque al eliminar outliers puede que los valores mínimos/máximos de algunas variables hayan cambiado.
X_filtrado_normal = X_filtrado.apply(norm_to_zero_one)
Obtenemos el dendograma usando scipy (que realmente va a volver a ejecutar el clustering jerárquico).
linkage_array = hierarchy.ward(X_filtrado_normal)
plt.figure(1)
plt.clf()
h_dict = hierarchy.dendrogram(linkage_array,orientation='left') #lo ponemos en horizontal para compararlo con el generado por seaborn
'h_dict' es un diccionario con información para representar el dendograma.
Generamos el dendograma usando seaborn (que a su vez usa scipy) para incluir un heatmap.
#Ahora lo saco usando seaborn (que a su vez usa scipy) para incluir un heatmap
sns.clustermap(X_filtrado_normal, method='ward', col_cluster=False, figsize=(20,10), cmap="YlGnBu", yticklabels=False)
<seaborn.matrix.ClusterGrid at 0x1a2de692a20>