From 3f74df535520bf344ad4795c673cba95bd69cb08 Mon Sep 17 00:00:00 2001 From: Kristian Krsnik Date: Sat, 4 Jan 2025 17:12:40 +0100 Subject: [PATCH] reworked api definition --- src/testdata/testdata.py | 193 +++++++++++++++++++++++++++------------ 1 file changed, 136 insertions(+), 57 deletions(-) diff --git a/src/testdata/testdata.py b/src/testdata/testdata.py index 63c12b0..09c001d 100644 --- a/src/testdata/testdata.py +++ b/src/testdata/testdata.py @@ -1,10 +1,14 @@ import os import json import asyncio +import inspect +import functools +import random -from typing_extensions import Annotated import uvicorn -from fastapi import FastAPI, Request, status, HTTPException +from typing_extensions import Annotated +from fastapi import FastAPI, Request, Security, status, HTTPException +from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict, Field, BeforeValidator, ValidationError @@ -54,72 +58,147 @@ class Testdata: def __init__(self, config: Config): self._config = config - self._api = FastAPI(docs_url=None, redoc_url=None) self._logger = logger.getLogger('testdata') + self._api = self._setup_api() # Store internal state self._state = {'data-used': 0} - @self._api.get('/zeros') - async def zeros(api_key: str, size: int | str, request: Request) -> StreamingResponse: - try: - extra = {'api_key': api_key, 'ip': request.client.host if request.client is not None else None, 'size': size} - self._logger.debug('Initiated request.', extra=extra) + def _setup_api(self) -> FastAPI: + api = FastAPI(docs_url='/', redoc_url=None) - if api_key not in config.authorized_keys: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail='Invalid API Key.' - ) - try: - size = convert_to_bytes(size) - except ValueError as err: - self._logger.warning('Invalid format for size.', extra=extra) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail='Invalid format for size.' - ) from err + # Security + def get_api_key( + api_key_query: str = Security(APIKeyQuery(name="api_key", auto_error=False)), + api_key_header: str = Security(APIKeyHeader(name="x-api-key", auto_error=False)) + ) -> str: + # https://joshdimella.com/blog/adding-api-key-auth-to-fast-api - if size < 0: - raise MinSizePerRequestError - if config.max_size < size: - raise MaxSizePerRequestError + if api_key_query in self._config.authorized_keys: + return api_key_query + if api_key_header in self._config.authorized_keys: + return api_key_header - # update internal state - if config.max_data < self._state['data-used'] + size: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail='Service not available.' - ) - self._state['data-used'] += size + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Invalid or missing API Key' + ) - self._logger.debug('Successfully processed request.', extra=extra) - return StreamingResponse( - status_code=status.HTTP_200_OK, - content=generate_data(size, config.buffer_size), - media_type='application/octet-stream', - headers={ - 'Content-Length': str(size) - } + # A wrapper to set the function signature to accept the api key dependency + def secure(func): + # Get old signature + positional_only, positional_or_keyword, variadic_positional, keyword_only, variadic_keyword = [], [], [], [], [] + for value in inspect.signature(func).parameters.values(): + if value.kind == inspect.Parameter.POSITIONAL_ONLY: + positional_only.append(value) + elif value.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + positional_or_keyword.append(value) + elif value.kind == inspect.Parameter.VAR_POSITIONAL: + variadic_positional.append(value) + elif value.kind == inspect.Parameter.KEYWORD_ONLY: + keyword_only.append(value) + elif value.kind == inspect.Parameter.VAR_KEYWORD: + variadic_keyword.append(value) + + # Avoid passing an unrecognized keyword + if inspect.iscoroutinefunction(func): + async def wrapper(*args, **kwargs): + if len(variadic_keyword) == 0: + if 'api_key' in kwargs: + del kwargs['api_key'] + + return await func(*args, **kwargs) + else: + def wrapper(*args, **kwargs): + if len(variadic_keyword) == 0: + if 'api_key' in kwargs: + del kwargs['api_key'] + + return func(*args, **kwargs) + + # Override signature + wrapper.__signature__ = inspect.signature(func).replace( + parameters=( + *positional_only, + *positional_or_keyword, + *variadic_positional, + *keyword_only, + inspect.Parameter('api_key', inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Security(get_api_key)), + *variadic_keyword ) + ) - except MinSizePerRequestError as err: - self._logger.warning('Size if negative.', extra=extra) - raise HTTPException( - status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, - detail='Size has to be non-negative.' - ) from err - except MaxSizePerRequestError as err: - self._logger.warning('Exceeded max size per request.', extra=extra) - raise HTTPException( - status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, - detail=f'Exceeded max size per request of {config.max_size} Bytes.' - ) from err - except Exception as err: - self._logger.exception(err) - raise err + return functools.wraps(func)(wrapper) + + # Routes + api.get('/zeros')(secure(self._zeros)) + + return api + + async def _zeros(self, size: int | str, request: Request, filename: str = 'zeros.bin') -> StreamingResponse: + try: + extra = {'id': f'{random.randint(0, 2 ** 32 - 1):08X}'} + self._logger.debug( + 'Initiated request.', + extra=extra | { + 'ip': request.client.host if request.client is not None else None, + 'query-params': dict(request.query_params), + 'headers': dict(request.headers) + } + ) + + try: + size = convert_to_bytes(size) + except ValueError as err: + self._logger.warning('Invalid format for size.', extra=extra) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='Invalid format for size.' + ) from err + + if size < 0: + raise MinSizePerRequestError + if self._config.max_size < size: + raise MaxSizePerRequestError + + # update internal state + if self._config.max_data < self._state['data-used'] + size: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Service not available.' + ) + self._state['data-used'] += size + + self._logger.debug('Successfully processed request.', extra=extra) + return StreamingResponse( + status_code=status.HTTP_200_OK, + content=generate_data(size, self._config.buffer_size), + media_type='application/octet-stream', + headers={ + 'Content-Length': str(size), + 'Content-Disposition': f'attachment; filename="{filename}"' + } + ) + + except MinSizePerRequestError as err: + self._logger.warning('Size if negative.', extra=extra) + raise HTTPException( + status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, + detail='Size has to be non-negative.' + ) from err + except MaxSizePerRequestError as err: + self._logger.warning('Exceeded max size per request.', extra=extra) + raise HTTPException( + status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, + detail=f'Exceeded max size per request of {self._config.max_size} Bytes.' + ) from err + except Exception as err: + self._logger.exception(err) + raise err + + async def _update_state(self) -> None: + assert self._config.database is not None - async def _update_state(self): mode = 'r+' if os.path.exists(self._config.database) else 'w+' with open(self._config.database, mode, encoding='utf-8') as file: @@ -142,7 +221,7 @@ class Testdata: self._logger = logger.getLogger('testdata') self._logger.info('Server started.') - coroutines = [asyncio.create_task(uvicorn.Server(uvicorn.Config(self._api, host, port)).serve())] + coroutines = [uvicorn.Server(uvicorn.Config(self._api, host, port)).serve()] if self._config.database is not None: coroutines.append(self._update_state())