Use custom get_single arg everywhere, remove sanicargs (fixes #193)

This commit is contained in:
Paul Bienkowski 2022-02-18 12:03:45 +01:00
parent 8bb5d71186
commit 7fc9558e42
9 changed files with 44 additions and 26 deletions

View file

@ -22,6 +22,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from obs.api.db import User, make_session, connect_db from obs.api.db import User, make_session, connect_db
from obs.api.utils import get_single_arg
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -128,6 +129,11 @@ def remove_right(l, r):
return l return l
@app.middleware("request")
async def inject_arg_getter(req):
req.ctx.get_single_arg = partial(get_single_arg, req)
@app.middleware("request") @app.middleware("request")
async def inject_urls(req): async def inject_urls(req):
if req.app.config.FRONTEND_HTTPS: if req.app.config.FRONTEND_HTTPS:

View file

@ -12,7 +12,6 @@ from sanic.response import raw
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage
from obs.api.app import app, json as json_response from obs.api.app import app, json as json_response
from obs.api.utils import get_single_arg
class ExportFormat(str, Enum): class ExportFormat(str, Enum):
@ -61,10 +60,10 @@ def shapefile_zip():
@app.get(r"/export/events") @app.get(r"/export/events")
async def export_events(req): async def export_events(req):
bbox = get_single_arg( bbox = req.ctx.get_single_arg(
req, "bbox", default="-180,-90,180,90", convert=parse_bounding_box "bbox", default="-180,-90,180,90", convert=parse_bounding_box
) )
fmt = get_single_arg(req, "fmt", convert=ExportFormat) fmt = req.ctx.get_single_arg("fmt", convert=ExportFormat)
events = await req.ctx.db.stream_scalars( events = await req.ctx.db.stream_scalars(
select(OvertakingEvent).where(OvertakingEvent.geometry.bool_op("&&")(bbox)) select(OvertakingEvent).where(OvertakingEvent.geometry.bool_op("&&")(bbox))

View file

@ -14,14 +14,14 @@ from obs.api.app import auth
from obs.api.db import User from obs.api.db import User
from sanic.response import json, redirect from sanic.response import json, redirect
from sanicargs import parse_parameters
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
client = Client(client_authn_method=CLIENT_AUTHN_METHOD) client = Client(client_authn_method=CLIENT_AUTHN_METHOD)
# Do not show verbose library output, even when the appliaction is in debug mode # Do not show verbose library output, even when the appliaction is in debug mode
logging.getLogger('oic').setLevel(logging.INFO) logging.getLogger("oic").setLevel(logging.INFO)
@auth.before_server_start @auth.before_server_start
async def connect_auth_client(app, loop): async def connect_auth_client(app, loop):
@ -43,12 +43,13 @@ async def connect_auth_client(app, loop):
@auth.route("/login") @auth.route("/login")
@parse_parameters async def login(req):
async def login(req, next: str = None): next_url = req.ctx.get_single_arg("next", default=None)
session = req.ctx.session session = req.ctx.session
session["state"] = rndstr() session["state"] = rndstr()
session["nonce"] = rndstr() session["nonce"] = rndstr()
session["next"] = next session["next"] = next_url
args = { args = {
"client_id": client.client_id, "client_id": client.client_id,
"response_type": "code", "response_type": "code",

View file

@ -10,7 +10,7 @@ from sanic.exceptions import InvalidUsage
from obs.api.app import api from obs.api.app import api
from obs.api.db import Road, OvertakingEvent, Track from obs.api.db import Road, OvertakingEvent, Track
from obs.api.utils import round_to, get_single_arg from obs.api.utils import round_to
round_distance = partial(round_to, multiples=0.001) round_distance = partial(round_to, multiples=0.001)
round_speed = partial(round_to, multiples=0.1) round_speed = partial(round_to, multiples=0.1)
@ -28,9 +28,9 @@ def get_bearing(a, b):
@api.route("/mapdetails/road", methods=["GET"]) @api.route("/mapdetails/road", methods=["GET"])
async def mapdetails_road(req): async def mapdetails_road(req):
longitude = get_single_arg(req, "longitude", convert=float) longitude = req.ctx.get_single_arg("longitude", convert=float)
latitude = get_single_arg(req, "latitude", convert=float) latitude = req.ctx.get_single_arg("latitude", convert=float)
radius = get_single_arg(req, "radius", default=100, convert=float) radius = req.ctx.get_single_arg("radius", default=100, convert=float)
if not (1 <= radius <= 1000): if not (1 <= radius <= 1000):
raise InvalidUsage("`radius` parameter must be between 1 and 1000") raise InvalidUsage("`radius` parameter must be between 1 and 1000")

View file

@ -7,7 +7,6 @@ from functools import reduce
from sqlalchemy import select, func from sqlalchemy import select, func
from sanic.response import json from sanic.response import json
from sanicargs import parse_parameters
from obs.api.app import api from obs.api.app import api
from obs.api.db import Track, OvertakingEvent, User from obs.api.db import Track, OvertakingEvent, User
@ -28,8 +27,11 @@ MINUMUM_RECORDING_DATE = datetime(2010, 1, 1)
@api.route("/stats") @api.route("/stats")
@parse_parameters async def stats(req):
async def stats(req, user: str = None, start: datetime = None, end: datetime = None): user = req.ctx.get_single_arg("user", default=None)
start = req.ctx.get_single_arg("start", default=None, convert=datetime)
end = req.ctx.get_single_arg("end", default=None, convert=datetime)
conditions = [ conditions = [
Track.recorded_at != None, Track.recorded_at != None,
Track.recorded_at > MINUMUM_RECORDING_DATE, Track.recorded_at > MINUMUM_RECORDING_DATE,

View file

@ -11,7 +11,6 @@ from obs.api.app import api, require_auth, read_api_key, json
from sanic.response import file_stream, empty from sanic.response import file_stream, empty
from sanic.exceptions import InvalidUsage, NotFound, Forbidden from sanic.exceptions import InvalidUsage, NotFound, Forbidden
from sanicargs import parse_parameters
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -61,8 +60,11 @@ async def _return_tracks(req, extend_query, limit, offset):
@api.get("/tracks") @api.get("/tracks")
@parse_parameters async def get_tracks(req):
async def get_tracks(req, limit: int = 20, offset: int = 0, author: str = None): limit = req.ctx.get_single_arg("limit", default=20, convert=int)
offset = req.ctx.get_single_arg("offset", default=0, convert=int)
author = req.ctx.get_single_arg("author", default=None, convert=str)
def extend_query(q): def extend_query(q):
q = q.where(Track.public) q = q.where(Track.public)
@ -76,8 +78,10 @@ async def get_tracks(req, limit: int = 20, offset: int = 0, author: str = None):
@api.get("/tracks/feed") @api.get("/tracks/feed")
@require_auth @require_auth
@parse_parameters async def get_feed(req):
async def get_feed(req, limit: int = 20, offset: int = 0): limit = req.ctx.get_single_arg("limit", default=20, convert=int)
offset = req.ctx.get_single_arg("offset", default=0, convert=int)
def extend_query(q): def extend_query(q):
return q.where(Track.author_id == req.ctx.user.id) return q.where(Track.author_id == req.ctx.user.id)
@ -260,8 +264,11 @@ async def put_track(req, slug: str):
@api.get("/tracks/<slug:str>/comments") @api.get("/tracks/<slug:str>/comments")
@parse_parameters async def get_track_comments(req):
async def get_track_comments(req, slug: str, limit: int = 20, offset: int = 0): slug = req.ctx.get_single_arg("slug")
limit = req.ctx.get_single_arg("limit", default=20, convert=int)
offset = req.ctx.get_single_arg("offset", default=0, convert=int)
track = await _load_track(req, slug) track = await _load_track(req, slug)
comment_count = await req.ctx.db.scalar( comment_count = await req.ctx.db.scalar(

View file

@ -1,3 +1,5 @@
from datetime import datetime
import dateutil.parser
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage
RAISE = object() RAISE = object()
@ -12,7 +14,10 @@ def get_single_arg(req, name, default=RAISE, convert=None):
value = default value = default
if convert is not None: if convert is not None and value is not None:
if convert is datetime or convert in ("date", "datetime"):
convert = lambda s: dateutil.parser.parse(s)
try: try:
value = convert(value) value = convert(value)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:

View file

@ -2,7 +2,6 @@ coloredlogs~=15.0.1
sanic~=21.9.3 sanic~=21.9.3
oic~=1.3.0 oic~=1.3.0
sanic-session~=0.8.0 sanic-session~=0.8.0
sanicargs~=2.1.0
sanic-cors~=2.0.1 sanic-cors~=2.0.1
python-slugify~=5.0.2 python-slugify~=5.0.2
motor~=2.5.1 motor~=2.5.1

View file

@ -14,7 +14,6 @@ setup(
"sanic~=21.9.3", "sanic~=21.9.3",
"oic>=1.3.0, <2", "oic>=1.3.0, <2",
"sanic-session~=0.8.0", "sanic-session~=0.8.0",
"sanicargs~=2.1.0",
"sanic-cors~=2.0.1", "sanic-cors~=2.0.1",
"python-slugify~=5.0.2", "python-slugify~=5.0.2",
"motor~=2.5.1", "motor~=2.5.1",