Use custom get_single arg everywhere, remove sanicargs (fixes #193)
This commit is contained in:
parent
8bb5d71186
commit
7fc9558e42
|
@ -22,6 +22,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from obs.api.db import User, make_session, connect_db
|
||||
from obs.api.utils import get_single_arg
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -128,6 +129,11 @@ def remove_right(l, r):
|
|||
return l
|
||||
|
||||
|
||||
@app.middleware("request")
|
||||
async def inject_arg_getter(req):
|
||||
req.ctx.get_single_arg = partial(get_single_arg, req)
|
||||
|
||||
|
||||
@app.middleware("request")
|
||||
async def inject_urls(req):
|
||||
if req.app.config.FRONTEND_HTTPS:
|
||||
|
|
|
@ -12,7 +12,6 @@ from sanic.response import raw
|
|||
from sanic.exceptions import InvalidUsage
|
||||
|
||||
from obs.api.app import app, json as json_response
|
||||
from obs.api.utils import get_single_arg
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
|
@ -61,10 +60,10 @@ def shapefile_zip():
|
|||
|
||||
@app.get(r"/export/events")
|
||||
async def export_events(req):
|
||||
bbox = get_single_arg(
|
||||
req, "bbox", default="-180,-90,180,90", convert=parse_bounding_box
|
||||
bbox = req.ctx.get_single_arg(
|
||||
"bbox", default="-180,-90,180,90", convert=parse_bounding_box
|
||||
)
|
||||
fmt = get_single_arg(req, "fmt", convert=ExportFormat)
|
||||
fmt = req.ctx.get_single_arg("fmt", convert=ExportFormat)
|
||||
|
||||
events = await req.ctx.db.stream_scalars(
|
||||
select(OvertakingEvent).where(OvertakingEvent.geometry.bool_op("&&")(bbox))
|
||||
|
|
|
@ -14,14 +14,14 @@ from obs.api.app import auth
|
|||
from obs.api.db import User
|
||||
|
||||
from sanic.response import json, redirect
|
||||
from sanicargs import parse_parameters
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
client = Client(client_authn_method=CLIENT_AUTHN_METHOD)
|
||||
|
||||
# Do not show verbose library output, even when the appliaction is in debug mode
|
||||
logging.getLogger('oic').setLevel(logging.INFO)
|
||||
logging.getLogger("oic").setLevel(logging.INFO)
|
||||
|
||||
|
||||
@auth.before_server_start
|
||||
async def connect_auth_client(app, loop):
|
||||
|
@ -43,12 +43,13 @@ async def connect_auth_client(app, loop):
|
|||
|
||||
|
||||
@auth.route("/login")
|
||||
@parse_parameters
|
||||
async def login(req, next: str = None):
|
||||
async def login(req):
|
||||
next_url = req.ctx.get_single_arg("next", default=None)
|
||||
|
||||
session = req.ctx.session
|
||||
session["state"] = rndstr()
|
||||
session["nonce"] = rndstr()
|
||||
session["next"] = next
|
||||
session["next"] = next_url
|
||||
args = {
|
||||
"client_id": client.client_id,
|
||||
"response_type": "code",
|
||||
|
|
|
@ -10,7 +10,7 @@ from sanic.exceptions import InvalidUsage
|
|||
|
||||
from obs.api.app import api
|
||||
from obs.api.db import Road, OvertakingEvent, Track
|
||||
from obs.api.utils import round_to, get_single_arg
|
||||
from obs.api.utils import round_to
|
||||
|
||||
round_distance = partial(round_to, multiples=0.001)
|
||||
round_speed = partial(round_to, multiples=0.1)
|
||||
|
@ -28,9 +28,9 @@ def get_bearing(a, b):
|
|||
|
||||
@api.route("/mapdetails/road", methods=["GET"])
|
||||
async def mapdetails_road(req):
|
||||
longitude = get_single_arg(req, "longitude", convert=float)
|
||||
latitude = get_single_arg(req, "latitude", convert=float)
|
||||
radius = get_single_arg(req, "radius", default=100, convert=float)
|
||||
longitude = req.ctx.get_single_arg("longitude", convert=float)
|
||||
latitude = req.ctx.get_single_arg("latitude", convert=float)
|
||||
radius = req.ctx.get_single_arg("radius", default=100, convert=float)
|
||||
|
||||
if not (1 <= radius <= 1000):
|
||||
raise InvalidUsage("`radius` parameter must be between 1 and 1000")
|
||||
|
|
|
@ -7,7 +7,6 @@ from functools import reduce
|
|||
from sqlalchemy import select, func
|
||||
|
||||
from sanic.response import json
|
||||
from sanicargs import parse_parameters
|
||||
|
||||
from obs.api.app import api
|
||||
from obs.api.db import Track, OvertakingEvent, User
|
||||
|
@ -28,8 +27,11 @@ MINUMUM_RECORDING_DATE = datetime(2010, 1, 1)
|
|||
|
||||
|
||||
@api.route("/stats")
|
||||
@parse_parameters
|
||||
async def stats(req, user: str = None, start: datetime = None, end: datetime = None):
|
||||
async def stats(req):
|
||||
user = req.ctx.get_single_arg("user", default=None)
|
||||
start = req.ctx.get_single_arg("start", default=None, convert=datetime)
|
||||
end = req.ctx.get_single_arg("end", default=None, convert=datetime)
|
||||
|
||||
conditions = [
|
||||
Track.recorded_at != None,
|
||||
Track.recorded_at > MINUMUM_RECORDING_DATE,
|
||||
|
|
|
@ -11,7 +11,6 @@ from obs.api.app import api, require_auth, read_api_key, json
|
|||
|
||||
from sanic.response import file_stream, empty
|
||||
from sanic.exceptions import InvalidUsage, NotFound, Forbidden
|
||||
from sanicargs import parse_parameters
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -61,8 +60,11 @@ async def _return_tracks(req, extend_query, limit, offset):
|
|||
|
||||
|
||||
@api.get("/tracks")
|
||||
@parse_parameters
|
||||
async def get_tracks(req, limit: int = 20, offset: int = 0, author: str = None):
|
||||
async def get_tracks(req):
|
||||
limit = req.ctx.get_single_arg("limit", default=20, convert=int)
|
||||
offset = req.ctx.get_single_arg("offset", default=0, convert=int)
|
||||
author = req.ctx.get_single_arg("author", default=None, convert=str)
|
||||
|
||||
def extend_query(q):
|
||||
q = q.where(Track.public)
|
||||
|
||||
|
@ -76,8 +78,10 @@ async def get_tracks(req, limit: int = 20, offset: int = 0, author: str = None):
|
|||
|
||||
@api.get("/tracks/feed")
|
||||
@require_auth
|
||||
@parse_parameters
|
||||
async def get_feed(req, limit: int = 20, offset: int = 0):
|
||||
async def get_feed(req):
|
||||
limit = req.ctx.get_single_arg("limit", default=20, convert=int)
|
||||
offset = req.ctx.get_single_arg("offset", default=0, convert=int)
|
||||
|
||||
def extend_query(q):
|
||||
return q.where(Track.author_id == req.ctx.user.id)
|
||||
|
||||
|
@ -260,8 +264,11 @@ async def put_track(req, slug: str):
|
|||
|
||||
|
||||
@api.get("/tracks/<slug:str>/comments")
|
||||
@parse_parameters
|
||||
async def get_track_comments(req, slug: str, limit: int = 20, offset: int = 0):
|
||||
async def get_track_comments(req):
|
||||
slug = req.ctx.get_single_arg("slug")
|
||||
limit = req.ctx.get_single_arg("limit", default=20, convert=int)
|
||||
offset = req.ctx.get_single_arg("offset", default=0, convert=int)
|
||||
|
||||
track = await _load_track(req, slug)
|
||||
|
||||
comment_count = await req.ctx.db.scalar(
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from datetime import datetime
|
||||
import dateutil.parser
|
||||
from sanic.exceptions import InvalidUsage
|
||||
|
||||
RAISE = object()
|
||||
|
@ -12,7 +14,10 @@ def get_single_arg(req, name, default=RAISE, convert=None):
|
|||
|
||||
value = default
|
||||
|
||||
if convert is not None:
|
||||
if convert is not None and value is not None:
|
||||
if convert is datetime or convert in ("date", "datetime"):
|
||||
convert = lambda s: dateutil.parser.parse(s)
|
||||
|
||||
try:
|
||||
value = convert(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
|
|
|
@ -2,7 +2,6 @@ coloredlogs~=15.0.1
|
|||
sanic~=21.9.3
|
||||
oic~=1.3.0
|
||||
sanic-session~=0.8.0
|
||||
sanicargs~=2.1.0
|
||||
sanic-cors~=2.0.1
|
||||
python-slugify~=5.0.2
|
||||
motor~=2.5.1
|
||||
|
|
|
@ -14,7 +14,6 @@ setup(
|
|||
"sanic~=21.9.3",
|
||||
"oic>=1.3.0, <2",
|
||||
"sanic-session~=0.8.0",
|
||||
"sanicargs~=2.1.0",
|
||||
"sanic-cors~=2.0.1",
|
||||
"python-slugify~=5.0.2",
|
||||
"motor~=2.5.1",
|
||||
|
|
Loading…
Reference in a new issue