Вы не можете выбрать более 25 тем
Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.
112 строки
3.3 KiB
Python
112 строки
3.3 KiB
Python
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from itertools import chain
|
|
from pandas import DataFrame
|
|
from pickle import load
|
|
from prometheus_client import Counter, Histogram
|
|
|
|
|
|
def open_model_file(file, *, buffering=-1, opener=None, **kwargs_extra):
|
|
open_kwargs_extra = {}
|
|
if 'closefd' in kwargs_extra:
|
|
open_kwargs_extra['closefd'] = kwargs_extra.pop('closefd')
|
|
if len(kwargs_extra) > 0:
|
|
raise TypeError(
|
|
'Unexpected keyword arguments given: {}'
|
|
.format(', '.join(map(repr, kwargs_extra.keys())))
|
|
)
|
|
return open(file, 'rb', buffering=buffering, opener=opener)
|
|
|
|
|
|
def load_model_from_file(file):
|
|
return load(file)
|
|
|
|
|
|
def load_model_from_path(path, *, buffering=-1, opener=None, **kwargs_extra):
|
|
open_kwargs_extra = {}
|
|
for k in ('closefd',):
|
|
if k in kwargs_extra:
|
|
open_kwargs_extra[k] = kwargs_extra.pop(k)
|
|
if len(kwargs_extra) > 0:
|
|
raise TypeError(
|
|
'Unexpected keyword arguments given: {}'.format(', '.join(kwargs_extra.keys()))
|
|
)
|
|
with open_model_file(
|
|
path, buffering=buffering, opener=opener, **open_kwargs_extra,
|
|
) as model_file:
|
|
return load_model_from_file(model_file)
|
|
|
|
|
|
class FuelType(Enum):
|
|
PETROL = 'petrol'
|
|
DIESEL = 'diesel'
|
|
CNG = 'cng'
|
|
|
|
|
|
class SellingType(Enum):
|
|
DEALER = 'dealer'
|
|
INDIVIDUAL = 'individual'
|
|
|
|
|
|
class TransmissionType(Enum):
|
|
MANUAL = 'manual'
|
|
AUTOMATIC = 'automatic'
|
|
|
|
|
|
@dataclass
|
|
class PricePredictionFeatures:
|
|
selling_price: float
|
|
driven_kms: float
|
|
age: float
|
|
fuel_type: FuelType
|
|
selling_type: SellingType
|
|
transmission_type: TransmissionType
|
|
|
|
|
|
metric_prediction_latency = Histogram(
|
|
'model_prediction_seconds', 'Время вычислений в модели',
|
|
buckets=(
|
|
list(chain.from_iterable((v * (10 ** p) for v in (1, 2, 5)) for p in range(-4, (1 + 1))))
|
|
+ [float('+inf')]
|
|
),
|
|
)
|
|
|
|
metric_prediction_errors = Counter(
|
|
'model_prediction_errors_total', 'Ошибки вычислений в модели по типу', ('error_type',),
|
|
)
|
|
|
|
metric_prediction_value = Histogram(
|
|
'model_prediction_value', 'Предсказанное значение цены',
|
|
buckets=(
|
|
list(chain.from_iterable((v * (10 ** p) for v in (1, 2, 5)) for p in range(-1, (2 + 1))))
|
|
+ [float('+inf')]
|
|
),
|
|
)
|
|
|
|
|
|
class PricePredictor:
|
|
|
|
def __init__(self, model_path):
|
|
self._model = load_model_from_path(model_path)
|
|
|
|
def predict(self, features):
|
|
# WARN: порядок столбцов вроде имеет значение
|
|
features_df = DataFrame([{
|
|
'selling_price': features.selling_price,
|
|
'driven_kms': features.driven_kms,
|
|
'fuel_type': features.fuel_type.value,
|
|
'selling_type': features.selling_type.value,
|
|
'transmission': features.transmission_type.value,
|
|
'age': features.age,
|
|
}])
|
|
try:
|
|
with metric_prediction_latency.time():
|
|
predictions = self._model.predict(features_df)
|
|
except Exception as err:
|
|
metric_prediction_errors.labels(error_type=type(err).__name__).inc()
|
|
raise
|
|
assert len(predictions) == 1
|
|
value = float(predictions[0])
|
|
metric_prediction_value.observe(value)
|
|
return value
|