v1.0.0
This commit is contained in:
109
src/main.py
109
src/main.py
@ -2,90 +2,97 @@ import sys
|
||||
import asyncio
|
||||
import argparse
|
||||
import json
|
||||
from os.path import exists
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import status
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.asyncio import serve
|
||||
import ipaddress
|
||||
|
||||
from .utils import convert_to_bytes, generate_data, load_database, save_database
|
||||
|
||||
# Setup Parser
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=argparse.FileType('r'),
|
||||
default='./config', help='Path to config file in JSON format.')
|
||||
# parser.add_argument('-db', '--database', type=argparse.FileType('r'), # TODO: read+write
|
||||
# default='./db.json', help='Path to database file in JSON format.')
|
||||
default='./config.json', help='Path to config file in JSON format.')
|
||||
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
|
||||
# Load Config
|
||||
config = json.load(args.config)
|
||||
CONFIG = json.load(args.config)
|
||||
BUFFER_SIZE = convert_to_bytes(CONFIG['buffer-size'])
|
||||
MAX_SIZE = convert_to_bytes(CONFIG['max-size'])
|
||||
MAX_DATA = convert_to_bytes(CONFIG['max-data'])
|
||||
AUTHORIZED_KEYS = CONFIG['keys']
|
||||
DATABASE = CONFIG['database']
|
||||
|
||||
if not exists(DATABASE):
|
||||
save_database(DATABASE, {'data-used': 0})
|
||||
|
||||
|
||||
api = FastAPI()
|
||||
|
||||
BUFFER_SIZE = 1024 * 4 # 4KB
|
||||
HOST = '127.0.0.1'
|
||||
PORT = 9250
|
||||
|
||||
class MaxSizePerRequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def generate_test_data(size: int | str, max_size=1024 * 1024) -> bytes:
|
||||
size_left = None
|
||||
class MinSizePerRequestError(Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
size_left = int(size)
|
||||
except ValueError: # treat as string
|
||||
units = {
|
||||
'GB': 10 ** 9, 'GiB': 2 ** 30,
|
||||
'MB': 10 ** 6, 'MiB': 2 ** 20,
|
||||
'KB': 10 ** 3, 'KiB': 2 ** 10,
|
||||
'B': 1
|
||||
}
|
||||
|
||||
for unit in units:
|
||||
if size.endswith(unit):
|
||||
size_left = int(float(size.removesuffix(unit)) * units[unit])
|
||||
break
|
||||
|
||||
print('size_left', size_left)
|
||||
|
||||
# yield data
|
||||
while size_left > BUFFER_SIZE:
|
||||
size_left -= BUFFER_SIZE
|
||||
yield b'\0' * BUFFER_SIZE
|
||||
else:
|
||||
yield b'\0' * size_left
|
||||
|
||||
|
||||
def check_policies(ip: str) -> None:
|
||||
network = ipaddress.ip_network(ip)
|
||||
print(network)
|
||||
|
||||
@api.get('/')
|
||||
async def test_data(size: str, request: Request) -> StreamingResponse:
|
||||
def test_data(api_key: str, size: str, request: Request) -> StreamingResponse:
|
||||
try:
|
||||
check_policies(request.client.host)
|
||||
if api_key not in AUTHORIZED_KEYS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Invalid API Key.'
|
||||
)
|
||||
|
||||
size = convert_to_bytes(size)
|
||||
|
||||
if size < 0:
|
||||
raise MinSizePerRequestError
|
||||
elif MAX_SIZE < size:
|
||||
raise MaxSizePerRequestError
|
||||
|
||||
database = load_database(DATABASE)
|
||||
if MAX_DATA <= database['data-used'] + size:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Service not available.'
|
||||
)
|
||||
database['data-used'] += size
|
||||
|
||||
save_database(DATABASE, database)
|
||||
|
||||
return StreamingResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=generate_test_data(size)
|
||||
content=generate_data(size, BUFFER_SIZE)
|
||||
)
|
||||
|
||||
except MinSizePerRequestError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
|
||||
detail=f'Size has to be not-negative.'
|
||||
)
|
||||
except MaxSizePerRequestError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
|
||||
detail=f'Exceeded max size per request of {MAX_SIZE} Bytes.'
|
||||
)
|
||||
# except TooManyRequestsError as err:
|
||||
# ...
|
||||
# except BlockedError as err:
|
||||
# ...
|
||||
# except OutOfQuotaError as err:
|
||||
# ...
|
||||
except err:
|
||||
raise err
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
asyncio.run(serve(
|
||||
api,
|
||||
Config().from_object({'bind': [f'{HOST}:{PORT}']})
|
||||
Config().from_mapping(
|
||||
bind=[f"{CONFIG['host']}:{CONFIG['port']}"],
|
||||
accesslog='-'
|
||||
)
|
||||
))
|
||||
|
||||
|
||||
|
43
src/utils.py
Normal file
43
src/utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
import json
|
||||
|
||||
def convert_to_bytes(size: int | str) -> int:
|
||||
try:
|
||||
return int(size)
|
||||
except ValueError: # treat as string
|
||||
units = {
|
||||
'TB': 10 ** 12, 'TiB': 2 ** 40,
|
||||
'GB': 10 ** 9, 'GiB': 2 ** 30,
|
||||
'MB': 10 ** 6, 'MiB': 2 ** 20,
|
||||
'KB': 10 ** 3, 'KiB': 2 ** 10,
|
||||
'B': 1
|
||||
}
|
||||
|
||||
for unit in units:
|
||||
if size.endswith(unit):
|
||||
return int(float(size.removesuffix(unit)) * units[unit])
|
||||
break
|
||||
|
||||
|
||||
async def generate_data(size: int, buffer_size: int = 4 * 1024) -> bytes:
|
||||
size_left = size
|
||||
|
||||
while size_left > buffer_size:
|
||||
size_left -= buffer_size
|
||||
yield b'\0' * buffer_size
|
||||
else:
|
||||
yield b'\0' * size_left
|
||||
|
||||
|
||||
def check_policies(ip: str) -> None:
|
||||
network = ipaddress.ip_network(ip)
|
||||
print(network)
|
||||
|
||||
|
||||
def load_database(path: str) -> dict:
|
||||
with open(path, 'r') as file:
|
||||
return json.load(file)
|
||||
|
||||
|
||||
def save_database(path: str, database: dict) -> None:
|
||||
with open(path, 'w') as file:
|
||||
json.dump(database, file, indent=2)
|
Reference in New Issue
Block a user