from argparse import ArgumentParser from pathlib import Path from pickle import dump from sys import exit as sys_exit, argv as sys_argv from mlflow import set_tracking_uri, set_registry_uri from mlflow.sklearn import load_model MLFLOW_TRACKING_URI_DEFAULT = 'http://localhost:5000' def open_file_for_model(file, *, buffering=-1, opener=None, **kwargs_extra): open_kwargs_extra = {} if 'closefd' in kwargs_extra: open_kwargs_extra['closefd'] = kwargs_extra.pop('closefd') if len(kwargs_extra) > 0: raise TypeError( 'Unexpected keyword arguments given: {}' .format(', '.join(map(repr, kwargs_extra.keys()))) ) return open(file, 'wb', buffering=buffering, opener=opener) def dump_model_to_file(model, file): return dump(model, file) def dump_model_to_path(model, path, *, buffering=-1, opener=None, **kwargs_extra): open_kwargs_extra = {} for k in ('closefd',): if k in kwargs_extra: open_kwargs_extra[k] = kwargs_extra.pop(k) if len(kwargs_extra) > 0: raise TypeError( 'Unexpected keyword arguments given: {}' .format(', '.join(map(repr, kwargs_extra.keys()))) ) with open_file_for_model( path, buffering=buffering, opener=opener, **open_kwargs_extra, ) as model_file: return dump_model_to_file(model, model_file) def parse_args(argv): parser = ArgumentParser( description=( 'Скачать модель с tracking server MLFlow и сохранить в локальный файл pickle' ), allow_abbrev=False, exit_on_error=True, ) model_ref_parser = parser.add_mutually_exclusive_group(required=True) model_ref_parser.add_argument('-m', '--model', type=str, dest='model_uri') model_ref_parser.add_argument('--run', type=str, dest='run_id') parser.add_argument( '--tracking-uri', default=MLFLOW_TRACKING_URI_DEFAULT, type=str, dest='tracking_uri', ) parser.add_argument('--registry-uri', type=str, dest='registry_uri') parser.add_argument('out_path', default=Path('.'), type=Path) args = parser.parse_args(argv) return args def main(argv): args = parse_args(argv) set_tracking_uri(args.tracking_uri) if args.registry_uri is not None: set_registry_uri(args.registry_uri) if args.model_uri is not None: model_uri = args.model_uri elif args.run_id is not None: model_uri = f'runs:/{args.run_id}/model' else: assert False return 1 model = load_model(model_uri) dump_model_to_path(model, args.out_path) return 0 if __name__ == '__main__': sys_exit(int(main(sys_argv) or 0))