рефакторинг блокнота research в части логирования в mlflow

lab_2/master
syropiatovvv 3 дней назад
Родитель 070688dc68
Сommit 6038a1c566
Подписано: syropiatovvv
Идентификатор GPG ключа: 297380B8143A31BD

@ -0,0 +1,60 @@
from collections.abc import Container, Sequence, Mapping
from typing import TypeAlias, TypeVar
ParamsFilterSpec: TypeAlias = (
bool
| Container[str]
| tuple[bool, Container[str]]
| Mapping[str, 'ParamsFilterSpec']
| tuple[bool, 'ParamsFilterSpec']
)
V = TypeVar('V')
def _split_param_key(key: str) -> tuple[str, ...]:
return tuple(key.split('__'))
def _match_key_to_filter_spec(
key: Sequence[str], spec: ParamsFilterSpec, empty_default: bool,
) -> bool:
if isinstance(spec, Sequence) and (len(spec) == 2) and isinstance(spec[0], bool):
if (len(key) == 0) and (not spec[0]):
return empty_default
spec = spec[1]
if isinstance(spec, Mapping):
if len(key) == 0:
return empty_default
spec_nested = spec.get(key[0])
if spec_nested is None:
return False
return _whether_to_include_param(key[1:], spec_nested)
elif isinstance(spec, Container):
if len(key) == 0:
return True
return (key[0] in spec)
return bool(spec)
def _whether_to_include_param(
key: Sequence[str], include: ParamsFilterSpec = True, exclude: ParamsFilterSpec = False,
) -> bool:
return (
(not _match_key_to_filter_spec(key, exclude, empty_default=False))
and _match_key_to_filter_spec(key, include, empty_default=True)
)
def filter_params(
params: Mapping[str, V],
include: ParamsFilterSpec = True,
exclude: ParamsFilterSpec = False,
) -> Mapping[str, V]:
return {
k: v
for k, v in params.items()
if _whether_to_include_param(_split_param_key(k), include, exclude)
}

@ -0,0 +1,3 @@
COLUMN_TRANSFORMER_PARAMS_COMMON_INCLUDE = [
'remainder', 'sparse_threshold', 'transformer_weights',
]

@ -0,0 +1 @@
RANDOM_FOREST_REGRESSOR_PARAMS_COMMON_EXCLUDE = ['n_jobs', 'verbose', 'warm_start']

@ -0,0 +1 @@
PIPELINE_PARAMS_COMMON_INCLUDE = ['transform_input']

@ -0,0 +1 @@
STANDARD_SCALER_PARAMS_COMMON_EXCLUDE = ['copy']

@ -51,6 +51,7 @@ mlflow_run_name: str = 'Baseline model'
import os
import pathlib
import pickle
import sys
# %%
import mlflow
@ -66,6 +67,17 @@ import sklearn.preprocessing
# %%
BASE_PATH = pathlib.Path('..')
# %%
CODE_PATH = BASE_PATH
sys.path.insert(0, str(CODE_PATH.resolve()))
# %%
from iis_project.sklearn_utils import filter_params
from iis_project.sklearn_utils.compose import COLUMN_TRANSFORMER_PARAMS_COMMON_INCLUDE
from iis_project.sklearn_utils.ensemble import RANDOM_FOREST_REGRESSOR_PARAMS_COMMON_EXCLUDE
from iis_project.sklearn_utils.pipeline import PIPELINE_PARAMS_COMMON_INCLUDE
from iis_project.sklearn_utils.preprocessing import STANDARD_SCALER_PARAMS_COMMON_EXCLUDE
# %%
MODEL_INOUT_EXAMPLE_SIZE = 0x10
@ -196,6 +208,22 @@ tuple(map(len, (df_target_train, df_target_test)))
mlflow_model_signature = mlflow.models.infer_signature(model_input=df_orig_features, model_output=df_target)
mlflow_model_signature
# %% [raw] vscode={"languageId": "raw"}
# input_schema = mlflow.types.schema.Schema([
# mlflow.types.schema.ColSpec("double", "selling_price"),
# mlflow.types.schema.ColSpec("double", "driven_kms"),
# mlflow.types.schema.ColSpec("string", "fuel_type"),
# mlflow.types.schema.ColSpec("string", "selling_type"),
# mlflow.types.schema.ColSpec("string", "transmission"),
# mlflow.types.schema.ColSpec("double", "age"),
# ])
#
# output_schema = mlflow.types.schema.Schema([
# mlflow.types.schema.ColSpec("double", "present_price"),
# ])
#
# mlflow_model_signature = mlflow.models.ModelSignature(inputs=input_schema, outputs=output_schema)
# %% [markdown]
# Пайплайн предобработки признаков:
@ -235,7 +263,25 @@ pipeline = sklearn.pipeline.Pipeline([
pipeline
# %%
model_params = pipeline.get_params()
model_params = filter_params(
pipeline.get_params(),
include={
**{k: True for k in PIPELINE_PARAMS_COMMON_INCLUDE},
'preprocess': (
False,
{
**{k: True for k in COLUMN_TRANSFORMER_PARAMS_COMMON_INCLUDE},
'scale_to_standard': True,
'encode_categorical_wrt_target': True,
},
),
'regress': (False, True),
},
exclude={
'preprocess': {'scale_to_standard': STANDARD_SCALER_PARAMS_COMMON_EXCLUDE},
'regress': RANDOM_FOREST_REGRESSOR_PARAMS_COMMON_EXCLUDE,
},
)
model_params
# %% [markdown]

Загрузка…
Отмена
Сохранить