Вы не можете выбрать более 25 тем
Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.
84 строки
2.7 KiB
Python
84 строки
2.7 KiB
Python
from argparse import ArgumentParser
|
|
from pathlib import Path
|
|
from pickle import dump
|
|
from sys import exit as sys_exit, argv as sys_argv
|
|
|
|
from mlflow import set_tracking_uri, set_registry_uri
|
|
from mlflow.sklearn import load_model
|
|
|
|
|
|
MLFLOW_TRACKING_URI_DEFAULT = 'http://localhost:5000'
|
|
|
|
|
|
def open_file_for_model(file, *, buffering=-1, opener=None, **kwargs_extra):
|
|
open_kwargs_extra = {}
|
|
if 'closefd' in kwargs_extra:
|
|
open_kwargs_extra['closefd'] = kwargs_extra.pop('closefd')
|
|
if len(kwargs_extra) > 0:
|
|
raise TypeError(
|
|
'Unexpected keyword arguments given: {}'
|
|
.format(', '.join(map(repr, kwargs_extra.keys())))
|
|
)
|
|
return open(file, 'wb', buffering=buffering, opener=opener)
|
|
|
|
|
|
def dump_model_to_file(model, file):
|
|
return dump(model, file)
|
|
|
|
|
|
def dump_model_to_path(model, path, *, buffering=-1, opener=None, **kwargs_extra):
|
|
open_kwargs_extra = {}
|
|
for k in ('closefd',):
|
|
if k in kwargs_extra:
|
|
open_kwargs_extra[k] = kwargs_extra.pop(k)
|
|
if len(kwargs_extra) > 0:
|
|
raise TypeError(
|
|
'Unexpected keyword arguments given: {}'
|
|
.format(', '.join(map(repr, kwargs_extra.keys())))
|
|
)
|
|
with open_file_for_model(
|
|
path, buffering=buffering, opener=opener, **open_kwargs_extra,
|
|
) as model_file:
|
|
return dump_model_to_file(model, model_file)
|
|
|
|
|
|
def parse_args(argv):
|
|
parser = ArgumentParser(
|
|
description=(
|
|
'Скачать модель с tracking server MLFlow и сохранить в локальный файл pickle'
|
|
),
|
|
allow_abbrev=False,
|
|
exit_on_error=True,
|
|
)
|
|
model_ref_parser = parser.add_mutually_exclusive_group(required=True)
|
|
model_ref_parser.add_argument('-m', '--model', type=str, dest='model_uri')
|
|
model_ref_parser.add_argument('--run', type=str, dest='run_id')
|
|
parser.add_argument(
|
|
'--tracking-uri', default=MLFLOW_TRACKING_URI_DEFAULT, type=str, dest='tracking_uri',
|
|
)
|
|
parser.add_argument('--registry-uri', type=str, dest='registry_uri')
|
|
parser.add_argument('out_path', default=Path('.'), type=Path)
|
|
args = parser.parse_args(argv)
|
|
return args
|
|
|
|
|
|
def main(argv):
|
|
args = parse_args(argv)
|
|
set_tracking_uri(args.tracking_uri)
|
|
if args.registry_uri is not None:
|
|
set_registry_uri(args.registry_uri)
|
|
if args.model_uri is not None:
|
|
model_uri = args.model_uri
|
|
elif args.run_id is not None:
|
|
model_uri = f'runs:/{args.run_id}/model'
|
|
else:
|
|
assert False
|
|
return 1
|
|
model = load_model(model_uri)
|
|
dump_model_to_path(model, args.out_path)
|
|
return 0
|
|
|
|
|
|
if __name__ == '__main__':
|
|
sys_exit(int(main(sys_argv) or 0))
|