from argparse import ArgumentParser from collections.abc import Callable, MutableMapping from dataclasses import dataclass, asdict from enum import Enum import logging from os import getenv from random import randint, uniform, expovariate, choice from signal import SIGINT, SIGTERM, signal import sys from time import sleep from types import FrameType from typing import Any, cast from requests import RequestException, Response, Session def fixup_payload_enum_value(mapping: MutableMapping[str, Any], key: str) -> None: mapping[key] = mapping[key].value ENDPOINT_URL: str = '/predict' class FuelType(Enum): PETROL = 'petrol' DIESEL = 'diesel' CNG = 'cng' class SellingType(Enum): DEALER = 'dealer' INDIVIDUAL = 'individual' class TransmissionType(Enum): MANUAL = 'manual' AUTOMATIC = 'automatic' @dataclass class PricePredictionFeatures: selling_price: float driven_kms: float age: float fuel_type: FuelType selling_type: SellingType transmission_type: TransmissionType MAX_RETRIES_DEFAULT = 3 def exp_delay_from_attempt_number(attempt_i: int) -> float: return 0.2 * (2 ** attempt_i) def post_item( session: Session, url: str, item_id: int, features: PricePredictionFeatures, *, max_retries: int = MAX_RETRIES_DEFAULT, ) -> Response: if max_retries < 0: raise ValueError('max_retries must be >= 0') payload = asdict(features) for k in ('fuel_type', 'selling_type', 'transmission_type'): fixup_payload_enum_value(payload, k) excs = [] for attempt_i in range(max_retries + 1): try: response = session.post(url, params={'item_id': item_id}, json=payload, timeout=10) except RequestException as err: excs.append(err) sleep(exp_delay_from_attempt_number(attempt_i)) else: return response assert len(excs) > 0 # XXX: ... raise IOError( f'Failed to post an item in {max_retries + 1} attempts;' ' see the latest exception in __cause__' ) from excs[-1] def generate_request_data() -> tuple[int, PricePredictionFeatures]: item_id = randint(1, 100) features = PricePredictionFeatures( selling_price=round(uniform(2.0, 16.0), 2), driven_kms=round(uniform(0.0, 100000.0), 0), age=round(uniform(0.0, 10.0), 1), fuel_type=choice(list(FuelType)), selling_type=choice(list(SellingType)), transmission_type=choice(list(TransmissionType)), ) return (item_id, features) INTERVAL_MEAN_DEFAULT = 4.0 INTERVAL_BOUNDS_DEFAULT: tuple[float | None, float | None] = (0.5, 10.0) class Requester: def __init__( self, base_url: str, interval_mean: float = INTERVAL_MEAN_DEFAULT, interval_bounds: tuple[float | None, float | None] = INTERVAL_BOUNDS_DEFAULT, *, max_retries: int = MAX_RETRIES_DEFAULT, ): self.base_url = base_url self.interval_mean = interval_mean self.interval_bounds = interval_bounds self.max_retries = max_retries self._session = Session() self._stop_requested: bool = False @property def endpoint(self) -> str: endpoint_url = ENDPOINT_URL if (len(endpoint_url) > 0) and (not endpoint_url.startswith('/')): endpoint_url = '/' + endpoint_url return (self.base_url + endpoint_url) @property def session(self) -> Session: return self._session @property def stop_requested(self) -> bool: return self._stop_requested def stop(self) -> None: self._stop_requested = True def _decide_delay(self) -> float: interval_bounds = self.interval_bounds val = expovariate(1. / self.interval_mean) if interval_bounds[0] is not None: val = max(val, interval_bounds[0]) if interval_bounds[1] is not None: val = min(val, interval_bounds[1]) return val def run(self) -> None: while not self._stop_requested: item_id, features = generate_request_data() try: response = post_item( self._session, self.endpoint, item_id, features, max_retries=self.max_retries, ) except IOError as err: logging.warning('%s: %s', str(err), str(err.__cause__)) raise err else: logging.debug('Success: %s %s', response.status_code, response.reason) sleep(self._decide_delay()) def _build_termination_handler(requester: Requester) -> Callable[[int, FrameType | None], None]: def termination_handler(sig: int, frame: FrameType | None) -> None: _ = sig _ = frame requester.stop() return termination_handler def _configure_logging(level: int, quiet: bool) -> None: if quiet: level = logging.CRITICAL + 1 logging.basicConfig( level=level, format='%(asctime)s %(levelname)s %(message)s', stream=sys.stderr, ) def _setup_signal_handlers(requester: Requester) -> None: termination_handler = _build_termination_handler(requester) for sig in (SIGINT, SIGTERM): signal(sig, termination_handler) def _validate_cli_interval_bound(string: str) -> float | None: string = string.lower() if string in ('', 'null', 'none'): return None return float(string) def _validate_cli_interval_bounds(string: str) -> tuple[float | None, float | None]: string = string.lower() if string in ('', 'null', 'none'): return (None, None) min_string, max_string = string.split(',', 1) return cast( tuple[float | None, float | None], tuple(map(_validate_cli_interval_bound, (min_string, max_string))) ) def _validate_cli_max_retries(string: str) -> int: val = int(string) if val < 0: raise ValueError(f'Max retries should be >=0, given {val}') return val def _validate_cli_logging_level(string: str) -> int: return { 'debug': logging.DEBUG, 'info': logging.INFO, 'warning': logging.WARNING, 'error': logging.ERROR, 'critical': logging.CRITICAL, }[string] def parse_args(argv): parser = ArgumentParser( description=( 'Регулярная отправка POST-запросов на эндпоинт предсказания цены.' ' Остановка по SIGINT / SIGTERM.' ), allow_abbrev=False, exit_on_error=True, ) parser.add_argument('base_url', type=str, nargs='?') parser.add_argument('--interval-mean', type=float, dest='interval_mean') parser.add_argument( '--interval-bounds', type=_validate_cli_interval_bounds, dest='interval_bounds', ) parser.add_argument( '--max-retries', type=_validate_cli_max_retries, default=MAX_RETRIES_DEFAULT, dest='max_retries', ) parser.add_argument('-q', '--quiet', action='store_true', dest='quiet') parser.add_argument( '--log-level', default=logging.WARNING, type=_validate_cli_logging_level, dest='logging_level', ) args = parser.parse_args(argv[1:]) if args.base_url is None: args.base_url = getenv('API_BASE_URL') if args.base_url is None: raise RuntimeError('No API base URL specified') if (args.interval_mean is not None) and (args.interval_mean <= 0): raise ValueError(f'Interval mean should be > 0, given {args.interval_mean}') if ( (args.interval_bounds is not None) and all((b is not None) for b in args.interval_bounds) and (args.interval_bounds[0] > args.interval_bounds[1]) ): raise ValueError(f'Interval bounds should be b_1 <= b_2, given {args.interval_bounds!r}') if args.interval_mean is not None: if args.interval_bounds is None: args.interval_bounds = ((args.interval_mean / 5), (args.interval_mean * 5)) else: args.interval_mean = INTERVAL_MEAN_DEFAULT args.interval_bounds = INTERVAL_BOUNDS_DEFAULT return args def main(argv): args = parse_args(argv) _configure_logging(args.logging_level, args.quiet) logging.debug('Creating a Requester with base URL: %s', args.base_url) requester = Requester( args.base_url, interval_mean=args.interval_mean, interval_bounds=args.interval_bounds, max_retries=args.max_retries, ) _setup_signal_handlers(requester) requester.run() return 0 if __name__ == '__main__': sys.exit(int(main(sys.argv) or 0))