форкнуто от main/is_dnn
				
			
							Родитель
							
								
									3f260f388c
								
							
						
					
					
						Сommit
						6b9946cbe5
					
				
											
												Двоичный файл не отображается.
											
										
									
								
											
												Двоичный файл не отображается.
											
										
									
								
											
												Двоичный файл не отображается.
											
										
									
								@ -1,119 +0,0 @@
 | 
				
			|||||||
import numpy as np
 | 
					 | 
				
			||||||
import matplotlib.pyplot as plt
 | 
					 | 
				
			||||||
import seaborn as sns
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def make_confusion_matrix(cf,
 | 
					 | 
				
			||||||
                          group_names=None,
 | 
					 | 
				
			||||||
                          categories='auto',
 | 
					 | 
				
			||||||
                          count=True,
 | 
					 | 
				
			||||||
                          percent=True,
 | 
					 | 
				
			||||||
                          cbar=True,
 | 
					 | 
				
			||||||
                          xyticks=True,
 | 
					 | 
				
			||||||
                          xyplotlabels=True,
 | 
					 | 
				
			||||||
                          sum_stats=True,
 | 
					 | 
				
			||||||
                          figsize=None,
 | 
					 | 
				
			||||||
                          cmap='Blues',
 | 
					 | 
				
			||||||
                          title=None,
 | 
					 | 
				
			||||||
                          f_name=None,):
 | 
					 | 
				
			||||||
    '''
 | 
					 | 
				
			||||||
    This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Arguments
 | 
					 | 
				
			||||||
    ---------
 | 
					 | 
				
			||||||
    cf:            confusion matrix to be passed in
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    group_names:   List of strings that represent the labels row by row to be shown in each square.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    categories:    List of strings containing the categories to be displayed on the x,y axis. Default is 'auto'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    count:         If True, show the raw number in the confusion matrix. Default is True.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    percent:       If True, show the proportions for each category. Default is True.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cbar:          If True, show the color bar. The cbar values are based off the values in the confusion matrix.
 | 
					 | 
				
			||||||
                   Default is True.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    xyticks:       If True, show x and y ticks. Default is True.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    xyplotlabels:  If True, show 'True Label' and 'Predicted Label' on the figure. Default is True.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    sum_stats:     If True, display summary statistics below the figure. Default is True.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    figsize:       Tuple representing the figure size. Default will be the matplotlib rcParams value.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cmap:          Colormap of the values displayed from matplotlib.pyplot.cm. Default is 'Blues'
 | 
					 | 
				
			||||||
                   See http://matplotlib.org/examples/color/colormaps_reference.html
 | 
					 | 
				
			||||||
                   
 | 
					 | 
				
			||||||
    title:         Title for the heatmap. Default is None.
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    f_name:        Filename for saving picture. Default is None, which means no saving.
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
    '''
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # CODE TO GENERATE TEXT INSIDE EACH SQUARE
 | 
					 | 
				
			||||||
    blanks = ['' for i in range(cf.size)]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if group_names and len(group_names)==cf.size:
 | 
					 | 
				
			||||||
        group_labels = ["{}\n".format(value) for value in group_names]
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        group_labels = blanks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if count:
 | 
					 | 
				
			||||||
        group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()]
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        group_counts = blanks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if percent:
 | 
					 | 
				
			||||||
        group_percentages = ["{0:.2%}".format(value) for value in cf.flatten()/np.sum(cf)]
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        group_percentages = blanks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels,group_counts,group_percentages)]
 | 
					 | 
				
			||||||
    box_labels = np.asarray(box_labels).reshape(cf.shape[0],cf.shape[1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS
 | 
					 | 
				
			||||||
    if sum_stats:
 | 
					 | 
				
			||||||
        #Accuracy is sum of diagonal divided by total observations
 | 
					 | 
				
			||||||
        accuracy  = np.trace(cf) / float(np.sum(cf))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        #if it is a binary confusion matrix, show some more stats
 | 
					 | 
				
			||||||
        if len(cf)==2:
 | 
					 | 
				
			||||||
            #Metrics for Binary Confusion Matrices
 | 
					 | 
				
			||||||
            precision = cf[1,1] / sum(cf[:,1])
 | 
					 | 
				
			||||||
            recall    = cf[1,1] / sum(cf[1,:])
 | 
					 | 
				
			||||||
            f1_score  = 2*precision*recall / (precision + recall)
 | 
					 | 
				
			||||||
            stats_text = "\n\nAccuracy={:0.3f}\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(
 | 
					 | 
				
			||||||
                accuracy,precision,recall,f1_score)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            stats_text = "\n\nAccuracy={:0.3f}".format(accuracy)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        stats_text = ""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS
 | 
					 | 
				
			||||||
    if figsize==None:
 | 
					 | 
				
			||||||
        #Get default figure size if not set
 | 
					 | 
				
			||||||
        figsize = plt.rcParams.get('figure.figsize')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if xyticks==False:
 | 
					 | 
				
			||||||
        #Do not show categories if xyticks is False
 | 
					 | 
				
			||||||
        categories=False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # MAKE THE HEATMAP VISUALIZATION
 | 
					 | 
				
			||||||
    plt.figure(figsize=figsize)
 | 
					 | 
				
			||||||
    sns.heatmap(cf,annot=box_labels,fmt="",cmap=cmap,cbar=cbar,xticklabels=categories,yticklabels=categories,annot_kws={"size": 20})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if xyplotlabels:
 | 
					 | 
				
			||||||
        plt.ylabel('True label')
 | 
					 | 
				
			||||||
        plt.xlabel('Predicted label' + stats_text)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        plt.xlabel(stats_text)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    if title:
 | 
					 | 
				
			||||||
        plt.title(title)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    if f_name:
 | 
					 | 
				
			||||||
        plt.savefig(fname = f_name, dpi=None, facecolor='w', edgecolor='w', orientation='portrait', pad_inches=0.1)
 | 
					 | 
				
			||||||
					Загрузка…
					
					
				
		Ссылка в новой задаче