{ "cells": [ { "cell_type": "markdown", "id": "5c39c249", "metadata": {}, "source": [ "# Исследование и настройка предсказательной модели для цен подержанных автомобилях" ] }, { "cell_type": "markdown", "id": "f8ee2da9", "metadata": {}, "source": [ "Блокнот использует файл аугментированных данных датасета о подержанных автомобилях, создаваемый блокнотом `eda/cars_eda.py`. См. ниже параметры блокнота для papermill." ] }, { "cell_type": "code", "execution_count": 1, "id": "030077f5-b1e3-4b5a-9e2f-4dc83a7bfa1e", "metadata": {}, "outputs": [], "source": [ "#XXX: разделить блокнот штук на 5" ] }, { "cell_type": "code", "execution_count": 2, "id": "2a9483a4", "metadata": {}, "outputs": [], "source": [ "from typing import Optional" ] }, { "cell_type": "code", "execution_count": 3, "id": "3d7aae3e", "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "data_aug_pickle_path: Optional[str] = None\n", "# Полный путь к файлу (pickle) для сохранения очищенного датасета. Если не установлен, используется `data/`.\n", "data_aug_pickle_relpath: str = 'cars.aug.pickle'\n", "# Путь к файлу (pickle) для сохранения очищенного датасета относительно директории данных `data`. Игнорируется, если установлен data_aug_pickle_path.\n", "\n", "#model_global_comment_path: Optional[str] = None\n", "## Полный путь к текстовому файлу с произвольным комментарием для сохранения в MLFlow как артефакт вместе с моделью. Если не установлен, используется `research/`.\n", "#model_comment_relpath: str = 'comment.txt'\n", "## Путь к текстовому файлу с произвольным комментарием для сохранения в MLFlow как артефакт вместе с моделью относительно директории `research`. Игнорируется, если установлен comment_path.\n", "\n", "mlflow_tracking_server_uri: str = 'http://localhost:5000'\n", "# URL tracking-сервера MLFlow.\n", "mlflow_registry_uri: Optional[str] = None\n", "# URL сервера registry MLFlow (если не указан, используется `mlflow_tracking_server_uri`).\n", "\n", "mlflow_do_log: bool = False\n", "# Записывать ли прогоны (runs) в MLFlow.\n", "mlflow_experiment_id: Optional[str] = None\n", "# ID эксперимента MLFlow, имеет приоритет над `mlflow_experiment_name`.\n", "mlflow_experiment_name: Optional[str] = 'Current price predicion for used cars'\n", "# Имя эксперимента MLFlow (ниже приоритетом, чем `mlflow_experiment_id`).\n", "mlflow_root_run_name: str = 'Models'\n", "# Имя корневого прогона MLFlow (остальные прогоны будут созданы блокнотом внутри этого, как nested)" ] }, { "cell_type": "code", "execution_count": 4, "id": "7afe82f3", "metadata": {}, "outputs": [], "source": [ "from collections.abc import Collection, Sequence\n", "import os\n", "import pathlib\n", "import pickle\n", "import sys" ] }, { "cell_type": "code", "execution_count": 5, "id": "a02f69a7", "metadata": {}, "outputs": [], "source": [ "import matplotlib\n", "import mlflow\n", "import mlflow.models\n", "import mlflow.sklearn\n", "import mlxtend.feature_selection\n", "import mlxtend.plotting\n", "import optuna\n", "import optuna.samplers\n", "import sklearn.compose\n", "import sklearn.ensemble\n", "import sklearn.metrics\n", "import sklearn.model_selection\n", "import sklearn.pipeline\n", "import sklearn.preprocessing" ] }, { "cell_type": "code", "execution_count": 6, "id": "f9a47ec5", "metadata": {}, "outputs": [], "source": [ "BASE_PATH = pathlib.Path('..')" ] }, { "cell_type": "code", "execution_count": 7, "id": "66d6fe3d", "metadata": {}, "outputs": [], "source": [ "CODE_PATH = BASE_PATH\n", "sys.path.insert(0, str(CODE_PATH.resolve()))" ] }, { "cell_type": "code", "execution_count": 8, "id": "1c227e7d", "metadata": {}, "outputs": [], "source": [ "from iis_project.mlxtend_utils.feature_selection import SEQUENTIAL_FEATURE_SELECTOR_PARAMS_COMMON_INCLUDE\n", "from iis_project.sklearn_utils import filter_params\n", "from iis_project.sklearn_utils.compose import COLUMN_TRANSFORMER_PARAMS_COMMON_INCLUDE\n", "from iis_project.sklearn_utils.ensemble import RANDOM_FOREST_REGRESSOR_PARAMS_COMMON_EXCLUDE\n", "from iis_project.sklearn_utils.pandas import pandas_dataframe_from_transformed_artifacts\n", "from iis_project.sklearn_utils.preprocessing import STANDARD_SCALER_PARAMS_COMMON_EXCLUDE" ] }, { "cell_type": "code", "execution_count": 9, "id": "0b847527", "metadata": {}, "outputs": [], "source": [ "MODEL_INOUT_EXAMPLE_SIZE = 0x10" ] }, { "cell_type": "code", "execution_count": 10, "id": "2a3a7a2e", "metadata": {}, "outputs": [], "source": [ "mlflow.set_tracking_uri(mlflow_tracking_server_uri)\n", "if mlflow_registry_uri is not None:\n", " mlflow.set_registry_uri(mlflow_registry_uri)" ] }, { "cell_type": "code", "execution_count": 11, "id": "4f60bfaa", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025/11/02 01:54:17 INFO mlflow.tracking.fluent: Experiment with name 'Current price predicion for used cars' does not exist. Creating a new experiment.\n" ] } ], "source": [ "if mlflow_do_log:\n", " mlflow_experiment = mlflow.set_experiment(experiment_name=mlflow_experiment_name, experiment_id=mlflow_experiment_id)\n", " mlflow_root_run_id = None # изменяется позже" ] }, { "cell_type": "code", "execution_count": 12, "id": "97d23eb9", "metadata": {}, "outputs": [], "source": [ "DATA_PATH = (\n", " pathlib.Path(os.path.dirname(data_aug_pickle_path))\n", " if data_aug_pickle_path is not None\n", " else (BASE_PATH / 'data')\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "493e9bd3-463c-41f6-b32f-f08bdf4a6323", "metadata": {}, "outputs": [], "source": [ "def build_sequential_feature_selector(*args, **kwargs):\n", " return mlxtend.feature_selection.SequentialFeatureSelector(*args, **kwargs)\n", "\n", "def plot_sequential_feature_selection(feature_selector, *args_rest, **kwargs):\n", " metric_dict = feature_selector.get_metric_dict()\n", " return mlxtend.plotting.plot_sequential_feature_selection(metric_dict, *args_rest, **kwargs)" ] }, { "cell_type": "markdown", "id": "4b20cbda", "metadata": {}, "source": [ "## Загрузка и обзор данных" ] }, { "cell_type": "code", "execution_count": 14, "id": "e2b45fd1", "metadata": {}, "outputs": [], "source": [ "with open(\n", " (\n", " data_aug_pickle_path\n", " if data_aug_pickle_path is not None\n", " else (DATA_PATH / data_aug_pickle_relpath)\n", " ),\n", " 'rb',\n", ") as input_file:\n", " df_orig = pickle.load(input_file)" ] }, { "cell_type": "markdown", "id": "c3ef97d1", "metadata": {}, "source": [ "Обзор датасета:" ] }, { "cell_type": "code", "execution_count": 15, "id": "d45da024", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "299" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(df_orig)" ] }, { "cell_type": "code", "execution_count": 16, "id": "75b0feea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index: 299 entries, 0 to 300\n", "Data columns (total 15 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 car_name 299 non-null object \n", " 1 year 299 non-null int64 \n", " 2 selling_price 299 non-null float64 \n", " 3 present_price 299 non-null float64 \n", " 4 driven_kms 299 non-null int64 \n", " 5 fuel_type 299 non-null category\n", " 6 selling_type 299 non-null category\n", " 7 transmission 299 non-null category\n", " 8 owner 299 non-null category\n", " 9 age 299 non-null float64 \n", " 10 present_price_ratio 299 non-null float64 \n", " 11 log_selling_price 299 non-null float64 \n", " 12 log_present_price 299 non-null float64 \n", " 13 log_driven_kms 299 non-null float64 \n", " 14 log_age 299 non-null float64 \n", "dtypes: category(4), float64(8), int64(2), object(1)\n", "memory usage: 29.3+ KB\n" ] } ], "source": [ "df_orig.info()" ] }, { "cell_type": "code", "execution_count": 17, "id": "5b336654", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
car_nameyearselling_pricepresent_pricedriven_kmsfuel_typeselling_typetransmissionowneragepresent_price_ratiolog_selling_pricelog_present_pricelog_driven_kmslog_age
0ritz20145.593.3527000petroldealermanual05.00.5992840.7474120.5250454.4313640.698970
1sx420139.544.7543000dieseldealermanual06.00.4979040.9795480.6766944.6334680.778151
2ciaz20179.857.256900petroldealermanual02.00.7360410.9934360.8603383.8388490.301030
3wagon r20114.152.855200petroldealermanual08.00.6867470.6180480.4548453.7160030.903090
4swift20146.874.6042450dieseldealermanual05.00.6695780.8369570.6627584.6278780.698970
5vitara brezza20189.839.252071dieseldealermanual01.00.9409970.9925540.9661423.3161800.000000
6ciaz20158.126.7518796petroldealermanual04.00.8312810.9095560.8293044.2740650.602060
7s cross20158.616.5033429dieseldealermanual04.00.7549360.9350030.8129134.5241230.602060
8ciaz20168.898.7520273dieseldealermanual03.00.9842520.9489020.9420084.3069180.477121
9ciaz20158.927.4542367dieseldealermanual04.00.8352020.9503650.8721564.6270280.602060
10alto 80020173.602.852135petroldealermanual02.00.7916670.5563030.4548453.3293980.301030
11ciaz201510.386.8551000dieseldealermanual04.00.6599231.0161970.8356914.7075700.602060
12ciaz20159.947.5015000petroldealerautomatic04.00.7545270.9973860.8750614.1760910.602060
13ertiga20157.716.1026000petroldealermanual04.00.7911800.8870540.7853304.4149730.602060
14dzire20097.212.2577427petroldealermanual010.00.3120670.8579350.3521834.8888921.000000
15ertiga201610.797.7543000dieseldealermanual03.00.7182581.0330210.8893024.6334680.477121
\n", "
" ], "text/plain": [ " car_name year selling_price present_price driven_kms fuel_type \\\n", "0 ritz 2014 5.59 3.35 27000 petrol \n", "1 sx4 2013 9.54 4.75 43000 diesel \n", "2 ciaz 2017 9.85 7.25 6900 petrol \n", "3 wagon r 2011 4.15 2.85 5200 petrol \n", "4 swift 2014 6.87 4.60 42450 diesel \n", "5 vitara brezza 2018 9.83 9.25 2071 diesel \n", "6 ciaz 2015 8.12 6.75 18796 petrol \n", "7 s cross 2015 8.61 6.50 33429 diesel \n", "8 ciaz 2016 8.89 8.75 20273 diesel \n", "9 ciaz 2015 8.92 7.45 42367 diesel \n", "10 alto 800 2017 3.60 2.85 2135 petrol \n", "11 ciaz 2015 10.38 6.85 51000 diesel \n", "12 ciaz 2015 9.94 7.50 15000 petrol \n", "13 ertiga 2015 7.71 6.10 26000 petrol \n", "14 dzire 2009 7.21 2.25 77427 petrol \n", "15 ertiga 2016 10.79 7.75 43000 diesel \n", "\n", " selling_type transmission owner age present_price_ratio \\\n", "0 dealer manual 0 5.0 0.599284 \n", "1 dealer manual 0 6.0 0.497904 \n", "2 dealer manual 0 2.0 0.736041 \n", "3 dealer manual 0 8.0 0.686747 \n", "4 dealer manual 0 5.0 0.669578 \n", "5 dealer manual 0 1.0 0.940997 \n", "6 dealer manual 0 4.0 0.831281 \n", "7 dealer manual 0 4.0 0.754936 \n", "8 dealer manual 0 3.0 0.984252 \n", "9 dealer manual 0 4.0 0.835202 \n", "10 dealer manual 0 2.0 0.791667 \n", "11 dealer manual 0 4.0 0.659923 \n", "12 dealer automatic 0 4.0 0.754527 \n", "13 dealer manual 0 4.0 0.791180 \n", "14 dealer manual 0 10.0 0.312067 \n", "15 dealer manual 0 3.0 0.718258 \n", "\n", " log_selling_price log_present_price log_driven_kms log_age \n", "0 0.747412 0.525045 4.431364 0.698970 \n", "1 0.979548 0.676694 4.633468 0.778151 \n", "2 0.993436 0.860338 3.838849 0.301030 \n", "3 0.618048 0.454845 3.716003 0.903090 \n", "4 0.836957 0.662758 4.627878 0.698970 \n", "5 0.992554 0.966142 3.316180 0.000000 \n", "6 0.909556 0.829304 4.274065 0.602060 \n", "7 0.935003 0.812913 4.524123 0.602060 \n", "8 0.948902 0.942008 4.306918 0.477121 \n", "9 0.950365 0.872156 4.627028 0.602060 \n", "10 0.556303 0.454845 3.329398 0.301030 \n", "11 1.016197 0.835691 4.707570 0.602060 \n", "12 0.997386 0.875061 4.176091 0.602060 \n", "13 0.887054 0.785330 4.414973 0.602060 \n", "14 0.857935 0.352183 4.888892 1.000000 \n", "15 1.033021 0.889302 4.633468 0.477121 " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_orig.head(0x10)" ] }, { "cell_type": "markdown", "id": "e39f88d0", "metadata": {}, "source": [ "## Разделение датасета на выборки" ] }, { "cell_type": "markdown", "id": "df5b723b", "metadata": {}, "source": [ "Выделение признаков и целевых переменных:" ] }, { "cell_type": "code", "execution_count": 18, "id": "7a24a133", "metadata": {}, "outputs": [], "source": [ "feature_columns = (\n", " 'selling_price',\n", " 'driven_kms',\n", " 'fuel_type',\n", " 'selling_type',\n", " 'transmission',\n", " #'owner',\n", " 'age',\n", ")\n", "\n", "target_columns = (\n", " 'present_price',\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "id": "f527d556", "metadata": {}, "outputs": [], "source": [ "features_to_scale_to_standard_columns = (\n", " 'selling_price',\n", " 'driven_kms',\n", " 'age',\n", ")\n", "assert all(\n", " (col in df_orig.select_dtypes(('number',)).columns)\n", " for col in features_to_scale_to_standard_columns\n", ")\n", "\n", "features_to_encode_wrt_target_columns = (\n", " 'fuel_type',\n", " 'selling_type',\n", " 'transmission',\n", " #'owner',\n", ")\n", "assert all(\n", " (col in df_orig.select_dtypes(('category', 'object')).columns)\n", " for col in features_to_encode_wrt_target_columns\n", ")" ] }, { "cell_type": "code", "execution_count": 20, "id": "8ce8c469", "metadata": {}, "outputs": [], "source": [ "df_orig_features = df_orig[list(feature_columns)]\n", "df_target = df_orig[list(target_columns)]" ] }, { "cell_type": "markdown", "id": "c82f9d7a", "metadata": {}, "source": [ "Разделение на обучающую и тестовую выборки:" ] }, { "cell_type": "code", "execution_count": 21, "id": "c9ba918e", "metadata": {}, "outputs": [], "source": [ "DF_TEST_PORTION = 0.25" ] }, { "cell_type": "code", "execution_count": 22, "id": "0147b1d6", "metadata": {}, "outputs": [], "source": [ "df_orig_features_train, df_orig_features_test, df_target_train, df_target_test = (\n", " sklearn.model_selection.train_test_split(\n", " df_orig_features, df_target, test_size=DF_TEST_PORTION, random_state=0x7AE6,\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "2f2e9fad", "metadata": {}, "source": [ "Размеры обучающей и тестовой выборки соответственно:" ] }, { "cell_type": "code", "execution_count": 23, "id": "dc58ff10", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(224, 75)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tuple(map(len, (df_target_train, df_target_test)))" ] }, { "cell_type": "markdown", "id": "d9ddbdc7", "metadata": {}, "source": [ "## Модели" ] }, { "cell_type": "code", "execution_count": 24, "id": "97a58917", "metadata": {}, "outputs": [], "source": [ "# XXX: один файл requirements для всех моделей\n", "MODEL_PIP_REQUIREMENTS_PATH = BASE_PATH / 'requirements' / 'requirements-isolated-research-model.txt'" ] }, { "cell_type": "markdown", "id": "4639cc98", "metadata": {}, "source": [ "Сигнатура модели для MLFlow:" ] }, { "cell_type": "code", "execution_count": 25, "id": "a78986be", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "D:\\studying\\university\\projects\\sem_03_iis\\mpei-iis-project\\.venv\\Lib\\site-packages\\mlflow\\types\\utils.py:452: UserWarning: Hint: Inferred schema contains integer column(s). Integer columns in Python cannot represent missing values. If your input data contains missing values at inference time, it will be encoded as floats and will cause a schema enforcement error. The best way to avoid this problem is to infer the model schema based on a realistic data sample (training dataset) that includes missing values. Alternatively, you can declare integer columns as doubles (float64) whenever these columns may have missing values. See `Handling Integers With Missing Values `_ for more details.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "inputs: \n", " ['selling_price': double (required), 'driven_kms': long (required), 'fuel_type': string (required), 'selling_type': string (required), 'transmission': string (required), 'age': double (required)]\n", "outputs: \n", " ['present_price': double (required)]\n", "params: \n", " None" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlflow_model_signature = mlflow.models.infer_signature(model_input=df_orig_features, model_output=df_target)\n", "mlflow_model_signature" ] }, { "cell_type": "raw", "id": "691d63bf", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "input_schema = mlflow.types.schema.Schema([\n", " mlflow.types.schema.ColSpec(\"double\", \"selling_price\"),\n", " mlflow.types.schema.ColSpec(\"double\", \"driven_kms\"),\n", " mlflow.types.schema.ColSpec(\"string\", \"fuel_type\"),\n", " mlflow.types.schema.ColSpec(\"string\", \"selling_type\"),\n", " mlflow.types.schema.ColSpec(\"string\", \"transmission\"),\n", " mlflow.types.schema.ColSpec(\"double\", \"age\"),\n", "])\n", "\n", "output_schema = mlflow.types.schema.Schema([\n", " mlflow.types.schema.ColSpec(\"double\", \"present_price\"),\n", "])\n", "\n", "mlflow_model_signature = mlflow.models.ModelSignature(inputs=input_schema, outputs=output_schema)" ] }, { "cell_type": "code", "execution_count": 26, "id": "e014b988", "metadata": {}, "outputs": [], "source": [ "def build_features_scaler_standard():\n", " return sklearn.preprocessing.StandardScaler()" ] }, { "cell_type": "code", "execution_count": 27, "id": "4e513ece", "metadata": {}, "outputs": [], "source": [ "#def build_categorical_features_encoder_onehot():\n", "# return sklearn.preprocessing.OneHotEncoder()\n", "\n", "def build_categorical_features_encoder_target(*, random_state=None):\n", " return sklearn.preprocessing.TargetEncoder(\n", " target_type='continuous', smooth='auto', shuffle=True, random_state=random_state,\n", " )" ] }, { "cell_type": "markdown", "id": "814626b7", "metadata": {}, "source": [ "Регрессор — небольшой случайный лес, цель — минимизация квадрата ошибки предсказания:" ] }, { "cell_type": "code", "execution_count": 28, "id": "e46bedcf", "metadata": {}, "outputs": [], "source": [ "def build_regressor(n_estimators, *, max_depth=None, max_features='sqrt', random_state=None):\n", " return sklearn.ensemble.RandomForestRegressor(\n", " n_estimators, criterion='squared_error',\n", " max_depth=max_depth, max_features=max_features,\n", " random_state=random_state,\n", " )\n", "\n", "def build_regressor_baseline(*, random_state=None):\n", " return build_regressor(10, max_depth=8, max_features='sqrt')" ] }, { "cell_type": "code", "execution_count": 29, "id": "e3d8b2f0-e0cd-4fcf-9dd2-bd01f903b9ad", "metadata": {}, "outputs": [], "source": [ "def score_predictions(target_test, target_test_predicted):\n", " return {\n", " 'mse': sklearn.metrics.mean_squared_error(target_test, target_test_predicted),\n", " 'mae': sklearn.metrics.mean_absolute_error(target_test, target_test_predicted),\n", " 'mape': sklearn.metrics.mean_absolute_percentage_error(target_test, target_test_predicted),\n", " }" ] }, { "cell_type": "code", "execution_count": 30, "id": "b62aca3d-d1c6-4075-aded-d4017bdc2129", "metadata": {}, "outputs": [], "source": [ "# использует глобальные переменные mlflow_do_log, mlflow_experiment, mlflow_root_run_name\n", "def mlflow_log_model(\n", " model,\n", " model_params,\n", " metrics,\n", " *,\n", " nested_run_name,\n", " model_signature=None,\n", " input_example=None,\n", " pip_requirements=None,\n", " #global_comment_file_path=None,\n", " extra_logs_handler=None,\n", "):\n", " global mlflow_root_run_id\n", " if not mlflow_do_log:\n", " return\n", " experiment_id = mlflow_experiment.experiment_id\n", " start_run_root_kwargs_extra = {}\n", " if mlflow_root_run_id is not None:\n", " start_run_root_kwargs_extra['run_id'] = mlflow_root_run_id\n", " else:\n", " start_run_root_kwargs_extra['run_name'] = mlflow_root_run_name\n", " with mlflow.start_run(experiment_id=experiment_id, **start_run_root_kwargs_extra) as root_run:\n", " if root_run.info.status not in ('RUNNING',):\n", " raise RuntimeError('Cannot get the root run to run')\n", " if mlflow_root_run_id is None:\n", " mlflow_root_run_id = root_run.info.run_id\n", " # важно одновременно использовать nested=True и parent_run_id=...:\n", " with mlflow.start_run(experiment_id=experiment_id, run_name=nested_run_name, nested=True, parent_run_id=mlflow_root_run_id):\n", " if isinstance(pip_requirements, pathlib.PurePath):\n", " pip_requirements = str(pip_requirements)\n", " _ = mlflow.sklearn.log_model(\n", " model,\n", " 'model',\n", " signature=model_signature,\n", " input_example=input_example,\n", " pip_requirements=pip_requirements,\n", " )\n", " if model_params is not None:\n", " _ = mlflow.log_params(model_params)\n", " if metrics is not None:\n", " _ = mlflow.log_metrics(metrics)\n", " #if (global_comment_file_path is not None) and global_comment_file_path.exists():\n", " # mlflow.log_artifact(str(global_comment_file_path))\n", " if extra_logs_handler is not None:\n", " if callable(extra_logs_handler) and (not isinstance(extra_logs_handler, Collection)):\n", " extra_logs_handler = (extra_logs_handler,)\n", " for extr_logs_handler_fn in extra_logs_handler:\n", " extr_logs_handler_fn(mlflow)" ] }, { "cell_type": "markdown", "id": "9271ef07", "metadata": {}, "source": [ "### Baseline модель" ] }, { "cell_type": "markdown", "id": "80a5e4c5", "metadata": {}, "source": [ "Пайплайн предобработки признаков:" ] }, { "cell_type": "code", "execution_count": 31, "id": "869bae01", "metadata": {}, "outputs": [], "source": [ "preprocess_transformer = sklearn.compose.ColumnTransformer(\n", " [\n", " ('scale_to_standard', build_features_scaler_standard(), features_to_scale_to_standard_columns),\n", " (\n", " #'encode_categoricals_one_hot',\n", " 'encode_categoricals_wrt_target',\n", " #build_categorical_features_encoder_onehot(),\n", " build_categorical_features_encoder_target(random_state=0x2ED6),\n", " features_to_encode_wrt_target_columns,\n", " ),\n", " ],\n", " remainder='drop',\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "id": "8959cb29", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomForestRegressor(max_depth=8, max_features='sqrt', n_estimators=10)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestRegressor(max_depth=8, max_features='sqrt', n_estimators=10)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "regressor = build_regressor_baseline(random_state=0x016B)\n", "regressor" ] }, { "cell_type": "markdown", "id": "cb0f1a67", "metadata": {}, "source": [ "Составной пайплайн:" ] }, { "cell_type": "code", "execution_count": 33, "id": "2ef69753", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('preprocess',\n",
       "                 ColumnTransformer(transformers=[('scale_to_standard',\n",
       "                                                  StandardScaler(),\n",
       "                                                  ('selling_price',\n",
       "                                                   'driven_kms', 'age')),\n",
       "                                                 ('encode_categoricals_wrt_target',\n",
       "                                                  TargetEncoder(random_state=11990,\n",
       "                                                                target_type='continuous'),\n",
       "                                                  ('fuel_type', 'selling_type',\n",
       "                                                   'transmission'))])),\n",
       "                ('regress',\n",
       "                 RandomForestRegressor(max_depth=8, max_features='sqrt',\n",
       "                                       n_estimators=10))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('preprocess',\n", " ColumnTransformer(transformers=[('scale_to_standard',\n", " StandardScaler(),\n", " ('selling_price',\n", " 'driven_kms', 'age')),\n", " ('encode_categoricals_wrt_target',\n", " TargetEncoder(random_state=11990,\n", " target_type='continuous'),\n", " ('fuel_type', 'selling_type',\n", " 'transmission'))])),\n", " ('regress',\n", " RandomForestRegressor(max_depth=8, max_features='sqrt',\n", " n_estimators=10))])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline = sklearn.pipeline.Pipeline([\n", " ('preprocess', preprocess_transformer),\n", " ('regress', regressor),\n", "])\n", "pipeline" ] }, { "cell_type": "code", "execution_count": 34, "id": "a38b50f8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'preprocess__remainder': 'drop',\n", " 'preprocess__sparse_threshold': 0.3,\n", " 'preprocess__transformer_weights': None,\n", " 'preprocess__scale_to_standard__with_mean': True,\n", " 'preprocess__scale_to_standard__with_std': True,\n", " 'regress__bootstrap': True,\n", " 'regress__ccp_alpha': 0.0,\n", " 'regress__criterion': 'squared_error',\n", " 'regress__max_depth': 8,\n", " 'regress__max_features': 'sqrt',\n", " 'regress__max_leaf_nodes': None,\n", " 'regress__max_samples': None,\n", " 'regress__min_impurity_decrease': 0.0,\n", " 'regress__min_samples_leaf': 1,\n", " 'regress__min_samples_split': 2,\n", " 'regress__min_weight_fraction_leaf': 0.0,\n", " 'regress__monotonic_cst': None,\n", " 'regress__n_estimators': 10,\n", " 'regress__oob_score': False,\n", " 'regress__random_state': None}" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_params = filter_params(\n", " pipeline.get_params(),\n", " include={\n", " 'preprocess': (\n", " False,\n", " {\n", " **{k: True for k in COLUMN_TRANSFORMER_PARAMS_COMMON_INCLUDE},\n", " 'scale_to_standard': True,\n", " 'encode_categorical_wrt_target': True,\n", " },\n", " ),\n", " 'regress': (False, True),\n", " },\n", " exclude={\n", " 'preprocess': {'scale_to_standard': STANDARD_SCALER_PARAMS_COMMON_EXCLUDE},\n", " 'regress': RANDOM_FOREST_REGRESSOR_PARAMS_COMMON_EXCLUDE,\n", " },\n", ")\n", "model_params" ] }, { "cell_type": "markdown", "id": "4064c359", "metadata": {}, "source": [ "Обучение модели:" ] }, { "cell_type": "code", "execution_count": 35, "id": "9639f2f4", "metadata": {}, "outputs": [], "source": [ "_ = pipeline.fit(df_orig_features_train, df_target_train.iloc[:, 0])" ] }, { "cell_type": "markdown", "id": "d385bf67", "metadata": {}, "source": [ "Оценка качества:" ] }, { "cell_type": "code", "execution_count": 36, "id": "c15e4e08", "metadata": {}, "outputs": [], "source": [ "target_test_predicted = pipeline.predict(df_orig_features_test)" ] }, { "cell_type": "markdown", "id": "24e1b454", "metadata": {}, "source": [ "Метрики качества (MAPE, а также MSE, MAE):" ] }, { "cell_type": "code", "execution_count": 37, "id": "ec74bb87", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mse': 1.1769122812432413,\n", " 'mae': 0.7433282022345273,\n", " 'mape': 0.3469466962984192}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics = score_predictions(df_target_test, target_test_predicted)\n", "metrics" ] }, { "cell_type": "code", "execution_count": 38, "id": "1f6b1ca5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ebfedda037646158f6e4acd2cbab0e5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading artifacts: 0%| | 0/7 [00:00#sk-container-id-3 {\n", " /* Definition of color scheme common for light and dark mode */\n", " --sklearn-color-text: #000;\n", " --sklearn-color-text-muted: #666;\n", " --sklearn-color-line: gray;\n", " /* Definition of color scheme for unfitted estimators */\n", " --sklearn-color-unfitted-level-0: #fff5e6;\n", " --sklearn-color-unfitted-level-1: #f6e4d2;\n", " --sklearn-color-unfitted-level-2: #ffe0b3;\n", " --sklearn-color-unfitted-level-3: chocolate;\n", " /* Definition of color scheme for fitted estimators */\n", " --sklearn-color-fitted-level-0: #f0f8ff;\n", " --sklearn-color-fitted-level-1: #d4ebff;\n", " --sklearn-color-fitted-level-2: #b3dbfd;\n", " --sklearn-color-fitted-level-3: cornflowerblue;\n", "\n", " /* Specific color for light theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-icon: #696969;\n", "\n", " @media (prefers-color-scheme: dark) {\n", " /* Redefinition of color scheme for dark theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-icon: #878787;\n", " }\n", "}\n", "\n", "#sk-container-id-3 {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "#sk-container-id-3 pre {\n", " padding: 0;\n", "}\n", "\n", "#sk-container-id-3 input.sk-hidden--visually {\n", " border: 0;\n", " clip: rect(1px 1px 1px 1px);\n", " clip: rect(1px, 1px, 1px, 1px);\n", " height: 1px;\n", " margin: -1px;\n", " overflow: hidden;\n", " padding: 0;\n", " position: absolute;\n", " width: 1px;\n", "}\n", "\n", "#sk-container-id-3 div.sk-dashed-wrapped {\n", " border: 1px dashed var(--sklearn-color-line);\n", " margin: 0 0.4em 0.5em 0.4em;\n", " box-sizing: border-box;\n", " padding-bottom: 0.4em;\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "#sk-container-id-3 div.sk-container {\n", " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", " so we also need the `!important` here to be able to override the\n", " default hidden behavior on the sphinx rendered scikit-learn.org.\n", " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", " display: inline-block !important;\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-3 div.sk-text-repr-fallback {\n", " display: none;\n", "}\n", "\n", "div.sk-parallel-item,\n", "div.sk-serial,\n", "div.sk-item {\n", " /* draw centered vertical line to link estimators */\n", " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", " background-size: 2px 100%;\n", " background-repeat: no-repeat;\n", " background-position: center center;\n", "}\n", "\n", "/* Parallel-specific style estimator block */\n", "\n", "#sk-container-id-3 div.sk-parallel-item::after {\n", " content: \"\";\n", " width: 100%;\n", " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", " flex-grow: 1;\n", "}\n", "\n", "#sk-container-id-3 div.sk-parallel {\n", " display: flex;\n", " align-items: stretch;\n", " justify-content: center;\n", " background-color: var(--sklearn-color-background);\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-3 div.sk-parallel-item {\n", " display: flex;\n", " flex-direction: column;\n", "}\n", "\n", "#sk-container-id-3 div.sk-parallel-item:first-child::after {\n", " align-self: flex-end;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-3 div.sk-parallel-item:last-child::after {\n", " align-self: flex-start;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-3 div.sk-parallel-item:only-child::after {\n", " width: 0;\n", "}\n", "\n", "/* Serial-specific style estimator block */\n", "\n", "#sk-container-id-3 div.sk-serial {\n", " display: flex;\n", " flex-direction: column;\n", " align-items: center;\n", " background-color: var(--sklearn-color-background);\n", " padding-right: 1em;\n", " padding-left: 1em;\n", "}\n", "\n", "\n", "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", "clickable and can be expanded/collapsed.\n", "- Pipeline and ColumnTransformer use this feature and define the default style\n", "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", "*/\n", "\n", "/* Pipeline and ColumnTransformer style (default) */\n", "\n", "#sk-container-id-3 div.sk-toggleable {\n", " /* Default theme specific background. It is overwritten whether we have a\n", " specific estimator or a Pipeline/ColumnTransformer */\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "/* Toggleable label */\n", "#sk-container-id-3 label.sk-toggleable__label {\n", " cursor: pointer;\n", " display: flex;\n", " width: 100%;\n", " margin-bottom: 0;\n", " padding: 0.5em;\n", " box-sizing: border-box;\n", " text-align: center;\n", " align-items: start;\n", " justify-content: space-between;\n", " gap: 0.5em;\n", "}\n", "\n", "#sk-container-id-3 label.sk-toggleable__label .caption {\n", " font-size: 0.6rem;\n", " font-weight: lighter;\n", " color: var(--sklearn-color-text-muted);\n", "}\n", "\n", "#sk-container-id-3 label.sk-toggleable__label-arrow:before {\n", " /* Arrow on the left of the label */\n", " content: \"▸\";\n", " float: left;\n", " margin-right: 0.25em;\n", " color: var(--sklearn-color-icon);\n", "}\n", "\n", "#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "/* Toggleable content - dropdown */\n", "\n", "#sk-container-id-3 div.sk-toggleable__content {\n", " display: none;\n", " text-align: left;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-3 div.sk-toggleable__content.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-3 div.sk-toggleable__content pre {\n", " margin: 0.2em;\n", " border-radius: 0.25em;\n", " color: var(--sklearn-color-text);\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-3 div.sk-toggleable__content.fitted pre {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", " /* Expand drop-down */\n", " display: block;\n", " width: 100%;\n", " overflow: visible;\n", "}\n", "\n", "#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", " content: \"▾\";\n", "}\n", "\n", "/* Pipeline/ColumnTransformer-specific style */\n", "\n", "#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-3 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator-specific style */\n", "\n", "/* Colorize estimator box */\n", "#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-3 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "#sk-container-id-3 div.sk-label label.sk-toggleable__label,\n", "#sk-container-id-3 div.sk-label label {\n", " /* The background is the default theme color */\n", " color: var(--sklearn-color-text-on-default-background);\n", "}\n", "\n", "/* On hover, darken the color of the background */\n", "#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "/* Label box, darken color on hover, fitted */\n", "#sk-container-id-3 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator label */\n", "\n", "#sk-container-id-3 div.sk-label label {\n", " font-family: monospace;\n", " font-weight: bold;\n", " display: inline-block;\n", " line-height: 1.2em;\n", "}\n", "\n", "#sk-container-id-3 div.sk-label-container {\n", " text-align: center;\n", "}\n", "\n", "/* Estimator-specific */\n", "#sk-container-id-3 div.sk-estimator {\n", " font-family: monospace;\n", " border: 1px dotted var(--sklearn-color-border-box);\n", " border-radius: 0.25em;\n", " box-sizing: border-box;\n", " margin-bottom: 0.5em;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-3 div.sk-estimator.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "/* on hover */\n", "#sk-container-id-3 div.sk-estimator:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-3 div.sk-estimator.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", "\n", "/* Common style for \"i\" and \"?\" */\n", "\n", ".sk-estimator-doc-link,\n", "a:link.sk-estimator-doc-link,\n", "a:visited.sk-estimator-doc-link {\n", " float: right;\n", " font-size: smaller;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1em;\n", " height: 1em;\n", " width: 1em;\n", " text-decoration: none !important;\n", " margin-left: 0.5em;\n", " text-align: center;\n", " /* unfitted */\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-unfitted-level-1);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted,\n", "a:link.sk-estimator-doc-link.fitted,\n", "a:visited.sk-estimator-doc-link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "/* Span, style for the box shown on hovering the info icon */\n", ".sk-estimator-doc-link span {\n", " display: none;\n", " z-index: 9999;\n", " position: relative;\n", " font-weight: normal;\n", " right: .2ex;\n", " padding: .5ex;\n", " margin: .5ex;\n", " width: min-content;\n", " min-width: 20ex;\n", " max-width: 50ex;\n", " color: var(--sklearn-color-text);\n", " box-shadow: 2pt 2pt 4pt #999;\n", " /* unfitted */\n", " background: var(--sklearn-color-unfitted-level-0);\n", " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted span {\n", " /* fitted */\n", " background: var(--sklearn-color-fitted-level-0);\n", " border: var(--sklearn-color-fitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link:hover span {\n", " display: block;\n", "}\n", "\n", "/* \"?\"-specific style due to the `` HTML tag */\n", "\n", "#sk-container-id-3 a.estimator_doc_link {\n", " float: right;\n", " font-size: 1rem;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1rem;\n", " height: 1rem;\n", " width: 1rem;\n", " text-decoration: none;\n", " /* unfitted */\n", " color: var(--sklearn-color-unfitted-level-1);\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", "}\n", "\n", "#sk-container-id-3 a.estimator_doc_link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "#sk-container-id-3 a.estimator_doc_link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "#sk-container-id-3 a.estimator_doc_link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", "}\n", "\n", ".estimator-table summary {\n", " padding: .5rem;\n", " font-family: monospace;\n", " cursor: pointer;\n", "}\n", "\n", ".estimator-table details[open] {\n", " padding-left: 0.1rem;\n", " padding-right: 0.1rem;\n", " padding-bottom: 0.3rem;\n", "}\n", "\n", ".estimator-table .parameters-table {\n", " margin-left: auto !important;\n", " margin-right: auto !important;\n", "}\n", "\n", ".estimator-table .parameters-table tr:nth-child(odd) {\n", " background-color: #fff;\n", "}\n", "\n", ".estimator-table .parameters-table tr:nth-child(even) {\n", " background-color: #f6f6f6;\n", "}\n", "\n", ".estimator-table .parameters-table tr:hover {\n", " background-color: #e0e0e0;\n", "}\n", "\n", ".estimator-table table td {\n", " border: 1px solid rgba(106, 105, 104, 0.232);\n", "}\n", "\n", ".user-set td {\n", " color:rgb(255, 94, 0);\n", " text-align: left;\n", "}\n", "\n", ".user-set td.value pre {\n", " color:rgb(255, 94, 0) !important;\n", " background-color: transparent !important;\n", "}\n", "\n", ".default td {\n", " color: black;\n", " text-align: left;\n", "}\n", "\n", ".user-set td i,\n", ".default td i {\n", " color: black;\n", "}\n", "\n", ".copy-paste-icon {\n", " background-image: url();\n", " background-repeat: no-repeat;\n", " background-size: 14px 14px;\n", " background-position: 0;\n", " display: inline-block;\n", " width: 14px;\n", " height: 14px;\n", " cursor: pointer;\n", "}\n", "
ColumnTransformer(transformers=[('extend_features_as_polynomial',\n",
       "                                 Pipeline(steps=[('extend_features',\n",
       "                                                  PolynomialFeatures(include_bias=False)),\n",
       "                                                 ('scale_to_standard',\n",
       "                                                  StandardScaler())]),\n",
       "                                 ('selling_price', 'driven_kms')),\n",
       "                                ('extend_features_as_spline',\n",
       "                                 SplineTransformer(include_bias=False,\n",
       "                                                   knots='quantile',\n",
       "                                                   n_knots=4),\n",
       "                                 ('age',)),\n",
       "                                ('scale_to_standard', StandardScaler(),\n",
       "                                 ('age',)),\n",
       "                                ('encode_categoricals_wrt_target',\n",
       "                                 TargetEncoder(random_state=11990,\n",
       "                                               target_type='continuous'),\n",
       "                                 ('fuel_type', 'selling_type',\n",
       "                                  'transmission'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "ColumnTransformer(transformers=[('extend_features_as_polynomial',\n", " Pipeline(steps=[('extend_features',\n", " PolynomialFeatures(include_bias=False)),\n", " ('scale_to_standard',\n", " StandardScaler())]),\n", " ('selling_price', 'driven_kms')),\n", " ('extend_features_as_spline',\n", " SplineTransformer(include_bias=False,\n", " knots='quantile',\n", " n_knots=4),\n", " ('age',)),\n", " ('scale_to_standard', StandardScaler(),\n", " ('age',)),\n", " ('encode_categoricals_wrt_target',\n", " TargetEncoder(random_state=11990,\n", " target_type='continuous'),\n", " ('fuel_type', 'selling_type',\n", " 'transmission'))])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preprocess_transformer = build_preprocess_augmenting_transformer()\n", "preprocess_transformer" ] }, { "cell_type": "markdown", "id": "0c041b34-bd18-4f26-b6cb-9f4567a4bc65", "metadata": {}, "source": [ "Демонстрация предобработки данных:" ] }, { "cell_type": "code", "execution_count": 43, "id": "df3207ab-36ea-417d-b10c-05145d6e3777", "metadata": {}, "outputs": [], "source": [ "preprocess_transformer_tmp = build_preprocess_augmenting_transformer()\n", "df_augd_features_matrix_train = preprocess_transformer_tmp.fit_transform(df_orig_features_train, df_target_train.iloc[:, 0])\n", "df_augd_features_train = pandas_dataframe_from_transformed_artifacts(df_augd_features_matrix_train, preprocess_transformer_tmp)\n", "del preprocess_transformer_tmp" ] }, { "cell_type": "markdown", "id": "41cc8af7-56d8-4d2d-b536-cbff20bb2545", "metadata": {}, "source": [ "Обзор предобработанного датасета:" ] }, { "cell_type": "code", "execution_count": 44, "id": "ec3f4c72-edea-4260-9a26-9b0d57628e9e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 224 entries, 0 to 223\n", "Data columns (total 14 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 extend_features_as_polynomial__selling_price 224 non-null float64\n", " 1 extend_features_as_polynomial__driven_kms 224 non-null float64\n", " 2 extend_features_as_polynomial__selling_price^2 224 non-null float64\n", " 3 extend_features_as_polynomial__selling_price driven_kms 224 non-null float64\n", " 4 extend_features_as_polynomial__driven_kms^2 224 non-null float64\n", " 5 extend_features_as_spline__age_sp_0 224 non-null float64\n", " 6 extend_features_as_spline__age_sp_1 224 non-null float64\n", " 7 extend_features_as_spline__age_sp_2 224 non-null float64\n", " 8 extend_features_as_spline__age_sp_3 224 non-null float64\n", " 9 extend_features_as_spline__age_sp_4 224 non-null float64\n", " 10 scale_to_standard__age 224 non-null float64\n", " 11 encode_categoricals_wrt_target__fuel_type 224 non-null float64\n", " 12 encode_categoricals_wrt_target__selling_type 224 non-null float64\n", " 13 encode_categoricals_wrt_target__transmission 224 non-null float64\n", "dtypes: float64(14)\n", "memory usage: 24.6 KB\n" ] } ], "source": [ "df_augd_features_train.info()" ] }, { "cell_type": "code", "execution_count": 45, "id": "393d1826-963c-4479-b004-6fedf2f6dc77", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
extend_features_as_polynomial__selling_priceextend_features_as_polynomial__driven_kmsextend_features_as_polynomial__selling_price^2extend_features_as_polynomial__selling_price driven_kmsextend_features_as_polynomial__driven_kms^2extend_features_as_spline__age_sp_0extend_features_as_spline__age_sp_1extend_features_as_spline__age_sp_2extend_features_as_spline__age_sp_3extend_features_as_spline__age_sp_4scale_to_standard__ageencode_categoricals_wrt_target__fuel_typeencode_categoricals_wrt_target__selling_typeencode_categoricals_wrt_target__transmission
0-0.104244-0.059337-0.160142-0.184156-0.2133920.0000000.0000000.2844440.6143430.0998790.9831593.4180666.7230444.251590
10.524405-0.9309840.023111-0.341051-0.4670470.0493830.5283950.4177780.0044440.000000-1.1412239.3746556.4008213.750236
2-0.364071-0.699614-0.204196-0.411821-0.4272500.0061730.3035490.6547220.0355560.000000-0.7871593.3134047.0181164.015122
3-0.686652-0.942552-0.233103-0.493887-0.4685140.0061730.3035490.6547220.0355560.000000-0.7871593.5320720.6731514.202766
4-0.2914070.090899-0.193742-0.236248-0.1411380.0000000.1000000.7800000.1200000.000000-0.4330964.9681117.1611094.059384
5-0.747205-0.236874-0.235345-0.474524-0.2879600.0000000.0000000.1905560.6402020.1647421.3372223.0806850.6971193.750236
60.0267711.112782-0.1309000.2464120.5729310.0000000.0000000.0000000.2272730.6060613.8156673.5320726.6754064.202766
7-0.180210-0.066162-0.174939-0.219328-0.2164750.0000000.1000000.7800000.1200000.000000-0.4330963.2843267.1611094.059384
\n", "
" ], "text/plain": [ " extend_features_as_polynomial__selling_price \\\n", "0 -0.104244 \n", "1 0.524405 \n", "2 -0.364071 \n", "3 -0.686652 \n", "4 -0.291407 \n", "5 -0.747205 \n", "6 0.026771 \n", "7 -0.180210 \n", "\n", " extend_features_as_polynomial__driven_kms \\\n", "0 -0.059337 \n", "1 -0.930984 \n", "2 -0.699614 \n", "3 -0.942552 \n", "4 0.090899 \n", "5 -0.236874 \n", "6 1.112782 \n", "7 -0.066162 \n", "\n", " extend_features_as_polynomial__selling_price^2 \\\n", "0 -0.160142 \n", "1 0.023111 \n", "2 -0.204196 \n", "3 -0.233103 \n", "4 -0.193742 \n", "5 -0.235345 \n", "6 -0.130900 \n", "7 -0.174939 \n", "\n", " extend_features_as_polynomial__selling_price driven_kms \\\n", "0 -0.184156 \n", "1 -0.341051 \n", "2 -0.411821 \n", "3 -0.493887 \n", "4 -0.236248 \n", "5 -0.474524 \n", "6 0.246412 \n", "7 -0.219328 \n", "\n", " extend_features_as_polynomial__driven_kms^2 \\\n", "0 -0.213392 \n", "1 -0.467047 \n", "2 -0.427250 \n", "3 -0.468514 \n", "4 -0.141138 \n", "5 -0.287960 \n", "6 0.572931 \n", "7 -0.216475 \n", "\n", " extend_features_as_spline__age_sp_0 extend_features_as_spline__age_sp_1 \\\n", "0 0.000000 0.000000 \n", "1 0.049383 0.528395 \n", "2 0.006173 0.303549 \n", "3 0.006173 0.303549 \n", "4 0.000000 0.100000 \n", "5 0.000000 0.000000 \n", "6 0.000000 0.000000 \n", "7 0.000000 0.100000 \n", "\n", " extend_features_as_spline__age_sp_2 extend_features_as_spline__age_sp_3 \\\n", "0 0.284444 0.614343 \n", "1 0.417778 0.004444 \n", "2 0.654722 0.035556 \n", "3 0.654722 0.035556 \n", "4 0.780000 0.120000 \n", "5 0.190556 0.640202 \n", "6 0.000000 0.227273 \n", "7 0.780000 0.120000 \n", "\n", " extend_features_as_spline__age_sp_4 scale_to_standard__age \\\n", "0 0.099879 0.983159 \n", "1 0.000000 -1.141223 \n", "2 0.000000 -0.787159 \n", "3 0.000000 -0.787159 \n", "4 0.000000 -0.433096 \n", "5 0.164742 1.337222 \n", "6 0.606061 3.815667 \n", "7 0.000000 -0.433096 \n", "\n", " encode_categoricals_wrt_target__fuel_type \\\n", "0 3.418066 \n", "1 9.374655 \n", "2 3.313404 \n", "3 3.532072 \n", "4 4.968111 \n", "5 3.080685 \n", "6 3.532072 \n", "7 3.284326 \n", "\n", " encode_categoricals_wrt_target__selling_type \\\n", "0 6.723044 \n", "1 6.400821 \n", "2 7.018116 \n", "3 0.673151 \n", "4 7.161109 \n", "5 0.697119 \n", "6 6.675406 \n", "7 7.161109 \n", "\n", " encode_categoricals_wrt_target__transmission \n", "0 4.251590 \n", "1 3.750236 \n", "2 4.015122 \n", "3 4.202766 \n", "4 4.059384 \n", "5 3.750236 \n", "6 4.202766 \n", "7 4.059384 " ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_augd_features_train.head(0x8)" ] }, { "cell_type": "code", "execution_count": 46, "id": "2bb56d09", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomForestRegressor(max_depth=8, max_features='sqrt', n_estimators=10)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestRegressor(max_depth=8, max_features='sqrt', n_estimators=10)" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "regressor = build_regressor_baseline(random_state=0x3AEF)\n", "regressor" ] }, { "cell_type": "markdown", "id": "dd34c150", "metadata": {}, "source": [ "Составной пайплайн:" ] }, { "cell_type": "code", "execution_count": 47, "id": "ff9d2a85", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('preprocess',\n",
       "                 ColumnTransformer(transformers=[('extend_features_as_polynomial',\n",
       "                                                  Pipeline(steps=[('extend_features',\n",
       "                                                                   PolynomialFeatures(include_bias=False)),\n",
       "                                                                  ('scale_to_standard',\n",
       "                                                                   StandardScaler())]),\n",
       "                                                  ('selling_price',\n",
       "                                                   'driven_kms')),\n",
       "                                                 ('extend_features_as_spline',\n",
       "                                                  SplineTransformer(include_bias=False,\n",
       "                                                                    knots='quantile',\n",
       "                                                                    n_knots=4),\n",
       "                                                  ('age',)),\n",
       "                                                 ('scale_to_standard',\n",
       "                                                  StandardScaler(), ('age',)),\n",
       "                                                 ('encode_categoricals_wrt_target',\n",
       "                                                  TargetEncoder(random_state=11990,\n",
       "                                                                target_type='continuous'),\n",
       "                                                  ('fuel_type', 'selling_type',\n",
       "                                                   'transmission'))])),\n",
       "                ('regress',\n",
       "                 RandomForestRegressor(max_depth=8, max_features='sqrt',\n",
       "                                       n_estimators=10))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('preprocess',\n", " ColumnTransformer(transformers=[('extend_features_as_polynomial',\n", " Pipeline(steps=[('extend_features',\n", " PolynomialFeatures(include_bias=False)),\n", " ('scale_to_standard',\n", " StandardScaler())]),\n", " ('selling_price',\n", " 'driven_kms')),\n", " ('extend_features_as_spline',\n", " SplineTransformer(include_bias=False,\n", " knots='quantile',\n", " n_knots=4),\n", " ('age',)),\n", " ('scale_to_standard',\n", " StandardScaler(), ('age',)),\n", " ('encode_categoricals_wrt_target',\n", " TargetEncoder(random_state=11990,\n", " target_type='continuous'),\n", " ('fuel_type', 'selling_type',\n", " 'transmission'))])),\n", " ('regress',\n", " RandomForestRegressor(max_depth=8, max_features='sqrt',\n", " n_estimators=10))])" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline = sklearn.pipeline.Pipeline([\n", " ('preprocess', preprocess_transformer),\n", " ('regress', regressor),\n", "])\n", "pipeline" ] }, { "cell_type": "code", "execution_count": 48, "id": "eec22b97", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'preprocess__remainder': 'drop',\n", " 'preprocess__sparse_threshold': 0.3,\n", " 'preprocess__transformer_weights': None,\n", " 'preprocess__extend_features_as_spline': SplineTransformer(include_bias=False, knots='quantile', n_knots=4),\n", " 'preprocess__extend_features_as_polynomial__extend_features': PolynomialFeatures(include_bias=False),\n", " 'preprocess__extend_features_as_polynomial__extend_features__degree': 2,\n", " 'preprocess__extend_features_as_polynomial__extend_features__include_bias': False,\n", " 'preprocess__extend_features_as_polynomial__extend_features__interaction_only': False,\n", " 'preprocess__extend_features_as_polynomial__extend_features__order': 'C',\n", " 'preprocess__extend_features_as_polynomial__scale_to_standard__with_mean': True,\n", " 'preprocess__extend_features_as_polynomial__scale_to_standard__with_std': True,\n", " 'preprocess__extend_features_as_spline__degree': 3,\n", " 'preprocess__extend_features_as_spline__extrapolation': 'constant',\n", " 'preprocess__extend_features_as_spline__include_bias': False,\n", " 'preprocess__extend_features_as_spline__knots': 'quantile',\n", " 'preprocess__extend_features_as_spline__n_knots': 4,\n", " 'preprocess__extend_features_as_spline__order': 'C',\n", " 'preprocess__extend_features_as_spline__sparse_output': False,\n", " 'preprocess__scale_to_standard__with_mean': True,\n", " 'preprocess__scale_to_standard__with_std': True,\n", " 'regress__bootstrap': True,\n", " 'regress__ccp_alpha': 0.0,\n", " 'regress__criterion': 'squared_error',\n", " 'regress__max_depth': 8,\n", " 'regress__max_features': 'sqrt',\n", " 'regress__max_leaf_nodes': None,\n", " 'regress__max_samples': None,\n", " 'regress__min_impurity_decrease': 0.0,\n", " 'regress__min_samples_leaf': 1,\n", " 'regress__min_samples_split': 2,\n", " 'regress__min_weight_fraction_leaf': 0.0,\n", " 'regress__monotonic_cst': None,\n", " 'regress__n_estimators': 10,\n", " 'regress__oob_score': False,\n", " 'regress__random_state': None}" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_params = filter_params(\n", " pipeline.get_params(),\n", " include={\n", " 'preprocess': (False, PREPROCESS_AUGMENTING_TRANSFORMER_PARAMS_COMMON_INCLUDE.copy()),\n", " 'regress': (False, True),\n", " },\n", " exclude={\n", " 'preprocess': PREPROCESS_AUGMENTING_TRANSFORMER_PARAMS_COMMON_EXCLUDE.copy(),\n", " 'regress': RANDOM_FOREST_REGRESSOR_PARAMS_COMMON_EXCLUDE,\n", " },\n", ")\n", "model_params" ] }, { "cell_type": "markdown", "id": "23519fd3", "metadata": {}, "source": [ "Обучение модели:" ] }, { "cell_type": "code", "execution_count": 49, "id": "95f8079b", "metadata": {}, "outputs": [], "source": [ "_ = pipeline.fit(df_orig_features_train, df_target_train.iloc[:, 0])" ] }, { "cell_type": "markdown", "id": "5627ba08-b0a9-4316-a4a9-630047cec1cc", "metadata": {}, "source": [ "Оценка качества:" ] }, { "cell_type": "code", "execution_count": 50, "id": "a88ccce4-cb32-4810-982e-9b126778d611", "metadata": {}, "outputs": [], "source": [ "target_test_predicted = pipeline.predict(df_orig_features_test)" ] }, { "cell_type": "markdown", "id": "5cef4314-5872-4f6e-9355-316c0419158a", "metadata": {}, "source": [ "Метрики качества (MAPE, а также MSE, MAE):" ] }, { "cell_type": "code", "execution_count": 51, "id": "bfd94c07-05ec-45c9-b49e-ddf861d39d06", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mse': 1.5006829920671902,\n", " 'mae': 0.7582020656775502,\n", " 'mape': 0.30794862210624835}" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics = score_predictions(df_target_test, target_test_predicted)\n", "metrics" ] }, { "cell_type": "code", "execution_count": 52, "id": "80b5ab1d-234c-4b18-98a6-eef0520677ec", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5821a1adbbe242a882fed4dd765843c8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading artifacts: 0%| | 0/7 [00:00#sk-container-id-6 {\n", " /* Definition of color scheme common for light and dark mode */\n", " --sklearn-color-text: #000;\n", " --sklearn-color-text-muted: #666;\n", " --sklearn-color-line: gray;\n", " /* Definition of color scheme for unfitted estimators */\n", " --sklearn-color-unfitted-level-0: #fff5e6;\n", " --sklearn-color-unfitted-level-1: #f6e4d2;\n", " --sklearn-color-unfitted-level-2: #ffe0b3;\n", " --sklearn-color-unfitted-level-3: chocolate;\n", " /* Definition of color scheme for fitted estimators */\n", " --sklearn-color-fitted-level-0: #f0f8ff;\n", " --sklearn-color-fitted-level-1: #d4ebff;\n", " --sklearn-color-fitted-level-2: #b3dbfd;\n", " --sklearn-color-fitted-level-3: cornflowerblue;\n", "\n", " /* Specific color for light theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-icon: #696969;\n", "\n", " @media (prefers-color-scheme: dark) {\n", " /* Redefinition of color scheme for dark theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-icon: #878787;\n", " }\n", "}\n", "\n", "#sk-container-id-6 {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "#sk-container-id-6 pre {\n", " padding: 0;\n", "}\n", "\n", "#sk-container-id-6 input.sk-hidden--visually {\n", " border: 0;\n", " clip: rect(1px 1px 1px 1px);\n", " clip: rect(1px, 1px, 1px, 1px);\n", " height: 1px;\n", " margin: -1px;\n", " overflow: hidden;\n", " padding: 0;\n", " position: absolute;\n", " width: 1px;\n", "}\n", "\n", "#sk-container-id-6 div.sk-dashed-wrapped {\n", " border: 1px dashed var(--sklearn-color-line);\n", " margin: 0 0.4em 0.5em 0.4em;\n", " box-sizing: border-box;\n", " padding-bottom: 0.4em;\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "#sk-container-id-6 div.sk-container {\n", " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", " so we also need the `!important` here to be able to override the\n", " default hidden behavior on the sphinx rendered scikit-learn.org.\n", " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", " display: inline-block !important;\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-6 div.sk-text-repr-fallback {\n", " display: none;\n", "}\n", "\n", "div.sk-parallel-item,\n", "div.sk-serial,\n", "div.sk-item {\n", " /* draw centered vertical line to link estimators */\n", " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", " background-size: 2px 100%;\n", " background-repeat: no-repeat;\n", " background-position: center center;\n", "}\n", "\n", "/* Parallel-specific style estimator block */\n", "\n", "#sk-container-id-6 div.sk-parallel-item::after {\n", " content: \"\";\n", " width: 100%;\n", " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", " flex-grow: 1;\n", "}\n", "\n", "#sk-container-id-6 div.sk-parallel {\n", " display: flex;\n", " align-items: stretch;\n", " justify-content: center;\n", " background-color: var(--sklearn-color-background);\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-6 div.sk-parallel-item {\n", " display: flex;\n", " flex-direction: column;\n", "}\n", "\n", "#sk-container-id-6 div.sk-parallel-item:first-child::after {\n", " align-self: flex-end;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-6 div.sk-parallel-item:last-child::after {\n", " align-self: flex-start;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-6 div.sk-parallel-item:only-child::after {\n", " width: 0;\n", "}\n", "\n", "/* Serial-specific style estimator block */\n", "\n", "#sk-container-id-6 div.sk-serial {\n", " display: flex;\n", " flex-direction: column;\n", " align-items: center;\n", " background-color: var(--sklearn-color-background);\n", " padding-right: 1em;\n", " padding-left: 1em;\n", "}\n", "\n", "\n", "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", "clickable and can be expanded/collapsed.\n", "- Pipeline and ColumnTransformer use this feature and define the default style\n", "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", "*/\n", "\n", "/* Pipeline and ColumnTransformer style (default) */\n", "\n", "#sk-container-id-6 div.sk-toggleable {\n", " /* Default theme specific background. It is overwritten whether we have a\n", " specific estimator or a Pipeline/ColumnTransformer */\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "/* Toggleable label */\n", "#sk-container-id-6 label.sk-toggleable__label {\n", " cursor: pointer;\n", " display: flex;\n", " width: 100%;\n", " margin-bottom: 0;\n", " padding: 0.5em;\n", " box-sizing: border-box;\n", " text-align: center;\n", " align-items: start;\n", " justify-content: space-between;\n", " gap: 0.5em;\n", "}\n", "\n", "#sk-container-id-6 label.sk-toggleable__label .caption {\n", " font-size: 0.6rem;\n", " font-weight: lighter;\n", " color: var(--sklearn-color-text-muted);\n", "}\n", "\n", "#sk-container-id-6 label.sk-toggleable__label-arrow:before {\n", " /* Arrow on the left of the label */\n", " content: \"▸\";\n", " float: left;\n", " margin-right: 0.25em;\n", " color: var(--sklearn-color-icon);\n", "}\n", "\n", "#sk-container-id-6 label.sk-toggleable__label-arrow:hover:before {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "/* Toggleable content - dropdown */\n", "\n", "#sk-container-id-6 div.sk-toggleable__content {\n", " display: none;\n", " text-align: left;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-6 div.sk-toggleable__content.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-6 div.sk-toggleable__content pre {\n", " margin: 0.2em;\n", " border-radius: 0.25em;\n", " color: var(--sklearn-color-text);\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-6 div.sk-toggleable__content.fitted pre {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-6 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", " /* Expand drop-down */\n", " display: block;\n", " width: 100%;\n", " overflow: visible;\n", "}\n", "\n", "#sk-container-id-6 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", " content: \"▾\";\n", "}\n", "\n", "/* Pipeline/ColumnTransformer-specific style */\n", "\n", "#sk-container-id-6 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-6 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator-specific style */\n", "\n", "/* Colorize estimator box */\n", "#sk-container-id-6 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-6 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "#sk-container-id-6 div.sk-label label.sk-toggleable__label,\n", "#sk-container-id-6 div.sk-label label {\n", " /* The background is the default theme color */\n", " color: var(--sklearn-color-text-on-default-background);\n", "}\n", "\n", "/* On hover, darken the color of the background */\n", "#sk-container-id-6 div.sk-label:hover label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "/* Label box, darken color on hover, fitted */\n", "#sk-container-id-6 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator label */\n", "\n", "#sk-container-id-6 div.sk-label label {\n", " font-family: monospace;\n", " font-weight: bold;\n", " display: inline-block;\n", " line-height: 1.2em;\n", "}\n", "\n", "#sk-container-id-6 div.sk-label-container {\n", " text-align: center;\n", "}\n", "\n", "/* Estimator-specific */\n", "#sk-container-id-6 div.sk-estimator {\n", " font-family: monospace;\n", " border: 1px dotted var(--sklearn-color-border-box);\n", " border-radius: 0.25em;\n", " box-sizing: border-box;\n", " margin-bottom: 0.5em;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-6 div.sk-estimator.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "/* on hover */\n", "#sk-container-id-6 div.sk-estimator:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-6 div.sk-estimator.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", "\n", "/* Common style for \"i\" and \"?\" */\n", "\n", ".sk-estimator-doc-link,\n", "a:link.sk-estimator-doc-link,\n", "a:visited.sk-estimator-doc-link {\n", " float: right;\n", " font-size: smaller;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1em;\n", " height: 1em;\n", " width: 1em;\n", " text-decoration: none !important;\n", " margin-left: 0.5em;\n", " text-align: center;\n", " /* unfitted */\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-unfitted-level-1);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted,\n", "a:link.sk-estimator-doc-link.fitted,\n", "a:visited.sk-estimator-doc-link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "/* Span, style for the box shown on hovering the info icon */\n", ".sk-estimator-doc-link span {\n", " display: none;\n", " z-index: 9999;\n", " position: relative;\n", " font-weight: normal;\n", " right: .2ex;\n", " padding: .5ex;\n", " margin: .5ex;\n", " width: min-content;\n", " min-width: 20ex;\n", " max-width: 50ex;\n", " color: var(--sklearn-color-text);\n", " box-shadow: 2pt 2pt 4pt #999;\n", " /* unfitted */\n", " background: var(--sklearn-color-unfitted-level-0);\n", " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted span {\n", " /* fitted */\n", " background: var(--sklearn-color-fitted-level-0);\n", " border: var(--sklearn-color-fitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link:hover span {\n", " display: block;\n", "}\n", "\n", "/* \"?\"-specific style due to the `` HTML tag */\n", "\n", "#sk-container-id-6 a.estimator_doc_link {\n", " float: right;\n", " font-size: 1rem;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1rem;\n", " height: 1rem;\n", " width: 1rem;\n", " text-decoration: none;\n", " /* unfitted */\n", " color: var(--sklearn-color-unfitted-level-1);\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", "}\n", "\n", "#sk-container-id-6 a.estimator_doc_link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "#sk-container-id-6 a.estimator_doc_link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "#sk-container-id-6 a.estimator_doc_link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", "}\n", "\n", ".estimator-table summary {\n", " padding: .5rem;\n", " font-family: monospace;\n", " cursor: pointer;\n", "}\n", "\n", ".estimator-table details[open] {\n", " padding-left: 0.1rem;\n", " padding-right: 0.1rem;\n", " padding-bottom: 0.3rem;\n", "}\n", "\n", ".estimator-table .parameters-table {\n", " margin-left: auto !important;\n", " margin-right: auto !important;\n", "}\n", "\n", ".estimator-table .parameters-table tr:nth-child(odd) {\n", " background-color: #fff;\n", "}\n", "\n", ".estimator-table .parameters-table tr:nth-child(even) {\n", " background-color: #f6f6f6;\n", "}\n", "\n", ".estimator-table .parameters-table tr:hover {\n", " background-color: #e0e0e0;\n", "}\n", "\n", ".estimator-table table td {\n", " border: 1px solid rgba(106, 105, 104, 0.232);\n", "}\n", "\n", ".user-set td {\n", " color:rgb(255, 94, 0);\n", " text-align: left;\n", "}\n", "\n", ".user-set td.value pre {\n", " color:rgb(255, 94, 0) !important;\n", " background-color: transparent !important;\n", "}\n", "\n", ".default td {\n", " color: black;\n", " text-align: left;\n", "}\n", "\n", ".user-set td i,\n", ".default td i {\n", " color: black;\n", "}\n", "\n", ".copy-paste-icon {\n", " background-image: url();\n", " background-repeat: no-repeat;\n", " background-size: 14px 14px;\n", " background-position: 0;\n", " display: inline-block;\n", " width: 14px;\n", " height: 14px;\n", " cursor: pointer;\n", "}\n", "
RandomForestRegressor(max_depth=8, max_features='sqrt', n_estimators=10)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestRegressor(max_depth=8, max_features='sqrt', n_estimators=10)" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "regressor = build_regressor_baseline(random_state=0x8EDD)\n", "regressor" ] }, { "cell_type": "markdown", "id": "36ebb8f4-bb00-4d20-82e8-940a8798f4b1", "metadata": {}, "source": [ "Выбор признаков среди дополненного набора по минимизации MAPE:" ] }, { "cell_type": "code", "execution_count": 56, "id": "3cc243e6-f4e0-4b03-a4f2-8e8e4835466a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "14" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(df_augd_features_train.columns)" ] }, { "cell_type": "code", "execution_count": 57, "id": "9c5645b1-e052-43a8-9766-9ed7e62f7ebc", "metadata": {}, "outputs": [], "source": [ "FILTERED_FEATURES_NUM = (4, 8)" ] }, { "cell_type": "code", "execution_count": 58, "id": "8e897be5-ba9a-427a-9c36-dd1cdfa2727e", "metadata": {}, "outputs": [], "source": [ "def build_feature_selector(*, verbose=0):\n", " return build_sequential_feature_selector(\n", " regressor, k_features=FILTERED_FEATURES_NUM, forward=True, floating=True, cv=4, scoring='neg_mean_absolute_percentage_error',\n", " verbose=verbose,\n", " )" ] }, { "cell_type": "code", "execution_count": 59, "id": "e2294a0a-5a8d-4daf-aa55-c099dd5085d0", "metadata": {}, "outputs": [], "source": [ "FEATURE_SELECTOR_PARAMS_COMMON_INCLUDE = {\n", " **{k: True for k in SEQUENTIAL_FEATURE_SELECTOR_PARAMS_COMMON_INCLUDE},\n", " 'estimator': False,\n", "}\n", "FEATURE_SELECTOR_PARAMS_COMMON_EXCLUDE = () # TODO: ай-яй-яй" ] }, { "cell_type": "code", "execution_count": 60, "id": "dfd187a8-5b32-42c3-bc38-987b75bb2a2d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
SequentialFeatureSelector(cv=4,\n",
       "                          estimator=RandomForestRegressor(max_depth=8,\n",
       "                                                          max_features='sqrt',\n",
       "                                                          n_estimators=10),\n",
       "                          floating=True, k_features=(4, 8),\n",
       "                          scoring='neg_mean_absolute_percentage_error',\n",
       "                          verbose=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "SequentialFeatureSelector(cv=4,\n", " estimator=RandomForestRegressor(max_depth=8,\n", " max_features='sqrt',\n", " n_estimators=10),\n", " floating=True, k_features=(4, 8),\n", " scoring='neg_mean_absolute_percentage_error',\n", " verbose=1)" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_selector = build_feature_selector(verbose=1)\n", "feature_selector" ] }, { "cell_type": "code", "execution_count": 61, "id": "e8bf4bb3-e3af-4bc4-8ab5-7625e1e428e3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Done 14 out of 14 | elapsed: 0.6s finished\n", "Features: 1/8[Parallel(n_jobs=1)]: Done 13 out of 13 | elapsed: 0.5s finished\n", "Features: 2/8[Parallel(n_jobs=1)]: Done 12 out of 12 | elapsed: 0.5s finished\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s finished\n", "Features: 3/8[Parallel(n_jobs=1)]: Done 11 out of 11 | elapsed: 0.4s finished\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n", "Features: 4/8[Parallel(n_jobs=1)]: Done 10 out of 10 | elapsed: 0.4s finished\n", "[Parallel(n_jobs=1)]: Done 4 out of 4 | elapsed: 0.1s finished\n", "Features: 5/8[Parallel(n_jobs=1)]: Done 9 out of 9 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 0.1s finished\n", "Features: 6/8[Parallel(n_jobs=1)]: Done 8 out of 8 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Done 6 out of 6 | elapsed: 0.2s finished\n", "[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=1)]: Done 4 out of 4 | elapsed: 0.1s finished\n", "Features: 5/8[Parallel(n_jobs=1)]: Done 9 out of 9 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 0.1s finished\n", "Features: 6/8[Parallel(n_jobs=1)]: Done 8 out of 8 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Done 6 out of 6 | elapsed: 0.2s finished\n", "Features: 7/8[Parallel(n_jobs=1)]: Done 7 out of 7 | elapsed: 0.2s finished\n", "[Parallel(n_jobs=1)]: Done 7 out of 7 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Done 6 out of 6 | elapsed: 0.2s finished\n", "Features: 7/8[Parallel(n_jobs=1)]: Done 7 out of 7 | elapsed: 0.2s finished\n", "[Parallel(n_jobs=1)]: Done 7 out of 7 | elapsed: 0.2s finished\n", "Features: 8/8" ] } ], "source": [ "_ = feature_selector.fit(df_augd_features_train, df_target_train.iloc[:, 0])" ] }, { "cell_type": "markdown", "id": "ed67ab27-023f-4639-85a0-cd4c3ef85dc8", "metadata": {}, "source": [ "Выбранные признаки (имена и индексы):" ] }, { "cell_type": "code", "execution_count": 62, "id": "66be5774-0ff7-43be-99df-b56c9165a4f7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'names': ('extend_features_as_polynomial__selling_price',\n", " 'extend_features_as_polynomial__selling_price^2',\n", " 'extend_features_as_spline__age_sp_1',\n", " 'extend_features_as_spline__age_sp_2',\n", " 'scale_to_standard__age'),\n", " 'indices': (0, 2, 6, 7, 10)}" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "build_selected_columns_info_for_mlflow_from_sequential_feature_selector(feature_selector)" ] }, { "cell_type": "markdown", "id": "1c7498fd-669b-4fec-83ad-7d688fd23698", "metadata": {}, "source": [ "MAPE в зависимости от количества выбранных признаков (указан регион выбора, ограниченный `FILTERED_FEATURES_NUM`):" ] }, { "cell_type": "code", "execution_count": 63, "id": "0180f3da-9775-451f-8262-b825f69228fc", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plot_sequential_feature_selection(feature_selector, kind='std_dev')\n", "ax.grid(True)\n", "if isinstance(FILTERED_FEATURES_NUM, Sequence):\n", " _ = ax.axvspan(min(FILTERED_FEATURES_NUM), max(FILTERED_FEATURES_NUM), color=matplotlib.colormaps.get_cmap('tab10')(6), alpha=0.15)\n", "# хотелось бы поставить верхнюю границу `len(df_augd_features_train.columns)`, но SequentialFeatureSelector до неё не досчитывает-то\n", "_ = ax.set_xlim((1, (max(FILTERED_FEATURES_NUM) if isinstance(FILTERED_FEATURES_NUM, Sequence) else FILTERED_FEATURES_NUM)))\n", "_ = ax.set_ylim((None, 0.))" ] }, { "cell_type": "markdown", "id": "1fc207ba-f324-4980-9f6f-b0c83ef2e127", "metadata": {}, "source": [ "Составной пайплайн:" ] }, { "cell_type": "code", "execution_count": 64, "id": "1ff048d8-63a9-45cc-b613-50891aab4612", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('preprocess',\n",
       "                 ColumnTransformer(transformers=[('extend_features_as_polynomial',\n",
       "                                                  Pipeline(steps=[('extend_features',\n",
       "                                                                   PolynomialFeatures(include_bias=False)),\n",
       "                                                                  ('scale_to_standard',\n",
       "                                                                   StandardScaler())]),\n",
       "                                                  ('selling_price',\n",
       "                                                   'driven_kms')),\n",
       "                                                 ('extend_features_as_spline',\n",
       "                                                  SplineTransformer(include_bias=False,\n",
       "                                                                    knots='quantile',\n",
       "                                                                    n_knots=4),\n",
       "                                                  ('age',)),\n",
       "                                                 ('s...\n",
       "                                                  ('fuel_type', 'selling_type',\n",
       "                                                   'transmission'))])),\n",
       "                ('select_features',\n",
       "                 SequentialFeatureSelector(cv=4,\n",
       "                                           estimator=RandomForestRegressor(max_depth=8,\n",
       "                                                                           max_features='sqrt',\n",
       "                                                                           n_estimators=10),\n",
       "                                           floating=True, k_features=(4, 8),\n",
       "                                           scoring='neg_mean_absolute_percentage_error',\n",
       "                                           verbose=1)),\n",
       "                ('regress',\n",
       "                 RandomForestRegressor(max_depth=8, max_features='sqrt',\n",
       "                                       n_estimators=10))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('preprocess',\n", " ColumnTransformer(transformers=[('extend_features_as_polynomial',\n", " Pipeline(steps=[('extend_features',\n", " PolynomialFeatures(include_bias=False)),\n", " ('scale_to_standard',\n", " StandardScaler())]),\n", " ('selling_price',\n", " 'driven_kms')),\n", " ('extend_features_as_spline',\n", " SplineTransformer(include_bias=False,\n", " knots='quantile',\n", " n_knots=4),\n", " ('age',)),\n", " ('s...\n", " ('fuel_type', 'selling_type',\n", " 'transmission'))])),\n", " ('select_features',\n", " SequentialFeatureSelector(cv=4,\n", " estimator=RandomForestRegressor(max_depth=8,\n", " max_features='sqrt',\n", " n_estimators=10),\n", " floating=True, k_features=(4, 8),\n", " scoring='neg_mean_absolute_percentage_error',\n", " verbose=1)),\n", " ('regress',\n", " RandomForestRegressor(max_depth=8, max_features='sqrt',\n", " n_estimators=10))])" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline = sklearn.pipeline.Pipeline([\n", " ('preprocess', build_preprocess_augmenting_transformer()),\n", " ('select_features', feature_selector),\n", " ('regress', regressor),\n", "])\n", "pipeline" ] }, { "cell_type": "code", "execution_count": 65, "id": "857ca3e3-39c2-4bea-99fa-526f9bb4fcf3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'preprocess__remainder': 'drop',\n", " 'preprocess__sparse_threshold': 0.3,\n", " 'preprocess__transformer_weights': None,\n", " 'preprocess__extend_features_as_spline': SplineTransformer(include_bias=False, knots='quantile', n_knots=4),\n", " 'preprocess__extend_features_as_polynomial__extend_features': PolynomialFeatures(include_bias=False),\n", " 'preprocess__extend_features_as_polynomial__extend_features__degree': 2,\n", " 'preprocess__extend_features_as_polynomial__extend_features__include_bias': False,\n", " 'preprocess__extend_features_as_polynomial__extend_features__interaction_only': False,\n", " 'preprocess__extend_features_as_polynomial__extend_features__order': 'C',\n", " 'preprocess__extend_features_as_polynomial__scale_to_standard__with_mean': True,\n", " 'preprocess__extend_features_as_polynomial__scale_to_standard__with_std': True,\n", " 'preprocess__extend_features_as_spline__degree': 3,\n", " 'preprocess__extend_features_as_spline__extrapolation': 'constant',\n", " 'preprocess__extend_features_as_spline__include_bias': False,\n", " 'preprocess__extend_features_as_spline__knots': 'quantile',\n", " 'preprocess__extend_features_as_spline__n_knots': 4,\n", " 'preprocess__extend_features_as_spline__order': 'C',\n", " 'preprocess__extend_features_as_spline__sparse_output': False,\n", " 'preprocess__scale_to_standard__with_mean': True,\n", " 'preprocess__scale_to_standard__with_std': True,\n", " 'select_features__cv': 4,\n", " 'select_features__feature_groups': None,\n", " 'select_features__fixed_features': None,\n", " 'select_features__floating': True,\n", " 'select_features__forward': True,\n", " 'select_features__k_features': (4, 8),\n", " 'select_features__scoring': 'neg_mean_absolute_percentage_error',\n", " 'regress__bootstrap': True,\n", " 'regress__ccp_alpha': 0.0,\n", " 'regress__criterion': 'squared_error',\n", " 'regress__max_depth': 8,\n", " 'regress__max_features': 'sqrt',\n", " 'regress__max_leaf_nodes': None,\n", " 'regress__max_samples': None,\n", " 'regress__min_impurity_decrease': 0.0,\n", " 'regress__min_samples_leaf': 1,\n", " 'regress__min_samples_split': 2,\n", " 'regress__min_weight_fraction_leaf': 0.0,\n", " 'regress__monotonic_cst': None,\n", " 'regress__n_estimators': 10,\n", " 'regress__oob_score': False,\n", " 'regress__random_state': None}" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_params = filter_params(\n", " pipeline.get_params(),\n", " include={\n", " 'preprocess': (False, PREPROCESS_AUGMENTING_TRANSFORMER_PARAMS_COMMON_INCLUDE.copy()),\n", " 'select_features': (False, FEATURE_SELECTOR_PARAMS_COMMON_INCLUDE.copy()),\n", " 'regress': (False, True),\n", " },\n", " exclude={\n", " 'preprocess': PREPROCESS_AUGMENTING_TRANSFORMER_PARAMS_COMMON_EXCLUDE.copy(),\n", " 'select_features': FEATURE_SELECTOR_PARAMS_COMMON_EXCLUDE,\n", " 'regress': RANDOM_FOREST_REGRESSOR_PARAMS_COMMON_EXCLUDE,\n", " },\n", ")\n", "model_params" ] }, { "cell_type": "markdown", "id": "f05a1163-dced-4f54-be05-8da7fac7d611", "metadata": {}, "source": [ "Обучение модели:" ] }, { "cell_type": "code", "execution_count": 66, "id": "1a22889d-8cc7-42a4-a3f0-51af14723db8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Done 14 out of 14 | elapsed: 0.5s finished\n", "Features: 1/8[Parallel(n_jobs=1)]: Done 13 out of 13 | elapsed: 0.5s finished\n", "Features: 2/8[Parallel(n_jobs=1)]: Done 12 out of 12 | elapsed: 0.6s finished\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s finished\n", "Features: 3/8[Parallel(n_jobs=1)]: Done 11 out of 11 | elapsed: 0.4s finished\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n", "Features: 4/8[Parallel(n_jobs=1)]: Done 10 out of 10 | elapsed: 0.4s finished\n", "[Parallel(n_jobs=1)]: Done 4 out of 4 | elapsed: 0.1s finished\n", "Features: 5/8[Parallel(n_jobs=1)]: Done 9 out of 9 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 0.1s finished\n", "Features: 6/8[Parallel(n_jobs=1)]: Done 8 out of 8 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Done 6 out of 6 | elapsed: 0.2s finished\n", "Features: 7/8[Parallel(n_jobs=1)]: Done 7 out of 7 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Done 7 out of 7 | elapsed: 0.2s finished\n", "Features: 8/8" ] } ], "source": [ "# XXX: SequentialFeatureSelector обучается опять!?\n", "_ = pipeline.fit(df_orig_features_train, df_target_train.iloc[:, 0])" ] }, { "cell_type": "markdown", "id": "8ebb22dc-2bb2-48b1-a80c-5cf73b414fd8", "metadata": {}, "source": [ "Оценка качества:" ] }, { "cell_type": "code", "execution_count": 67, "id": "f6a1ebfb-13b0-4c40-896c-dc6e5c588d11", "metadata": {}, "outputs": [], "source": [ "target_test_predicted = pipeline.predict(df_orig_features_test)" ] }, { "cell_type": "markdown", "id": "fca0ac78-1371-43e3-8b57-4f5921ccedbe", "metadata": {}, "source": [ "Метрики качества (MAPE, а также MSE, MAE):" ] }, { "cell_type": "code", "execution_count": 68, "id": "2690f68f-4e4e-456e-880d-7a13cf60b0ea", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mse': 1.0194872911964548,\n", " 'mae': 0.6263087407494466,\n", " 'mape': 0.20033337884798225}" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics = score_predictions(df_target_test, target_test_predicted)\n", "metrics" ] }, { "cell_type": "code", "execution_count": 69, "id": "1f3d069c-2c5b-4214-9bed-17e6fe92a8d3", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "15d75fa1d12046c8b197bf0ac21439b9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading artifacts: 0%| | 0/7 [00:00#sk-container-id-9 {\n", " /* Definition of color scheme common for light and dark mode */\n", " --sklearn-color-text: #000;\n", " --sklearn-color-text-muted: #666;\n", " --sklearn-color-line: gray;\n", " /* Definition of color scheme for unfitted estimators */\n", " --sklearn-color-unfitted-level-0: #fff5e6;\n", " --sklearn-color-unfitted-level-1: #f6e4d2;\n", " --sklearn-color-unfitted-level-2: #ffe0b3;\n", " --sklearn-color-unfitted-level-3: chocolate;\n", " /* Definition of color scheme for fitted estimators */\n", " --sklearn-color-fitted-level-0: #f0f8ff;\n", " --sklearn-color-fitted-level-1: #d4ebff;\n", " --sklearn-color-fitted-level-2: #b3dbfd;\n", " --sklearn-color-fitted-level-3: cornflowerblue;\n", "\n", " /* Specific color for light theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-icon: #696969;\n", "\n", " @media (prefers-color-scheme: dark) {\n", " /* Redefinition of color scheme for dark theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-icon: #878787;\n", " }\n", "}\n", "\n", "#sk-container-id-9 {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "#sk-container-id-9 pre {\n", " padding: 0;\n", "}\n", "\n", "#sk-container-id-9 input.sk-hidden--visually {\n", " border: 0;\n", " clip: rect(1px 1px 1px 1px);\n", " clip: rect(1px, 1px, 1px, 1px);\n", " height: 1px;\n", " margin: -1px;\n", " overflow: hidden;\n", " padding: 0;\n", " position: absolute;\n", " width: 1px;\n", "}\n", "\n", "#sk-container-id-9 div.sk-dashed-wrapped {\n", " border: 1px dashed var(--sklearn-color-line);\n", " margin: 0 0.4em 0.5em 0.4em;\n", " box-sizing: border-box;\n", " padding-bottom: 0.4em;\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "#sk-container-id-9 div.sk-container {\n", " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", " so we also need the `!important` here to be able to override the\n", " default hidden behavior on the sphinx rendered scikit-learn.org.\n", " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", " display: inline-block !important;\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-9 div.sk-text-repr-fallback {\n", " display: none;\n", "}\n", "\n", "div.sk-parallel-item,\n", "div.sk-serial,\n", "div.sk-item {\n", " /* draw centered vertical line to link estimators */\n", " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", " background-size: 2px 100%;\n", " background-repeat: no-repeat;\n", " background-position: center center;\n", "}\n", "\n", "/* Parallel-specific style estimator block */\n", "\n", "#sk-container-id-9 div.sk-parallel-item::after {\n", " content: \"\";\n", " width: 100%;\n", " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", " flex-grow: 1;\n", "}\n", "\n", "#sk-container-id-9 div.sk-parallel {\n", " display: flex;\n", " align-items: stretch;\n", " justify-content: center;\n", " background-color: var(--sklearn-color-background);\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-9 div.sk-parallel-item {\n", " display: flex;\n", " flex-direction: column;\n", "}\n", "\n", "#sk-container-id-9 div.sk-parallel-item:first-child::after {\n", " align-self: flex-end;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-9 div.sk-parallel-item:last-child::after {\n", " align-self: flex-start;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-9 div.sk-parallel-item:only-child::after {\n", " width: 0;\n", "}\n", "\n", "/* Serial-specific style estimator block */\n", "\n", "#sk-container-id-9 div.sk-serial {\n", " display: flex;\n", " flex-direction: column;\n", " align-items: center;\n", " background-color: var(--sklearn-color-background);\n", " padding-right: 1em;\n", " padding-left: 1em;\n", "}\n", "\n", "\n", "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", "clickable and can be expanded/collapsed.\n", "- Pipeline and ColumnTransformer use this feature and define the default style\n", "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", "*/\n", "\n", "/* Pipeline and ColumnTransformer style (default) */\n", "\n", "#sk-container-id-9 div.sk-toggleable {\n", " /* Default theme specific background. It is overwritten whether we have a\n", " specific estimator or a Pipeline/ColumnTransformer */\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "/* Toggleable label */\n", "#sk-container-id-9 label.sk-toggleable__label {\n", " cursor: pointer;\n", " display: flex;\n", " width: 100%;\n", " margin-bottom: 0;\n", " padding: 0.5em;\n", " box-sizing: border-box;\n", " text-align: center;\n", " align-items: start;\n", " justify-content: space-between;\n", " gap: 0.5em;\n", "}\n", "\n", "#sk-container-id-9 label.sk-toggleable__label .caption {\n", " font-size: 0.6rem;\n", " font-weight: lighter;\n", " color: var(--sklearn-color-text-muted);\n", "}\n", "\n", "#sk-container-id-9 label.sk-toggleable__label-arrow:before {\n", " /* Arrow on the left of the label */\n", " content: \"▸\";\n", " float: left;\n", " margin-right: 0.25em;\n", " color: var(--sklearn-color-icon);\n", "}\n", "\n", "#sk-container-id-9 label.sk-toggleable__label-arrow:hover:before {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "/* Toggleable content - dropdown */\n", "\n", "#sk-container-id-9 div.sk-toggleable__content {\n", " display: none;\n", " text-align: left;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-9 div.sk-toggleable__content.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-9 div.sk-toggleable__content pre {\n", " margin: 0.2em;\n", " border-radius: 0.25em;\n", " color: var(--sklearn-color-text);\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-9 div.sk-toggleable__content.fitted pre {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-9 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", " /* Expand drop-down */\n", " display: block;\n", " width: 100%;\n", " overflow: visible;\n", "}\n", "\n", "#sk-container-id-9 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", " content: \"▾\";\n", "}\n", "\n", "/* Pipeline/ColumnTransformer-specific style */\n", "\n", "#sk-container-id-9 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-9 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator-specific style */\n", "\n", "/* Colorize estimator box */\n", "#sk-container-id-9 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-9 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "#sk-container-id-9 div.sk-label label.sk-toggleable__label,\n", "#sk-container-id-9 div.sk-label label {\n", " /* The background is the default theme color */\n", " color: var(--sklearn-color-text-on-default-background);\n", "}\n", "\n", "/* On hover, darken the color of the background */\n", "#sk-container-id-9 div.sk-label:hover label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "/* Label box, darken color on hover, fitted */\n", "#sk-container-id-9 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator label */\n", "\n", "#sk-container-id-9 div.sk-label label {\n", " font-family: monospace;\n", " font-weight: bold;\n", " display: inline-block;\n", " line-height: 1.2em;\n", "}\n", "\n", "#sk-container-id-9 div.sk-label-container {\n", " text-align: center;\n", "}\n", "\n", "/* Estimator-specific */\n", "#sk-container-id-9 div.sk-estimator {\n", " font-family: monospace;\n", " border: 1px dotted var(--sklearn-color-border-box);\n", " border-radius: 0.25em;\n", " box-sizing: border-box;\n", " margin-bottom: 0.5em;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-9 div.sk-estimator.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "/* on hover */\n", "#sk-container-id-9 div.sk-estimator:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-9 div.sk-estimator.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", "\n", "/* Common style for \"i\" and \"?\" */\n", "\n", ".sk-estimator-doc-link,\n", "a:link.sk-estimator-doc-link,\n", "a:visited.sk-estimator-doc-link {\n", " float: right;\n", " font-size: smaller;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1em;\n", " height: 1em;\n", " width: 1em;\n", " text-decoration: none !important;\n", " margin-left: 0.5em;\n", " text-align: center;\n", " /* unfitted */\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-unfitted-level-1);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted,\n", "a:link.sk-estimator-doc-link.fitted,\n", "a:visited.sk-estimator-doc-link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "/* Span, style for the box shown on hovering the info icon */\n", ".sk-estimator-doc-link span {\n", " display: none;\n", " z-index: 9999;\n", " position: relative;\n", " font-weight: normal;\n", " right: .2ex;\n", " padding: .5ex;\n", " margin: .5ex;\n", " width: min-content;\n", " min-width: 20ex;\n", " max-width: 50ex;\n", " color: var(--sklearn-color-text);\n", " box-shadow: 2pt 2pt 4pt #999;\n", " /* unfitted */\n", " background: var(--sklearn-color-unfitted-level-0);\n", " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted span {\n", " /* fitted */\n", " background: var(--sklearn-color-fitted-level-0);\n", " border: var(--sklearn-color-fitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link:hover span {\n", " display: block;\n", "}\n", "\n", "/* \"?\"-specific style due to the `` HTML tag */\n", "\n", "#sk-container-id-9 a.estimator_doc_link {\n", " float: right;\n", " font-size: 1rem;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1rem;\n", " height: 1rem;\n", " width: 1rem;\n", " text-decoration: none;\n", " /* unfitted */\n", " color: var(--sklearn-color-unfitted-level-1);\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", "}\n", "\n", "#sk-container-id-9 a.estimator_doc_link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "#sk-container-id-9 a.estimator_doc_link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "#sk-container-id-9 a.estimator_doc_link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", "}\n", "\n", ".estimator-table summary {\n", " padding: .5rem;\n", " font-family: monospace;\n", " cursor: pointer;\n", "}\n", "\n", ".estimator-table details[open] {\n", " padding-left: 0.1rem;\n", " padding-right: 0.1rem;\n", " padding-bottom: 0.3rem;\n", "}\n", "\n", ".estimator-table .parameters-table {\n", " margin-left: auto !important;\n", " margin-right: auto !important;\n", "}\n", "\n", ".estimator-table .parameters-table tr:nth-child(odd) {\n", " background-color: #fff;\n", "}\n", "\n", ".estimator-table .parameters-table tr:nth-child(even) {\n", " background-color: #f6f6f6;\n", "}\n", "\n", ".estimator-table .parameters-table tr:hover {\n", " background-color: #e0e0e0;\n", "}\n", "\n", ".estimator-table table td {\n", " border: 1px solid rgba(106, 105, 104, 0.232);\n", "}\n", "\n", ".user-set td {\n", " color:rgb(255, 94, 0);\n", " text-align: left;\n", "}\n", "\n", ".user-set td.value pre {\n", " color:rgb(255, 94, 0) !important;\n", " background-color: transparent !important;\n", "}\n", "\n", ".default td {\n", " color: black;\n", " text-align: left;\n", "}\n", "\n", ".user-set td i,\n", ".default td i {\n", " color: black;\n", "}\n", "\n", ".copy-paste-icon {\n", " background-image: url();\n", " background-repeat: no-repeat;\n", " background-size: 14px 14px;\n", " background-position: 0;\n", " display: inline-block;\n", " width: 14px;\n", " height: 14px;\n", " cursor: pointer;\n", "}\n", "
Pipeline(steps=[('preprocess',\n",
       "                 ColumnTransformer(transformers=[('extend_features_as_polynomial',\n",
       "                                                  Pipeline(steps=[('extend_features',\n",
       "                                                                   PolynomialFeatures(include_bias=False)),\n",
       "                                                                  ('scale_to_standard',\n",
       "                                                                   StandardScaler())]),\n",
       "                                                  ('selling_price',\n",
       "                                                   'driven_kms')),\n",
       "                                                 ('extend_features_as_spline',\n",
       "                                                  SplineTransformer(include_bias=False,\n",
       "                                                                    knots='quantile',\n",
       "                                                                    n_knots=4),\n",
       "                                                  ('age',)),\n",
       "                                                 ('s...\n",
       "                                                  ('fuel_type', 'selling_type',\n",
       "                                                   'transmission'))])),\n",
       "                ('select_features',\n",
       "                 SequentialFeatureSelector(cv=4,\n",
       "                                           estimator=RandomForestRegressor(max_depth=8,\n",
       "                                                                           max_features='sqrt',\n",
       "                                                                           n_estimators=10),\n",
       "                                           floating=True, k_features=(4, 8),\n",
       "                                           scoring='neg_mean_absolute_percentage_error')),\n",
       "                ('regress',\n",
       "                 RandomForestRegressor(max_depth=10,\n",
       "                                       max_features=0.4752873867901817,\n",
       "                                       n_estimators=78))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('preprocess',\n", " ColumnTransformer(transformers=[('extend_features_as_polynomial',\n", " Pipeline(steps=[('extend_features',\n", " PolynomialFeatures(include_bias=False)),\n", " ('scale_to_standard',\n", " StandardScaler())]),\n", " ('selling_price',\n", " 'driven_kms')),\n", " ('extend_features_as_spline',\n", " SplineTransformer(include_bias=False,\n", " knots='quantile',\n", " n_knots=4),\n", " ('age',)),\n", " ('s...\n", " ('fuel_type', 'selling_type',\n", " 'transmission'))])),\n", " ('select_features',\n", " SequentialFeatureSelector(cv=4,\n", " estimator=RandomForestRegressor(max_depth=8,\n", " max_features='sqrt',\n", " n_estimators=10),\n", " floating=True, k_features=(4, 8),\n", " scoring='neg_mean_absolute_percentage_error')),\n", " ('regress',\n", " RandomForestRegressor(max_depth=10,\n", " max_features=0.4752873867901817,\n", " n_estimators=78))])" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline = build_pipeline_optimized_best()\n", "pipeline" ] }, { "cell_type": "code", "execution_count": 78, "id": "445380bd-a56f-41f6-b148-9e4fee189a09", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'preprocess__remainder': 'drop',\n", " 'preprocess__sparse_threshold': 0.3,\n", " 'preprocess__transformer_weights': None,\n", " 'preprocess__extend_features_as_spline': SplineTransformer(include_bias=False, knots='quantile', n_knots=4),\n", " 'preprocess__extend_features_as_polynomial__extend_features': PolynomialFeatures(include_bias=False),\n", " 'preprocess__extend_features_as_polynomial__extend_features__degree': 2,\n", " 'preprocess__extend_features_as_polynomial__extend_features__include_bias': False,\n", " 'preprocess__extend_features_as_polynomial__extend_features__interaction_only': False,\n", " 'preprocess__extend_features_as_polynomial__extend_features__order': 'C',\n", " 'preprocess__extend_features_as_polynomial__scale_to_standard__with_mean': True,\n", " 'preprocess__extend_features_as_polynomial__scale_to_standard__with_std': True,\n", " 'preprocess__extend_features_as_spline__degree': 3,\n", " 'preprocess__extend_features_as_spline__extrapolation': 'constant',\n", " 'preprocess__extend_features_as_spline__include_bias': False,\n", " 'preprocess__extend_features_as_spline__knots': 'quantile',\n", " 'preprocess__extend_features_as_spline__n_knots': 4,\n", " 'preprocess__extend_features_as_spline__order': 'C',\n", " 'preprocess__extend_features_as_spline__sparse_output': False,\n", " 'preprocess__scale_to_standard__with_mean': True,\n", " 'preprocess__scale_to_standard__with_std': True,\n", " 'select_features__cv': 4,\n", " 'select_features__feature_groups': None,\n", " 'select_features__fixed_features': None,\n", " 'select_features__floating': True,\n", " 'select_features__forward': True,\n", " 'select_features__k_features': (4, 8),\n", " 'select_features__scoring': 'neg_mean_absolute_percentage_error',\n", " 'regress__bootstrap': True,\n", " 'regress__ccp_alpha': 0.0,\n", " 'regress__criterion': 'squared_error',\n", " 'regress__max_depth': 10,\n", " 'regress__max_features': 0.4752873867901817,\n", " 'regress__max_leaf_nodes': None,\n", " 'regress__max_samples': None,\n", " 'regress__min_impurity_decrease': 0.0,\n", " 'regress__min_samples_leaf': 1,\n", " 'regress__min_samples_split': 2,\n", " 'regress__min_weight_fraction_leaf': 0.0,\n", " 'regress__monotonic_cst': None,\n", " 'regress__n_estimators': 78,\n", " 'regress__oob_score': False,\n", " 'regress__random_state': None}" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_params = filter_params(\n", " pipeline.get_params(),\n", " include={\n", " 'preprocess': (False, PREPROCESS_AUGMENTING_TRANSFORMER_PARAMS_COMMON_INCLUDE.copy()),\n", " 'select_features': (False, FEATURE_SELECTOR_PARAMS_COMMON_INCLUDE.copy()),\n", " 'regress': (False, True),\n", " },\n", " exclude={\n", " 'preprocess': PREPROCESS_AUGMENTING_TRANSFORMER_PARAMS_COMMON_EXCLUDE.copy(),\n", " 'select_features': FEATURE_SELECTOR_PARAMS_COMMON_EXCLUDE,\n", " 'regress': RANDOM_FOREST_REGRESSOR_PARAMS_COMMON_EXCLUDE,\n", " },\n", ")\n", "model_params" ] }, { "cell_type": "markdown", "id": "3f30dacc-3edd-4821-b45b-5dbb06327cbd", "metadata": {}, "source": [ "Обучение модели:" ] }, { "cell_type": "code", "execution_count": 79, "id": "3b4d37f6-e3e0-4dbf-98f4-2993b5e2216d", "metadata": {}, "outputs": [], "source": [ "_ = pipeline.fit(df_orig_features_train, df_target_train.iloc[:, 0])" ] }, { "cell_type": "markdown", "id": "dc586b98-7431-4fa6-848d-fa50c03d4952", "metadata": {}, "source": [ "Оценка качества:" ] }, { "cell_type": "code", "execution_count": 80, "id": "99b16840-f368-4b38-b3d5-e98cfd52ace8", "metadata": {}, "outputs": [], "source": [ "target_test_predicted = pipeline.predict(df_orig_features_test)" ] }, { "cell_type": "markdown", "id": "e4601f93-a431-494f-b047-6bcffb406c90", "metadata": {}, "source": [ "Метрики качества (MAPE, а также MSE, MAE):" ] }, { "cell_type": "code", "execution_count": 81, "id": "29bb2b58-fd88-40d0-9998-8376f72a83fb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mse': 0.9370236080018509,\n", " 'mae': 0.6048078379366015,\n", " 'mape': 0.19721535277529492}" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics = score_predictions(df_target_test, target_test_predicted)\n", "metrics" ] }, { "cell_type": "code", "execution_count": 82, "id": "57c13865-8763-41d6-9b1d-3103070be086", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6f4a84b68c834b93bc62c1982114ddea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading artifacts: 0%| | 0/7 [00:00#sk-container-id-10 {\n", " /* Definition of color scheme common for light and dark mode */\n", " --sklearn-color-text: #000;\n", " --sklearn-color-text-muted: #666;\n", " --sklearn-color-line: gray;\n", " /* Definition of color scheme for unfitted estimators */\n", " --sklearn-color-unfitted-level-0: #fff5e6;\n", " --sklearn-color-unfitted-level-1: #f6e4d2;\n", " --sklearn-color-unfitted-level-2: #ffe0b3;\n", " --sklearn-color-unfitted-level-3: chocolate;\n", " /* Definition of color scheme for fitted estimators */\n", " --sklearn-color-fitted-level-0: #f0f8ff;\n", " --sklearn-color-fitted-level-1: #d4ebff;\n", " --sklearn-color-fitted-level-2: #b3dbfd;\n", " --sklearn-color-fitted-level-3: cornflowerblue;\n", "\n", " /* Specific color for light theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-icon: #696969;\n", "\n", " @media (prefers-color-scheme: dark) {\n", " /* Redefinition of color scheme for dark theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-icon: #878787;\n", " }\n", "}\n", "\n", "#sk-container-id-10 {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "#sk-container-id-10 pre {\n", " padding: 0;\n", "}\n", "\n", "#sk-container-id-10 input.sk-hidden--visually {\n", " border: 0;\n", " clip: rect(1px 1px 1px 1px);\n", " clip: rect(1px, 1px, 1px, 1px);\n", " height: 1px;\n", " margin: -1px;\n", " overflow: hidden;\n", " padding: 0;\n", " position: absolute;\n", " width: 1px;\n", "}\n", "\n", "#sk-container-id-10 div.sk-dashed-wrapped {\n", " border: 1px dashed var(--sklearn-color-line);\n", " margin: 0 0.4em 0.5em 0.4em;\n", " box-sizing: border-box;\n", " padding-bottom: 0.4em;\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "#sk-container-id-10 div.sk-container {\n", " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", " so we also need the `!important` here to be able to override the\n", " default hidden behavior on the sphinx rendered scikit-learn.org.\n", " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", " display: inline-block !important;\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-10 div.sk-text-repr-fallback {\n", " display: none;\n", "}\n", "\n", "div.sk-parallel-item,\n", "div.sk-serial,\n", "div.sk-item {\n", " /* draw centered vertical line to link estimators */\n", " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", " background-size: 2px 100%;\n", " background-repeat: no-repeat;\n", " background-position: center center;\n", "}\n", "\n", "/* Parallel-specific style estimator block */\n", "\n", "#sk-container-id-10 div.sk-parallel-item::after {\n", " content: \"\";\n", " width: 100%;\n", " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", " flex-grow: 1;\n", "}\n", "\n", "#sk-container-id-10 div.sk-parallel {\n", " display: flex;\n", " align-items: stretch;\n", " justify-content: center;\n", " background-color: var(--sklearn-color-background);\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-10 div.sk-parallel-item {\n", " display: flex;\n", " flex-direction: column;\n", "}\n", "\n", "#sk-container-id-10 div.sk-parallel-item:first-child::after {\n", " align-self: flex-end;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-10 div.sk-parallel-item:last-child::after {\n", " align-self: flex-start;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-10 div.sk-parallel-item:only-child::after {\n", " width: 0;\n", "}\n", "\n", "/* Serial-specific style estimator block */\n", "\n", "#sk-container-id-10 div.sk-serial {\n", " display: flex;\n", " flex-direction: column;\n", " align-items: center;\n", " background-color: var(--sklearn-color-background);\n", " padding-right: 1em;\n", " padding-left: 1em;\n", "}\n", "\n", "\n", "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", "clickable and can be expanded/collapsed.\n", "- Pipeline and ColumnTransformer use this feature and define the default style\n", "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", "*/\n", "\n", "/* Pipeline and ColumnTransformer style (default) */\n", "\n", "#sk-container-id-10 div.sk-toggleable {\n", " /* Default theme specific background. It is overwritten whether we have a\n", " specific estimator or a Pipeline/ColumnTransformer */\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "/* Toggleable label */\n", "#sk-container-id-10 label.sk-toggleable__label {\n", " cursor: pointer;\n", " display: flex;\n", " width: 100%;\n", " margin-bottom: 0;\n", " padding: 0.5em;\n", " box-sizing: border-box;\n", " text-align: center;\n", " align-items: start;\n", " justify-content: space-between;\n", " gap: 0.5em;\n", "}\n", "\n", "#sk-container-id-10 label.sk-toggleable__label .caption {\n", " font-size: 0.6rem;\n", " font-weight: lighter;\n", " color: var(--sklearn-color-text-muted);\n", "}\n", "\n", "#sk-container-id-10 label.sk-toggleable__label-arrow:before {\n", " /* Arrow on the left of the label */\n", " content: \"▸\";\n", " float: left;\n", " margin-right: 0.25em;\n", " color: var(--sklearn-color-icon);\n", "}\n", "\n", "#sk-container-id-10 label.sk-toggleable__label-arrow:hover:before {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "/* Toggleable content - dropdown */\n", "\n", "#sk-container-id-10 div.sk-toggleable__content {\n", " display: none;\n", " text-align: left;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-10 div.sk-toggleable__content.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-10 div.sk-toggleable__content pre {\n", " margin: 0.2em;\n", " border-radius: 0.25em;\n", " color: var(--sklearn-color-text);\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-10 div.sk-toggleable__content.fitted pre {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-10 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", " /* Expand drop-down */\n", " display: block;\n", " width: 100%;\n", " overflow: visible;\n", "}\n", "\n", "#sk-container-id-10 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", " content: \"▾\";\n", "}\n", "\n", "/* Pipeline/ColumnTransformer-specific style */\n", "\n", "#sk-container-id-10 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-10 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator-specific style */\n", "\n", "/* Colorize estimator box */\n", "#sk-container-id-10 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-10 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "#sk-container-id-10 div.sk-label label.sk-toggleable__label,\n", "#sk-container-id-10 div.sk-label label {\n", " /* The background is the default theme color */\n", " color: var(--sklearn-color-text-on-default-background);\n", "}\n", "\n", "/* On hover, darken the color of the background */\n", "#sk-container-id-10 div.sk-label:hover label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "/* Label box, darken color on hover, fitted */\n", "#sk-container-id-10 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator label */\n", "\n", "#sk-container-id-10 div.sk-label label {\n", " font-family: monospace;\n", " font-weight: bold;\n", " display: inline-block;\n", " line-height: 1.2em;\n", "}\n", "\n", "#sk-container-id-10 div.sk-label-container {\n", " text-align: center;\n", "}\n", "\n", "/* Estimator-specific */\n", "#sk-container-id-10 div.sk-estimator {\n", " font-family: monospace;\n", " border: 1px dotted var(--sklearn-color-border-box);\n", " border-radius: 0.25em;\n", " box-sizing: border-box;\n", " margin-bottom: 0.5em;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-10 div.sk-estimator.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "/* on hover */\n", "#sk-container-id-10 div.sk-estimator:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-10 div.sk-estimator.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", "\n", "/* Common style for \"i\" and \"?\" */\n", "\n", ".sk-estimator-doc-link,\n", "a:link.sk-estimator-doc-link,\n", "a:visited.sk-estimator-doc-link {\n", " float: right;\n", " font-size: smaller;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1em;\n", " height: 1em;\n", " width: 1em;\n", " text-decoration: none !important;\n", " margin-left: 0.5em;\n", " text-align: center;\n", " /* unfitted */\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-unfitted-level-1);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted,\n", "a:link.sk-estimator-doc-link.fitted,\n", "a:visited.sk-estimator-doc-link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "/* Span, style for the box shown on hovering the info icon */\n", ".sk-estimator-doc-link span {\n", " display: none;\n", " z-index: 9999;\n", " position: relative;\n", " font-weight: normal;\n", " right: .2ex;\n", " padding: .5ex;\n", " margin: .5ex;\n", " width: min-content;\n", " min-width: 20ex;\n", " max-width: 50ex;\n", " color: var(--sklearn-color-text);\n", " box-shadow: 2pt 2pt 4pt #999;\n", " /* unfitted */\n", " background: var(--sklearn-color-unfitted-level-0);\n", " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted span {\n", " /* fitted */\n", " background: var(--sklearn-color-fitted-level-0);\n", " border: var(--sklearn-color-fitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link:hover span {\n", " display: block;\n", "}\n", "\n", "/* \"?\"-specific style due to the `` HTML tag */\n", "\n", "#sk-container-id-10 a.estimator_doc_link {\n", " float: right;\n", " font-size: 1rem;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1rem;\n", " height: 1rem;\n", " width: 1rem;\n", " text-decoration: none;\n", " /* unfitted */\n", " color: var(--sklearn-color-unfitted-level-1);\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", "}\n", "\n", "#sk-container-id-10 a.estimator_doc_link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "#sk-container-id-10 a.estimator_doc_link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "#sk-container-id-10 a.estimator_doc_link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", "}\n", "\n", ".estimator-table summary {\n", " padding: .5rem;\n", " font-family: monospace;\n", " cursor: pointer;\n", "}\n", "\n", ".estimator-table details[open] {\n", " padding-left: 0.1rem;\n", " padding-right: 0.1rem;\n", " padding-bottom: 0.3rem;\n", "}\n", "\n", ".estimator-table .parameters-table {\n", " margin-left: auto !important;\n", " margin-right: auto !important;\n", "}\n", "\n", ".estimator-table .parameters-table tr:nth-child(odd) {\n", " background-color: #fff;\n", "}\n", "\n", ".estimator-table .parameters-table tr:nth-child(even) {\n", " background-color: #f6f6f6;\n", "}\n", "\n", ".estimator-table .parameters-table tr:hover {\n", " background-color: #e0e0e0;\n", "}\n", "\n", ".estimator-table table td {\n", " border: 1px solid rgba(106, 105, 104, 0.232);\n", "}\n", "\n", ".user-set td {\n", " color:rgb(255, 94, 0);\n", " text-align: left;\n", "}\n", "\n", ".user-set td.value pre {\n", " color:rgb(255, 94, 0) !important;\n", " background-color: transparent !important;\n", "}\n", "\n", ".default td {\n", " color: black;\n", " text-align: left;\n", "}\n", "\n", ".user-set td i,\n", ".default td i {\n", " color: black;\n", "}\n", "\n", ".copy-paste-icon {\n", " background-image: url();\n", " background-repeat: no-repeat;\n", " background-size: 14px 14px;\n", " background-position: 0;\n", " display: inline-block;\n", " width: 14px;\n", " height: 14px;\n", " cursor: pointer;\n", "}\n", "
Pipeline(steps=[('preprocess',\n",
       "                 ColumnTransformer(transformers=[('extend_features_as_polynomial',\n",
       "                                                  Pipeline(steps=[('extend_features',\n",
       "                                                                   PolynomialFeatures(include_bias=False)),\n",
       "                                                                  ('scale_to_standard',\n",
       "                                                                   StandardScaler())]),\n",
       "                                                  ('selling_price',\n",
       "                                                   'driven_kms')),\n",
       "                                                 ('extend_features_as_spline',\n",
       "                                                  SplineTransformer(include_bias=False,\n",
       "                                                                    knots='quantile',\n",
       "                                                                    n_knots=4),\n",
       "                                                  ('age',)),\n",
       "                                                 ('s...\n",
       "                                                  ('fuel_type', 'selling_type',\n",
       "                                                   'transmission'))])),\n",
       "                ('select_features',\n",
       "                 SequentialFeatureSelector(cv=4,\n",
       "                                           estimator=RandomForestRegressor(max_depth=8,\n",
       "                                                                           max_features='sqrt',\n",
       "                                                                           n_estimators=10),\n",
       "                                           floating=True, k_features=(4, 8),\n",
       "                                           scoring='neg_mean_absolute_percentage_error')),\n",
       "                ('regress',\n",
       "                 RandomForestRegressor(max_depth=10,\n",
       "                                       max_features=0.4752873867901817,\n",
       "                                       n_estimators=78))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('preprocess',\n", " ColumnTransformer(transformers=[('extend_features_as_polynomial',\n", " Pipeline(steps=[('extend_features',\n", " PolynomialFeatures(include_bias=False)),\n", " ('scale_to_standard',\n", " StandardScaler())]),\n", " ('selling_price',\n", " 'driven_kms')),\n", " ('extend_features_as_spline',\n", " SplineTransformer(include_bias=False,\n", " knots='quantile',\n", " n_knots=4),\n", " ('age',)),\n", " ('s...\n", " ('fuel_type', 'selling_type',\n", " 'transmission'))])),\n", " ('select_features',\n", " SequentialFeatureSelector(cv=4,\n", " estimator=RandomForestRegressor(max_depth=8,\n", " max_features='sqrt',\n", " n_estimators=10),\n", " floating=True, k_features=(4, 8),\n", " scoring='neg_mean_absolute_percentage_error')),\n", " ('regress',\n", " RandomForestRegressor(max_depth=10,\n", " max_features=0.4752873867901817,\n", " n_estimators=78))])" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline = build_pipeline_optimized_best()\n", "pipeline" ] }, { "cell_type": "code", "execution_count": 84, "id": "02ed0ad8-4068-4007-97a1-ffad1a79839e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'preprocess__remainder': 'drop',\n", " 'preprocess__sparse_threshold': 0.3,\n", " 'preprocess__transformer_weights': None,\n", " 'preprocess__extend_features_as_spline': SplineTransformer(include_bias=False, knots='quantile', n_knots=4),\n", " 'preprocess__extend_features_as_polynomial__extend_features': PolynomialFeatures(include_bias=False),\n", " 'preprocess__extend_features_as_polynomial__extend_features__degree': 2,\n", " 'preprocess__extend_features_as_polynomial__extend_features__include_bias': False,\n", " 'preprocess__extend_features_as_polynomial__extend_features__interaction_only': False,\n", " 'preprocess__extend_features_as_polynomial__extend_features__order': 'C',\n", " 'preprocess__extend_features_as_polynomial__scale_to_standard__with_mean': True,\n", " 'preprocess__extend_features_as_polynomial__scale_to_standard__with_std': True,\n", " 'preprocess__extend_features_as_spline__degree': 3,\n", " 'preprocess__extend_features_as_spline__extrapolation': 'constant',\n", " 'preprocess__extend_features_as_spline__include_bias': False,\n", " 'preprocess__extend_features_as_spline__knots': 'quantile',\n", " 'preprocess__extend_features_as_spline__n_knots': 4,\n", " 'preprocess__extend_features_as_spline__order': 'C',\n", " 'preprocess__extend_features_as_spline__sparse_output': False,\n", " 'preprocess__scale_to_standard__with_mean': True,\n", " 'preprocess__scale_to_standard__with_std': True,\n", " 'select_features__cv': 4,\n", " 'select_features__feature_groups': None,\n", " 'select_features__fixed_features': None,\n", " 'select_features__floating': True,\n", " 'select_features__forward': True,\n", " 'select_features__k_features': (4, 8),\n", " 'select_features__scoring': 'neg_mean_absolute_percentage_error',\n", " 'regress__bootstrap': True,\n", " 'regress__ccp_alpha': 0.0,\n", " 'regress__criterion': 'squared_error',\n", " 'regress__max_depth': 10,\n", " 'regress__max_features': 0.4752873867901817,\n", " 'regress__max_leaf_nodes': None,\n", " 'regress__max_samples': None,\n", " 'regress__min_impurity_decrease': 0.0,\n", " 'regress__min_samples_leaf': 1,\n", " 'regress__min_samples_split': 2,\n", " 'regress__min_weight_fraction_leaf': 0.0,\n", " 'regress__monotonic_cst': None,\n", " 'regress__n_estimators': 78,\n", " 'regress__oob_score': False,\n", " 'regress__random_state': None}" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_params = filter_params(\n", " pipeline.get_params(),\n", " include={\n", " 'preprocess': (False, PREPROCESS_AUGMENTING_TRANSFORMER_PARAMS_COMMON_INCLUDE.copy()),\n", " 'select_features': (False, FEATURE_SELECTOR_PARAMS_COMMON_INCLUDE.copy()),\n", " 'regress': (False, True),\n", " },\n", " exclude={\n", " 'preprocess': PREPROCESS_AUGMENTING_TRANSFORMER_PARAMS_COMMON_EXCLUDE.copy(),\n", " 'select_features': FEATURE_SELECTOR_PARAMS_COMMON_EXCLUDE,\n", " 'regress': RANDOM_FOREST_REGRESSOR_PARAMS_COMMON_EXCLUDE,\n", " },\n", ")\n", "model_params" ] }, { "cell_type": "code", "execution_count": 85, "id": "5c8b4d0c-f777-4c2a-8263-b2dda900a577", "metadata": {}, "outputs": [], "source": [ "_ = pipeline.fit(df_orig_features, df_target.iloc[:, 0])" ] }, { "cell_type": "code", "execution_count": 86, "id": "e01a1fe1-0e58-418a-9f05-7f4d304da7e5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "877854c58cbf4e3c959298d0959eea39", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading artifacts: 0%| | 0/7 [00:00