diff --git a/app.js b/app.js index 0b2fd29..bdaf2b1 100644 --- a/app.js +++ b/app.js @@ -5,6 +5,7 @@ const session = require('express-session'); const cors = require('cors'); const errorhandler = require('errorhandler'); const mongoose = require('mongoose'); +const auth = require('./routes/auth'); const isProduction = process.env.NODE_ENV === 'production'; @@ -12,6 +13,7 @@ const isProduction = process.env.NODE_ENV === 'production'; const app = express(); app.use(cors()); +app.use(auth.getUserIdMiddleware); // Normal express config defaults app.use(require('morgan')('dev')); diff --git a/routes/auth.js b/routes/auth.js index 15a4e5f..1cf0dc8 100644 --- a/routes/auth.js +++ b/routes/auth.js @@ -2,30 +2,51 @@ const jwt = require('express-jwt'); const secret = require('../config').secret; function getTokenFromHeader(req) { - if ( - (req.headers.authorization && req.headers.authorization.split(' ')[0] === 'Token') || - (req.headers.authorization && req.headers.authorization.split(' ')[0] === 'Bearer') - ) { - return req.headers.authorization.split(' ')[1]; + const [tokenType, token] = req.headers.authorization?.split(' ') || []; + + if (tokenType === 'Token' || tokenType === 'Bearer') { + return token; } return null; } -const auth = { - required: jwt({ - secret: secret, - userProperty: 'payload', - getToken: getTokenFromHeader, - algorithms: ['HS256'], - }), - optional: jwt({ - secret: secret, - userProperty: 'payload', - credentialsRequired: false, - getToken: getTokenFromHeader, - algorithms: ['HS256'], - }), -}; +const jwtOptional = jwt({ + secret: secret, + userProperty: 'payload', + credentialsRequired: false, + getToken: getTokenFromHeader, + algorithms: ['HS256'], +}); -module.exports = auth; +function getUserIdMiddleware(req, res, next) { + try { + const [tokenType, token] = req.headers.authorization.split(' ') || []; + + if (tokenType === 'Token' || tokenType === 'Bearer') { + return jwtOptional(req, res, next); + } else if (tokenType === 'OBSUserId') { + req.payload = { id: token.trim() }; + next(); + } else { + req.payload = null; + next(); + } + } catch (err) { + next(err); + } +} + +module.exports = { + required(req, res, next) { + if (!req.payload) { + return res.sendStatus(403); + } else { + return next(); + } + }, + optional(req, res, next) { + return next(); + }, + getUserIdMiddleware, +};