Вы не можете выбрать более 25 тем
Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.
77 строки
2.0 KiB
Python
77 строки
2.0 KiB
Python
import numpy as np
|
|
import joblib
|
|
|
|
from sklearn.datasets import load_iris
|
|
from sklearn.model_selection import train_test_split, GridSearchCV
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.metrics import classification_report
|
|
|
|
# ============================
|
|
# 1. ЗАГРУЗКА ДАННЫХ
|
|
# ============================
|
|
iris = load_iris()
|
|
X = iris.data
|
|
y = iris.target
|
|
|
|
# ============================
|
|
# 2. РАЗБИЕНИЕ НА ОБУЧАЮЩУЮ И ТЕСТОВУЮ ВЫБОРКИ
|
|
# ============================
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
X, y, test_size=0.3, random_state=42
|
|
)
|
|
|
|
# ============================
|
|
# 3. КЛАССИФИКАТОР
|
|
# ============================
|
|
model = LogisticRegression(max_iter=2000)
|
|
|
|
# ============================
|
|
# 4. СЕТКА ПАРАМЕТРОВ
|
|
# ============================
|
|
param_grid = {
|
|
'C': [0.01, 0.1, 1, 10, 100],
|
|
'solver': ['lbfgs', 'liblinear'],
|
|
'penalty': ['l2']
|
|
}
|
|
|
|
# ============================
|
|
# 5. НАСТРОЙКА GridSearchCV
|
|
# ============================
|
|
grid = GridSearchCV(
|
|
estimator=model,
|
|
param_grid=param_grid,
|
|
cv=5,
|
|
scoring='accuracy',
|
|
n_jobs=-1
|
|
)
|
|
|
|
# ============================
|
|
# 6. ОБУЧЕНИЕ
|
|
# ============================
|
|
grid.fit(X_train, y_train)
|
|
|
|
# ============================
|
|
# 7. ЛУЧШИЕ ПАРАМЕТРЫ
|
|
# ============================
|
|
print("Лучшие параметры:")
|
|
print(grid.best_params_)
|
|
|
|
print("\nЛучшая точность по кросс-валидации:")
|
|
print(grid.best_score_)
|
|
|
|
# ============================
|
|
# 8. ПРОВЕРКА НА ТЕСТЕ
|
|
# ============================
|
|
best_model = grid.best_estimator_
|
|
y_pred = best_model.predict(X_test)
|
|
|
|
print("\nОтчёт по классификации:")
|
|
print(classification_report(y_test, y_pred))
|
|
|
|
# ============================
|
|
# 9. СОХРАНЕНИЕ МОДЕЛИ
|
|
# ============================
|
|
joblib.dump(best_model, "best_model.pkl")
|
|
|
|
print("\nЛучшая модель сохранена в файл: best_model.pkl")
|