{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "2b6fb168", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "\n", "from sklearn import metrics\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.model_selection import cross_val_score, train_test_split, GridSearchCV, StratifiedKFold\n", "\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "from sklearn.manifold import TSNE\n", "from sklearn.manifold import MDS\n", "\n", "\n", "from sklearn import preprocessing" ] }, { "cell_type": "code", "execution_count": 2, "id": "75f28a82", "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('titanic.csv')" ] }, { "cell_type": "markdown", "id": "34567fe6", "metadata": {}, "source": [ "# Знакомство с датасетом" ] }, { "cell_type": "code", "execution_count": 3, "id": "26b735e0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
.......................................
88688702Montvila, Rev. Juozasmale27.00021153613.0000NaNS
88788811Graham, Miss. Margaret Edithfemale19.00011205330.0000B42S
88888903Johnston, Miss. Catherine Helen \"Carrie\"femaleNaN12W./C. 660723.4500NaNS
88989011Behr, Mr. Karl Howellmale26.00011136930.0000C148C
89089103Dooley, Mr. Patrickmale32.0003703767.7500NaNQ
\n", "

891 rows × 12 columns

\n", "
" ], "text/plain": [ " PassengerId Survived Pclass \\\n", "0 1 0 3 \n", "1 2 1 1 \n", "2 3 1 3 \n", "3 4 1 1 \n", "4 5 0 3 \n", ".. ... ... ... \n", "886 887 0 2 \n", "887 888 1 1 \n", "888 889 0 3 \n", "889 890 1 1 \n", "890 891 0 3 \n", "\n", " Name Sex Age SibSp \\\n", "0 Braund, Mr. Owen Harris male 22.0 1 \n", "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", "2 Heikkinen, Miss. Laina female 26.0 0 \n", "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", "4 Allen, Mr. William Henry male 35.0 0 \n", ".. ... ... ... ... \n", "886 Montvila, Rev. Juozas male 27.0 0 \n", "887 Graham, Miss. Margaret Edith female 19.0 0 \n", "888 Johnston, Miss. Catherine Helen \"Carrie\" female NaN 1 \n", "889 Behr, Mr. Karl Howell male 26.0 0 \n", "890 Dooley, Mr. Patrick male 32.0 0 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 A/5 21171 7.2500 NaN S \n", "1 0 PC 17599 71.2833 C85 C \n", "2 0 STON/O2. 3101282 7.9250 NaN S \n", "3 0 113803 53.1000 C123 S \n", "4 0 373450 8.0500 NaN S \n", ".. ... ... ... ... ... \n", "886 0 211536 13.0000 NaN S \n", "887 0 112053 30.0000 B42 S \n", "888 2 W./C. 6607 23.4500 NaN S \n", "889 0 111369 30.0000 C148 C \n", "890 0 370376 7.7500 NaN Q \n", "\n", "[891 rows x 12 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 4, "id": "546f361a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 891 entries, 0 to 890\n", "Data columns (total 12 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 PassengerId 891 non-null int64 \n", " 1 Survived 891 non-null int64 \n", " 2 Pclass 891 non-null int64 \n", " 3 Name 891 non-null object \n", " 4 Sex 891 non-null object \n", " 5 Age 714 non-null float64\n", " 6 SibSp 891 non-null int64 \n", " 7 Parch 891 non-null int64 \n", " 8 Ticket 891 non-null object \n", " 9 Fare 891 non-null float64\n", " 10 Cabin 204 non-null object \n", " 11 Embarked 889 non-null object \n", "dtypes: float64(2), int64(5), object(5)\n", "memory usage: 83.7+ KB\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "code", "execution_count": 5, "id": "7136451b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassAgeSibSpParchFare
count891.000000891.000000891.000000714.000000891.000000891.000000891.000000
mean446.0000000.3838382.30864229.6991180.5230080.38159432.204208
std257.3538420.4865920.83607114.5264971.1027430.80605749.693429
min1.0000000.0000001.0000000.4200000.0000000.0000000.000000
25%223.5000000.0000002.00000020.1250000.0000000.0000007.910400
50%446.0000000.0000003.00000028.0000000.0000000.00000014.454200
75%668.5000001.0000003.00000038.0000001.0000000.00000031.000000
max891.0000001.0000003.00000080.0000008.0000006.000000512.329200
\n", "
" ], "text/plain": [ " PassengerId Survived Pclass Age SibSp \\\n", "count 891.000000 891.000000 891.000000 714.000000 891.000000 \n", "mean 446.000000 0.383838 2.308642 29.699118 0.523008 \n", "std 257.353842 0.486592 0.836071 14.526497 1.102743 \n", "min 1.000000 0.000000 1.000000 0.420000 0.000000 \n", "25% 223.500000 0.000000 2.000000 20.125000 0.000000 \n", "50% 446.000000 0.000000 3.000000 28.000000 0.000000 \n", "75% 668.500000 1.000000 3.000000 38.000000 1.000000 \n", "max 891.000000 1.000000 3.000000 80.000000 8.000000 \n", "\n", " Parch Fare \n", "count 891.000000 891.000000 \n", "mean 0.381594 32.204208 \n", "std 0.806057 49.693429 \n", "min 0.000000 0.000000 \n", "25% 0.000000 7.910400 \n", "50% 0.000000 14.454200 \n", "75% 0.000000 31.000000 \n", "max 6.000000 512.329200 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.describe()" ] }, { "cell_type": "markdown", "id": "2f55a184", "metadata": {}, "source": [ "# Предварительная обработка" ] }, { "cell_type": "code", "execution_count": 6, "id": "6991cc16", "metadata": {}, "outputs": [], "source": [ "# Удаляем ненужные столбцы\n", "df = df.drop(['PassengerId', 'Ticket', 'Cabin', 'Name'], axis=1)" ] }, { "cell_type": "code", "execution_count": 7, "id": "7826ffb8", "metadata": {}, "outputs": [], "source": [ "# Кодируем поле Пол\n", "df.loc[df['Sex'] == 'male', 'Sex'] = 1\n", "df.loc[df['Sex'] == 'female', 'Sex'] = 0\n", "df.Sex = df.Sex.astype(bool)" ] }, { "cell_type": "code", "execution_count": null, "id": "dc36c0d9", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 8, "id": "f90a0ba2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassSexAgeSibSpParchFareEmbarked
003True22.0107.25002
111False38.01071.28330
213False26.0007.92502
311False35.01053.10002
403True35.0008.05002
...........................
88602True27.00013.00002
88711False19.00030.00002
88803FalseNaN1223.45002
88911True26.00030.00000
89003True32.0007.75001
\n", "

891 rows × 8 columns

\n", "
" ], "text/plain": [ " Survived Pclass Sex Age SibSp Parch Fare Embarked\n", "0 0 3 True 22.0 1 0 7.2500 2\n", "1 1 1 False 38.0 1 0 71.2833 0\n", "2 1 3 False 26.0 0 0 7.9250 2\n", "3 1 1 False 35.0 1 0 53.1000 2\n", "4 0 3 True 35.0 0 0 8.0500 2\n", ".. ... ... ... ... ... ... ... ...\n", "886 0 2 True 27.0 0 0 13.0000 2\n", "887 1 1 False 19.0 0 0 30.0000 2\n", "888 0 3 False NaN 1 2 23.4500 2\n", "889 1 1 True 26.0 0 0 30.0000 0\n", "890 0 3 True 32.0 0 0 7.7500 1\n", "\n", "[891 rows x 8 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Кодируем поле Embarked\n", "df.Embarked.fillna(df.Embarked.mode()[0],inplace=True)\n", "\n", "le = preprocessing.LabelEncoder()\n", "le.fit(df['Embarked'])\n", "df['Embarked'] = le.transform(df['Embarked'])\n", "df" ] }, { "cell_type": "code", "execution_count": null, "id": "01a0a715", "metadata": { "scrolled": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "2aaacde7", "metadata": {}, "outputs": [], "source": [ "# Заполняем возраст медианой\n", "df.Age.fillna(df.Age.median(), inplace = True)" ] }, { "cell_type": "code", "execution_count": null, "id": "17786cd0", "metadata": {}, "outputs": [], "source": [ "df.info()" ] }, { "cell_type": "code", "execution_count": null, "id": "45039730", "metadata": {}, "outputs": [], "source": [ " df['Survived'].value_counts()" ] }, { "cell_type": "markdown", "id": "8a170235", "metadata": {}, "source": [ "# Визуализация датасета" ] }, { "cell_type": "code", "execution_count": null, "id": "8abc136d", "metadata": {}, "outputs": [], "source": [ "tsne = TSNE(n_components=2, \n", " init=\"pca\", \n", " random_state=0,\n", " perplexity=50,\n", " n_iter = 1000,\n", " metric = 'cosine')" ] }, { "cell_type": "code", "execution_count": null, "id": "b248c498", "metadata": {}, "outputs": [], "source": [ "Y = tsne.fit_transform(df.iloc[:,1:])" ] }, { "cell_type": "code", "execution_count": null, "id": "cccd1aa8", "metadata": {}, "outputs": [], "source": [ "plt.scatter(Y[:,0], Y[:,1], c = df.iloc[:,0])" ] }, { "cell_type": "code", "execution_count": null, "id": "5401296d", "metadata": {}, "outputs": [], "source": [ "# Multidimentional scaling\n", "mds = MDS(n_components=2, \n", " random_state=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "835c6abe", "metadata": {}, "outputs": [], "source": [ "Y_MDS = mds.fit_transform(df.iloc[:,1:])" ] }, { "cell_type": "code", "execution_count": null, "id": "c0306f25", "metadata": {}, "outputs": [], "source": [ "plt.scatter(Y_MDS[:,0], Y_MDS[:,1], c = df.iloc[:,0])" ] }, { "cell_type": "markdown", "id": "f272811a", "metadata": {}, "source": [ "\n", "## Графики по столбцам" ] }, { "cell_type": "code", "execution_count": null, "id": "8f42d003", "metadata": {}, "outputs": [], "source": [ "sns.countplot(x = df.Survived)" ] }, { "cell_type": "code", "execution_count": null, "id": "959232fa", "metadata": {}, "outputs": [], "source": [ "sns.countplot(x = df.Pclass, hue = df.Survived)" ] }, { "cell_type": "code", "execution_count": null, "id": "00bbd94f", "metadata": {}, "outputs": [], "source": [ "sns.set(rc={'figure.figsize':(10,5)})\n", "sns.histplot(x = df.loc[df.Fare < 200].Fare, hue = df.Survived, kde=True)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1ae491dd", "metadata": {}, "outputs": [], "source": [ "sns.set(rc={'figure.figsize':(10,5)})\n", "sns.histplot(x = df.Age, hue = df.Survived, kde=True )\n" ] }, { "cell_type": "code", "execution_count": null, "id": "aafbfaaf", "metadata": {}, "outputs": [], "source": [ "sns.countplot(x = df.Sex, hue = df.Survived)" ] }, { "cell_type": "code", "execution_count": null, "id": "ddb15368", "metadata": {}, "outputs": [], "source": [ "sns.set(rc={'figure.figsize':(10,5)})\n", "sns.histplot(x = df.SibSp, hue = df.Survived, kde=True )\n" ] }, { "cell_type": "markdown", "id": "3e34f4ec", "metadata": {}, "source": [ "## Корреляционная матрица" ] }, { "cell_type": "code", "execution_count": null, "id": "52722264", "metadata": {}, "outputs": [], "source": [ "sns.heatmap(df.corr(numeric_only = True), annot = True, vmin=-1, vmax=1, cmap = 'bwr')" ] }, { "cell_type": "markdown", "id": "345f8d73", "metadata": {}, "source": [ "# Классификация" ] }, { "cell_type": "code", "execution_count": null, "id": "6ce53d2c", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(df.iloc[:,1:], df.iloc[:,0], test_size=0.33, random_state=42) " ] }, { "cell_type": "code", "execution_count": null, "id": "f868d885", "metadata": {}, "outputs": [], "source": [ "RF = RandomForestClassifier(random_state=42)\n", "RF.fit(X_train, y_train)\n", "rf_prediction = RF.predict(X_test)" ] }, { "cell_type": "code", "execution_count": null, "id": "378e01b0", "metadata": {}, "outputs": [], "source": [ "print('Conf matrix')\n", "print(metrics.confusion_matrix(rf_prediction, y_test))\n", "print('Classification report')\n", "print(metrics.classification_report(rf_prediction, y_test))" ] }, { "cell_type": "code", "execution_count": null, "id": "7087f545", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Вероятности каждого класса\n", "rf_prediction_proba = RF.predict_proba(X_test)\n", "rf_prediction_proba" ] }, { "cell_type": "code", "execution_count": null, "id": "34f7f23b", "metadata": {}, "outputs": [], "source": [ "rf_prediction" ] }, { "cell_type": "markdown", "id": "a7dba003", "metadata": {}, "source": [ "## Важность признаков" ] }, { "cell_type": "code", "execution_count": null, "id": "e5652417", "metadata": {}, "outputs": [], "source": [ "fi = pd.DataFrame(RF.feature_importances_, RF.feature_names_in_)\n", "fi" ] }, { "cell_type": "code", "execution_count": null, "id": "6ca8c32a", "metadata": {}, "outputs": [], "source": [ " sns.barplot(fi.T)" ] }, { "cell_type": "code", "execution_count": null, "id": "e8fea763", "metadata": {}, "outputs": [], "source": [ "X_train_saf = X_train[['Sex', 'Age', 'Fare']]\n", "X_test_saf = X_test[['Sex', 'Age', 'Fare']]\n", "RF_saf = RandomForestClassifier(random_state=42)\n", "RF_saf.fit(X_train_saf, y_train)\n", "rf_saf_prediction =RF_saf.predict(X_test_saf )" ] }, { "cell_type": "code", "execution_count": null, "id": "7ba5e827", "metadata": {}, "outputs": [], "source": [ "print('Conf matrix')\n", "metrics.confusion_matrix(rf_saf_prediction, y_test)\n", "print('Classification report')\n", "print(metrics.classification_report(rf_saf_prediction, y_test))" ] }, { "cell_type": "markdown", "id": "81cffad2", "metadata": {}, "source": [ "## XGBoost" ] }, { "cell_type": "code", "execution_count": null, "id": "4dae05fa", "metadata": {}, "outputs": [], "source": [ "from xgboost import XGBClassifier" ] }, { "cell_type": "code", "execution_count": null, "id": "824c59ca", "metadata": {}, "outputs": [], "source": [ "\n", "XGB = XGBClassifier(random_state=42)\n", "XGB.fit(X_train, y_train)\n", "xgb_prediction =XGB.predict(X_test )\n", "print('Conf matrix')\n", "print(metrics.confusion_matrix(xgb_prediction, y_test))\n", "print('Classification report')\n", "print(metrics.classification_report(xgb_prediction, y_test))" ] }, { "cell_type": "code", "execution_count": null, "id": "7b73d67b", "metadata": {}, "outputs": [], "source": [ "\n", "XGB_saf = XGBClassifier(random_state=42)\n", "XGB_saf.fit(X_train_saf, y_train)\n", "xgb_saf_prediction =XGB_saf.predict(X_test_saf)\n", "print('Conf matrix')\n", "print(metrics.confusion_matrix(xgb_saf_prediction, y_test))\n", "print('Classification report')\n", "print(metrics.classification_report(xgb_saf_prediction, y_test))" ] }, { "cell_type": "code", "execution_count": null, "id": "3964f444", "metadata": {}, "outputs": [], "source": [ "sns.heatmap(metrics.confusion_matrix(xgb_saf_prediction, y_test), annot = True, cmap = 'RdYlGn')" ] }, { "cell_type": "markdown", "id": "837596c9", "metadata": {}, "source": [ "## Попробуем настроить параметры для RF" ] }, { "cell_type": "code", "execution_count": null, "id": "8ea4088c", "metadata": {}, "outputs": [], "source": [ "#X_train_sef = X_train[['Sex', 'Age', 'Fare']]\n", "#X_test_sef = X_test[['Sex', 'Age', 'Fare']]\n", "\n", "params = {\n", " 'n_estimators': [1, 10, 50, 100],\n", " 'max_depth': [1, 5, 10],\n", " 'criterion':['gini', 'entropy']\n", "}\n", "\n", "\n", "RF_saf_gs = RandomForestClassifier(random_state=42)\n", "\n", "gs = GridSearchCV(param_grid=params, estimator=RF_saf_gs)\n", "gs.fit(X_train_saf, y_train)\n", "rf_saf_gs_prediction=gs.predict(X_test_saf )\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1eaa5d33", "metadata": {}, "outputs": [], "source": [ "gs.best_params_" ] }, { "cell_type": "code", "execution_count": null, "id": "bb7206ef", "metadata": {}, "outputs": [], "source": [ "print('Conf matrix')\n", "print(metrics.confusion_matrix(rf_saf_gs_prediction, y_test))\n", "print('Classification report')\n", "print(metrics.classification_report(rf_saf_gs_prediction, y_test))" ] }, { "cell_type": "code", "execution_count": null, "id": "08d86012", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "7e98ede8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }