Вы не можете выбрать более 25 тем
Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.
459 KiB
459 KiB
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import cross_val_score, train_test_split, GridSearchCV, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.manifold import TSNE
from sklearn.manifold import MDS
from sklearn import preprocessingdf = pd.read_csv('titanic.csv')Знакомство с датасетом
df| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
| 1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
| 2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
| 3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
| 4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 886 | 887 | 0 | 2 | Montvila, Rev. Juozas | male | 27.0 | 0 | 0 | 211536 | 13.0000 | NaN | S |
| 887 | 888 | 1 | 1 | Graham, Miss. Margaret Edith | female | 19.0 | 0 | 0 | 112053 | 30.0000 | B42 | S |
| 888 | 889 | 0 | 3 | Johnston, Miss. Catherine Helen "Carrie" | female | NaN | 1 | 2 | W./C. 6607 | 23.4500 | NaN | S |
| 889 | 890 | 1 | 1 | Behr, Mr. Karl Howell | male | 26.0 | 0 | 0 | 111369 | 30.0000 | C148 | C |
| 890 | 891 | 0 | 3 | Dooley, Mr. Patrick | male | 32.0 | 0 | 0 | 370376 | 7.7500 | NaN | Q |
891 rows × 12 columns
df.info()<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non-null int64
8 Ticket 891 non-null object
9 Fare 891 non-null float64
10 Cabin 204 non-null object
11 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
df.describe()| PassengerId | Survived | Pclass | Age | SibSp | Parch | Fare | |
|---|---|---|---|---|---|---|---|
| count | 891.000000 | 891.000000 | 891.000000 | 714.000000 | 891.000000 | 891.000000 | 891.000000 |
| mean | 446.000000 | 0.383838 | 2.308642 | 29.699118 | 0.523008 | 0.381594 | 32.204208 |
| std | 257.353842 | 0.486592 | 0.836071 | 14.526497 | 1.102743 | 0.806057 | 49.693429 |
| min | 1.000000 | 0.000000 | 1.000000 | 0.420000 | 0.000000 | 0.000000 | 0.000000 |
| 25% | 223.500000 | 0.000000 | 2.000000 | 20.125000 | 0.000000 | 0.000000 | 7.910400 |
| 50% | 446.000000 | 0.000000 | 3.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 |
| 75% | 668.500000 | 1.000000 | 3.000000 | 38.000000 | 1.000000 | 0.000000 | 31.000000 |
| max | 891.000000 | 1.000000 | 3.000000 | 80.000000 | 8.000000 | 6.000000 | 512.329200 |
Предварительная обработка
# Удаляем ненужные столбцы
df = df.drop(['PassengerId', 'Ticket', 'Cabin', 'Name'], axis=1)# Кодируем поле Пол
df.loc[df['Sex'] == 'male', 'Sex'] = 1
df.loc[df['Sex'] == 'female', 'Sex'] = 0
df.Sex = df.Sex.astype(bool)# Кодируем поле Embarked
le = preprocessing.LabelEncoder()
le.fit(df['Embarked'])
df['Embarked'] = le.transform(df['Embarked'])
df| Survived | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked | |
|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 3 | True | 22.0 | 1 | 0 | 7.2500 | 2 |
| 1 | 1 | 1 | False | 38.0 | 1 | 0 | 71.2833 | 0 |
| 2 | 1 | 3 | False | 26.0 | 0 | 0 | 7.9250 | 2 |
| 3 | 1 | 1 | False | 35.0 | 1 | 0 | 53.1000 | 2 |
| 4 | 0 | 3 | True | 35.0 | 0 | 0 | 8.0500 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 886 | 0 | 2 | True | 27.0 | 0 | 0 | 13.0000 | 2 |
| 887 | 1 | 1 | False | 19.0 | 0 | 0 | 30.0000 | 2 |
| 888 | 0 | 3 | False | NaN | 1 | 2 | 23.4500 | 2 |
| 889 | 1 | 1 | True | 26.0 | 0 | 0 | 30.0000 | 0 |
| 890 | 0 | 3 | True | 32.0 | 0 | 0 | 7.7500 | 1 |
891 rows × 8 columns
# Заполняем возраст медианой
df.Age.fillna(df.Age.median(), inplace = True)df| Survived | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked | |
|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 3 | True | 22.0 | 1 | 0 | 7.2500 | 2 |
| 1 | 1 | 1 | False | 38.0 | 1 | 0 | 71.2833 | 0 |
| 2 | 1 | 3 | False | 26.0 | 0 | 0 | 7.9250 | 2 |
| 3 | 1 | 1 | False | 35.0 | 1 | 0 | 53.1000 | 2 |
| 4 | 0 | 3 | True | 35.0 | 0 | 0 | 8.0500 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 886 | 0 | 2 | True | 27.0 | 0 | 0 | 13.0000 | 2 |
| 887 | 1 | 1 | False | 19.0 | 0 | 0 | 30.0000 | 2 |
| 888 | 0 | 3 | False | 28.0 | 1 | 2 | 23.4500 | 2 |
| 889 | 1 | 1 | True | 26.0 | 0 | 0 | 30.0000 | 0 |
| 890 | 0 | 3 | True | 32.0 | 0 | 0 | 7.7500 | 1 |
891 rows × 8 columns
Визуализация датасета
tsne = TSNE(n_components=2,
init="pca",
random_state=0,
perplexity=50,
n_iter = 1000,
metric = 'cosine')Y = tsne.fit_transform(df.iloc[:,1:])plt.scatter(Y[:,0], Y[:,1], c = df.iloc[:,0])<matplotlib.collections.PathCollection at 0x1a19d41f820>

