class AgglomerativeClustering():
def __init__(self, n_clusters=2, linkage="ward"):
self.n_clusters = n_clusters
self.linkage = linkage
def _get_descendent(self, node, n_samples):
ind = [node]
ret = []
while len(ind) > 0:
i = ind.pop()
if i < n_samples:
ret.append(i)
else:
ind.extend(self.children_[i - n_samples])
return ret
def fit(self, X):
Z = linkage(X, method=self.linkage)
self.children_ = Z[:, :2].astype(np.int)
nodes = []
heappush(nodes, -(X.shape[0] * 2 - 2)) # root node
for _ in range(self.n_clusters - 1):
these_children = self.children_[-nodes[0] - X.shape[0]]
heappush(nodes, -these_children[0])
heappushpop(nodes, -these_children[1])
label = np.zeros(X.shape[0])
for i, node in enumerate(nodes):
label[self._get_descendent(-node, X.shape[0])] = i
self.labels_ = label
return self
def fit_predict(self, X):
self.fit(X)
return self.labels_