reworked api definition
This commit is contained in:
parent
5050996547
commit
3f74df5355
193
src/testdata/testdata.py
vendored
193
src/testdata/testdata.py
vendored
@ -1,10 +1,14 @@
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import inspect
|
||||
import functools
|
||||
import random
|
||||
|
||||
from typing_extensions import Annotated
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, status, HTTPException
|
||||
from typing_extensions import Annotated
|
||||
from fastapi import FastAPI, Request, Security, status, HTTPException
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict, Field, BeforeValidator, ValidationError
|
||||
|
||||
@ -54,72 +58,147 @@ class Testdata:
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self._config = config
|
||||
self._api = FastAPI(docs_url=None, redoc_url=None)
|
||||
self._logger = logger.getLogger('testdata')
|
||||
self._api = self._setup_api()
|
||||
|
||||
# Store internal state
|
||||
self._state = {'data-used': 0}
|
||||
|
||||
@self._api.get('/zeros')
|
||||
async def zeros(api_key: str, size: int | str, request: Request) -> StreamingResponse:
|
||||
try:
|
||||
extra = {'api_key': api_key, 'ip': request.client.host if request.client is not None else None, 'size': size}
|
||||
self._logger.debug('Initiated request.', extra=extra)
|
||||
def _setup_api(self) -> FastAPI:
|
||||
api = FastAPI(docs_url='/', redoc_url=None)
|
||||
|
||||
if api_key not in config.authorized_keys:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Invalid API Key.'
|
||||
)
|
||||
try:
|
||||
size = convert_to_bytes(size)
|
||||
except ValueError as err:
|
||||
self._logger.warning('Invalid format for size.', extra=extra)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid format for size.'
|
||||
) from err
|
||||
# Security
|
||||
def get_api_key(
|
||||
api_key_query: str = Security(APIKeyQuery(name="api_key", auto_error=False)),
|
||||
api_key_header: str = Security(APIKeyHeader(name="x-api-key", auto_error=False))
|
||||
) -> str:
|
||||
# https://joshdimella.com/blog/adding-api-key-auth-to-fast-api
|
||||
|
||||
if size < 0:
|
||||
raise MinSizePerRequestError
|
||||
if config.max_size < size:
|
||||
raise MaxSizePerRequestError
|
||||
if api_key_query in self._config.authorized_keys:
|
||||
return api_key_query
|
||||
if api_key_header in self._config.authorized_keys:
|
||||
return api_key_header
|
||||
|
||||
# update internal state
|
||||
if config.max_data < self._state['data-used'] + size:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Service not available.'
|
||||
)
|
||||
self._state['data-used'] += size
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Invalid or missing API Key'
|
||||
)
|
||||
|
||||
self._logger.debug('Successfully processed request.', extra=extra)
|
||||
return StreamingResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=generate_data(size, config.buffer_size),
|
||||
media_type='application/octet-stream',
|
||||
headers={
|
||||
'Content-Length': str(size)
|
||||
}
|
||||
# A wrapper to set the function signature to accept the api key dependency
|
||||
def secure(func):
|
||||
# Get old signature
|
||||
positional_only, positional_or_keyword, variadic_positional, keyword_only, variadic_keyword = [], [], [], [], []
|
||||
for value in inspect.signature(func).parameters.values():
|
||||
if value.kind == inspect.Parameter.POSITIONAL_ONLY:
|
||||
positional_only.append(value)
|
||||
elif value.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
||||
positional_or_keyword.append(value)
|
||||
elif value.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
variadic_positional.append(value)
|
||||
elif value.kind == inspect.Parameter.KEYWORD_ONLY:
|
||||
keyword_only.append(value)
|
||||
elif value.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
variadic_keyword.append(value)
|
||||
|
||||
# Avoid passing an unrecognized keyword
|
||||
if inspect.iscoroutinefunction(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
if len(variadic_keyword) == 0:
|
||||
if 'api_key' in kwargs:
|
||||
del kwargs['api_key']
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
def wrapper(*args, **kwargs):
|
||||
if len(variadic_keyword) == 0:
|
||||
if 'api_key' in kwargs:
|
||||
del kwargs['api_key']
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Override signature
|
||||
wrapper.__signature__ = inspect.signature(func).replace(
|
||||
parameters=(
|
||||
*positional_only,
|
||||
*positional_or_keyword,
|
||||
*variadic_positional,
|
||||
*keyword_only,
|
||||
inspect.Parameter('api_key', inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Security(get_api_key)),
|
||||
*variadic_keyword
|
||||
)
|
||||
)
|
||||
|
||||
except MinSizePerRequestError as err:
|
||||
self._logger.warning('Size if negative.', extra=extra)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
|
||||
detail='Size has to be non-negative.'
|
||||
) from err
|
||||
except MaxSizePerRequestError as err:
|
||||
self._logger.warning('Exceeded max size per request.', extra=extra)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
|
||||
detail=f'Exceeded max size per request of {config.max_size} Bytes.'
|
||||
) from err
|
||||
except Exception as err:
|
||||
self._logger.exception(err)
|
||||
raise err
|
||||
return functools.wraps(func)(wrapper)
|
||||
|
||||
# Routes
|
||||
api.get('/zeros')(secure(self._zeros))
|
||||
|
||||
return api
|
||||
|
||||
async def _zeros(self, size: int | str, request: Request, filename: str = 'zeros.bin') -> StreamingResponse:
|
||||
try:
|
||||
extra = {'id': f'{random.randint(0, 2 ** 32 - 1):08X}'}
|
||||
self._logger.debug(
|
||||
'Initiated request.',
|
||||
extra=extra | {
|
||||
'ip': request.client.host if request.client is not None else None,
|
||||
'query-params': dict(request.query_params),
|
||||
'headers': dict(request.headers)
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
size = convert_to_bytes(size)
|
||||
except ValueError as err:
|
||||
self._logger.warning('Invalid format for size.', extra=extra)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid format for size.'
|
||||
) from err
|
||||
|
||||
if size < 0:
|
||||
raise MinSizePerRequestError
|
||||
if self._config.max_size < size:
|
||||
raise MaxSizePerRequestError
|
||||
|
||||
# update internal state
|
||||
if self._config.max_data < self._state['data-used'] + size:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Service not available.'
|
||||
)
|
||||
self._state['data-used'] += size
|
||||
|
||||
self._logger.debug('Successfully processed request.', extra=extra)
|
||||
return StreamingResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=generate_data(size, self._config.buffer_size),
|
||||
media_type='application/octet-stream',
|
||||
headers={
|
||||
'Content-Length': str(size),
|
||||
'Content-Disposition': f'attachment; filename="{filename}"'
|
||||
}
|
||||
)
|
||||
|
||||
except MinSizePerRequestError as err:
|
||||
self._logger.warning('Size if negative.', extra=extra)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
|
||||
detail='Size has to be non-negative.'
|
||||
) from err
|
||||
except MaxSizePerRequestError as err:
|
||||
self._logger.warning('Exceeded max size per request.', extra=extra)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
|
||||
detail=f'Exceeded max size per request of {self._config.max_size} Bytes.'
|
||||
) from err
|
||||
except Exception as err:
|
||||
self._logger.exception(err)
|
||||
raise err
|
||||
|
||||
async def _update_state(self) -> None:
|
||||
assert self._config.database is not None
|
||||
|
||||
async def _update_state(self):
|
||||
mode = 'r+' if os.path.exists(self._config.database) else 'w+'
|
||||
|
||||
with open(self._config.database, mode, encoding='utf-8') as file:
|
||||
@ -142,7 +221,7 @@ class Testdata:
|
||||
self._logger = logger.getLogger('testdata')
|
||||
self._logger.info('Server started.')
|
||||
|
||||
coroutines = [asyncio.create_task(uvicorn.Server(uvicorn.Config(self._api, host, port)).serve())]
|
||||
coroutines = [uvicorn.Server(uvicorn.Config(self._api, host, port)).serve()]
|
||||
if self._config.database is not None:
|
||||
coroutines.append(self._update_state())
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user