Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.
iis-project/services/models/fetch_model_as_pickle_from_...

84 строки
2.7 KiB
Python

from argparse import ArgumentParser
from pathlib import Path
from pickle import dump
from sys import exit as sys_exit, argv as sys_argv
from mlflow import set_tracking_uri, set_registry_uri
from mlflow.sklearn import load_model
MLFLOW_TRACKING_URI_DEFAULT = 'http://localhost:5000'
def open_file_for_model(file, *, buffering=-1, opener=None, **kwargs_extra):
open_kwargs_extra = {}
if 'closefd' in kwargs_extra:
open_kwargs_extra['closefd'] = kwargs_extra.pop('closefd')
if len(kwargs_extra) > 0:
raise TypeError(
'Unexpected keyword arguments given: {}'
.format(', '.join(map(repr, kwargs_extra.keys())))
)
return open(file, 'wb', buffering=buffering, opener=opener)
def dump_model_to_file(model, file):
return dump(model, file)
def dump_model_to_path(model, path, *, buffering=-1, opener=None, **kwargs_extra):
open_kwargs_extra = {}
for k in ('closefd',):
if k in kwargs_extra:
open_kwargs_extra[k] = kwargs_extra.pop(k)
if len(kwargs_extra) > 0:
raise TypeError(
'Unexpected keyword arguments given: {}'
.format(', '.join(map(repr, kwargs_extra.keys())))
)
with open_file_for_model(
path, buffering=buffering, opener=opener, **open_kwargs_extra,
) as model_file:
return dump_model_to_file(model, model_file)
def parse_args(argv):
parser = ArgumentParser(
description=(
'Скачать модель с tracking server MLFlow и сохранить в локальный файл pickle'
),
allow_abbrev=False,
exit_on_error=True,
)
model_ref_parser = parser.add_mutually_exclusive_group(required=True)
model_ref_parser.add_argument('-m', '--model', type=str, dest='model_uri')
model_ref_parser.add_argument('--run', type=str, dest='run_id')
parser.add_argument(
'--tracking-uri', default=MLFLOW_TRACKING_URI_DEFAULT, type=str, dest='tracking_uri',
)
parser.add_argument('--registry-uri', type=str, dest='registry_uri')
parser.add_argument('out_path', default=Path('.'), type=Path)
args = parser.parse_args(argv)
return args
def main(argv):
args = parse_args(argv)
set_tracking_uri(args.tracking_uri)
if args.registry_uri is not None:
set_registry_uri(args.registry_uri)
if args.model_uri is not None:
model_uri = args.model_uri
elif args.run_id is not None:
model_uri = f'runs:/{args.run_id}/model'
else:
assert False
return 1
model = load_model(model_uri)
dump_model_to_path(model, args.out_path)
return 0
if __name__ == '__main__':
sys_exit(int(main(sys_argv) or 0))