Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

107 строки
3.4 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import joblib
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.decomposition import PCA
# ============================
# 1. ЗАГРУЗКА ДАННЫХ
# ============================
iris = load_iris()
X = iris.data
y = iris.target
# ============================
# 2. РАЗБИЕНИЕ НА ОБУЧАЮЩУЮ И ТЕСТОВУЮ ВЫБОРКИ
# ============================
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
# ============================
# 3. ЗАГРУЗКА ЛУЧШЕЙ МОДЕЛИ
# ============================
model = joblib.load("best_model.pkl")
print("Модель загружена из файла best_model.pkl")
# ============================
# 4. ОБУЧЕНИЕ НА ТРЕНИРОВОЧНОЙ ВЫБОРКЕ
# ============================
model.fit(X_train, y_train)
# ============================
# 5. ПРЕДСКАЗАНИЯ
# ============================
y_pred = model.predict(X_test)
# ============================
# 6. МЕТРИКИ КАЧЕСТВА
# ============================
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average="macro")
recall = recall_score(y_test, y_pred, average="macro")
f1 = f1_score(y_test, y_pred, average="macro")
print("\nМЕТРИКИ КАЧЕСТВА (ПУНКТ 4):")
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1-score:", f1)
# ============================
# 7. PCA ДЛЯ ВИЗУАЛИЗАЦИИ
# ============================
pca = PCA(n_components=2)
X_test_2d = pca.fit_transform(X_test)
# ============================
# 8. ГРАФИК 1 — ИСТИННЫЕ МЕТКИ
# ============================
plt.figure()
plt.scatter(X_test_2d[:, 0], X_test_2d[:, 1])
plt.title("Тестовая выборка — ИСТИННЫЕ МЕТКИ")
plt.xlabel("Компонента 1")
plt.ylabel("Компонента 2")
plt.grid(True)
plt.savefig("4_true_labels.png", dpi=300)
plt.show()
# ============================
# 9. ГРАФИК 2 — МЕТКИ ИЗ ПУНКТА 2 (БЕЗ GridSearch)
# ============================
# Их можно восстановить повторным обучением обычной LogisticRegression
from sklearn.linear_model import LogisticRegression
base_model = LogisticRegression(max_iter=1000)
base_model.fit(X_train, y_train)
y_pred_base = base_model.predict(X_test)
plt.figure()
plt.scatter(X_test_2d[:, 0], X_test_2d[:, 1])
plt.title("Тестовая выборка — ПРЕДСКАЗАНИЯ (пункт 2)")
plt.xlabel("Компонента 1")
plt.ylabel("Компонента 2")
plt.grid(True)
plt.savefig("4_pred_labels_p2.png", dpi=300)
plt.show()
# ============================
# 10. ГРАФИК 3 — МЕТКИ ИЗ ПУНКТА 4 (ПОСЛЕ GridSearch)
# ============================
plt.figure()
plt.scatter(X_test_2d[:, 0], X_test_2d[:, 1])
plt.title("Тестовая выборка — ПРЕДСКАЗАНИЯ (пункт 4, после GridSearch)")
plt.xlabel("Компонента 1")
plt.ylabel("Компонента 2")
plt.grid(True)
plt.savefig("4_pred_labels_p4.png", dpi=300)
plt.show()
print("\nГрафики сохранены:")
print("4_true_labels.png")
print("4_pred_labels_p2.png")
print("4_pred_labels_p4.png")