fix cors by implementing it ourselves
This commit is contained in:
parent
ed272b4e4a
commit
84ab957aa0
|
@ -21,6 +21,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from obs.api.db import User, make_session, connect_db
|
from obs.api.db import User, make_session, connect_db
|
||||||
|
from obs.api.cors import setup_options, add_cors_headers
|
||||||
from obs.api.utils import get_single_arg
|
from obs.api.utils import get_single_arg
|
||||||
from sqlalchemy.util import asyncio
|
from sqlalchemy.util import asyncio
|
||||||
|
|
||||||
|
@ -84,6 +85,39 @@ class NoConnectionLostFilter(logging.Filter):
|
||||||
logging.getLogger("sanic.error").addFilter(NoConnectionLostFilter)
|
logging.getLogger("sanic.error").addFilter(NoConnectionLostFilter)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_cors(app):
|
||||||
|
frontend_url = app.config.get("FRONTEND_URL")
|
||||||
|
additional_origins = app.config.get("ADDITIONAL_CORS_ORIGINS")
|
||||||
|
if not frontend_url and not additional_origins:
|
||||||
|
# No CORS configured
|
||||||
|
return
|
||||||
|
|
||||||
|
origins = []
|
||||||
|
if frontend_url:
|
||||||
|
u = urlparse(frontend_url)
|
||||||
|
origins.append(f"{u.scheme}://{u.netloc}")
|
||||||
|
|
||||||
|
if isinstance(additional_origins, str):
|
||||||
|
origins += re.split(r"\s+", additional_origins)
|
||||||
|
elif isinstance(additional_origins, list):
|
||||||
|
origins += additional_origins
|
||||||
|
elif additional_origins is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"invalid option type for ADDITIONAL_CORS_ORIGINS, must be list or space separated str"
|
||||||
|
)
|
||||||
|
|
||||||
|
app.ctx.cors_origins = origins
|
||||||
|
|
||||||
|
# Add OPTIONS handlers to any route that is missing it
|
||||||
|
app.register_listener(setup_options, "before_server_start")
|
||||||
|
|
||||||
|
# Fill in CORS headers
|
||||||
|
app.register_middleware(add_cors_headers, "response")
|
||||||
|
|
||||||
|
|
||||||
|
setup_cors(app)
|
||||||
|
|
||||||
|
|
||||||
@app.exception(SanicException, BaseException)
|
@app.exception(SanicException, BaseException)
|
||||||
async def _handle_sanic_errors(_request, exception):
|
async def _handle_sanic_errors(_request, exception):
|
||||||
if isinstance(exception, asyncio.CancelledError):
|
if isinstance(exception, asyncio.CancelledError):
|
||||||
|
@ -120,39 +154,6 @@ def configure_paths(c):
|
||||||
configure_paths(app.config)
|
configure_paths(app.config)
|
||||||
|
|
||||||
|
|
||||||
def setup_cors(app):
|
|
||||||
frontend_url = app.config.get("FRONTEND_URL")
|
|
||||||
additional_origins = app.config.get("ADDITIONAL_CORS_ORIGINS")
|
|
||||||
if not frontend_url and not additional_origins:
|
|
||||||
# No CORS configured
|
|
||||||
return
|
|
||||||
|
|
||||||
origins = []
|
|
||||||
if frontend_url:
|
|
||||||
u = urlparse(frontend_url)
|
|
||||||
origins.append(f"{u.scheme}://{u.netloc}")
|
|
||||||
|
|
||||||
if isinstance(additional_origins, str):
|
|
||||||
origins += re.split(r"\s+", additional_origins)
|
|
||||||
elif isinstance(additional_origins, list):
|
|
||||||
origins += additional_origins
|
|
||||||
elif additional_origins is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"invalid option type for ADDITIONAL_CORS_ORIGINS, must be list or space separated str"
|
|
||||||
)
|
|
||||||
|
|
||||||
from sanic_cors import CORS
|
|
||||||
|
|
||||||
CORS(
|
|
||||||
app,
|
|
||||||
origins=origins,
|
|
||||||
supports_credentials=True,
|
|
||||||
expose_headers={"Content-Disposition"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
setup_cors(app)
|
|
||||||
|
|
||||||
# TODO: use a different interface, maybe backed by the PostgreSQL, to allow
|
# TODO: use a different interface, maybe backed by the PostgreSQL, to allow
|
||||||
# scaling the API
|
# scaling the API
|
||||||
Session(app, interface=InMemorySessionInterface())
|
Session(app, interface=InMemorySessionInterface())
|
||||||
|
|
67
api/obs/api/cors.py
Normal file
67
api/obs/api/cors.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, FrozenSet, Iterable
|
||||||
|
|
||||||
|
from sanic import Sanic, response
|
||||||
|
from sanic_routing.router import Route
|
||||||
|
|
||||||
|
|
||||||
|
def _add_cors_headers(request, response, methods: Iterable[str]) -> None:
|
||||||
|
allow_methods = list(set(methods))
|
||||||
|
|
||||||
|
if "OPTIONS" not in allow_methods:
|
||||||
|
allow_methods.append("OPTIONS")
|
||||||
|
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
if origin in request.app.ctx.cors_origins:
|
||||||
|
headers = {
|
||||||
|
"Access-Control-Allow-Methods": ",".join(allow_methods),
|
||||||
|
"Access-Control-Allow-Origin": origin,
|
||||||
|
"Access-Control-Allow-Credentials": "true",
|
||||||
|
"Access-Control-Allow-Headers": (
|
||||||
|
"origin, content-type, accept, "
|
||||||
|
"authorization, x-xsrf-token, x-request-id"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
response.headers.extend(headers)
|
||||||
|
|
||||||
|
|
||||||
|
def add_cors_headers(request, response):
|
||||||
|
if request.method != "OPTIONS":
|
||||||
|
methods = [method for method in request.route.methods]
|
||||||
|
_add_cors_headers(request, response, methods)
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_routes_needing_options(routes: Dict[str, Route]) -> Dict[str, FrozenSet]:
|
||||||
|
needs_options = defaultdict(list)
|
||||||
|
# This is 21.12 and later. You will need to change this for older versions.
|
||||||
|
for route in routes.values():
|
||||||
|
if "OPTIONS" not in route.methods:
|
||||||
|
needs_options[route.uri].extend(route.methods)
|
||||||
|
|
||||||
|
return {uri: frozenset(methods) for uri, methods in dict(needs_options).items()}
|
||||||
|
|
||||||
|
|
||||||
|
def _options_wrapper(handler, methods):
|
||||||
|
def wrapped_handler(request, *args, **kwargs):
|
||||||
|
nonlocal methods
|
||||||
|
return handler(request, methods)
|
||||||
|
|
||||||
|
return wrapped_handler
|
||||||
|
|
||||||
|
|
||||||
|
async def options_handler(request, methods) -> response.HTTPResponse:
|
||||||
|
resp = response.empty()
|
||||||
|
_add_cors_headers(request, resp, methods)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def setup_options(app: Sanic, _):
|
||||||
|
app.router.reset()
|
||||||
|
needs_options = _compile_routes_needing_options(app.router.routes_all)
|
||||||
|
for uri, methods in needs_options.items():
|
||||||
|
app.add_route(
|
||||||
|
_options_wrapper(options_handler, methods),
|
||||||
|
uri,
|
||||||
|
methods=["OPTIONS"],
|
||||||
|
)
|
||||||
|
app.router.finalize()
|
|
@ -1,8 +1,7 @@
|
||||||
coloredlogs~=15.0.1
|
coloredlogs~=15.0.1
|
||||||
sanic~=22.6.0
|
sanic==22.6.2
|
||||||
oic~=1.3.0
|
oic~=1.3.0
|
||||||
sanic-session~=0.8.0
|
sanic-session~=0.8.0
|
||||||
sanic-cors~=2.0.1
|
|
||||||
python-slugify~=6.1.2
|
python-slugify~=6.1.2
|
||||||
motor~=3.0.0
|
motor~=3.0.0
|
||||||
pyyaml<6
|
pyyaml<6
|
||||||
|
|
|
@ -11,10 +11,9 @@ setup(
|
||||||
package_data={},
|
package_data={},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"coloredlogs~=15.0.1",
|
"coloredlogs~=15.0.1",
|
||||||
"sanic>=21.9.3,<22.7.0",
|
"sanic==22.6.2",
|
||||||
"oic>=1.3.0, <2",
|
"oic>=1.3.0, <2",
|
||||||
"sanic-session~=0.8.0",
|
"sanic-session~=0.8.0",
|
||||||
"sanic-cors~=2.0.1",
|
|
||||||
"python-slugify>=5.0.2,<6.2.0",
|
"python-slugify>=5.0.2,<6.2.0",
|
||||||
"motor>=2.5.1,<3.1.0",
|
"motor>=2.5.1,<3.1.0",
|
||||||
"pyyaml<6",
|
"pyyaml<6",
|
||||||
|
|
Loading…
Reference in a new issue