Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.
iis-project/services/load_tester/tester.py

281 строка
8.5 KiB
Python

from argparse import ArgumentParser
from collections.abc import Callable, MutableMapping
from dataclasses import dataclass, asdict
from enum import Enum
import logging
from os import getenv
from random import randint, uniform, expovariate, choice
from signal import SIGINT, SIGTERM, signal
import sys
from time import sleep
from types import FrameType
from typing import Any, cast
from requests import RequestException, Response, Session
def fixup_payload_enum_value(mapping: MutableMapping[str, Any], key: str) -> None:
mapping[key] = mapping[key].value
ENDPOINT_URL: str = '/predict'
class FuelType(Enum):
PETROL = 'petrol'
DIESEL = 'diesel'
CNG = 'cng'
class SellingType(Enum):
DEALER = 'dealer'
INDIVIDUAL = 'individual'
class TransmissionType(Enum):
MANUAL = 'manual'
AUTOMATIC = 'automatic'
@dataclass
class PricePredictionFeatures:
selling_price: float
driven_kms: float
age: float
fuel_type: FuelType
selling_type: SellingType
transmission_type: TransmissionType
MAX_RETRIES_DEFAULT = 3
def exp_delay_from_attempt_number(attempt_i: int) -> float:
return 0.2 * (2 ** attempt_i)
def post_item(
session: Session, url: str, item_id: int, features: PricePredictionFeatures,
*, max_retries: int = MAX_RETRIES_DEFAULT,
) -> Response:
if max_retries < 0:
raise ValueError('max_retries must be >= 0')
payload = asdict(features)
for k in ('fuel_type', 'selling_type', 'transmission_type'):
fixup_payload_enum_value(payload, k)
excs = []
for attempt_i in range(max_retries + 1):
try:
response = session.post(url, params={'item_id': item_id}, json=payload, timeout=10)
except RequestException as err:
excs.append(err)
sleep(exp_delay_from_attempt_number(attempt_i))
else:
return response
assert len(excs) > 0
# XXX: ...
raise IOError(
f'Failed to post an item in {max_retries + 1} attempts;'
' see the latest exception in __cause__'
) from excs[-1]
def generate_request_data() -> tuple[int, PricePredictionFeatures]:
item_id = randint(1, 100)
features = PricePredictionFeatures(
selling_price=round(uniform(2.0, 16.0), 2),
driven_kms=round(uniform(0.0, 100000.0), 0),
age=round(uniform(0.0, 10.0), 1),
fuel_type=choice(list(FuelType)),
selling_type=choice(list(SellingType)),
transmission_type=choice(list(TransmissionType)),
)
return (item_id, features)
INTERVAL_MEAN_DEFAULT = 4.0
INTERVAL_BOUNDS_DEFAULT: tuple[float | None, float | None] = (0.5, 10.0)
class Requester:
def __init__(
self,
base_url: str,
interval_mean: float = INTERVAL_MEAN_DEFAULT,
interval_bounds: tuple[float | None, float | None] = INTERVAL_BOUNDS_DEFAULT,
*, max_retries: int = MAX_RETRIES_DEFAULT,
):
self.base_url = base_url
self.interval_mean = interval_mean
self.interval_bounds = interval_bounds
self.max_retries = max_retries
self._session = Session()
self._stop_requested: bool = False
@property
def endpoint(self) -> str:
endpoint_url = ENDPOINT_URL
if (len(endpoint_url) > 0) and (not endpoint_url.startswith('/')):
endpoint_url = '/' + endpoint_url
return (self.base_url + endpoint_url)
@property
def session(self) -> Session:
return self._session
@property
def stop_requested(self) -> bool:
return self._stop_requested
def stop(self) -> None:
self._stop_requested = True
def _decide_delay(self) -> float:
interval_bounds = self.interval_bounds
val = expovariate(1. / self.interval_mean)
if interval_bounds[0] is not None:
val = max(val, interval_bounds[0])
if interval_bounds[1] is not None:
val = min(val, interval_bounds[1])
return val
def run(self) -> None:
while not self._stop_requested:
item_id, features = generate_request_data()
try:
response = post_item(
self._session, self.endpoint, item_id, features, max_retries=self.max_retries,
)
except IOError as err:
logging.warning('%s: %s', str(err), str(err.__cause__))
raise err
else:
logging.debug('Success: %s %s', response.status_code, response.reason)
sleep(self._decide_delay())
def _build_termination_handler(requester: Requester) -> Callable[[int, FrameType | None], None]:
def termination_handler(sig: int, frame: FrameType | None) -> None:
_ = sig
_ = frame
requester.stop()
return termination_handler
def _configure_logging(level: int, quiet: bool) -> None:
if quiet:
level = logging.CRITICAL + 1
logging.basicConfig(
level=level, format='%(asctime)s %(levelname)s %(message)s', stream=sys.stderr,
)
def _setup_signal_handlers(requester: Requester) -> None:
termination_handler = _build_termination_handler(requester)
for sig in (SIGINT, SIGTERM):
signal(sig, termination_handler)
def _validate_cli_interval_bound(string: str) -> float | None:
string = string.lower()
if string in ('', 'null', 'none'):
return None
return float(string)
def _validate_cli_interval_bounds(string: str) -> tuple[float | None, float | None]:
string = string.lower()
if string in ('', 'null', 'none'):
return (None, None)
min_string, max_string = string.split(',', 1)
return cast(
tuple[float | None, float | None],
tuple(map(_validate_cli_interval_bound, (min_string, max_string)))
)
def _validate_cli_max_retries(string: str) -> int:
val = int(string)
if val < 0:
raise ValueError(f'Max retries should be >=0, given {val}')
return val
def _validate_cli_logging_level(string: str) -> int:
return {
'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'critical': logging.CRITICAL,
}[string]
def parse_args(argv):
parser = ArgumentParser(
description=(
'Регулярная отправка POST-запросов на эндпоинт предсказания цены.'
' Остановка по SIGINT / SIGTERM.'
),
allow_abbrev=False,
exit_on_error=True,
)
parser.add_argument('base_url', type=str, nargs='?')
parser.add_argument('--interval-mean', type=float, dest='interval_mean')
parser.add_argument(
'--interval-bounds', type=_validate_cli_interval_bounds, dest='interval_bounds',
)
parser.add_argument(
'--max-retries',
type=_validate_cli_max_retries,
default=MAX_RETRIES_DEFAULT,
dest='max_retries',
)
parser.add_argument('-q', '--quiet', action='store_true', dest='quiet')
parser.add_argument(
'--log-level',
default=logging.WARNING,
type=_validate_cli_logging_level,
dest='logging_level',
)
args = parser.parse_args(argv[1:])
if args.base_url is None:
args.base_url = getenv('API_BASE_URL')
if args.base_url is None:
raise RuntimeError('No API base URL specified')
if (args.interval_mean is not None) and (args.interval_mean <= 0):
raise ValueError(f'Interval mean should be > 0, given {args.interval_mean}')
if (
(args.interval_bounds is not None)
and all((b is not None) for b in args.interval_bounds)
and (args.interval_bounds[0] > args.interval_bounds[1])
):
raise ValueError(f'Interval bounds should be b_1 <= b_2, given {args.interval_bounds!r}')
if args.interval_mean is not None:
if args.interval_bounds is None:
args.interval_bounds = ((args.interval_mean / 5), (args.interval_mean * 5))
else:
args.interval_mean = INTERVAL_MEAN_DEFAULT
args.interval_bounds = INTERVAL_BOUNDS_DEFAULT
return args
def main(argv):
args = parse_args(argv)
_configure_logging(args.logging_level, args.quiet)
logging.debug('Creating a Requester with base URL: %s', args.base_url)
requester = Requester(
args.base_url,
interval_mean=args.interval_mean,
interval_bounds=args.interval_bounds,
max_retries=args.max_retries,
)
_setup_signal_handlers(requester)
requester.run()
return 0
if __name__ == '__main__':
sys.exit(int(main(sys.argv) or 0))