reworked api definition

This commit is contained in:
Kristian Krsnik 2025-01-04 17:12:40 +01:00
parent 5050996547
commit 3f74df5355
Signed by: Kristian
GPG Key ID: FD1330AC9F909E85

View File

@ -1,10 +1,14 @@
import os import os
import json import json
import asyncio import asyncio
import inspect
import functools
import random
from typing_extensions import Annotated
import uvicorn import uvicorn
from fastapi import FastAPI, Request, status, HTTPException from typing_extensions import Annotated
from fastapi import FastAPI, Request, Security, status, HTTPException
from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict, Field, BeforeValidator, ValidationError from pydantic import BaseModel, ConfigDict, Field, BeforeValidator, ValidationError
@ -54,23 +58,95 @@ class Testdata:
def __init__(self, config: Config): def __init__(self, config: Config):
self._config = config self._config = config
self._api = FastAPI(docs_url=None, redoc_url=None)
self._logger = logger.getLogger('testdata') self._logger = logger.getLogger('testdata')
self._api = self._setup_api()
# Store internal state # Store internal state
self._state = {'data-used': 0} self._state = {'data-used': 0}
@self._api.get('/zeros') def _setup_api(self) -> FastAPI:
async def zeros(api_key: str, size: int | str, request: Request) -> StreamingResponse: api = FastAPI(docs_url='/', redoc_url=None)
try:
extra = {'api_key': api_key, 'ip': request.client.host if request.client is not None else None, 'size': size} # Security
self._logger.debug('Initiated request.', extra=extra) def get_api_key(
api_key_query: str = Security(APIKeyQuery(name="api_key", auto_error=False)),
api_key_header: str = Security(APIKeyHeader(name="x-api-key", auto_error=False))
) -> str:
# https://joshdimella.com/blog/adding-api-key-auth-to-fast-api
if api_key_query in self._config.authorized_keys:
return api_key_query
if api_key_header in self._config.authorized_keys:
return api_key_header
if api_key not in config.authorized_keys:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail='Invalid API Key.' detail='Invalid or missing API Key'
) )
# A wrapper to set the function signature to accept the api key dependency
def secure(func):
# Get old signature
positional_only, positional_or_keyword, variadic_positional, keyword_only, variadic_keyword = [], [], [], [], []
for value in inspect.signature(func).parameters.values():
if value.kind == inspect.Parameter.POSITIONAL_ONLY:
positional_only.append(value)
elif value.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
positional_or_keyword.append(value)
elif value.kind == inspect.Parameter.VAR_POSITIONAL:
variadic_positional.append(value)
elif value.kind == inspect.Parameter.KEYWORD_ONLY:
keyword_only.append(value)
elif value.kind == inspect.Parameter.VAR_KEYWORD:
variadic_keyword.append(value)
# Avoid passing an unrecognized keyword
if inspect.iscoroutinefunction(func):
async def wrapper(*args, **kwargs):
if len(variadic_keyword) == 0:
if 'api_key' in kwargs:
del kwargs['api_key']
return await func(*args, **kwargs)
else:
def wrapper(*args, **kwargs):
if len(variadic_keyword) == 0:
if 'api_key' in kwargs:
del kwargs['api_key']
return func(*args, **kwargs)
# Override signature
wrapper.__signature__ = inspect.signature(func).replace(
parameters=(
*positional_only,
*positional_or_keyword,
*variadic_positional,
*keyword_only,
inspect.Parameter('api_key', inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Security(get_api_key)),
*variadic_keyword
)
)
return functools.wraps(func)(wrapper)
# Routes
api.get('/zeros')(secure(self._zeros))
return api
async def _zeros(self, size: int | str, request: Request, filename: str = 'zeros.bin') -> StreamingResponse:
try:
extra = {'id': f'{random.randint(0, 2 ** 32 - 1):08X}'}
self._logger.debug(
'Initiated request.',
extra=extra | {
'ip': request.client.host if request.client is not None else None,
'query-params': dict(request.query_params),
'headers': dict(request.headers)
}
)
try: try:
size = convert_to_bytes(size) size = convert_to_bytes(size)
except ValueError as err: except ValueError as err:
@ -82,11 +158,11 @@ class Testdata:
if size < 0: if size < 0:
raise MinSizePerRequestError raise MinSizePerRequestError
if config.max_size < size: if self._config.max_size < size:
raise MaxSizePerRequestError raise MaxSizePerRequestError
# update internal state # update internal state
if config.max_data < self._state['data-used'] + size: if self._config.max_data < self._state['data-used'] + size:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Service not available.' detail='Service not available.'
@ -96,10 +172,11 @@ class Testdata:
self._logger.debug('Successfully processed request.', extra=extra) self._logger.debug('Successfully processed request.', extra=extra)
return StreamingResponse( return StreamingResponse(
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
content=generate_data(size, config.buffer_size), content=generate_data(size, self._config.buffer_size),
media_type='application/octet-stream', media_type='application/octet-stream',
headers={ headers={
'Content-Length': str(size) 'Content-Length': str(size),
'Content-Disposition': f'attachment; filename="{filename}"'
} }
) )
@ -113,13 +190,15 @@ class Testdata:
self._logger.warning('Exceeded max size per request.', extra=extra) self._logger.warning('Exceeded max size per request.', extra=extra)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
detail=f'Exceeded max size per request of {config.max_size} Bytes.' detail=f'Exceeded max size per request of {self._config.max_size} Bytes.'
) from err ) from err
except Exception as err: except Exception as err:
self._logger.exception(err) self._logger.exception(err)
raise err raise err
async def _update_state(self): async def _update_state(self) -> None:
assert self._config.database is not None
mode = 'r+' if os.path.exists(self._config.database) else 'w+' mode = 'r+' if os.path.exists(self._config.database) else 'w+'
with open(self._config.database, mode, encoding='utf-8') as file: with open(self._config.database, mode, encoding='utf-8') as file:
@ -142,7 +221,7 @@ class Testdata:
self._logger = logger.getLogger('testdata') self._logger = logger.getLogger('testdata')
self._logger.info('Server started.') self._logger.info('Server started.')
coroutines = [asyncio.create_task(uvicorn.Server(uvicorn.Config(self._api, host, port)).serve())] coroutines = [uvicorn.Server(uvicorn.Config(self._api, host, port)).serve()]
if self._config.database is not None: if self._config.database is not None:
coroutines.append(self._update_state()) coroutines.append(self._update_state())