Mean shift clustering aims to discover “blobs” in a smooth density of samples. It is a centroid-based algorithm, which works by updating candidates for centroids to be the mean of the points within a given region. These candidates are then filtered in a post-processing stage to eliminate near-duplicates to form the final set of centroids.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
df = pd.read_csv('my_machine-learning/datasets/customers.csv')
df.sample(2)
CustomerID | Genre | Age | Annual Income (k$) | Spending Score (1-100) | |
---|---|---|---|---|---|
127 | 128 | Male | 40 | 71 | 95 |
9 | 10 | Female | 30 | 19 | 72 |
#X = df.iloc[:, [2,3]].values (2d)
X = df.iloc[:, [2,3,4]].values
from sklearn.cluster import MeanShift
from sklearn.cluster import KMeans
ms = MeanShift(bandwidth=22)
y_ms = ms.fit_predict(X)
kmean = KMeans(n_clusters=5,)
y_kmeans = kmean.fit_predict(X)
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(14, 20))
#plt.subplot(121) (for 2d)
plt.subplot(211, projection='3d')
plt.scatter(X[:, 0], X[:, 1],X[:,2], c=y_ms, cmap='prism')
centers = ms.cluster_centers_
plt.scatter(centers[:, 0], centers[:, 1],centers[:,2], c='black')
plt.title('Clusters using MeanShift')
plt.xlabel('Age')
plt.ylabel('Annual Income (k$)')
#plt.subplot(122) (for 2d)
plt.subplot(212, projection='3d')
plt.scatter(X[:, 0], X[:, 1],X[:,2], c=y_kmeans, cmap='prism')
centers = kmean.cluster_centers_
plt.scatter(centers[:, 0], centers[:, 1], centers[:,2], c='black')
plt.xlabel('Age')
plt.title('Clusters using Kmeans')
plt.ylabel('Annual Income (k$)')
plt.show()