From 5309527c3eab0fd274bd9877f593d06ebef3d31a Mon Sep 17 00:00:00 2001 From: Paul Bienkowski Date: Thu, 2 Dec 2021 21:00:21 +0100 Subject: [PATCH] fix: read API key for posting tracks, fixes upload from OBS --- api/obs/api/app.py | 38 ++++++++++++++++++++++++++++++++++++ api/obs/api/routes/tracks.py | 3 ++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/api/obs/api/app.py b/api/obs/api/app.py index 05c9e9d..aeece95 100644 --- a/api/obs/api/app.py +++ b/api/obs/api/app.py @@ -173,6 +173,44 @@ def require_auth(fn): return wrapper +def read_api_key(fn): + """ + A middleware decorator to read the API Key of a user. It is an opt-in to + allow usage with API Keys on certain urls. Combine with require_auth to + actually check whether a user was authenticated through this. If a login + session exists, the api key is ignored. + """ + + @wraps(fn) + async def wrapper(req, *args, **kwargs): + # try to parse a token if one exists, unless a user is already authenticated + if ( + not req.ctx.user + and isinstance(req.token, str) + and req.token.lower().startswith("obsuserid ") + ): + try: + api_key = req.token.split()[1] + except LookupError: + api_key = None + + if api_key: + user = ( + await req.ctx.db.execute( + select(User).where(User.api_key == api_key.strip()) + ) + ).scalar() + + if not user: + raise Unauthorized("invalid OBSUserId token") + + req.ctx.user = user + + return await fn(req, *args, **kwargs) + + return wrapper + + class CustomJsonEncoder(JSONEncoder): def default(self, obj): if isinstance(obj, (datetime, date)): diff --git a/api/obs/api/routes/tracks.py b/api/obs/api/routes/tracks.py index b0c01cb..1954fa8 100644 --- a/api/obs/api/routes/tracks.py +++ b/api/obs/api/routes/tracks.py @@ -7,7 +7,7 @@ from sqlalchemy import select, func from sqlalchemy.orm import joinedload from obs.api.db import Track, User, Comment -from obs.api.app import api, require_auth, json +from obs.api.app import api, require_auth, read_api_key, json from sanic.response import file_stream, empty from sanic.exceptions import InvalidUsage, NotFound, Forbidden @@ -85,6 +85,7 @@ async def get_feed(req, limit: int = 20, offset: int = 0): @api.post("/tracks") +@read_api_key @require_auth async def post_track(req): try: