# -*- coding: utf-8 -*- """IS_LR3_2 Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1ATu8wYdHLgC6dGpFJboJXvIoTohx65eT """ from google.colab import drive drive.mount('/content/drive') import os os.chdir('/content/drive/MyDrive/Colab Notebooks/is_lab3') from tensorflow import keras from tensorflow.keras.models import Sequential from tensorflow.keras import layers import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import classification_report, confusion_matrix from sklearn.metrics import ConfusionMatrixDisplay # загрузка датасета from keras.datasets import cifar10 (X_train, y_train), (X_test, y_test) = cifar10.load_data() # создание своего разбиения датасета from sklearn.model_selection import train_test_split # объединяем в один набор X = np.concatenate((X_train, X_test)) y = np.concatenate((y_train, y_test)) # разбиваем по вариантам X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 10000, train_size = 50000, random_state = 7) # вывод размерностей print('Shape of X train:', X_train.shape) print('Shape of y train:', y_train.shape) print('Shape of X test:', X_test.shape) print('Shape of y test:', y_test.shape) class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] plt.figure(figsize=(10,10)) for i in range(25): plt.subplot(5,5,i+1) plt.xticks([]) plt.yticks([]) plt.grid(False) plt.imshow(X_train[i]) plt.xlabel(class_names[y_train[i][0]]) plt.show() # Зададим параметры данных и модели num_classes = 10 input_shape = (32, 32, 3) # Приведение входных данных к диапазону [0, 1] X_train = X_train / 255 X_test = X_test / 255 print('Shape of transformed X train:', X_train.shape) print('Shape of transformed X test:', X_test.shape) # переведем метки в one-hot y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) print('Shape of transformed y train:', y_train.shape) print('Shape of transformed y test:', y_test.shape) # создаем модель model = Sequential() # Блок 1 model.add(layers.Conv2D(32, (3, 3), padding="same", activation="relu", input_shape=input_shape)) model.add(layers.BatchNormalization()) model.add(layers.Conv2D(32, (3, 3), padding="same", activation="relu")) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Dropout(0.25)) # Блок 2 model.add(layers.Conv2D(64, (3, 3), padding="same", activation="relu")) model.add(layers.BatchNormalization()) model.add(layers.Conv2D(64, (3, 3), padding="same", activation="relu")) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Dropout(0.25)) model.add(layers.Flatten()) model.add(layers.Dense(128, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation="softmax")) model.summary() batch_size = 64 epochs = 50 model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) scores = model.evaluate(X_test, y_test) print('Loss on test data:', scores[0]) print('Accuracy on test data:', scores[1]) for n in [5,17]: result = model.predict(X_test[n:n+1]) print('NN output:', result) plt.imshow(X_test[n].reshape(32,32,3), cmap=plt.get_cmap('gray')) plt.show() print('Real mark: ', np.argmax(y_test[n])) print('NN answer: ', np.argmax(result)) # истинные метки классов true_labels = np.argmax(y_test, axis=1) # предсказанные метки классов predicted_labels = np.argmax(model.predict(X_test), axis=1) # отчет о качестве классификации print(classification_report(true_labels, predicted_labels, target_names=class_names)) # вычисление матрицы ошибок conf_matrix = confusion_matrix(true_labels, predicted_labels) # отрисовка матрицы ошибок в виде "тепловой карты" fig, ax = plt.subplots(figsize=(6, 6)) disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix,display_labels=class_names) disp.plot(ax=ax, xticks_rotation=45) # поворот подписей по X и приятная палитра plt.tight_layout() # чтобы всё влезло plt.show()