459 KiB
459 KiB
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import cross_val_score, train_test_split, GridSearchCV, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.manifold import TSNE
from sklearn.manifold import MDS
from sklearn import preprocessing
= pd.read_csv('titanic.csv') df
Знакомство с датасетом
df
PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
886 | 887 | 0 | 2 | Montvila, Rev. Juozas | male | 27.0 | 0 | 0 | 211536 | 13.0000 | NaN | S |
887 | 888 | 1 | 1 | Graham, Miss. Margaret Edith | female | 19.0 | 0 | 0 | 112053 | 30.0000 | B42 | S |
888 | 889 | 0 | 3 | Johnston, Miss. Catherine Helen "Carrie" | female | NaN | 1 | 2 | W./C. 6607 | 23.4500 | NaN | S |
889 | 890 | 1 | 1 | Behr, Mr. Karl Howell | male | 26.0 | 0 | 0 | 111369 | 30.0000 | C148 | C |
890 | 891 | 0 | 3 | Dooley, Mr. Patrick | male | 32.0 | 0 | 0 | 370376 | 7.7500 | NaN | Q |
891 rows × 12 columns
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non-null int64
8 Ticket 891 non-null object
9 Fare 891 non-null float64
10 Cabin 204 non-null object
11 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
df.describe()
PassengerId | Survived | Pclass | Age | SibSp | Parch | Fare | |
---|---|---|---|---|---|---|---|
count | 891.000000 | 891.000000 | 891.000000 | 714.000000 | 891.000000 | 891.000000 | 891.000000 |
mean | 446.000000 | 0.383838 | 2.308642 | 29.699118 | 0.523008 | 0.381594 | 32.204208 |
std | 257.353842 | 0.486592 | 0.836071 | 14.526497 | 1.102743 | 0.806057 | 49.693429 |
min | 1.000000 | 0.000000 | 1.000000 | 0.420000 | 0.000000 | 0.000000 | 0.000000 |
25% | 223.500000 | 0.000000 | 2.000000 | 20.125000 | 0.000000 | 0.000000 | 7.910400 |
50% | 446.000000 | 0.000000 | 3.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 |
75% | 668.500000 | 1.000000 | 3.000000 | 38.000000 | 1.000000 | 0.000000 | 31.000000 |
max | 891.000000 | 1.000000 | 3.000000 | 80.000000 | 8.000000 | 6.000000 | 512.329200 |
Предварительная обработка
# Удаляем ненужные столбцы
= df.drop(['PassengerId', 'Ticket', 'Cabin', 'Name'], axis=1) df
# Кодируем поле Пол
'Sex'] == 'male', 'Sex'] = 1
df.loc[df['Sex'] == 'female', 'Sex'] = 0
df.loc[df[= df.Sex.astype(bool) df.Sex
# Кодируем поле Embarked
= preprocessing.LabelEncoder()
le 'Embarked'])
le.fit(df['Embarked'] = le.transform(df['Embarked'])
df[ df
Survived | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked | |
---|---|---|---|---|---|---|---|---|
0 | 0 | 3 | True | 22.0 | 1 | 0 | 7.2500 | 2 |
1 | 1 | 1 | False | 38.0 | 1 | 0 | 71.2833 | 0 |
2 | 1 | 3 | False | 26.0 | 0 | 0 | 7.9250 | 2 |
3 | 1 | 1 | False | 35.0 | 1 | 0 | 53.1000 | 2 |
4 | 0 | 3 | True | 35.0 | 0 | 0 | 8.0500 | 2 |
... | ... | ... | ... | ... | ... | ... | ... | ... |
886 | 0 | 2 | True | 27.0 | 0 | 0 | 13.0000 | 2 |
887 | 1 | 1 | False | 19.0 | 0 | 0 | 30.0000 | 2 |
888 | 0 | 3 | False | NaN | 1 | 2 | 23.4500 | 2 |
889 | 1 | 1 | True | 26.0 | 0 | 0 | 30.0000 | 0 |
890 | 0 | 3 | True | 32.0 | 0 | 0 | 7.7500 | 1 |
891 rows × 8 columns
# Заполняем возраст медианой
= True) df.Age.fillna(df.Age.median(), inplace
df
Survived | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked | |
---|---|---|---|---|---|---|---|---|
0 | 0 | 3 | True | 22.0 | 1 | 0 | 7.2500 | 2 |
1 | 1 | 1 | False | 38.0 | 1 | 0 | 71.2833 | 0 |
2 | 1 | 3 | False | 26.0 | 0 | 0 | 7.9250 | 2 |
3 | 1 | 1 | False | 35.0 | 1 | 0 | 53.1000 | 2 |
4 | 0 | 3 | True | 35.0 | 0 | 0 | 8.0500 | 2 |
... | ... | ... | ... | ... | ... | ... | ... | ... |
886 | 0 | 2 | True | 27.0 | 0 | 0 | 13.0000 | 2 |
887 | 1 | 1 | False | 19.0 | 0 | 0 | 30.0000 | 2 |
888 | 0 | 3 | False | 28.0 | 1 | 2 | 23.4500 | 2 |
889 | 1 | 1 | True | 26.0 | 0 | 0 | 30.0000 | 0 |
890 | 0 | 3 | True | 32.0 | 0 | 0 | 7.7500 | 1 |
891 rows × 8 columns
Визуализация датасета
= TSNE(n_components=2,
tsne ="pca",
init=0,
random_state=50,
perplexity= 1000,
n_iter = 'cosine') metric
= tsne.fit_transform(df.iloc[:,1:]) Y
0], Y[:,1], c = df.iloc[:,0]) plt.scatter(Y[:,
<matplotlib.collections.PathCollection at 0x1a19d41f820>
= MDS(n_components=2,
mds =0) random_state
= mds.fit_transform(df.iloc[:,1:]) Y_MDS
C:\Users\Андрей\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\manifold\_mds.py:299: FutureWarning: The default value of `normalized_stress` will change to `'auto'` in version 1.4. To suppress this warning, manually set the value of `normalized_stress`.
warnings.warn(
0], Y_MDS[:,1], c = df.iloc[:,0]) plt.scatter(Y_MDS[:,
<matplotlib.collections.PathCollection at 0x1a19d64fdc0>
Графики по столбцам
= df.Survived) sns.countplot(x
<Axes: xlabel='Survived', ylabel='count'>
= df.Pclass, hue = df.Survived) sns.countplot(x
<Axes: xlabel='Pclass', ylabel='count'>
set(rc={'figure.figsize':(10,5)})
sns.= df.loc[df.Fare < 100].Fare, hue = df.Survived, kde=True, ) sns.histplot(x
<Axes: xlabel='Fare', ylabel='Count'>
set(rc={'figure.figsize':(10,5)})
sns.= df.Age, hue = df.Survived, kde=True ) sns.histplot(x
<Axes: xlabel='Age', ylabel='Count'>
set(rc={'figure.figsize':(10,5)})
sns.= df.SibSp, hue = df.Survived, kde=True ) sns.histplot(x
<Axes: xlabel='SibSp', ylabel='Count'>
Корреляционная матрица
= True), annot = True, vmin=-1, vmax=1, cmap = 'bwr') sns.heatmap(df.corr(numeric_only
<Axes: >
Классификация
= train_test_split(df.iloc[:,1:], df.iloc[:,0], test_size=0.33, random_state=42) X_train, X_test, y_train, y_test
= RandomForestClassifier()
RF
RF.fit(X_train, y_train)= RF.predict(X_test) rf_prediction
print('Conf matrix')
print(metrics.confusion_matrix(rf_prediction, y_test))
print('Classification report')
print(metrics.classification_report(rf_prediction, y_test))
Conf matrix
[[143 33]
[ 32 87]]
Classification report
precision recall f1-score support
0 0.82 0.81 0.81 176
1 0.72 0.73 0.73 119
accuracy 0.78 295
macro avg 0.77 0.77 0.77 295
weighted avg 0.78 0.78 0.78 295
# Вероятности каждого класса
= RF.predict_proba(X_test)
rf_prediction_proba rf_prediction_proba
array([[0.66 , 0.34 ],
[0.965 , 0.035 ],
[0.83833333, 0.16166667],
[0.02 , 0.98 ],
[0.76 , 0.24 ],
[0.06 , 0.94 ],
[0.13648918, 0.86351082],
[0.89 , 0.11 ],
[0.25266667, 0.74733333],
[0.03 , 0.97 ],
[0.56 , 0.44 ],
[0.935 , 0.065 ],
[0.94 , 0.06 ],
[0.92 , 0.08 ],
[0.679 , 0.321 ],
[0.05 , 0.95 ],
[0.67 , 0.33 ],
[0.14376299, 0.85623701],
[0.81 , 0.19 ],
[0.96 , 0.04 ],
[0.99 , 0.01 ],
[0.43666667, 0.56333333],
[0.88 , 0.12 ],
[1. , 0. ],
[0.99 , 0.01 ],
[0.91 , 0.09 ],
[0.68 , 0.32 ],
[0.901 , 0.099 ],
[0.78 , 0.22 ],
[0.38 , 0.62 ],
[1. , 0. ],
[0.61166667, 0.38833333],
[0.36 , 0.64 ],
[0.39 , 0.61 ],
[0.70516667, 0.29483333],
[0.81 , 0.19 ],
[0.99 , 0.01 ],
[0.13648918, 0.86351082],
[0.2 , 0.8 ],
[1. , 0. ],
[0.98 , 0.02 ],
[0.77 , 0.23 ],
[1. , 0. ],
[0.82158261, 0.17841739],
[0.61 , 0.39 ],
[0.92 , 0.08 ],
[0.466 , 0.534 ],
[1. , 0. ],
[0.53 , 0.47 ],
[0.37 , 0.63 ],
[0.12 , 0.88 ],
[0.03 , 0.97 ],
[0.97 , 0.03 ],
[0.42 , 0.58 ],
[0.98666667, 0.01333333],
[0.04 , 0.96 ],
[0.87555556, 0.12444444],
[0.1 , 0.9 ],
[0.09 , 0.91 ],
[0.74 , 0.26 ],
[0.65833333, 0.34166667],
[0.11 , 0.89 ],
[0.09 , 0.91 ],
[0.93 , 0.07 ],
[0.82158261, 0.17841739],
[0.04 , 0.96 ],
[0.956 , 0.044 ],
[0.7925 , 0.2075 ],
[0.74857143, 0.25142857],
[0.06 , 0.94 ],
[0.15 , 0.85 ],
[0.21 , 0.79 ],
[0.19 , 0.81 ],
[0.03 , 0.97 ],
[0.97 , 0.03 ],
[0.96 , 0.04 ],
[0.21845346, 0.78154654],
[0.04 , 0.96 ],
[0.01 , 0.99 ],
[0.51 , 0.49 ],
[0.98666667, 0.01333333],
[0.01 , 0.99 ],
[0.06 , 0.94 ],
[0.95369012, 0.04630988],
[0.45 , 0.55 ],
[0.6 , 0.4 ],
[0.05 , 0.95 ],
[0.03 , 0.97 ],
[0.98 , 0.02 ],
[0.97833333, 0.02166667],
[0.48 , 0.52 ],
[0.78 , 0.22 ],
[0.98 , 0.02 ],
[0.96369012, 0.03630988],
[1. , 0. ],
[0.78583333, 0.21416667],
[0.95 , 0.05 ],
[0.945 , 0.055 ],
[0. , 1. ],
[0.8 , 0.2 ],
[0.89 , 0.11 ],
[0.98 , 0.02 ],
[0.03 , 0.97 ],
[0.95666667, 0.04333333],
[0.92 , 0.08 ],
[0.31083333, 0.68916667],
[0.09 , 0.91 ],
[0.95 , 0.05 ],
[0.99 , 0.01 ],
[0.5 , 0.5 ],
[0.1 , 0.9 ],
[0.95 , 0.05 ],
[0.06 , 0.94 ],
[0.14 , 0.86 ],
[0.66133333, 0.33866667],
[1. , 0. ],
[0.43 , 0.57 ],
[0.454 , 0.546 ],
[0.09 , 0.91 ],
[0.82 , 0.18 ],
[0.77 , 0.23 ],
[0.01 , 0.99 ],
[0.09 , 0.91 ],
[0.18 , 0.82 ],
[0.91 , 0.09 ],
[0.67 , 0.33 ],
[0.04 , 0.96 ],
[0.41 , 0.59 ],
[0.71 , 0.29 ],
[0.97433333, 0.02566667],
[0.17026299, 0.82973701],
[0.66 , 0.34 ],
[0.99 , 0.01 ],
[0.49 , 0.51 ],
[0.21 , 0.79 ],
[0.87 , 0.13 ],
[0.02 , 0.98 ],
[0.57504762, 0.42495238],
[0.98 , 0.02 ],
[0.74 , 0.26 ],
[0.99 , 0.01 ],
[0.07 , 0.93 ],
[0.26 , 0.74 ],
[0.9 , 0.1 ],
[0.8 , 0.2 ],
[0.07 , 0.93 ],
[0.93 , 0.07 ],
[0.32 , 0.68 ],
[0.12 , 0.88 ],
[0.955 , 0.045 ],
[0.91171429, 0.08828571],
[0.83 , 0.17 ],
[0.69 , 0.31 ],
[0.37 , 0.63 ],
[1. , 0. ],
[0.91 , 0.09 ],
[0.65891667, 0.34108333],
[0.20331457, 0.79668543],
[0.05 , 0.95 ],
[0.23 , 0.77 ],
[0.45496429, 0.54503571],
[0.82 , 0.18 ],
[0.60490476, 0.39509524],
[0.09 , 0.91 ],
[0.92652742, 0.07347258],
[0.45 , 0.55 ],
[0.9575 , 0.0425 ],
[0.09 , 0.91 ],
[1. , 0. ],
[0.99 , 0.01 ],
[0.43 , 0.57 ],
[0.07 , 0.93 ],
[0.14 , 0.86 ],
[0.34 , 0.66 ],
[0.77333333, 0.22666667],
[0.97833333, 0.02166667],
[0.95 , 0.05 ],
[0.09 , 0.91 ],
[0.3 , 0.7 ],
[0.31 , 0.69 ],
[1. , 0. ],
[0.46621429, 0.53378571],
[0.52 , 0.48 ],
[0.6 , 0.4 ],
[0.56 , 0.44 ],
[0.13 , 0.87 ],
[0.8 , 0.2 ],
[1. , 0. ],
[0.99 , 0.01 ],
[0.02 , 0.98 ],
[0.94 , 0.06 ],
[0.63738889, 0.36261111],
[0.86 , 0.14 ],
[0.97 , 0.03 ],
[0.03 , 0.97 ],
[0. , 1. ],
[0.01 , 0.99 ],
[0.58322222, 0.41677778],
[0.07 , 0.93 ],
[0.81 , 0.19 ],
[0.96 , 0.04 ],
[0.42 , 0.58 ],
[0.03 , 0.97 ],
[0.92 , 0.08 ],
[0.60954762, 0.39045238],
[0.24 , 0.76 ],
[0.97 , 0.03 ],
[0.58 , 0.42 ],
[0.71571429, 0.28428571],
[0.985 , 0.015 ],
[0.37 , 0.63 ],
[0.98 , 0.02 ],
[0.03 , 0.97 ],
[0.96369012, 0.03630988],
[0.94 , 0.06 ],
[0.64 , 0.36 ],
[0.14 , 0.86 ],
[0.53 , 0.47 ],
[0.07 , 0.93 ],
[0.97833333, 0.02166667],
[1. , 0. ],
[0.02 , 0.98 ],
[0.93 , 0.07 ],
[0.37 , 0.63 ],
[0.81 , 0.19 ],
[0.04 , 0.96 ],
[0.97 , 0.03 ],
[0.44 , 0.56 ],
[0.17 , 0.83 ],
[0.38 , 0.62 ],
[0.94 , 0.06 ],
[0.22 , 0.78 ],
[0.82158261, 0.17841739],
[0.30816667, 0.69183333],
[0.73 , 0.27 ],
[0.79 , 0.21 ],
[0.07 , 0.93 ],
[0.89 , 0.11 ],
[0.01 , 0.99 ],
[0.685 , 0.315 ],
[0.52 , 0.48 ],
[0.07 , 0.93 ],
[0.97833333, 0.02166667],
[0.82 , 0.18 ],
[0.99 , 0.01 ],
[0.97 , 0.03 ],
[1. , 0. ],
[1. , 0. ],
[0.97 , 0.03 ],
[0.67 , 0.33 ],
[0.51 , 0.49 ],
[0.79 , 0.21 ],
[0.82158261, 0.17841739],
[0.47496429, 0.52503571],
[0.73 , 0.27 ],
[0.69 , 0.31 ],
[0.98333333, 0.01666667],
[0.05 , 0.95 ],
[0.99 , 0.01 ],
[0.03 , 0.97 ],
[0.08 , 0.92 ],
[0.36 , 0.64 ],
[0.02 , 0.98 ],
[0.79 , 0.21 ],
[0.99 , 0.01 ],
[0.95 , 0.05 ],
[1. , 0. ],
[0.82 , 0.18 ],
[0.47 , 0.53 ],
[1. , 0. ],
[0.57 , 0.43 ],
[0.8 , 0.2 ],
[0.84 , 0.16 ],
[0.67 , 0.33 ],
[0.08 , 0.92 ],
[0.98 , 0.02 ],
[0.72583333, 0.27416667],
[0.36571429, 0.63428571],
[0.05 , 0.95 ],
[0.60490476, 0.39509524],
[0.83 , 0.17 ],
[0.84 , 0.16 ],
[0.1 , 0.9 ],
[0.94 , 0.06 ],
[0.34 , 0.66 ],
[0.86666667, 0.13333333],
[0.17026299, 0.82973701],
[0.975 , 0.025 ],
[0.05 , 0.95 ],
[0.25 , 0.75 ],
[0.96 , 0.04 ],
[0.98380952, 0.01619048],
[0.17 , 0.83 ],
[0.23 , 0.77 ],
[0.89166667, 0.10833333]])
rf_prediction
array([0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0,
1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0,
0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1,
0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0,
1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0,
1, 0, 1, 1, 0, 0, 1, 1, 0], dtype=int64)
Важность признаков
= pd.DataFrame(RF.feature_importances_, RF.feature_names_in_)
fi fi
0 | |
---|---|
Pclass | 0.087579 |
Sex | 0.249358 |
Age | 0.255040 |
SibSp | 0.051566 |
Parch | 0.040596 |
Fare | 0.275444 |
Embarked | 0.040416 |
sns.barplot(fi.T)
<Axes: >
= X_train[['Sex', 'Age', 'Fare']]
X_train_sef = X_test[['Sex', 'Age', 'Fare']]
X_test_sef = RandomForestClassifier()
RF_sef
RF_sef.fit(X_train_sef, y_train)=RF_sef.predict(X_test_sef ) rf_sef_prediction
print('Conf matrix')
metrics.confusion_matrix(rf_sef_prediction, y_test)print('Classification report')
print(metrics.classification_report(rf_sef_prediction, y_test))
Conf matrix
Classification report
precision recall f1-score support
0 0.84 0.80 0.82 184
1 0.69 0.75 0.72 111
accuracy 0.78 295
macro avg 0.77 0.77 0.77 295
weighted avg 0.78 0.78 0.78 295
XGBoost
from xgboost import XGBClassifier
= XGBClassifier()
XGB
XGB.fit(X_train, y_train)=XGB.predict(X_test )
xgb_prediction print('Conf matrix')
print(metrics.confusion_matrix(xgb_prediction, y_test))
print('Classification report')
print(metrics.classification_report(xgb_prediction, y_test))
Conf matrix
[[141 31]
[ 34 89]]
Classification report
precision recall f1-score support
0 0.81 0.82 0.81 172
1 0.74 0.72 0.73 123
accuracy 0.78 295
macro avg 0.77 0.77 0.77 295
weighted avg 0.78 0.78 0.78 295
= XGBClassifier()
XGB_sef
XGB_sef.fit(X_train_sef, y_train)=XGB_sef.predict(X_test_sef)
xgb_sef_prediction print('Conf matrix')
print(metrics.confusion_matrix(xgb_sef_prediction, y_test))
print('Classification report')
print(metrics.classification_report(xgb_sef_prediction, y_test))
Conf matrix
[[150 37]
[ 25 83]]
Classification report
precision recall f1-score support
0 0.86 0.80 0.83 187
1 0.69 0.77 0.73 108
accuracy 0.79 295
macro avg 0.77 0.79 0.78 295
weighted avg 0.80 0.79 0.79 295
= True, cmap = 'RdYlGn') sns.heatmap(metrics.confusion_matrix(xgb_sef_prediction, y_test), annot
<Axes: >
Попробуем настроить параметры для RF
= X_train[['Sex', 'Age', 'Fare']]
X_train_sef = X_test[['Sex', 'Age', 'Fare']]
X_test_sef
= {
params 'n_estimators': [1, 10, 50, 100],
'max_depth': [1, 5, 10],
'criterion':['gini', 'entropy']
}
= RandomForestClassifier()
RF_sef_gs
= GridSearchCV(param_grid=params, estimator=RF_sef_gs)
gs
gs.fit(X_train_sef, y_train)=gs.predict(X_test_sef ) rf_sef_gs_prediction
gs.best_params_
{'criterion': 'entropy', 'max_depth': 5, 'n_estimators': 50}
print('Conf matrix')
print(metrics.confusion_matrix(rf_sef_gs_prediction, y_test))
print('Classification report')
print(metrics.classification_report(rf_sef_gs_prediction, y_test))
Conf matrix
[[154 34]
[ 21 86]]
Classification report
precision recall f1-score support
0 0.88 0.82 0.85 188
1 0.72 0.80 0.76 107
accuracy 0.81 295
macro avg 0.80 0.81 0.80 295
weighted avg 0.82 0.81 0.82 295