Родитель
7a3ba02966
Сommit
7bb2455d4c
@ -0,0 +1,3 @@
|
||||
### Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
@ -0,0 +1,20 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /
|
||||
|
||||
COPY . /service
|
||||
|
||||
RUN pip install --no-cache-dir -r /service/requirements.txt
|
||||
|
||||
VOLUME /models
|
||||
|
||||
EXPOSE 8000/tcp
|
||||
|
||||
ENV MODELS_PATH=/models
|
||||
|
||||
WORKDIR /service
|
||||
|
||||
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
# docker build -t ml_service services/ml_service/
|
||||
# docker run -v "$(pwd)/services/models:/models" -p 8000:8000 ml_service
|
||||
@ -0,0 +1,2 @@
|
||||
from ._meta import PACKAGE_PATH
|
||||
from .main import app
|
||||
@ -0,0 +1,4 @@
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
PACKAGE_PATH = Path(__file__).parent
|
||||
@ -0,0 +1,64 @@
|
||||
from os import getenv
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ._meta import PACKAGE_PATH
|
||||
from .predictor import (
|
||||
FuelType, SellingType, TransmissionType, PricePredictionFeatures, PricePredictor,
|
||||
)
|
||||
|
||||
|
||||
MODELS_PATH = getenv('MODELS_PATH', None)
|
||||
if MODELS_PATH is not None:
|
||||
MODELS_PATH = Path(MODELS_PATH)
|
||||
else:
|
||||
SERVICES_PATH = PACKAGE_PATH.parents[1]
|
||||
assert SERVICES_PATH.name == 'services'
|
||||
MODELS_PATH = SERVICES_PATH / 'models'
|
||||
|
||||
MODEL_PATH = MODELS_PATH / 'model.pkl'
|
||||
|
||||
|
||||
predictor = PricePredictor(MODEL_PATH)
|
||||
|
||||
|
||||
API_BASE_PATH = '/api'
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title='Сервис ML',
|
||||
version='0.1.0',
|
||||
root_path=API_BASE_PATH,
|
||||
#redoc_url=None,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/', summary='Тестовый эндпоинт')
|
||||
async def root():
|
||||
return {'Hello': 'World'}
|
||||
|
||||
|
||||
class PricePredictionRequest(BaseModel):
|
||||
|
||||
selling_price: float = Field(..., gt=0)
|
||||
driven_kms: float = Field(..., ge=0)
|
||||
age: float = Field(..., ge=0)
|
||||
fuel_type: FuelType
|
||||
selling_type: SellingType
|
||||
transmission_type: TransmissionType
|
||||
|
||||
|
||||
@app.post('/predict', summary='Предсказать цену подержанного автомобиля')
|
||||
def predict_price(item_id: int, req: PricePredictionRequest):
|
||||
features = PricePredictionFeatures(
|
||||
selling_price=req.selling_price,
|
||||
driven_kms=req.driven_kms,
|
||||
age=req.age,
|
||||
fuel_type=req.fuel_type,
|
||||
selling_type=req.selling_type,
|
||||
transmission_type=req.transmission_type,
|
||||
)
|
||||
pred = predictor.predict(features)
|
||||
return {'item_id': item_id, 'price': pred}
|
||||
@ -0,0 +1,81 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pandas import DataFrame
|
||||
from pickle import load
|
||||
|
||||
|
||||
def open_model_file(file, *, buffering=-1, opener=None, **kwargs_extra):
|
||||
open_kwargs_extra = {}
|
||||
if 'closefd' in kwargs_extra:
|
||||
open_kwargs_extra['closefd'] = kwargs_extra.pop('closefd')
|
||||
if len(kwargs_extra) > 0:
|
||||
raise TypeError(
|
||||
'Unexpected keyword arguments given: {}'
|
||||
.format(', '.join(map(repr, kwargs_extra.keys())))
|
||||
)
|
||||
return open(file, 'rb', buffering=buffering, opener=opener)
|
||||
|
||||
|
||||
def load_model_from_file(file):
|
||||
return load(file)
|
||||
|
||||
|
||||
def load_model_from_path(path, *, buffering=-1, opener=None, **kwargs_extra):
|
||||
open_kwargs_extra = {}
|
||||
for k in ('closefd',):
|
||||
if k in kwargs_extra:
|
||||
open_kwargs_extra[k] = kwargs_extra.pop(k)
|
||||
if len(kwargs_extra) > 0:
|
||||
raise TypeError(
|
||||
'Unexpected keyword arguments given: {}'.format(', '.join(kwargs_extra.keys()))
|
||||
)
|
||||
with open_model_file(
|
||||
path, buffering=buffering, opener=opener, **open_kwargs_extra,
|
||||
) as model_file:
|
||||
return load_model_from_file(model_file)
|
||||
|
||||
|
||||
class FuelType(Enum):
|
||||
PETROL = 'petrol'
|
||||
DIESEL = 'diesel'
|
||||
CNG = 'cng'
|
||||
|
||||
|
||||
class SellingType(Enum):
|
||||
DEALER = 'dealer'
|
||||
INDIVIDUAL = 'individual'
|
||||
|
||||
|
||||
class TransmissionType(Enum):
|
||||
MANUAL = 'manual'
|
||||
AUTOMATIC = 'automatic'
|
||||
|
||||
|
||||
@dataclass
|
||||
class PricePredictionFeatures:
|
||||
selling_price: float
|
||||
driven_kms: float
|
||||
age: float
|
||||
fuel_type: FuelType
|
||||
selling_type: SellingType
|
||||
transmission_type: TransmissionType
|
||||
|
||||
|
||||
class PricePredictor:
|
||||
|
||||
def __init__(self, model_path):
|
||||
self._model = load_model_from_path(model_path)
|
||||
|
||||
def predict(self, features):
|
||||
# WARN: порядок столбцов вроде имеет значение
|
||||
features_df = DataFrame([{
|
||||
'selling_price': features.selling_price,
|
||||
'driven_kms': features.driven_kms,
|
||||
'fuel_type': features.fuel_type.value,
|
||||
'selling_type': features.selling_type.value,
|
||||
'transmission': features.transmission_type.value,
|
||||
'age': features.age,
|
||||
}])
|
||||
predictions = self._model.predict(features_df)
|
||||
assert len(predictions) == 1
|
||||
return float(predictions[0])
|
||||
@ -0,0 +1,5 @@
|
||||
fastapi ~=0.120.4
|
||||
mlxtend ~=0.23.4
|
||||
pandas >=2.3.1,<3
|
||||
scikit-learn >=1.7.2,<2
|
||||
uvicorn ~=0.38.0
|
||||
@ -0,0 +1 @@
|
||||
*.pkl
|
||||
@ -0,0 +1,83 @@
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from pickle import dump
|
||||
from sys import exit as sys_exit, argv as sys_argv
|
||||
|
||||
from mlflow import set_tracking_uri, set_registry_uri
|
||||
from mlflow.sklearn import load_model
|
||||
|
||||
|
||||
MLFLOW_TRACKING_URI_DEFAULT = 'http://localhost:5000'
|
||||
|
||||
|
||||
def open_file_for_model(file, *, buffering=-1, opener=None, **kwargs_extra):
|
||||
open_kwargs_extra = {}
|
||||
if 'closefd' in kwargs_extra:
|
||||
open_kwargs_extra['closefd'] = kwargs_extra.pop('closefd')
|
||||
if len(kwargs_extra) > 0:
|
||||
raise TypeError(
|
||||
'Unexpected keyword arguments given: {}'
|
||||
.format(', '.join(map(repr, kwargs_extra.keys())))
|
||||
)
|
||||
return open(file, 'wb', buffering=buffering, opener=opener)
|
||||
|
||||
|
||||
def dump_model_to_file(model, file):
|
||||
return dump(model, file)
|
||||
|
||||
|
||||
def dump_model_to_path(model, path, *, buffering=-1, opener=None, **kwargs_extra):
|
||||
open_kwargs_extra = {}
|
||||
for k in ('closefd',):
|
||||
if k in kwargs_extra:
|
||||
open_kwargs_extra[k] = kwargs_extra.pop(k)
|
||||
if len(kwargs_extra) > 0:
|
||||
raise TypeError(
|
||||
'Unexpected keyword arguments given: {}'
|
||||
.format(', '.join(map(repr, kwargs_extra.keys())))
|
||||
)
|
||||
with open_file_for_model(
|
||||
path, buffering=buffering, opener=opener, **open_kwargs_extra,
|
||||
) as model_file:
|
||||
return dump_model_to_file(model, model_file)
|
||||
|
||||
|
||||
def parse_args(argv):
|
||||
parser = ArgumentParser(
|
||||
description=(
|
||||
'Скачать модель с tracking server MLFlow и сохранить в локальный файл pickle'
|
||||
),
|
||||
allow_abbrev=False,
|
||||
exit_on_error=True,
|
||||
)
|
||||
model_ref_parser = parser.add_mutually_exclusive_group(required=True)
|
||||
model_ref_parser.add_argument('-m', '--model', type=str, dest='model_uri')
|
||||
model_ref_parser.add_argument('--run', type=str, dest='run_id')
|
||||
parser.add_argument(
|
||||
'--tracking-uri', default=MLFLOW_TRACKING_URI_DEFAULT, type=str, dest='tracking_uri',
|
||||
)
|
||||
parser.add_argument('--registry-uri', type=str, dest='registry_uri')
|
||||
parser.add_argument('out_path', default=Path('.'), type=Path)
|
||||
args = parser.parse_args(argv)
|
||||
return args
|
||||
|
||||
|
||||
def main(argv):
|
||||
args = parse_args(argv)
|
||||
set_tracking_uri(args.tracking_uri)
|
||||
if args.registry_uri is not None:
|
||||
set_registry_uri(args.registry_uri)
|
||||
if args.model_uri is not None:
|
||||
model_uri = args.model_uri
|
||||
elif args.run_id is not None:
|
||||
model_uri = f'runs:/{args.run_id}/model'
|
||||
else:
|
||||
assert False
|
||||
return 1
|
||||
model = load_model(model_uri)
|
||||
dump_model_to_path(model, args.out_path)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys_exit(int(main(sys_argv) or 0))
|
||||
Загрузка…
Ссылка в новой задаче