import numpy as np from matplotlib import pyplot as plt from scipy.cluster.hierarchy import dendrogram import pandas as pd from sklearn.cluster import AgglomerativeClustering from sklearn.datasets import load_iris from scipy.spatial import distance iris = load_iris() iris_pd=pd.DataFrame(data=np.c_[iris['data']], columns=iris['feature_names']) data = iris_pd[['petal length (cm)', 'petal width (cm)', 'sepal length (cm)']].to_numpy() # расчет матрицы расстояний Чебышева distance_matrix = np.zeros((len(data), len(data))) for i in range(len(data)): for j in range(i+1, len(data)): distance_matrix[i][j] = distance.chebyshev(data[i], data[j]) distance_matrix[j][i] = distance_matrix[i][j] def plot_dendrogram(model, **kwargs): # Create linkage matrix and then plot the dendrogram # create the counts of samples under each node counts = np.zeros(model.children_.shape[0]) n_samples = len(model.labels_) for i, merge in enumerate(model.children_): current_count = 0 for child_idx in merge: if child_idx < n_samples: current_count += 1 # leaf node else: current_count += counts[child_idx - n_samples] counts[i] = current_count linkage_matrix = np.column_stack( [model.children_, model.distances_, counts] ).astype(float) # Plot the corresponding dendrogram dendrogram(linkage_matrix, **kwargs) # setting distance_threshold=0 ensures we compute the full tree. metric='precomputed' linkage="single" model = AgglomerativeClustering(compute_distances=True, metric=metric, linkage=linkage) model = model.fit(distance_matrix) print(model.labels_) plt.title('Hierarchical Clustering Dendrogram \n metric="{}", linkage="{}'.format(metric, linkage)) # plot the top three levels of the dendrogram plot_dendrogram(model, truncate_mode="level", p=3) plt.xlabel("Number of points in node") fig1 = plt.figure() ax = plt.axes(projection='3d') ax.scatter3D(iris_pd['petal length (cm)'], iris_pd['petal width (cm)'], iris_pd['sepal length (cm)'], c = model.labels_, cmap='tab10') ax.set_title('Agglomerative Clustering \n metric="{}", linkage="{}"'.format(metric, linkage)) ax.set_xlabel('petal length (cm)') ax.set_ylabel('petal width (cm)') ax.set_zlabel('sepal length (cm)') plt.show()