diff --git a/assets/mlflow/requirements.txt b/assets/mlflow/requirements.txt
index 24d3659..f7274ec 100644
--- a/assets/mlflow/requirements.txt
+++ b/assets/mlflow/requirements.txt
@@ -1,4 +1,6 @@
mlflow==2.16
scikit-learn
catboost
-numpy
\ No newline at end of file
+numpy
+mlxtend==0.23.1
+optuna==4.0.0
diff --git a/assets/mlflow/research.ipynb b/assets/mlflow/research.ipynb
index fa97d46..cd10d3c 100644
--- a/assets/mlflow/research.ipynb
+++ b/assets/mlflow/research.ipynb
@@ -731,9 +731,285 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "execution_count": 80,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " num__geo_lat | \n",
+ " num__geo_lon | \n",
+ " num__level | \n",
+ " num__levels | \n",
+ " num__rooms | \n",
+ " num__area | \n",
+ " num__kitchen_area | \n",
+ " cat__region | \n",
+ " cat__building_type | \n",
+ " cat__object_type | \n",
+ " quantile__geo_lat | \n",
+ " quantile__geo_lon | \n",
+ " quantile__level | \n",
+ " quantile__levels | \n",
+ " quantile__rooms | \n",
+ " quantile__area | \n",
+ " quantile__kitchen_area | \n",
+ " poly__1 | \n",
+ " poly__area | \n",
+ " poly__kitchen_area | \n",
+ " poly__area^2 | \n",
+ " poly__area kitchen_area | \n",
+ " poly__kitchen_area^2 | \n",
+ " spline__area_sp_0 | \n",
+ " spline__area_sp_1 | \n",
+ " spline__area_sp_2 | \n",
+ " spline__area_sp_3 | \n",
+ " spline__area_sp_4 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.495902 | \n",
+ " -0.449742 | \n",
+ " 0.359235 | \n",
+ " -0.214789 | \n",
+ " 0.253413 | \n",
+ " 0.063735 | \n",
+ " -0.186285 | \n",
+ " 20.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.766257 | \n",
+ " 0.511028 | \n",
+ " 0.717217 | \n",
+ " 0.536537 | \n",
+ " 0.600601 | \n",
+ " 0.623624 | \n",
+ " 0.374875 | \n",
+ " 0.0 | \n",
+ " 0.063735 | \n",
+ " -0.186285 | \n",
+ " -0.010002 | \n",
+ " -0.132188 | \n",
+ " -0.002792 | \n",
+ " 0.155806 | \n",
+ " 0.666179 | \n",
+ " 0.178013 | \n",
+ " 0.000002 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.177806 | \n",
+ " 1.433673 | \n",
+ " -0.246529 | \n",
+ " -0.367718 | \n",
+ " 0.253413 | \n",
+ " -0.114293 | \n",
+ " -0.186285 | \n",
+ " 70.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.297142 | \n",
+ " 0.867999 | \n",
+ " 0.522022 | \n",
+ " 0.386887 | \n",
+ " 0.600601 | \n",
+ " 0.541542 | \n",
+ " 0.374875 | \n",
+ " 0.0 | \n",
+ " -0.114293 | \n",
+ " -0.186285 | \n",
+ " -0.017375 | \n",
+ " -0.169370 | \n",
+ " -0.002792 | \n",
+ " 0.156921 | \n",
+ " 0.666275 | \n",
+ " 0.176803 | \n",
+ " 0.000001 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 410773 | \n",
+ " -0.748366 | \n",
+ " -0.804077 | \n",
+ " -0.650371 | \n",
+ " 0.702788 | \n",
+ " 0.253413 | \n",
+ " 1.365441 | \n",
+ " 1.501833 | \n",
+ " 52.0 | \n",
+ " 3.0 | \n",
+ " 0.0 | \n",
+ " 0.193143 | \n",
+ " 0.114753 | \n",
+ " 0.309810 | \n",
+ " 0.741742 | \n",
+ " 0.600601 | \n",
+ " 0.961367 | \n",
+ " 0.984535 | \n",
+ " 0.0 | \n",
+ " 1.365441 | \n",
+ " 1.501833 | \n",
+ " 0.068438 | \n",
+ " 1.570163 | \n",
+ " 0.008616 | \n",
+ " 0.147820 | \n",
+ " 0.665159 | \n",
+ " 0.187011 | \n",
+ " 0.000010 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 410774 | \n",
+ " 1.257769 | \n",
+ " -1.101815 | \n",
+ " -0.044608 | \n",
+ " 0.091070 | \n",
+ " 1.175911 | \n",
+ " 0.553789 | \n",
+ " -0.142544 | \n",
+ " 14.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.908036 | \n",
+ " 0.075725 | \n",
+ " 0.604605 | \n",
+ " 0.645646 | \n",
+ " 0.867367 | \n",
+ " 0.841842 | \n",
+ " 0.436436 | \n",
+ " 0.0 | \n",
+ " 0.553789 | \n",
+ " -0.142544 | \n",
+ " 0.014463 | \n",
+ " -0.002742 | \n",
+ " -0.002649 | \n",
+ " 0.152767 | \n",
+ " 0.665860 | \n",
+ " 0.181370 | \n",
+ " 0.000004 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
410775 rows × 28 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " num__geo_lat num__geo_lon num__level num__levels num__rooms \\\n",
+ "0 0.495902 -0.449742 0.359235 -0.214789 0.253413 \n",
+ "1 0.177806 1.433673 -0.246529 -0.367718 0.253413 \n",
+ "... ... ... ... ... ... \n",
+ "410773 -0.748366 -0.804077 -0.650371 0.702788 0.253413 \n",
+ "410774 1.257769 -1.101815 -0.044608 0.091070 1.175911 \n",
+ "\n",
+ " num__area num__kitchen_area cat__region cat__building_type \\\n",
+ "0 0.063735 -0.186285 20.0 1.0 \n",
+ "1 -0.114293 -0.186285 70.0 1.0 \n",
+ "... ... ... ... ... \n",
+ "410773 1.365441 1.501833 52.0 3.0 \n",
+ "410774 0.553789 -0.142544 14.0 1.0 \n",
+ "\n",
+ " cat__object_type quantile__geo_lat quantile__geo_lon \\\n",
+ "0 0.0 0.766257 0.511028 \n",
+ "1 0.0 0.297142 0.867999 \n",
+ "... ... ... ... \n",
+ "410773 0.0 0.193143 0.114753 \n",
+ "410774 0.0 0.908036 0.075725 \n",
+ "\n",
+ " quantile__level quantile__levels quantile__rooms quantile__area \\\n",
+ "0 0.717217 0.536537 0.600601 0.623624 \n",
+ "1 0.522022 0.386887 0.600601 0.541542 \n",
+ "... ... ... ... ... \n",
+ "410773 0.309810 0.741742 0.600601 0.961367 \n",
+ "410774 0.604605 0.645646 0.867367 0.841842 \n",
+ "\n",
+ " quantile__kitchen_area poly__1 poly__area poly__kitchen_area \\\n",
+ "0 0.374875 0.0 0.063735 -0.186285 \n",
+ "1 0.374875 0.0 -0.114293 -0.186285 \n",
+ "... ... ... ... ... \n",
+ "410773 0.984535 0.0 1.365441 1.501833 \n",
+ "410774 0.436436 0.0 0.553789 -0.142544 \n",
+ "\n",
+ " poly__area^2 poly__area kitchen_area poly__kitchen_area^2 \\\n",
+ "0 -0.010002 -0.132188 -0.002792 \n",
+ "1 -0.017375 -0.169370 -0.002792 \n",
+ "... ... ... ... \n",
+ "410773 0.068438 1.570163 0.008616 \n",
+ "410774 0.014463 -0.002742 -0.002649 \n",
+ "\n",
+ " spline__area_sp_0 spline__area_sp_1 spline__area_sp_2 \\\n",
+ "0 0.155806 0.666179 0.178013 \n",
+ "1 0.156921 0.666275 0.176803 \n",
+ "... ... ... ... \n",
+ "410773 0.147820 0.665159 0.187011 \n",
+ "410774 0.152767 0.665860 0.181370 \n",
+ "\n",
+ " spline__area_sp_3 spline__area_sp_4 \n",
+ "0 0.000002 0.0 \n",
+ "1 0.000001 0.0 \n",
+ "... ... ... \n",
+ "410773 0.000010 0.0 \n",
+ "410774 0.000004 0.0 \n",
+ "\n",
+ "[410775 rows x 28 columns]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"# Удобно использовать для отображения всех строк\\столбцов в DataFrame\n",
"with pd.option_context('display.max_rows', 5, 'display.max_columns', None):\n",
@@ -914,9 +1190,368 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "execution_count": 81,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " num__geo_lat | \n",
+ " num__geo_lon | \n",
+ " num__level | \n",
+ " num__levels | \n",
+ " num__rooms | \n",
+ " num__area | \n",
+ " num__kitchen_area | \n",
+ " cat__region | \n",
+ " cat__building_type | \n",
+ " cat__object_type | \n",
+ " afr__geo_lat | \n",
+ " afr__geo_lon | \n",
+ " afr__level | \n",
+ " afr__levels | \n",
+ " afr__rooms | \n",
+ " afr__area | \n",
+ " afr__kitchen_area | \n",
+ " afr__area*rooms | \n",
+ " afr__area*geo_lon | \n",
+ " afr__levels*rooms | \n",
+ " afr__area*kitchen_area | \n",
+ " afr__sqrt(area)*geo_lat | \n",
+ " afr__sqrt(area)*log(level) | \n",
+ " afr__kitchen_area*log(level) | \n",
+ " afr__sqrt(area)*kitchen_area | \n",
+ " afr__geo_lon*log(kitchen_area) | \n",
+ " afr__sqrt(area)*sqrt(kitchen_area) | \n",
+ " afr__sqrt(geo_lon)*sqrt(kitchen_area) | \n",
+ " afr__log(area) | \n",
+ " afr__rooms*log(level) | \n",
+ " afr__kitchen_area*rooms | \n",
+ " afr__kitchen_area*levels | \n",
+ " afr__sqrt(geo_lon)*sqrt(level) | \n",
+ " afr__area**(3/2) | \n",
+ " afr__geo_lat*log(kitchen_area) | \n",
+ " afr__geo_lat*log(geo_lon) | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.495902 | \n",
+ " -0.449742 | \n",
+ " 0.359235 | \n",
+ " -0.214789 | \n",
+ " 0.253413 | \n",
+ " 0.063735 | \n",
+ " -0.186285 | \n",
+ " 20.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.495902 | \n",
+ " -0.449742 | \n",
+ " 0.359235 | \n",
+ " -0.214789 | \n",
+ " 0.253413 | \n",
+ " 0.063735 | \n",
+ " -0.186285 | \n",
+ " 0.006208 | \n",
+ " -0.195129 | \n",
+ " 0.060916 | \n",
+ " -0.132188 | \n",
+ " 0.373151 | \n",
+ " 0.688076 | \n",
+ " 0.044178 | \n",
+ " -0.211335 | \n",
+ " -0.481294 | \n",
+ " -0.153548 | \n",
+ " -0.490805 | \n",
+ " 0.307835 | \n",
+ " 0.690329 | \n",
+ " -0.132529 | \n",
+ " -0.352834 | \n",
+ " 0.323880 | \n",
+ " -0.008748 | \n",
+ " -0.031529 | \n",
+ " 0.068167 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.177806 | \n",
+ " 1.433673 | \n",
+ " -0.246529 | \n",
+ " -0.367718 | \n",
+ " 0.253413 | \n",
+ " -0.114293 | \n",
+ " -0.186285 | \n",
+ " 70.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.177806 | \n",
+ " 1.433673 | \n",
+ " -0.246529 | \n",
+ " -0.367718 | \n",
+ " 0.253413 | \n",
+ " -0.114293 | \n",
+ " -0.186285 | \n",
+ " -0.083402 | \n",
+ " 0.655053 | \n",
+ " -0.054279 | \n",
+ " -0.169370 | \n",
+ " 0.005114 | \n",
+ " 0.071369 | \n",
+ " -0.173647 | \n",
+ " -0.252775 | \n",
+ " 1.191304 | \n",
+ " -0.267268 | \n",
+ " 0.615798 | \n",
+ " 0.031907 | \n",
+ " 0.282625 | \n",
+ " -0.132529 | \n",
+ " -0.418643 | \n",
+ " 0.552794 | \n",
+ " -0.056540 | \n",
+ " -0.143829 | \n",
+ " 1.129118 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 410773 | \n",
+ " -0.748366 | \n",
+ " -0.804077 | \n",
+ " -0.650371 | \n",
+ " 0.702788 | \n",
+ " 0.253413 | \n",
+ " 1.365441 | \n",
+ " 1.501833 | \n",
+ " 52.0 | \n",
+ " 3.0 | \n",
+ " 0.0 | \n",
+ " -0.748366 | \n",
+ " -0.804077 | \n",
+ " -0.650371 | \n",
+ " 0.702788 | \n",
+ " 0.253413 | \n",
+ " 1.365441 | \n",
+ " 1.501833 | \n",
+ " 0.661427 | \n",
+ " 0.375199 | \n",
+ " 0.752088 | \n",
+ " 1.570163 | \n",
+ " 1.274445 | \n",
+ " -0.002521 | \n",
+ " 0.745507 | \n",
+ " 2.382258 | \n",
+ " 0.071599 | \n",
+ " 2.828890 | \n",
+ " 1.431272 | \n",
+ " 1.729715 | \n",
+ " -0.160491 | \n",
+ " 1.581436 | \n",
+ " 2.432437 | \n",
+ " -0.843150 | \n",
+ " 0.411475 | \n",
+ " 1.671069 | \n",
+ " -1.052343 | \n",
+ "
\n",
+ " \n",
+ " 410774 | \n",
+ " 1.257769 | \n",
+ " -1.101815 | \n",
+ " -0.044608 | \n",
+ " 0.091070 | \n",
+ " 1.175911 | \n",
+ " 0.553789 | \n",
+ " -0.142544 | \n",
+ " 14.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.257769 | \n",
+ " -1.101815 | \n",
+ " -0.044608 | \n",
+ " 0.091070 | \n",
+ " 1.175911 | \n",
+ " 0.553789 | \n",
+ " -0.142544 | \n",
+ " 0.807887 | \n",
+ " -0.330070 | \n",
+ " 0.982478 | \n",
+ " -0.002742 | \n",
+ " 1.338996 | \n",
+ " 0.635065 | \n",
+ " -0.040302 | \n",
+ " -0.055435 | \n",
+ " -1.025588 | \n",
+ " 0.202136 | \n",
+ " -0.916054 | \n",
+ " 0.940624 | \n",
+ " 1.217910 | \n",
+ " 0.311575 | \n",
+ " -0.174762 | \n",
+ " -0.415359 | \n",
+ " 0.135617 | \n",
+ " 0.359680 | \n",
+ " -0.246790 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
410775 rows × 36 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " num__geo_lat num__geo_lon num__level num__levels num__rooms \\\n",
+ "0 0.495902 -0.449742 0.359235 -0.214789 0.253413 \n",
+ "1 0.177806 1.433673 -0.246529 -0.367718 0.253413 \n",
+ "... ... ... ... ... ... \n",
+ "410773 -0.748366 -0.804077 -0.650371 0.702788 0.253413 \n",
+ "410774 1.257769 -1.101815 -0.044608 0.091070 1.175911 \n",
+ "\n",
+ " num__area num__kitchen_area cat__region cat__building_type \\\n",
+ "0 0.063735 -0.186285 20.0 1.0 \n",
+ "1 -0.114293 -0.186285 70.0 1.0 \n",
+ "... ... ... ... ... \n",
+ "410773 1.365441 1.501833 52.0 3.0 \n",
+ "410774 0.553789 -0.142544 14.0 1.0 \n",
+ "\n",
+ " cat__object_type afr__geo_lat afr__geo_lon afr__level afr__levels \\\n",
+ "0 0.0 0.495902 -0.449742 0.359235 -0.214789 \n",
+ "1 0.0 0.177806 1.433673 -0.246529 -0.367718 \n",
+ "... ... ... ... ... ... \n",
+ "410773 0.0 -0.748366 -0.804077 -0.650371 0.702788 \n",
+ "410774 0.0 1.257769 -1.101815 -0.044608 0.091070 \n",
+ "\n",
+ " afr__rooms afr__area afr__kitchen_area afr__area*rooms \\\n",
+ "0 0.253413 0.063735 -0.186285 0.006208 \n",
+ "1 0.253413 -0.114293 -0.186285 -0.083402 \n",
+ "... ... ... ... ... \n",
+ "410773 0.253413 1.365441 1.501833 0.661427 \n",
+ "410774 1.175911 0.553789 -0.142544 0.807887 \n",
+ "\n",
+ " afr__area*geo_lon afr__levels*rooms afr__area*kitchen_area \\\n",
+ "0 -0.195129 0.060916 -0.132188 \n",
+ "1 0.655053 -0.054279 -0.169370 \n",
+ "... ... ... ... \n",
+ "410773 0.375199 0.752088 1.570163 \n",
+ "410774 -0.330070 0.982478 -0.002742 \n",
+ "\n",
+ " afr__sqrt(area)*geo_lat afr__sqrt(area)*log(level) \\\n",
+ "0 0.373151 0.688076 \n",
+ "1 0.005114 0.071369 \n",
+ "... ... ... \n",
+ "410773 1.274445 -0.002521 \n",
+ "410774 1.338996 0.635065 \n",
+ "\n",
+ " afr__kitchen_area*log(level) afr__sqrt(area)*kitchen_area \\\n",
+ "0 0.044178 -0.211335 \n",
+ "1 -0.173647 -0.252775 \n",
+ "... ... ... \n",
+ "410773 0.745507 2.382258 \n",
+ "410774 -0.040302 -0.055435 \n",
+ "\n",
+ " afr__geo_lon*log(kitchen_area) afr__sqrt(area)*sqrt(kitchen_area) \\\n",
+ "0 -0.481294 -0.153548 \n",
+ "1 1.191304 -0.267268 \n",
+ "... ... ... \n",
+ "410773 0.071599 2.828890 \n",
+ "410774 -1.025588 0.202136 \n",
+ "\n",
+ " afr__sqrt(geo_lon)*sqrt(kitchen_area) afr__log(area) \\\n",
+ "0 -0.490805 0.307835 \n",
+ "1 0.615798 0.031907 \n",
+ "... ... ... \n",
+ "410773 1.431272 1.729715 \n",
+ "410774 -0.916054 0.940624 \n",
+ "\n",
+ " afr__rooms*log(level) afr__kitchen_area*rooms \\\n",
+ "0 0.690329 -0.132529 \n",
+ "1 0.282625 -0.132529 \n",
+ "... ... ... \n",
+ "410773 -0.160491 1.581436 \n",
+ "410774 1.217910 0.311575 \n",
+ "\n",
+ " afr__kitchen_area*levels afr__sqrt(geo_lon)*sqrt(level) \\\n",
+ "0 -0.352834 0.323880 \n",
+ "1 -0.418643 0.552794 \n",
+ "... ... ... \n",
+ "410773 2.432437 -0.843150 \n",
+ "410774 -0.174762 -0.415359 \n",
+ "\n",
+ " afr__area**(3/2) afr__geo_lat*log(kitchen_area) \\\n",
+ "0 -0.008748 -0.031529 \n",
+ "1 -0.056540 -0.143829 \n",
+ "... ... ... \n",
+ "410773 0.411475 1.671069 \n",
+ "410774 0.135617 0.359680 \n",
+ "\n",
+ " afr__geo_lat*log(geo_lon) \n",
+ "0 0.068167 \n",
+ "1 1.129118 \n",
+ "... ... \n",
+ "410773 -1.052343 \n",
+ "410774 -0.246790 \n",
+ "\n",
+ "[410775 rows x 36 columns]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"with pd.option_context('display.max_rows', 5, 'display.max_columns', None):\n",
" display (X_train_afr)\n"
@@ -983,6 +1618,1679 @@
"assert (run.info.status =='FINISHED')"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# FEATURE SELECTION\n",
+ "## RFE\n",
+ "### Используем autofeat признаки\n",
+ "Поскольку autofeat дает разные совокупности сгенерированных признаков, мы можем добавить выбор информативных только как шаг пайплайна "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 294,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " num__geo_lat | \n",
+ " num__geo_lon | \n",
+ " num__level | \n",
+ " num__levels | \n",
+ " num__rooms | \n",
+ " num__area | \n",
+ " num__kitchen_area | \n",
+ " cat__region | \n",
+ " cat__building_type | \n",
+ " cat__object_type | \n",
+ " ... | \n",
+ " afr__sqrt(area)*sqrt(kitchen_area) | \n",
+ " afr__sqrt(geo_lon)*sqrt(kitchen_area) | \n",
+ " afr__log(area) | \n",
+ " afr__rooms*log(level) | \n",
+ " afr__kitchen_area*rooms | \n",
+ " afr__kitchen_area*levels | \n",
+ " afr__sqrt(geo_lon)*sqrt(level) | \n",
+ " afr__area**(3/2) | \n",
+ " afr__geo_lat*log(kitchen_area) | \n",
+ " afr__geo_lat*log(geo_lon) | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.495902 | \n",
+ " -0.449742 | \n",
+ " 0.359235 | \n",
+ " -0.214789 | \n",
+ " 0.253413 | \n",
+ " 0.063735 | \n",
+ " -0.186285 | \n",
+ " 20.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " -0.153548 | \n",
+ " -0.490805 | \n",
+ " 0.307835 | \n",
+ " 0.690329 | \n",
+ " -0.132529 | \n",
+ " -0.352834 | \n",
+ " 0.323880 | \n",
+ " -0.008748 | \n",
+ " -0.031529 | \n",
+ " 0.068167 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.177806 | \n",
+ " 1.433673 | \n",
+ " -0.246529 | \n",
+ " -0.367718 | \n",
+ " 0.253413 | \n",
+ " -0.114293 | \n",
+ " -0.186285 | \n",
+ " 70.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " -0.267268 | \n",
+ " 0.615798 | \n",
+ " 0.031907 | \n",
+ " 0.282625 | \n",
+ " -0.132529 | \n",
+ " -0.418643 | \n",
+ " 0.552794 | \n",
+ " -0.056540 | \n",
+ " -0.143829 | \n",
+ " 1.129118 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.440548 | \n",
+ " 0.047222 | \n",
+ " -0.448450 | \n",
+ " -0.367718 | \n",
+ " -0.669085 | \n",
+ " -0.456947 | \n",
+ " -0.142544 | \n",
+ " 15.0 | \n",
+ " 3.0 | \n",
+ " 1.0 | \n",
+ " ... | \n",
+ " -0.454880 | \n",
+ " -0.067183 | \n",
+ " -0.603122 | \n",
+ " -0.512211 | \n",
+ " -0.487813 | \n",
+ " -0.383803 | \n",
+ " -0.243092 | \n",
+ " -0.140800 | \n",
+ " 0.063464 | \n",
+ " 0.460495 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " -1.588818 | \n",
+ " -0.722477 | \n",
+ " -0.246529 | \n",
+ " -0.979436 | \n",
+ " 0.253413 | \n",
+ " -0.181292 | \n",
+ " -0.142544 | \n",
+ " 18.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " -0.254514 | \n",
+ " -0.607607 | \n",
+ " -0.080304 | \n",
+ " 0.282625 | \n",
+ " -0.088119 | \n",
+ " -0.662523 | \n",
+ " -0.369355 | \n",
+ " -0.073838 | \n",
+ " -0.672113 | \n",
+ " -1.481033 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1.493662 | \n",
+ " 1.125819 | \n",
+ " 0.157313 | \n",
+ " 0.549858 | \n",
+ " 0.253413 | \n",
+ " 0.615045 | \n",
+ " -0.011322 | \n",
+ " 10.0 | \n",
+ " 2.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.438600 | \n",
+ " 0.891383 | \n",
+ " 1.009612 | \n",
+ " 0.574497 | \n",
+ " 0.045112 | \n",
+ " 0.208478 | \n",
+ " 0.945981 | \n",
+ " 0.154902 | \n",
+ " 0.780855 | \n",
+ " 1.923382 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 410770 | \n",
+ " 0.592011 | \n",
+ " 0.355014 | \n",
+ " 0.561156 | \n",
+ " 1.008646 | \n",
+ " 0.253413 | \n",
+ " -0.079836 | \n",
+ " -0.092653 | \n",
+ " 54.0 | \n",
+ " 2.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " -0.120035 | \n",
+ " 0.237580 | \n",
+ " 0.087725 | \n",
+ " 0.792500 | \n",
+ " -0.037463 | \n",
+ " 0.322797 | \n",
+ " 0.974381 | \n",
+ " -0.047496 | \n",
+ " 0.243018 | \n",
+ " 0.789871 | \n",
+ "
\n",
+ " \n",
+ " 410771 | \n",
+ " 0.240478 | \n",
+ " 0.392697 | \n",
+ " -0.650371 | \n",
+ " -0.979436 | \n",
+ " 0.253413 | \n",
+ " -0.334434 | \n",
+ " -0.404989 | \n",
+ " 45.0 | \n",
+ " 3.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " -0.716150 | \n",
+ " -0.510766 | \n",
+ " -0.357277 | \n",
+ " -0.160491 | \n",
+ " -0.354582 | \n",
+ " -0.778657 | \n",
+ " -0.406361 | \n",
+ " -0.111897 | \n",
+ " -0.808157 | \n",
+ " 0.574534 | \n",
+ "
\n",
+ " \n",
+ " 410772 | \n",
+ " -1.936771 | \n",
+ " -0.688830 | \n",
+ " 0.359235 | \n",
+ " 0.855717 | \n",
+ " -0.669085 | \n",
+ " -0.456947 | \n",
+ " -0.142544 | \n",
+ " 18.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " ... | \n",
+ " -0.454880 | \n",
+ " -0.581851 | \n",
+ " -0.603122 | \n",
+ " -0.211576 | \n",
+ " -0.487813 | \n",
+ " 0.173638 | \n",
+ " 0.170166 | \n",
+ " -0.140800 | \n",
+ " -0.798234 | \n",
+ " -1.663294 | \n",
+ "
\n",
+ " \n",
+ " 410773 | \n",
+ " -0.748366 | \n",
+ " -0.804077 | \n",
+ " -0.650371 | \n",
+ " 0.702788 | \n",
+ " 0.253413 | \n",
+ " 1.365441 | \n",
+ " 1.501833 | \n",
+ " 52.0 | \n",
+ " 3.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 2.828890 | \n",
+ " 1.431272 | \n",
+ " 1.729715 | \n",
+ " -0.160491 | \n",
+ " 1.581436 | \n",
+ " 2.432437 | \n",
+ " -0.843150 | \n",
+ " 0.411475 | \n",
+ " 1.671069 | \n",
+ " -1.052343 | \n",
+ "
\n",
+ " \n",
+ " 410774 | \n",
+ " 1.257769 | \n",
+ " -1.101815 | \n",
+ " -0.044608 | \n",
+ " 0.091070 | \n",
+ " 1.175911 | \n",
+ " 0.553789 | \n",
+ " -0.142544 | \n",
+ " 14.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.202136 | \n",
+ " -0.916054 | \n",
+ " 0.940624 | \n",
+ " 1.217910 | \n",
+ " 0.311575 | \n",
+ " -0.174762 | \n",
+ " -0.415359 | \n",
+ " 0.135617 | \n",
+ " 0.359680 | \n",
+ " -0.246790 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
410775 rows × 36 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " num__geo_lat num__geo_lon num__level num__levels num__rooms \\\n",
+ "0 0.495902 -0.449742 0.359235 -0.214789 0.253413 \n",
+ "1 0.177806 1.433673 -0.246529 -0.367718 0.253413 \n",
+ "2 0.440548 0.047222 -0.448450 -0.367718 -0.669085 \n",
+ "3 -1.588818 -0.722477 -0.246529 -0.979436 0.253413 \n",
+ "4 1.493662 1.125819 0.157313 0.549858 0.253413 \n",
+ "... ... ... ... ... ... \n",
+ "410770 0.592011 0.355014 0.561156 1.008646 0.253413 \n",
+ "410771 0.240478 0.392697 -0.650371 -0.979436 0.253413 \n",
+ "410772 -1.936771 -0.688830 0.359235 0.855717 -0.669085 \n",
+ "410773 -0.748366 -0.804077 -0.650371 0.702788 0.253413 \n",
+ "410774 1.257769 -1.101815 -0.044608 0.091070 1.175911 \n",
+ "\n",
+ " num__area num__kitchen_area cat__region cat__building_type \\\n",
+ "0 0.063735 -0.186285 20.0 1.0 \n",
+ "1 -0.114293 -0.186285 70.0 1.0 \n",
+ "2 -0.456947 -0.142544 15.0 3.0 \n",
+ "3 -0.181292 -0.142544 18.0 1.0 \n",
+ "4 0.615045 -0.011322 10.0 2.0 \n",
+ "... ... ... ... ... \n",
+ "410770 -0.079836 -0.092653 54.0 2.0 \n",
+ "410771 -0.334434 -0.404989 45.0 3.0 \n",
+ "410772 -0.456947 -0.142544 18.0 0.0 \n",
+ "410773 1.365441 1.501833 52.0 3.0 \n",
+ "410774 0.553789 -0.142544 14.0 1.0 \n",
+ "\n",
+ " cat__object_type ... afr__sqrt(area)*sqrt(kitchen_area) \\\n",
+ "0 0.0 ... -0.153548 \n",
+ "1 0.0 ... -0.267268 \n",
+ "2 1.0 ... -0.454880 \n",
+ "3 0.0 ... -0.254514 \n",
+ "4 0.0 ... 0.438600 \n",
+ "... ... ... ... \n",
+ "410770 0.0 ... -0.120035 \n",
+ "410771 0.0 ... -0.716150 \n",
+ "410772 1.0 ... -0.454880 \n",
+ "410773 0.0 ... 2.828890 \n",
+ "410774 0.0 ... 0.202136 \n",
+ "\n",
+ " afr__sqrt(geo_lon)*sqrt(kitchen_area) afr__log(area) \\\n",
+ "0 -0.490805 0.307835 \n",
+ "1 0.615798 0.031907 \n",
+ "2 -0.067183 -0.603122 \n",
+ "3 -0.607607 -0.080304 \n",
+ "4 0.891383 1.009612 \n",
+ "... ... ... \n",
+ "410770 0.237580 0.087725 \n",
+ "410771 -0.510766 -0.357277 \n",
+ "410772 -0.581851 -0.603122 \n",
+ "410773 1.431272 1.729715 \n",
+ "410774 -0.916054 0.940624 \n",
+ "\n",
+ " afr__rooms*log(level) afr__kitchen_area*rooms \\\n",
+ "0 0.690329 -0.132529 \n",
+ "1 0.282625 -0.132529 \n",
+ "2 -0.512211 -0.487813 \n",
+ "3 0.282625 -0.088119 \n",
+ "4 0.574497 0.045112 \n",
+ "... ... ... \n",
+ "410770 0.792500 -0.037463 \n",
+ "410771 -0.160491 -0.354582 \n",
+ "410772 -0.211576 -0.487813 \n",
+ "410773 -0.160491 1.581436 \n",
+ "410774 1.217910 0.311575 \n",
+ "\n",
+ " afr__kitchen_area*levels afr__sqrt(geo_lon)*sqrt(level) \\\n",
+ "0 -0.352834 0.323880 \n",
+ "1 -0.418643 0.552794 \n",
+ "2 -0.383803 -0.243092 \n",
+ "3 -0.662523 -0.369355 \n",
+ "4 0.208478 0.945981 \n",
+ "... ... ... \n",
+ "410770 0.322797 0.974381 \n",
+ "410771 -0.778657 -0.406361 \n",
+ "410772 0.173638 0.170166 \n",
+ "410773 2.432437 -0.843150 \n",
+ "410774 -0.174762 -0.415359 \n",
+ "\n",
+ " afr__area**(3/2) afr__geo_lat*log(kitchen_area) \\\n",
+ "0 -0.008748 -0.031529 \n",
+ "1 -0.056540 -0.143829 \n",
+ "2 -0.140800 0.063464 \n",
+ "3 -0.073838 -0.672113 \n",
+ "4 0.154902 0.780855 \n",
+ "... ... ... \n",
+ "410770 -0.047496 0.243018 \n",
+ "410771 -0.111897 -0.808157 \n",
+ "410772 -0.140800 -0.798234 \n",
+ "410773 0.411475 1.671069 \n",
+ "410774 0.135617 0.359680 \n",
+ "\n",
+ " afr__geo_lat*log(geo_lon) \n",
+ "0 0.068167 \n",
+ "1 1.129118 \n",
+ "2 0.460495 \n",
+ "3 -1.481033 \n",
+ "4 1.923382 \n",
+ "... ... \n",
+ "410770 0.789871 \n",
+ "410771 0.574534 \n",
+ "410772 -1.663294 \n",
+ "410773 -1.052343 \n",
+ "410774 -0.246790 \n",
+ "\n",
+ "[410775 rows x 36 columns]"
+ ]
+ },
+ "execution_count": 294,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from sklearn.feature_selection import RFE\n",
+ "X_train_afr"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "rfe_selector = RFE(estimator=regressor, n_features_to_select=12, step = 0.2) #drop 20% of features each iteration\n",
+ "X_train_rfe = rfe_selector.fit_transform(X_train_afr,y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 297,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " num__geo_lat | \n",
+ " num__geo_lon | \n",
+ " afr__geo_lon | \n",
+ " afr__area*kitchen_area | \n",
+ " afr__sqrt(area)*geo_lat | \n",
+ " afr__sqrt(area)*log(level) | \n",
+ " afr__kitchen_area*log(level) | \n",
+ " afr__sqrt(area)*sqrt(kitchen_area) | \n",
+ " afr__rooms*log(level) | \n",
+ " afr__kitchen_area*rooms | \n",
+ " afr__sqrt(geo_lon)*sqrt(level) | \n",
+ " afr__geo_lat*log(geo_lon) | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.495902 | \n",
+ " -0.449742 | \n",
+ " -0.449742 | \n",
+ " -0.132188 | \n",
+ " 0.373151 | \n",
+ " 0.688076 | \n",
+ " 0.044178 | \n",
+ " -0.153548 | \n",
+ " 0.690329 | \n",
+ " -0.132529 | \n",
+ " 0.323880 | \n",
+ " 0.068167 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.177806 | \n",
+ " 1.433673 | \n",
+ " 1.433673 | \n",
+ " -0.169370 | \n",
+ " 0.005114 | \n",
+ " 0.071369 | \n",
+ " -0.173647 | \n",
+ " -0.267268 | \n",
+ " 0.282625 | \n",
+ " -0.132529 | \n",
+ " 0.552794 | \n",
+ " 1.129118 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.440548 | \n",
+ " 0.047222 | \n",
+ " 0.047222 | \n",
+ " -0.226261 | \n",
+ " -0.425530 | \n",
+ " -0.335537 | \n",
+ " -0.239271 | \n",
+ " -0.454880 | \n",
+ " -0.512211 | \n",
+ " -0.487813 | \n",
+ " -0.243092 | \n",
+ " 0.460495 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " -1.588818 | \n",
+ " -0.722477 | \n",
+ " -0.722477 | \n",
+ " -0.165302 | \n",
+ " -0.723225 | \n",
+ " 0.034116 | \n",
+ " -0.129771 | \n",
+ " -0.254514 | \n",
+ " 0.282625 | \n",
+ " -0.088119 | \n",
+ " -0.369355 | \n",
+ " -1.481033 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1.493662 | \n",
+ " 1.125819 | \n",
+ " 1.125819 | \n",
+ " 0.094342 | \n",
+ " 1.522265 | \n",
+ " 0.862773 | \n",
+ " 0.194490 | \n",
+ " 0.438600 | \n",
+ " 0.574497 | \n",
+ " 0.045112 | \n",
+ " 0.945981 | \n",
+ " 1.923382 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 410770 | \n",
+ " 0.592011 | \n",
+ " 0.355014 | \n",
+ " 0.355014 | \n",
+ " -0.120841 | \n",
+ " 0.206926 | \n",
+ " 0.714499 | \n",
+ " 0.226990 | \n",
+ " -0.120035 | \n",
+ " 0.792500 | \n",
+ " -0.037463 | \n",
+ " 0.974381 | \n",
+ " 0.789871 | \n",
+ "
\n",
+ " \n",
+ " 410771 | \n",
+ " 0.240478 | \n",
+ " 0.392697 | \n",
+ " 0.392697 | \n",
+ " -0.296252 | \n",
+ " -0.297209 | \n",
+ " -0.551021 | \n",
+ " -0.560144 | \n",
+ " -0.716150 | \n",
+ " -0.160491 | \n",
+ " -0.354582 | \n",
+ " -0.406361 | \n",
+ " 0.574534 | \n",
+ "
\n",
+ " \n",
+ " 410772 | \n",
+ " -1.936771 | \n",
+ " -0.688830 | \n",
+ " -0.688830 | \n",
+ " -0.226261 | \n",
+ " -1.192706 | \n",
+ " 0.306280 | \n",
+ " 0.100868 | \n",
+ " -0.454880 | \n",
+ " -0.211576 | \n",
+ " -0.487813 | \n",
+ " 0.170166 | \n",
+ " -1.663294 | \n",
+ "
\n",
+ " \n",
+ " 410773 | \n",
+ " -0.748366 | \n",
+ " -0.804077 | \n",
+ " -0.804077 | \n",
+ " 1.570163 | \n",
+ " 1.274445 | \n",
+ " -0.002521 | \n",
+ " 0.745507 | \n",
+ " 2.828890 | \n",
+ " -0.160491 | \n",
+ " 1.581436 | \n",
+ " -0.843150 | \n",
+ " -1.052343 | \n",
+ "
\n",
+ " \n",
+ " 410774 | \n",
+ " 1.257769 | \n",
+ " -1.101815 | \n",
+ " -1.101815 | \n",
+ " -0.002742 | \n",
+ " 1.338996 | \n",
+ " 0.635065 | \n",
+ " -0.040302 | \n",
+ " 0.202136 | \n",
+ " 1.217910 | \n",
+ " 0.311575 | \n",
+ " -0.415359 | \n",
+ " -0.246790 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
410775 rows × 12 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " num__geo_lat num__geo_lon afr__geo_lon afr__area*kitchen_area \\\n",
+ "0 0.495902 -0.449742 -0.449742 -0.132188 \n",
+ "1 0.177806 1.433673 1.433673 -0.169370 \n",
+ "2 0.440548 0.047222 0.047222 -0.226261 \n",
+ "3 -1.588818 -0.722477 -0.722477 -0.165302 \n",
+ "4 1.493662 1.125819 1.125819 0.094342 \n",
+ "... ... ... ... ... \n",
+ "410770 0.592011 0.355014 0.355014 -0.120841 \n",
+ "410771 0.240478 0.392697 0.392697 -0.296252 \n",
+ "410772 -1.936771 -0.688830 -0.688830 -0.226261 \n",
+ "410773 -0.748366 -0.804077 -0.804077 1.570163 \n",
+ "410774 1.257769 -1.101815 -1.101815 -0.002742 \n",
+ "\n",
+ " afr__sqrt(area)*geo_lat afr__sqrt(area)*log(level) \\\n",
+ "0 0.373151 0.688076 \n",
+ "1 0.005114 0.071369 \n",
+ "2 -0.425530 -0.335537 \n",
+ "3 -0.723225 0.034116 \n",
+ "4 1.522265 0.862773 \n",
+ "... ... ... \n",
+ "410770 0.206926 0.714499 \n",
+ "410771 -0.297209 -0.551021 \n",
+ "410772 -1.192706 0.306280 \n",
+ "410773 1.274445 -0.002521 \n",
+ "410774 1.338996 0.635065 \n",
+ "\n",
+ " afr__kitchen_area*log(level) afr__sqrt(area)*sqrt(kitchen_area) \\\n",
+ "0 0.044178 -0.153548 \n",
+ "1 -0.173647 -0.267268 \n",
+ "2 -0.239271 -0.454880 \n",
+ "3 -0.129771 -0.254514 \n",
+ "4 0.194490 0.438600 \n",
+ "... ... ... \n",
+ "410770 0.226990 -0.120035 \n",
+ "410771 -0.560144 -0.716150 \n",
+ "410772 0.100868 -0.454880 \n",
+ "410773 0.745507 2.828890 \n",
+ "410774 -0.040302 0.202136 \n",
+ "\n",
+ " afr__rooms*log(level) afr__kitchen_area*rooms \\\n",
+ "0 0.690329 -0.132529 \n",
+ "1 0.282625 -0.132529 \n",
+ "2 -0.512211 -0.487813 \n",
+ "3 0.282625 -0.088119 \n",
+ "4 0.574497 0.045112 \n",
+ "... ... ... \n",
+ "410770 0.792500 -0.037463 \n",
+ "410771 -0.160491 -0.354582 \n",
+ "410772 -0.211576 -0.487813 \n",
+ "410773 -0.160491 1.581436 \n",
+ "410774 1.217910 0.311575 \n",
+ "\n",
+ " afr__sqrt(geo_lon)*sqrt(level) afr__geo_lat*log(geo_lon) \n",
+ "0 0.323880 0.068167 \n",
+ "1 0.552794 1.129118 \n",
+ "2 -0.243092 0.460495 \n",
+ "3 -0.369355 -1.481033 \n",
+ "4 0.945981 1.923382 \n",
+ "... ... ... \n",
+ "410770 0.974381 0.789871 \n",
+ "410771 -0.406361 0.574534 \n",
+ "410772 0.170166 -1.663294 \n",
+ "410773 -0.843150 -1.052343 \n",
+ "410774 -0.415359 -0.246790 \n",
+ "\n",
+ "[410775 rows x 12 columns]"
+ ]
+ },
+ "execution_count": 297,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_train_afr_rfe = pd.DataFrame(X_train_rfe, columns=rfe_selector.get_feature_names_out())\n",
+ "X_train_afr_rfe"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rfe_pipeline = Pipeline(steps=[\n",
+ " ('preprocessor', preprocessor_afr), \n",
+ " ('rfe_extractor', RFE(estimator=regressor, n_features_to_select=12, step = 0.2)),\n",
+ " ('model', regressor)\n",
+ "])\n",
+ "\n",
+ "rfe_pipeline.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 301,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'mae': 1431925.3203264712,\n",
+ " 'mape': 1.239752923791043e+18,\n",
+ " 'mse': 261947924998018.2}"
+ ]
+ },
+ "execution_count": 301,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "predictions_rfe = rfe_pipeline.predict(X_test)\n",
+ "\n",
+ "metrics = {}\n",
+ "metrics[\"mae\"] = mean_absolute_error(y_test, predictions_rfe) \n",
+ "metrics[\"mape\"] = mean_absolute_percentage_error(y_test, predictions_rfe)\n",
+ "metrics[\"mse\"] = mean_squared_error(y_test, predictions_rfe)\n",
+ "\n",
+ "metrics"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 302,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading artifacts: 100%|██████████| 7/7 [00:00<00:00, 40.15it/s]\n",
+ "2024/10/17 14:26:50 INFO mlflow.tracking._tracking_service.client: 🏃 View run rfe_feature_selection at: http://127.0.0.1:5000/#/experiments/1/runs/96f0bbcd6d88466abcf38f3b53f06ff1.\n",
+ "2024/10/17 14:26:50 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.\n"
+ ]
+ }
+ ],
+ "source": [
+ "experiment_id = mlflow.get_experiment_by_name(EXPERIMENT_NAME).experiment_id\n",
+ "RUN_NAME = 'rfe_feature_selection'\n",
+ "\n",
+ "with mlflow.start_run(run_name=RUN_NAME, experiment_id=experiment_id) as run:\n",
+ " # получаем уникальный идентификатор запуска эксперимента\n",
+ " run_id = run.info.run_id \n",
+ " mlflow.sklearn.log_model(rfe_pipeline, \n",
+ " artifact_path=\"models\",\n",
+ " signature=signature,\n",
+ " input_example=input_example,\n",
+ " pip_requirements=req_file\n",
+ " )\n",
+ " mlflow.log_metrics(metrics)\n",
+ " mlflow.log_artifact(art)\n",
+ " mlflow.log_params(model_sklearn.get_params())\n",
+ "\n",
+ "run = mlflow.get_run(run_id) \n",
+ "assert (run.info.status =='FINISHED')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Используем sklearn признаки\n",
+ "Тут мы можем отобрать признаки один раз на обучении, а далее в качестве шага пайплайна использовать написанный класс ColumnExtractor для выбора нуных столбцов"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "rfe_skl_selector = RFE(estimator=regressor, n_features_to_select=12, step = 0.2) #drop 20% of features each iteration\n",
+ "X_train_skl_rfe = rfe_skl_selector.fit_transform(X_train_sklearn,y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 305,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " num__geo_lat | \n",
+ " num__geo_lon | \n",
+ " num__level | \n",
+ " num__rooms | \n",
+ " num__kitchen_area | \n",
+ " cat__region | \n",
+ " quantile__geo_lat | \n",
+ " quantile__geo_lon | \n",
+ " quantile__level | \n",
+ " poly__area kitchen_area | \n",
+ " spline__area_sp_0 | \n",
+ " spline__area_sp_2 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.495902 | \n",
+ " -0.449742 | \n",
+ " 0.359235 | \n",
+ " 0.253413 | \n",
+ " -0.186285 | \n",
+ " 20.0 | \n",
+ " 0.766257 | \n",
+ " 0.511028 | \n",
+ " 0.717217 | \n",
+ " -0.132188 | \n",
+ " 0.155806 | \n",
+ " 0.178013 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.177806 | \n",
+ " 1.433673 | \n",
+ " -0.246529 | \n",
+ " 0.253413 | \n",
+ " -0.186285 | \n",
+ " 70.0 | \n",
+ " 0.297142 | \n",
+ " 0.867999 | \n",
+ " 0.522022 | \n",
+ " -0.169370 | \n",
+ " 0.156921 | \n",
+ " 0.176803 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.440548 | \n",
+ " 0.047222 | \n",
+ " -0.448450 | \n",
+ " -0.669085 | \n",
+ " -0.142544 | \n",
+ " 15.0 | \n",
+ " 0.732330 | \n",
+ " 0.629984 | \n",
+ " 0.417417 | \n",
+ " -0.226261 | \n",
+ " 0.159080 | \n",
+ " 0.174488 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " -1.588818 | \n",
+ " -0.722477 | \n",
+ " -0.246529 | \n",
+ " 0.253413 | \n",
+ " -0.142544 | \n",
+ " 18.0 | \n",
+ " 0.148789 | \n",
+ " 0.295262 | \n",
+ " 0.522022 | \n",
+ " -0.165302 | \n",
+ " 0.157341 | \n",
+ " 0.176349 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1.493662 | \n",
+ " 1.125819 | \n",
+ " 0.157313 | \n",
+ " 0.253413 | \n",
+ " -0.011322 | \n",
+ " 10.0 | \n",
+ " 0.985937 | \n",
+ " 0.758363 | \n",
+ " 0.662663 | \n",
+ " 0.094342 | \n",
+ " 0.152390 | \n",
+ " 0.181792 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 410770 | \n",
+ " 0.592011 | \n",
+ " 0.355014 | \n",
+ " 0.561156 | \n",
+ " 0.253413 | \n",
+ " -0.092653 | \n",
+ " 54.0 | \n",
+ " 0.788393 | \n",
+ " 0.686728 | \n",
+ " 0.771271 | \n",
+ " -0.120841 | \n",
+ " 0.156705 | \n",
+ " 0.177037 | \n",
+ "
\n",
+ " \n",
+ " 410771 | \n",
+ " 0.240478 | \n",
+ " 0.392697 | \n",
+ " -0.650371 | \n",
+ " 0.253413 | \n",
+ " -0.404989 | \n",
+ " 45.0 | \n",
+ " 0.494062 | \n",
+ " 0.717240 | \n",
+ " 0.309810 | \n",
+ " -0.296252 | \n",
+ " 0.158306 | \n",
+ " 0.175314 | \n",
+ "
\n",
+ " \n",
+ " 410772 | \n",
+ " -1.936771 | \n",
+ " -0.688830 | \n",
+ " 0.359235 | \n",
+ " -0.669085 | \n",
+ " -0.142544 | \n",
+ " 18.0 | \n",
+ " 0.131352 | \n",
+ " 0.327613 | \n",
+ " 0.717217 | \n",
+ " -0.226261 | \n",
+ " 0.159080 | \n",
+ " 0.174488 | \n",
+ "
\n",
+ " \n",
+ " 410773 | \n",
+ " -0.748366 | \n",
+ " -0.804077 | \n",
+ " -0.650371 | \n",
+ " 0.253413 | \n",
+ " 1.501833 | \n",
+ " 52.0 | \n",
+ " 0.193143 | \n",
+ " 0.114753 | \n",
+ " 0.309810 | \n",
+ " 1.570163 | \n",
+ " 0.147820 | \n",
+ " 0.187011 | \n",
+ "
\n",
+ " \n",
+ " 410774 | \n",
+ " 1.257769 | \n",
+ " -1.101815 | \n",
+ " -0.044608 | \n",
+ " 1.175911 | \n",
+ " -0.142544 | \n",
+ " 14.0 | \n",
+ " 0.908036 | \n",
+ " 0.075725 | \n",
+ " 0.604605 | \n",
+ " -0.002742 | \n",
+ " 0.152767 | \n",
+ " 0.181370 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
410775 rows × 12 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " num__geo_lat num__geo_lon num__level num__rooms num__kitchen_area \\\n",
+ "0 0.495902 -0.449742 0.359235 0.253413 -0.186285 \n",
+ "1 0.177806 1.433673 -0.246529 0.253413 -0.186285 \n",
+ "2 0.440548 0.047222 -0.448450 -0.669085 -0.142544 \n",
+ "3 -1.588818 -0.722477 -0.246529 0.253413 -0.142544 \n",
+ "4 1.493662 1.125819 0.157313 0.253413 -0.011322 \n",
+ "... ... ... ... ... ... \n",
+ "410770 0.592011 0.355014 0.561156 0.253413 -0.092653 \n",
+ "410771 0.240478 0.392697 -0.650371 0.253413 -0.404989 \n",
+ "410772 -1.936771 -0.688830 0.359235 -0.669085 -0.142544 \n",
+ "410773 -0.748366 -0.804077 -0.650371 0.253413 1.501833 \n",
+ "410774 1.257769 -1.101815 -0.044608 1.175911 -0.142544 \n",
+ "\n",
+ " cat__region quantile__geo_lat quantile__geo_lon quantile__level \\\n",
+ "0 20.0 0.766257 0.511028 0.717217 \n",
+ "1 70.0 0.297142 0.867999 0.522022 \n",
+ "2 15.0 0.732330 0.629984 0.417417 \n",
+ "3 18.0 0.148789 0.295262 0.522022 \n",
+ "4 10.0 0.985937 0.758363 0.662663 \n",
+ "... ... ... ... ... \n",
+ "410770 54.0 0.788393 0.686728 0.771271 \n",
+ "410771 45.0 0.494062 0.717240 0.309810 \n",
+ "410772 18.0 0.131352 0.327613 0.717217 \n",
+ "410773 52.0 0.193143 0.114753 0.309810 \n",
+ "410774 14.0 0.908036 0.075725 0.604605 \n",
+ "\n",
+ " poly__area kitchen_area spline__area_sp_0 spline__area_sp_2 \n",
+ "0 -0.132188 0.155806 0.178013 \n",
+ "1 -0.169370 0.156921 0.176803 \n",
+ "2 -0.226261 0.159080 0.174488 \n",
+ "3 -0.165302 0.157341 0.176349 \n",
+ "4 0.094342 0.152390 0.181792 \n",
+ "... ... ... ... \n",
+ "410770 -0.120841 0.156705 0.177037 \n",
+ "410771 -0.296252 0.158306 0.175314 \n",
+ "410772 -0.226261 0.159080 0.174488 \n",
+ "410773 1.570163 0.147820 0.187011 \n",
+ "410774 -0.002742 0.152767 0.181370 \n",
+ "\n",
+ "[410775 rows x 12 columns]"
+ ]
+ },
+ "execution_count": 305,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_train_skl_rfe = pd.DataFrame(X_train_skl_rfe, columns=rfe_skl_selector.get_feature_names_out())\n",
+ "X_train_skl_rfe"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 306,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['num__geo_lat',\n",
+ " 'num__geo_lon',\n",
+ " 'num__level',\n",
+ " 'num__rooms',\n",
+ " 'num__kitchen_area',\n",
+ " 'cat__region',\n",
+ " 'quantile__geo_lat',\n",
+ " 'quantile__geo_lon',\n",
+ " 'quantile__level',\n",
+ " 'poly__area kitchen_area',\n",
+ " 'spline__area_sp_0',\n",
+ " 'spline__area_sp_2']"
+ ]
+ },
+ "execution_count": 306,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "rfe_cols = X_train_skl_rfe.columns.tolist()\n",
+ "rfe_cols"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 307,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([ True, True, True, False, True, False, True, True, False,\n",
+ " False, True, True, True, False, False, False, False, False,\n",
+ " False, False, False, True, False, True, False, True, False,\n",
+ " False])"
+ ]
+ },
+ "execution_count": 307,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "rfe_idx = rfe_skl_selector.support_\n",
+ "rfe_idx"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 316,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Отбираемые столбцы нужно залогировать, иначе мы потеряем информацию о том, какие призныки выбраны\n",
+ "with open('rfe_skl_idx.txt', 'w+') as f:\n",
+ " f.write(str(rfe_idx))\n",
+ "with open('rfe_skl_cols.txt', 'w+') as f:\n",
+ " f.write(str(rfe_cols))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 309,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class ColumnExtractor(object):\n",
+ "\n",
+ " def __init__(self, cols):\n",
+ " self.cols = cols\n",
+ "\n",
+ " def transform(self, X):\n",
+ " return X[:,self.cols]\n",
+ " \n",
+ " def fit(self, X, y=None):\n",
+ " return self\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rfe_skl_pipeline = Pipeline(steps=[\n",
+ " ('preprocessor', preprocessor_sklearn), \n",
+ " ('rfe_extractor', ColumnExtractor(rfe_idx)),\n",
+ " ('model', regressor)\n",
+ "])\n",
+ "\n",
+ "rfe_skl_pipeline.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 311,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading artifacts: 100%|██████████| 7/7 [00:00<00:00, 193.34it/s]\n",
+ "2024/10/17 14:32:07 INFO mlflow.tracking._tracking_service.client: 🏃 View run rfe_skl_feature_selection at: http://127.0.0.1:5000/#/experiments/1/runs/e55206caeb1549e4aa0d98343d5c1d4d.\n",
+ "2024/10/17 14:32:07 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1.\n"
+ ]
+ }
+ ],
+ "source": [
+ "predictions_rfe_skl = rfe_skl_pipeline.predict(X_test)\n",
+ "\n",
+ "metrics = {}\n",
+ "metrics[\"mae\"] = mean_absolute_error(y_test, predictions_rfe_skl) \n",
+ "metrics[\"mape\"] = mean_absolute_percentage_error(y_test, predictions_rfe_skl)\n",
+ "metrics[\"mse\"] = mean_squared_error(y_test, predictions_rfe_skl)\n",
+ "\n",
+ "metrics\n",
+ "experiment_id = mlflow.get_experiment_by_name(EXPERIMENT_NAME).experiment_id\n",
+ "RUN_NAME = 'rfe_skl_feature_selection'\n",
+ "\n",
+ "with mlflow.start_run(run_name=RUN_NAME, experiment_id=experiment_id) as run:\n",
+ " # получаем уникальный идентификатор запуска эксперимента\n",
+ " run_id = run.info.run_id \n",
+ " mlflow.sklearn.log_model(rfe_pipeline, \n",
+ " artifact_path=\"models\",\n",
+ " signature=signature,\n",
+ " input_example=input_example,\n",
+ " pip_requirements=req_file\n",
+ " )\n",
+ " mlflow.log_metrics(metrics)\n",
+ " mlflow.log_artifact('rfe_skl_cols.txt')\n",
+ " mlflow.log_artifact('rfe_skl_idx.txt')\n",
+ " mlflow.log_params(model_sklearn.get_params())\n",
+ "\n",
+ "run = mlflow.get_run(run_id) \n",
+ "assert (run.info.status =='FINISHED')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## mlextend\n",
+ "https://github.com/rasbt/mlxtend/blob/master/docs/sources/user_guide/feature_selection/SequentialFeatureSelector.ipynb "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 312,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mlxtend.feature_selection import SequentialFeatureSelector \n",
+ "#from sklearn.feature_selection import SequentialFeatureSelector"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sfs = SequentialFeatureSelector(RandomForestRegressor(n_estimators=3), \n",
+ " k_features=3,\n",
+ " forward=True,\n",
+ " floating=False, # True to drop selected features\n",
+ " scoring='neg_mean_absolute_error',\n",
+ " cv=2)\n",
+ "\n",
+ "sfs.fit(X_train_sklearn,y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 314,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " num__geo_lon | \n",
+ " quantile__geo_lat | \n",
+ " spline__area_sp_3 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " -0.449742 | \n",
+ " 0.766257 | \n",
+ " 1.826008e-06 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1.433673 | \n",
+ " 0.297142 | \n",
+ " 1.310449e-06 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.047222 | \n",
+ " 0.732330 | \n",
+ " 6.098363e-07 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " -0.722477 | \n",
+ " 0.148789 | \n",
+ " 1.144942e-06 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1.125819 | \n",
+ " 0.985937 | \n",
+ " 4.240047e-06 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 410770 | \n",
+ " 0.355014 | \n",
+ " 0.788393 | \n",
+ " 1.401454e-06 | \n",
+ "
\n",
+ " \n",
+ " 410771 | \n",
+ " 0.392697 | \n",
+ " 0.494062 | \n",
+ " 8.202272e-07 | \n",
+ "
\n",
+ " \n",
+ " 410772 | \n",
+ " -0.688830 | \n",
+ " 0.131352 | \n",
+ " 6.098363e-07 | \n",
+ "
\n",
+ " \n",
+ " 410773 | \n",
+ " -0.804077 | \n",
+ " 0.193143 | \n",
+ " 1.004843e-05 | \n",
+ "
\n",
+ " \n",
+ " 410774 | \n",
+ " -1.101815 | \n",
+ " 0.908036 | \n",
+ " 3.903343e-06 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
410775 rows × 3 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " num__geo_lon quantile__geo_lat spline__area_sp_3\n",
+ "0 -0.449742 0.766257 1.826008e-06\n",
+ "1 1.433673 0.297142 1.310449e-06\n",
+ "2 0.047222 0.732330 6.098363e-07\n",
+ "3 -0.722477 0.148789 1.144942e-06\n",
+ "4 1.125819 0.985937 4.240047e-06\n",
+ "... ... ... ...\n",
+ "410770 0.355014 0.788393 1.401454e-06\n",
+ "410771 0.392697 0.494062 8.202272e-07\n",
+ "410772 -0.688830 0.131352 6.098363e-07\n",
+ "410773 -0.804077 0.193143 1.004843e-05\n",
+ "410774 -1.101815 0.908036 3.903343e-06\n",
+ "\n",
+ "[410775 rows x 3 columns]"
+ ]
+ },
+ "execution_count": 314,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "selected_features_sfs = X_train_sklearn.loc[:, sfs.k_feature_names_]\n",
+ "selected_features_sfs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 315,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['num__geo_lon', 'quantile__geo_lat', 'spline__area_sp_3']"
+ ]
+ },
+ "execution_count": 315,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "rfe_sfs_idx = list(sfs.k_feature_idx_)\n",
+ "rfe_sfs_idx\n",
+ "rfe_sfs_col = list(sfs.k_feature_names_)\n",
+ "rfe_sfs_col"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 317,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "from mlxtend.plotting import plot_sequential_feature_selection as plot_sfs\n",
+ "\n",
+ "fig = plot_sfs(sfs.get_metric_dict(), kind='std_dev')\n",
+ "\n",
+ "plt.title('Sequential Forward Selection (w. StdDev)')\n",
+ "plt.grid()\n",
+ "plt.show()\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rfe_sfs_pipeline = Pipeline(steps=[\n",
+ " ('preprocessor', preprocessor_sklearn), \n",
+ " ('rfe_extractor', ColumnExtractor(rfe_sfs_idx)),\n",
+ " ('model', regressor)\n",
+ "])\n",
+ "\n",
+ "rfe_sfs_pipeline.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "predictions_sfs = rfe_sfs_pipeline.predict(X_test)\n",
+ "\n",
+ "metrics = {}\n",
+ "metrics[\"mae\"] = mean_absolute_error(y_test, predictions_sfs) \n",
+ "metrics[\"mape\"] = mean_absolute_percentage_error(y_test, predictions_sfs)\n",
+ "metrics[\"mse\"] = mean_squared_error(y_test, predictions_sfs)\n",
+ "\n",
+ "metrics\n",
+ "experiment_id = mlflow.get_experiment_by_name(EXPERIMENT_NAME).experiment_id\n",
+ "RUN_NAME = 'rfe_sfs_feature_selection'\n",
+ "\n",
+ "with mlflow.start_run(run_name=RUN_NAME, experiment_id=experiment_id) as run:\n",
+ " # получаем уникальный идентификатор запуска эксперимента\n",
+ " run_id = run.info.run_id \n",
+ " mlflow.sklearn.log_model(rfe_sfs_pipeline, \n",
+ " artifact_path=\"models\",\n",
+ " signature=signature,\n",
+ " input_example=input_example,\n",
+ " pip_requirements=req_file\n",
+ " )\n",
+ " mlflow.log_metrics(metrics)\n",
+ " mlflow.log_artifact('rfe_skl_cols.txt')\n",
+ " mlflow.log_artifact('rfe_skl_idx.txt')\n",
+ " mlflow.log_params(model_sklearn.get_params())\n",
+ "\n",
+ "run = mlflow.get_run(run_id) \n",
+ "assert (run.info.status =='FINISHED')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Можно совмещать признаки, выбранные по sfs и sbs: брать их объединение или пересечение. Можно комбинировать с признаками, выделенными разными подходами - целое поле для исследований"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# HYPERPARAMS\n",
+ "## Gridsearch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 224,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import GridSearchCV"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "param_grid = {\n",
+ " 'model__depth': [1,3,5]\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gs = GridSearchCV(rfe_sfs_pipeline, param_grid, cv=2, scoring='neg_mean_absolute_error')\n",
+ "gs.fit(X_train, y_train)\n",
+ "print(\"Лучшие гиперпараметры:\", gs.best_params_)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gs_pipeline = Pipeline(steps=[\n",
+ " ('preprocessor', preprocessor_sklearn), \n",
+ " ('rfe_extractor', ColumnExtractor(rfe_sfs_idx)),\n",
+ " ('model', CatBoostRegressor(depth=5))\n",
+ "])\n",
+ "\n",
+ "# Проведем стандартную проверку на тестовом множестве и залогируем run"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Вместо GridSearch можно использовать RandomSearch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Optuna"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 292,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import optuna"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def objective(trial):\n",
+ " # предлагаем гиперпараметры\n",
+ " depth = trial.suggest_int('depth', 1, 10)\n",
+ " learning_rate = trial.suggest_float('learning_rate', 0.001, 0.1)\n",
+ "\n",
+ " # создаём и обучаем модель\n",
+ " opt_pipeline = Pipeline(steps=[\n",
+ " ('preprocessor', preprocessor_sklearn), \n",
+ " ('rfe_extractor', ColumnExtractor(rfe_sfs_idx)),\n",
+ " ('model', CatBoostRegressor(depth=depth, learning_rate=learning_rate, verbose=0))\n",
+ " ])\n",
+ "\n",
+ " opt_pipeline.fit(X_train, y_train)\n",
+ "\n",
+ " # предсказываем и вычисляем RMSE\n",
+ " preds = opt_pipeline.predict(X_test)\n",
+ " mae = mean_absolute_error(y_test, preds) \n",
+ "\n",
+ " return mae"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "study = optuna.create_study(direction='minimize')\n",
+ "study.optimize(objective, n_trials=10)\n",
+ "\n",
+ "# выводим результаты\n",
+ "print('Number of finished trials:', len(study.trials))\n",
+ "print('Best trial:', study.best_trial.params) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "opt_pipeline = Pipeline(steps=[\n",
+ " ('preprocessor', preprocessor_sklearn), \n",
+ " ('rfe_extractor', ColumnExtractor(rfe_sfs_idx)),\n",
+ " ('model', CatBoostRegressor(depth=3, learning_rate=0.02789))\n",
+ "])\n",
+ "\n",
+ "# Проведем стандартную проверку на тестовом множестве и залогируем run"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Выбираем лучшую модель.\n",
+ "Обучаем ее на всей выборке (а не только на train-части). \n",
+ "Далее будем деплоить именно её"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,