diff --git a/api/obs/api/app.py b/api/obs/api/app.py index 49519e6..fa14e1a 100644 --- a/api/obs/api/app.py +++ b/api/obs/api/app.py @@ -21,6 +21,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from obs.api.db import User, make_session, connect_db +from obs.api.cors import setup_options, add_cors_headers from obs.api.utils import get_single_arg from sqlalchemy.util import asyncio @@ -84,6 +85,39 @@ class NoConnectionLostFilter(logging.Filter): logging.getLogger("sanic.error").addFilter(NoConnectionLostFilter) +def setup_cors(app): + frontend_url = app.config.get("FRONTEND_URL") + additional_origins = app.config.get("ADDITIONAL_CORS_ORIGINS") + if not frontend_url and not additional_origins: + # No CORS configured + return + + origins = [] + if frontend_url: + u = urlparse(frontend_url) + origins.append(f"{u.scheme}://{u.netloc}") + + if isinstance(additional_origins, str): + origins += re.split(r"\s+", additional_origins) + elif isinstance(additional_origins, list): + origins += additional_origins + elif additional_origins is not None: + raise ValueError( + "invalid option type for ADDITIONAL_CORS_ORIGINS, must be list or space separated str" + ) + + app.ctx.cors_origins = origins + + # Add OPTIONS handlers to any route that is missing it + app.register_listener(setup_options, "before_server_start") + + # Fill in CORS headers + app.register_middleware(add_cors_headers, "response") + + +setup_cors(app) + + @app.exception(SanicException, BaseException) async def _handle_sanic_errors(_request, exception): if isinstance(exception, asyncio.CancelledError): @@ -120,39 +154,6 @@ def configure_paths(c): configure_paths(app.config) -def setup_cors(app): - frontend_url = app.config.get("FRONTEND_URL") - additional_origins = app.config.get("ADDITIONAL_CORS_ORIGINS") - if not frontend_url and not additional_origins: - # No CORS configured - return - - origins = [] - if frontend_url: - u = urlparse(frontend_url) - origins.append(f"{u.scheme}://{u.netloc}") - - if isinstance(additional_origins, str): - origins += re.split(r"\s+", additional_origins) - elif isinstance(additional_origins, list): - origins += additional_origins - elif additional_origins is not None: - raise ValueError( - "invalid option type for ADDITIONAL_CORS_ORIGINS, must be list or space separated str" - ) - - from sanic_cors import CORS - - CORS( - app, - origins=origins, - supports_credentials=True, - expose_headers={"Content-Disposition"}, - ) - - -setup_cors(app) - # TODO: use a different interface, maybe backed by the PostgreSQL, to allow # scaling the API Session(app, interface=InMemorySessionInterface()) diff --git a/api/obs/api/cors.py b/api/obs/api/cors.py new file mode 100644 index 0000000..e46d6f1 --- /dev/null +++ b/api/obs/api/cors.py @@ -0,0 +1,67 @@ +from collections import defaultdict +from typing import Dict, FrozenSet, Iterable + +from sanic import Sanic, response +from sanic_routing.router import Route + + +def _add_cors_headers(request, response, methods: Iterable[str]) -> None: + allow_methods = list(set(methods)) + + if "OPTIONS" not in allow_methods: + allow_methods.append("OPTIONS") + + origin = request.headers.get("origin") + if origin in request.app.ctx.cors_origins: + headers = { + "Access-Control-Allow-Methods": ",".join(allow_methods), + "Access-Control-Allow-Origin": origin, + "Access-Control-Allow-Credentials": "true", + "Access-Control-Allow-Headers": ( + "origin, content-type, accept, " + "authorization, x-xsrf-token, x-request-id" + ), + } + response.headers.extend(headers) + + +def add_cors_headers(request, response): + if request.method != "OPTIONS": + methods = [method for method in request.route.methods] + _add_cors_headers(request, response, methods) + + +def _compile_routes_needing_options(routes: Dict[str, Route]) -> Dict[str, FrozenSet]: + needs_options = defaultdict(list) + # This is 21.12 and later. You will need to change this for older versions. + for route in routes.values(): + if "OPTIONS" not in route.methods: + needs_options[route.uri].extend(route.methods) + + return {uri: frozenset(methods) for uri, methods in dict(needs_options).items()} + + +def _options_wrapper(handler, methods): + def wrapped_handler(request, *args, **kwargs): + nonlocal methods + return handler(request, methods) + + return wrapped_handler + + +async def options_handler(request, methods) -> response.HTTPResponse: + resp = response.empty() + _add_cors_headers(request, resp, methods) + return resp + + +def setup_options(app: Sanic, _): + app.router.reset() + needs_options = _compile_routes_needing_options(app.router.routes_all) + for uri, methods in needs_options.items(): + app.add_route( + _options_wrapper(options_handler, methods), + uri, + methods=["OPTIONS"], + ) + app.router.finalize() diff --git a/api/requirements.txt b/api/requirements.txt index 837e553..ee242b4 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -1,8 +1,7 @@ coloredlogs~=15.0.1 -sanic~=22.6.0 +sanic==22.6.2 oic~=1.3.0 sanic-session~=0.8.0 -sanic-cors~=2.0.1 python-slugify~=6.1.2 motor~=3.0.0 pyyaml<6 diff --git a/api/setup.py b/api/setup.py index e76b57f..2b8ce1e 100644 --- a/api/setup.py +++ b/api/setup.py @@ -11,10 +11,9 @@ setup( package_data={}, install_requires=[ "coloredlogs~=15.0.1", - "sanic>=21.9.3,<22.7.0", + "sanic==22.6.2", "oic>=1.3.0, <2", "sanic-session~=0.8.0", - "sanic-cors~=2.0.1", "python-slugify>=5.0.2,<6.2.0", "motor>=2.5.1,<3.1.0", "pyyaml<6",