From 7fc9558e42fe2fab50e9b0aa33ca9cacbaf61c73 Mon Sep 17 00:00:00 2001 From: Paul Bienkowski Date: Fri, 18 Feb 2022 12:03:45 +0100 Subject: [PATCH] Use custom get_single arg everywhere, remove sanicargs (fixes #193) --- api/obs/api/app.py | 6 ++++++ api/obs/api/routes/exports.py | 7 +++---- api/obs/api/routes/login.py | 11 ++++++----- api/obs/api/routes/mapdetails.py | 8 ++++---- api/obs/api/routes/stats.py | 8 +++++--- api/obs/api/routes/tracks.py | 21 ++++++++++++++------- api/obs/api/utils.py | 7 ++++++- api/requirements.txt | 1 - api/setup.py | 1 - 9 files changed, 44 insertions(+), 26 deletions(-) diff --git a/api/obs/api/app.py b/api/obs/api/app.py index 05f0a39..68f35ac 100644 --- a/api/obs/api/app.py +++ b/api/obs/api/app.py @@ -22,6 +22,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import sessionmaker from obs.api.db import User, make_session, connect_db +from obs.api.utils import get_single_arg log = logging.getLogger(__name__) @@ -128,6 +129,11 @@ def remove_right(l, r): return l +@app.middleware("request") +async def inject_arg_getter(req): + req.ctx.get_single_arg = partial(get_single_arg, req) + + @app.middleware("request") async def inject_urls(req): if req.app.config.FRONTEND_HTTPS: diff --git a/api/obs/api/routes/exports.py b/api/obs/api/routes/exports.py index c71670b..974b05d 100644 --- a/api/obs/api/routes/exports.py +++ b/api/obs/api/routes/exports.py @@ -12,7 +12,6 @@ from sanic.response import raw from sanic.exceptions import InvalidUsage from obs.api.app import app, json as json_response -from obs.api.utils import get_single_arg class ExportFormat(str, Enum): @@ -61,10 +60,10 @@ def shapefile_zip(): @app.get(r"/export/events") async def export_events(req): - bbox = get_single_arg( - req, "bbox", default="-180,-90,180,90", convert=parse_bounding_box + bbox = req.ctx.get_single_arg( + "bbox", default="-180,-90,180,90", convert=parse_bounding_box ) - fmt = get_single_arg(req, "fmt", convert=ExportFormat) + fmt = req.ctx.get_single_arg("fmt", convert=ExportFormat) events = await req.ctx.db.stream_scalars( select(OvertakingEvent).where(OvertakingEvent.geometry.bool_op("&&")(bbox)) diff --git a/api/obs/api/routes/login.py b/api/obs/api/routes/login.py index bd6cb32..2c95dcc 100644 --- a/api/obs/api/routes/login.py +++ b/api/obs/api/routes/login.py @@ -14,14 +14,14 @@ from obs.api.app import auth from obs.api.db import User from sanic.response import json, redirect -from sanicargs import parse_parameters log = logging.getLogger(__name__) client = Client(client_authn_method=CLIENT_AUTHN_METHOD) # Do not show verbose library output, even when the appliaction is in debug mode -logging.getLogger('oic').setLevel(logging.INFO) +logging.getLogger("oic").setLevel(logging.INFO) + @auth.before_server_start async def connect_auth_client(app, loop): @@ -43,12 +43,13 @@ async def connect_auth_client(app, loop): @auth.route("/login") -@parse_parameters -async def login(req, next: str = None): +async def login(req): + next_url = req.ctx.get_single_arg("next", default=None) + session = req.ctx.session session["state"] = rndstr() session["nonce"] = rndstr() - session["next"] = next + session["next"] = next_url args = { "client_id": client.client_id, "response_type": "code", diff --git a/api/obs/api/routes/mapdetails.py b/api/obs/api/routes/mapdetails.py index 9d08dd8..45be46b 100644 --- a/api/obs/api/routes/mapdetails.py +++ b/api/obs/api/routes/mapdetails.py @@ -10,7 +10,7 @@ from sanic.exceptions import InvalidUsage from obs.api.app import api from obs.api.db import Road, OvertakingEvent, Track -from obs.api.utils import round_to, get_single_arg +from obs.api.utils import round_to round_distance = partial(round_to, multiples=0.001) round_speed = partial(round_to, multiples=0.1) @@ -28,9 +28,9 @@ def get_bearing(a, b): @api.route("/mapdetails/road", methods=["GET"]) async def mapdetails_road(req): - longitude = get_single_arg(req, "longitude", convert=float) - latitude = get_single_arg(req, "latitude", convert=float) - radius = get_single_arg(req, "radius", default=100, convert=float) + longitude = req.ctx.get_single_arg("longitude", convert=float) + latitude = req.ctx.get_single_arg("latitude", convert=float) + radius = req.ctx.get_single_arg("radius", default=100, convert=float) if not (1 <= radius <= 1000): raise InvalidUsage("`radius` parameter must be between 1 and 1000") diff --git a/api/obs/api/routes/stats.py b/api/obs/api/routes/stats.py index 1e7b248..8f5603c 100644 --- a/api/obs/api/routes/stats.py +++ b/api/obs/api/routes/stats.py @@ -7,7 +7,6 @@ from functools import reduce from sqlalchemy import select, func from sanic.response import json -from sanicargs import parse_parameters from obs.api.app import api from obs.api.db import Track, OvertakingEvent, User @@ -28,8 +27,11 @@ MINUMUM_RECORDING_DATE = datetime(2010, 1, 1) @api.route("/stats") -@parse_parameters -async def stats(req, user: str = None, start: datetime = None, end: datetime = None): +async def stats(req): + user = req.ctx.get_single_arg("user", default=None) + start = req.ctx.get_single_arg("start", default=None, convert=datetime) + end = req.ctx.get_single_arg("end", default=None, convert=datetime) + conditions = [ Track.recorded_at != None, Track.recorded_at > MINUMUM_RECORDING_DATE, diff --git a/api/obs/api/routes/tracks.py b/api/obs/api/routes/tracks.py index 8697724..3de5bf8 100644 --- a/api/obs/api/routes/tracks.py +++ b/api/obs/api/routes/tracks.py @@ -11,7 +11,6 @@ from obs.api.app import api, require_auth, read_api_key, json from sanic.response import file_stream, empty from sanic.exceptions import InvalidUsage, NotFound, Forbidden -from sanicargs import parse_parameters log = logging.getLogger(__name__) @@ -61,8 +60,11 @@ async def _return_tracks(req, extend_query, limit, offset): @api.get("/tracks") -@parse_parameters -async def get_tracks(req, limit: int = 20, offset: int = 0, author: str = None): +async def get_tracks(req): + limit = req.ctx.get_single_arg("limit", default=20, convert=int) + offset = req.ctx.get_single_arg("offset", default=0, convert=int) + author = req.ctx.get_single_arg("author", default=None, convert=str) + def extend_query(q): q = q.where(Track.public) @@ -76,8 +78,10 @@ async def get_tracks(req, limit: int = 20, offset: int = 0, author: str = None): @api.get("/tracks/feed") @require_auth -@parse_parameters -async def get_feed(req, limit: int = 20, offset: int = 0): +async def get_feed(req): + limit = req.ctx.get_single_arg("limit", default=20, convert=int) + offset = req.ctx.get_single_arg("offset", default=0, convert=int) + def extend_query(q): return q.where(Track.author_id == req.ctx.user.id) @@ -260,8 +264,11 @@ async def put_track(req, slug: str): @api.get("/tracks//comments") -@parse_parameters -async def get_track_comments(req, slug: str, limit: int = 20, offset: int = 0): +async def get_track_comments(req): + slug = req.ctx.get_single_arg("slug") + limit = req.ctx.get_single_arg("limit", default=20, convert=int) + offset = req.ctx.get_single_arg("offset", default=0, convert=int) + track = await _load_track(req, slug) comment_count = await req.ctx.db.scalar( diff --git a/api/obs/api/utils.py b/api/obs/api/utils.py index 8175e9b..b9a50e3 100644 --- a/api/obs/api/utils.py +++ b/api/obs/api/utils.py @@ -1,3 +1,5 @@ +from datetime import datetime +import dateutil.parser from sanic.exceptions import InvalidUsage RAISE = object() @@ -12,7 +14,10 @@ def get_single_arg(req, name, default=RAISE, convert=None): value = default - if convert is not None: + if convert is not None and value is not None: + if convert is datetime or convert in ("date", "datetime"): + convert = lambda s: dateutil.parser.parse(s) + try: value = convert(value) except (ValueError, TypeError) as e: diff --git a/api/requirements.txt b/api/requirements.txt index 1190ac1..4100778 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -2,7 +2,6 @@ coloredlogs~=15.0.1 sanic~=21.9.3 oic~=1.3.0 sanic-session~=0.8.0 -sanicargs~=2.1.0 sanic-cors~=2.0.1 python-slugify~=5.0.2 motor~=2.5.1 diff --git a/api/setup.py b/api/setup.py index acd58ce..640d051 100644 --- a/api/setup.py +++ b/api/setup.py @@ -14,7 +14,6 @@ setup( "sanic~=21.9.3", "oic>=1.3.0, <2", "sanic-session~=0.8.0", - "sanicargs~=2.1.0", "sanic-cors~=2.0.1", "python-slugify~=5.0.2", "motor~=2.5.1",