Use custom get_single arg everywhere, remove sanicargs (fixes #193)

This commit is contained in:
Paul Bienkowski 2022-02-18 12:03:45 +01:00
parent 8bb5d71186
commit 7fc9558e42
9 changed files with 44 additions and 26 deletions

View file

@ -22,6 +22,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
from obs.api.db import User, make_session, connect_db
from obs.api.utils import get_single_arg
log = logging.getLogger(__name__)
@ -128,6 +129,11 @@ def remove_right(l, r):
return l
@app.middleware("request")
async def inject_arg_getter(req):
req.ctx.get_single_arg = partial(get_single_arg, req)
@app.middleware("request")
async def inject_urls(req):
if req.app.config.FRONTEND_HTTPS:

View file

@ -12,7 +12,6 @@ from sanic.response import raw
from sanic.exceptions import InvalidUsage
from obs.api.app import app, json as json_response
from obs.api.utils import get_single_arg
class ExportFormat(str, Enum):
@ -61,10 +60,10 @@ def shapefile_zip():
@app.get(r"/export/events")
async def export_events(req):
bbox = get_single_arg(
req, "bbox", default="-180,-90,180,90", convert=parse_bounding_box
bbox = req.ctx.get_single_arg(
"bbox", default="-180,-90,180,90", convert=parse_bounding_box
)
fmt = get_single_arg(req, "fmt", convert=ExportFormat)
fmt = req.ctx.get_single_arg("fmt", convert=ExportFormat)
events = await req.ctx.db.stream_scalars(
select(OvertakingEvent).where(OvertakingEvent.geometry.bool_op("&&")(bbox))

View file

@ -14,14 +14,14 @@ from obs.api.app import auth
from obs.api.db import User
from sanic.response import json, redirect
from sanicargs import parse_parameters
log = logging.getLogger(__name__)
client = Client(client_authn_method=CLIENT_AUTHN_METHOD)
# Do not show verbose library output, even when the appliaction is in debug mode
logging.getLogger('oic').setLevel(logging.INFO)
logging.getLogger("oic").setLevel(logging.INFO)
@auth.before_server_start
async def connect_auth_client(app, loop):
@ -43,12 +43,13 @@ async def connect_auth_client(app, loop):
@auth.route("/login")
@parse_parameters
async def login(req, next: str = None):
async def login(req):
next_url = req.ctx.get_single_arg("next", default=None)
session = req.ctx.session
session["state"] = rndstr()
session["nonce"] = rndstr()
session["next"] = next
session["next"] = next_url
args = {
"client_id": client.client_id,
"response_type": "code",

View file

@ -10,7 +10,7 @@ from sanic.exceptions import InvalidUsage
from obs.api.app import api
from obs.api.db import Road, OvertakingEvent, Track
from obs.api.utils import round_to, get_single_arg
from obs.api.utils import round_to
round_distance = partial(round_to, multiples=0.001)
round_speed = partial(round_to, multiples=0.1)
@ -28,9 +28,9 @@ def get_bearing(a, b):
@api.route("/mapdetails/road", methods=["GET"])
async def mapdetails_road(req):
longitude = get_single_arg(req, "longitude", convert=float)
latitude = get_single_arg(req, "latitude", convert=float)
radius = get_single_arg(req, "radius", default=100, convert=float)
longitude = req.ctx.get_single_arg("longitude", convert=float)
latitude = req.ctx.get_single_arg("latitude", convert=float)
radius = req.ctx.get_single_arg("radius", default=100, convert=float)
if not (1 <= radius <= 1000):
raise InvalidUsage("`radius` parameter must be between 1 and 1000")

View file

@ -7,7 +7,6 @@ from functools import reduce
from sqlalchemy import select, func
from sanic.response import json
from sanicargs import parse_parameters
from obs.api.app import api
from obs.api.db import Track, OvertakingEvent, User
@ -28,8 +27,11 @@ MINUMUM_RECORDING_DATE = datetime(2010, 1, 1)
@api.route("/stats")
@parse_parameters
async def stats(req, user: str = None, start: datetime = None, end: datetime = None):
async def stats(req):
user = req.ctx.get_single_arg("user", default=None)
start = req.ctx.get_single_arg("start", default=None, convert=datetime)
end = req.ctx.get_single_arg("end", default=None, convert=datetime)
conditions = [
Track.recorded_at != None,
Track.recorded_at > MINUMUM_RECORDING_DATE,

View file

@ -11,7 +11,6 @@ from obs.api.app import api, require_auth, read_api_key, json
from sanic.response import file_stream, empty
from sanic.exceptions import InvalidUsage, NotFound, Forbidden
from sanicargs import parse_parameters
log = logging.getLogger(__name__)
@ -61,8 +60,11 @@ async def _return_tracks(req, extend_query, limit, offset):
@api.get("/tracks")
@parse_parameters
async def get_tracks(req, limit: int = 20, offset: int = 0, author: str = None):
async def get_tracks(req):
limit = req.ctx.get_single_arg("limit", default=20, convert=int)
offset = req.ctx.get_single_arg("offset", default=0, convert=int)
author = req.ctx.get_single_arg("author", default=None, convert=str)
def extend_query(q):
q = q.where(Track.public)
@ -76,8 +78,10 @@ async def get_tracks(req, limit: int = 20, offset: int = 0, author: str = None):
@api.get("/tracks/feed")
@require_auth
@parse_parameters
async def get_feed(req, limit: int = 20, offset: int = 0):
async def get_feed(req):
limit = req.ctx.get_single_arg("limit", default=20, convert=int)
offset = req.ctx.get_single_arg("offset", default=0, convert=int)
def extend_query(q):
return q.where(Track.author_id == req.ctx.user.id)
@ -260,8 +264,11 @@ async def put_track(req, slug: str):
@api.get("/tracks/<slug:str>/comments")
@parse_parameters
async def get_track_comments(req, slug: str, limit: int = 20, offset: int = 0):
async def get_track_comments(req):
slug = req.ctx.get_single_arg("slug")
limit = req.ctx.get_single_arg("limit", default=20, convert=int)
offset = req.ctx.get_single_arg("offset", default=0, convert=int)
track = await _load_track(req, slug)
comment_count = await req.ctx.db.scalar(

View file

@ -1,3 +1,5 @@
from datetime import datetime
import dateutil.parser
from sanic.exceptions import InvalidUsage
RAISE = object()
@ -12,7 +14,10 @@ def get_single_arg(req, name, default=RAISE, convert=None):
value = default
if convert is not None:
if convert is not None and value is not None:
if convert is datetime or convert in ("date", "datetime"):
convert = lambda s: dateutil.parser.parse(s)
try:
value = convert(value)
except (ValueError, TypeError) as e:

View file

@ -2,7 +2,6 @@ coloredlogs~=15.0.1
sanic~=21.9.3
oic~=1.3.0
sanic-session~=0.8.0
sanicargs~=2.1.0
sanic-cors~=2.0.1
python-slugify~=5.0.2
motor~=2.5.1

View file

@ -14,7 +14,6 @@ setup(
"sanic~=21.9.3",
"oic>=1.3.0, <2",
"sanic-session~=0.8.0",
"sanicargs~=2.1.0",
"sanic-cors~=2.0.1",
"python-slugify~=5.0.2",
"motor~=2.5.1",