mds = MDS(n_components=2,
random_state=0)Y_MDS = mds.fit_transform(df.iloc[:,1:])C:\Users\Андрей\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\manifold\_mds.py:299: FutureWarning: The default value of `normalized_stress` will change to `'auto'` in version 1.4. To suppress this warning, manually set the value of `normalized_stress`.
warnings.warn(
plt.scatter(Y_MDS[:,0], Y_MDS[:,1], c = df.iloc[:,0])<matplotlib.collections.PathCollection at 0x1a19d64fdc0>

Графики по столбцам
sns.countplot(x = df.Survived)<Axes: xlabel='Survived', ylabel='count'>

sns.countplot(x = df.Pclass, hue = df.Survived)<Axes: xlabel='Pclass', ylabel='count'>

sns.set(rc={'figure.figsize':(10,5)})
sns.histplot(x = df.loc[df.Fare < 100].Fare, hue = df.Survived, kde=True, )<Axes: xlabel='Fare', ylabel='Count'>

sns.set(rc={'figure.figsize':(10,5)})
sns.histplot(x = df.Age, hue = df.Survived, kde=True )<Axes: xlabel='Age', ylabel='Count'>

sns.set(rc={'figure.figsize':(10,5)})
sns.histplot(x = df.SibSp, hue = df.Survived, kde=True )<Axes: xlabel='SibSp', ylabel='Count'>

Корреляционная матрица
sns.heatmap(df.corr(numeric_only = True), annot = True, vmin=-1, vmax=1, cmap = 'bwr')<Axes: >

Классификация
X_train, X_test, y_train, y_test = train_test_split(df.iloc[:,1:], df.iloc[:,0], test_size=0.33, random_state=42) RF = RandomForestClassifier()
RF.fit(X_train, y_train)
rf_prediction = RF.predict(X_test)print('Conf matrix')
print(metrics.confusion_matrix(rf_prediction, y_test))
print('Classification report')
print(metrics.classification_report(rf_prediction, y_test))Conf matrix
[[143 33]
[ 32 87]]
Classification report
precision recall f1-score support
0 0.82 0.81 0.81 176
1 0.72 0.73 0.73 119
accuracy 0.78 295
macro avg 0.77 0.77 0.77 295
weighted avg 0.78 0.78 0.78 295
# Вероятности каждого класса
rf_prediction_proba = RF.predict_proba(X_test)
rf_prediction_probaarray([[0.66 , 0.34 ],
[0.965 , 0.035 ],
[0.83833333, 0.16166667],
[0.02 , 0.98 ],
[0.76 , 0.24 ],
[0.06 , 0.94 ],
[0.13648918, 0.86351082],
[0.89 , 0.11 ],
[0.25266667, 0.74733333],
[0.03 , 0.97 ],
[0.56 , 0.44 ],
[0.935 , 0.065 ],
[0.94 , 0.06 ],
[0.92 , 0.08 ],
[0.679 , 0.321 ],
[0.05 , 0.95 ],
[0.67 , 0.33 ],
[0.14376299, 0.85623701],
[0.81 , 0.19 ],
[0.96 , 0.04 ],
[0.99 , 0.01 ],
[0.43666667, 0.56333333],
[0.88 , 0.12 ],
[1. , 0. ],
[0.99 , 0.01 ],
[0.91 , 0.09 ],
[0.68 , 0.32 ],
[0.901 , 0.099 ],
[0.78 , 0.22 ],
[0.38 , 0.62 ],
[1. , 0. ],
[0.61166667, 0.38833333],
[0.36 , 0.64 ],
[0.39 , 0.61 ],
[0.70516667, 0.29483333],
[0.81 , 0.19 ],
[0.99 , 0.01 ],
[0.13648918, 0.86351082],
[0.2 , 0.8 ],
[1. , 0. ],
[0.98 , 0.02 ],
[0.77 , 0.23 ],
[1. , 0. ],
[0.82158261, 0.17841739],
[0.61 , 0.39 ],
[0.92 , 0.08 ],
[0.466 , 0.534 ],
[1. , 0. ],
[0.53 , 0.47 ],
[0.37 , 0.63 ],
[0.12 , 0.88 ],
[0.03 , 0.97 ],
[0.97 , 0.03 ],
[0.42 , 0.58 ],
[0.98666667, 0.01333333],
[0.04 , 0.96 ],
[0.87555556, 0.12444444],
[0.1 , 0.9 ],
[0.09 , 0.91 ],
[0.74 , 0.26 ],
[0.65833333, 0.34166667],
[0.11 , 0.89 ],
[0.09 , 0.91 ],
[0.93 , 0.07 ],
[0.82158261, 0.17841739],
[0.04 , 0.96 ],
[0.956 , 0.044 ],
[0.7925 , 0.2075 ],
[0.74857143, 0.25142857],
[0.06 , 0.94 ],
[0.15 , 0.85 ],
[0.21 , 0.79 ],
[0.19 , 0.81 ],
[0.03 , 0.97 ],
[0.97 , 0.03 ],
[0.96 , 0.04 ],
[0.21845346, 0.78154654],
[0.04 , 0.96 ],
[0.01 , 0.99 ],
[0.51 , 0.49 ],
[0.98666667, 0.01333333],
[0.01 , 0.99 ],
[0.06 , 0.94 ],
[0.95369012, 0.04630988],
[0.45 , 0.55 ],
[0.6 , 0.4 ],
[0.05 , 0.95 ],
[0.03 , 0.97 ],
[0.98 , 0.02 ],
[0.97833333, 0.02166667],
[0.48 , 0.52 ],
[0.78 , 0.22 ],
[0.98 , 0.02 ],
[0.96369012, 0.03630988],
[1. , 0. ],
[0.78583333, 0.21416667],
[0.95 , 0.05 ],
[0.945 , 0.055 ],
[0. , 1. ],
[0.8 , 0.2 ],
[0.89 , 0.11 ],
[0.98 , 0.02 ],
[0.03 , 0.97 ],
[0.95666667, 0.04333333],
[0.92 , 0.08 ],
[0.31083333, 0.68916667],
[0.09 , 0.91 ],
[0.95 , 0.05 ],
[0.99 , 0.01 ],
[0.5 , 0.5 ],
[0.1 , 0.9 ],
[0.95 , 0.05 ],
[0.06 , 0.94 ],
[0.14 , 0.86 ],
[0.66133333, 0.33866667],
[1. , 0. ],
[0.43 , 0.57 ],
[0.454 , 0.546 ],
[0.09 , 0.91 ],
[0.82 , 0.18 ],
[0.77 , 0.23 ],
[0.01 , 0.99 ],
[0.09 , 0.91 ],
[0.18 , 0.82 ],
[0.91 , 0.09 ],
[0.67 , 0.33 ],
[0.04 , 0.96 ],
[0.41 , 0.59 ],
[0.71 , 0.29 ],
[0.97433333, 0.02566667],
[0.17026299, 0.82973701],
[0.66 , 0.34 ],
[0.99 , 0.01 ],
[0.49 , 0.51 ],
[0.21 , 0.79 ],
[0.87 , 0.13 ],
[0.02 , 0.98 ],
[0.57504762, 0.42495238],
[0.98 , 0.02 ],
[0.74 , 0.26 ],
[0.99 , 0.01 ],
[0.07 , 0.93 ],
[0.26 , 0.74 ],
[0.9 , 0.1 ],
[0.8 , 0.2 ],
[0.07 , 0.93 ],
[0.93 , 0.07 ],
[0.32 , 0.68 ],
[0.12 , 0.88 ],
[0.955 , 0.045 ],
[0.91171429, 0.08828571],
[0.83 , 0.17 ],
[0.69 , 0.31 ],
[0.37 , 0.63 ],
[1. , 0. ],
[0.91 , 0.09 ],
[0.65891667, 0.34108333],
[0.20331457, 0.79668543],
[0.05 , 0.95 ],
[0.23 , 0.77 ],
[0.45496429, 0.54503571],
[0.82 , 0.18 ],
[0.60490476, 0.39509524],
[0.09 , 0.91 ],
[0.92652742, 0.07347258],
[0.45 , 0.55 ],
[0.9575 , 0.0425 ],
[0.09 , 0.91 ],
[1. , 0. ],
[0.99 , 0.01 ],
[0.43 , 0.57 ],
[0.07 , 0.93 ],
[0.14 , 0.86 ],
[0.34 , 0.66 ],
[0.77333333, 0.22666667],
[0.97833333, 0.02166667],
[0.95 , 0.05 ],
[0.09 , 0.91 ],
[0.3 , 0.7 ],
[0.31 , 0.69 ],
[1. , 0. ],
[0.46621429, 0.53378571],
[0.52 , 0.48 ],
[0.6 , 0.4 ],
[0.56 , 0.44 ],
[0.13 , 0.87 ],
[0.8 , 0.2 ],
[1. , 0. ],
[0.99 , 0.01 ],
[0.02 , 0.98 ],
[0.94 , 0.06 ],
[0.63738889, 0.36261111],
[0.86 , 0.14 ],
[0.97 , 0.03 ],
[0.03 , 0.97 ],
[0. , 1. ],
[0.01 , 0.99 ],
[0.58322222, 0.41677778],
[0.07 , 0.93 ],
[0.81 , 0.19 ],
[0.96 , 0.04 ],
[0.42 , 0.58 ],
[0.03 , 0.97 ],
[0.92 , 0.08 ],
[0.60954762, 0.39045238],
[0.24 , 0.76 ],
[0.97 , 0.03 ],
[0.58 , 0.42 ],
[0.71571429, 0.28428571],
[0.985 , 0.015 ],
[0.37 , 0.63 ],
[0.98 , 0.02 ],
[0.03 , 0.97 ],
[0.96369012, 0.03630988],
[0.94 , 0.06 ],
[0.64 , 0.36 ],
[0.14 , 0.86 ],
[0.53 , 0.47 ],
[0.07 , 0.93 ],
[0.97833333, 0.02166667],
[1. , 0. ],
[0.02 , 0.98 ],
[0.93 , 0.07 ],
[0.37 , 0.63 ],
[0.81 , 0.19 ],
[0.04 , 0.96 ],
[0.97 , 0.03 ],
[0.44 , 0.56 ],
[0.17 , 0.83 ],
[0.38 , 0.62 ],
[0.94 , 0.06 ],
[0.22 , 0.78 ],
[0.82158261, 0.17841739],
[0.30816667, 0.69183333],
[0.73 , 0.27 ],
[0.79 , 0.21 ],
[0.07 , 0.93 ],
[0.89 , 0.11 ],
[0.01 , 0.99 ],
[0.685 , 0.315 ],
[0.52 , 0.48 ],
[0.07 , 0.93 ],
[0.97833333, 0.02166667],
[0.82 , 0.18 ],
[0.99 , 0.01 ],
[0.97 , 0.03 ],
[1. , 0. ],
[1. , 0. ],
[0.97 , 0.03 ],
[0.67 , 0.33 ],
[0.51 , 0.49 ],
[0.79 , 0.21 ],
[0.82158261, 0.17841739],
[0.47496429, 0.52503571],
[0.73 , 0.27 ],
[0.69 , 0.31 ],
[0.98333333, 0.01666667],
[0.05 , 0.95 ],
[0.99 , 0.01 ],
[0.03 , 0.97 ],
[0.08 , 0.92 ],
[0.36 , 0.64 ],
[0.02 , 0.98 ],
[0.79 , 0.21 ],
[0.99 , 0.01 ],
[0.95 , 0.05 ],
[1. , 0. ],
[0.82 , 0.18 ],
[0.47 , 0.53 ],
[1. , 0. ],
[0.57 , 0.43 ],
[0.8 , 0.2 ],
[0.84 , 0.16 ],
[0.67 , 0.33 ],
[0.08 , 0.92 ],
[0.98 , 0.02 ],
[0.72583333, 0.27416667],
[0.36571429, 0.63428571],
[0.05 , 0.95 ],
[0.60490476, 0.39509524],
[0.83 , 0.17 ],
[0.84 , 0.16 ],
[0.1 , 0.9 ],
[0.94 , 0.06 ],
[0.34 , 0.66 ],
[0.86666667, 0.13333333],
[0.17026299, 0.82973701],
[0.975 , 0.025 ],
[0.05 , 0.95 ],
[0.25 , 0.75 ],
[0.96 , 0.04 ],
[0.98380952, 0.01619048],
[0.17 , 0.83 ],
[0.23 , 0.77 ],
[0.89166667, 0.10833333]])
rf_predictionarray([0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0,
1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0,
0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1,
0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0,
1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0,
1, 0, 1, 1, 0, 0, 1, 1, 0], dtype=int64)
Важность признаков
fi = pd.DataFrame(RF.feature_importances_, RF.feature_names_in_)
fi| 0 | |
|---|---|
| Pclass | 0.087579 |
| Sex | 0.249358 |
| Age | 0.255040 |
| SibSp | 0.051566 |
| Parch | 0.040596 |
| Fare | 0.275444 |
| Embarked | 0.040416 |
sns.barplot(fi.T)<Axes: >

X_train_sef = X_train[['Sex', 'Age', 'Fare']]
X_test_sef = X_test[['Sex', 'Age', 'Fare']]
RF_sef = RandomForestClassifier()
RF_sef.fit(X_train_sef, y_train)
rf_sef_prediction =RF_sef.predict(X_test_sef )print('Conf matrix')
metrics.confusion_matrix(rf_sef_prediction, y_test)
print('Classification report')
print(metrics.classification_report(rf_sef_prediction, y_test))Conf matrix
Classification report
precision recall f1-score support
0 0.84 0.80 0.82 184
1 0.69 0.75 0.72 111
accuracy 0.78 295
macro avg 0.77 0.77 0.77 295
weighted avg 0.78 0.78 0.78 295
XGBoost
from xgboost import XGBClassifier
XGB = XGBClassifier()
XGB.fit(X_train, y_train)
xgb_prediction =XGB.predict(X_test )
print('Conf matrix')
print(metrics.confusion_matrix(xgb_prediction, y_test))
print('Classification report')
print(metrics.classification_report(xgb_prediction, y_test))Conf matrix
[[141 31]
[ 34 89]]
Classification report
precision recall f1-score support
0 0.81 0.82 0.81 172
1 0.74 0.72 0.73 123
accuracy 0.78 295
macro avg 0.77 0.77 0.77 295
weighted avg 0.78 0.78 0.78 295
XGB_sef = XGBClassifier()
XGB_sef.fit(X_train_sef, y_train)
xgb_sef_prediction =XGB_sef.predict(X_test_sef)
print('Conf matrix')
print(metrics.confusion_matrix(xgb_sef_prediction, y_test))
print('Classification report')
print(metrics.classification_report(xgb_sef_prediction, y_test))Conf matrix
[[150 37]
[ 25 83]]
Classification report
precision recall f1-score support
0 0.86 0.80 0.83 187
1 0.69 0.77 0.73 108
accuracy 0.79 295
macro avg 0.77 0.79 0.78 295
weighted avg 0.80 0.79 0.79 295
sns.heatmap(metrics.confusion_matrix(xgb_sef_prediction, y_test), annot = True, cmap = 'RdYlGn')<Axes: >

Попробуем настроить параметры для RF
X_train_sef = X_train[['Sex', 'Age', 'Fare']]
X_test_sef = X_test[['Sex', 'Age', 'Fare']]
params = {
'n_estimators': [1, 10, 50, 100],
'max_depth': [1, 5, 10],
'criterion':['gini', 'entropy']
}
RF_sef_gs = RandomForestClassifier()
gs = GridSearchCV(param_grid=params, estimator=RF_sef_gs)
gs.fit(X_train_sef, y_train)
rf_sef_gs_prediction=gs.predict(X_test_sef )gs.best_params_{'criterion': 'entropy', 'max_depth': 5, 'n_estimators': 50}
print('Conf matrix')
print(metrics.confusion_matrix(rf_sef_gs_prediction, y_test))
print('Classification report')
print(metrics.classification_report(rf_sef_gs_prediction, y_test))Conf matrix
[[154 34]
[ 21 86]]
Classification report
precision recall f1-score support
0 0.88 0.82 0.85 188
1 0.72 0.80 0.76 107
accuracy 0.81 295
macro avg 0.80 0.81 0.80 295
weighted avg 0.82 0.81 0.82 295