Refactor socket.io middleware for future uws compatibility

This commit is contained in:
Calvin Montgomery 2018-06-17 16:25:40 -07:00
parent 1021cc706a
commit 61b856c2c9
2 changed files with 30 additions and 16 deletions

View File

@ -34,6 +34,20 @@ const authFailureCount = new Counter({
help: 'Number of failed authentications from session middleware' help: 'Number of failed authentications from session middleware'
}); });
class SocketIOContext {
constructor(socket) {
socket.handshake.connection = {
remoteAddress: socket.handshake.address
};
this.upgradeReq = socket.handshake;
this.ipAddress = null;
this.torConnection = null;
this.ipSessionFirstSeen = null;
this.user = null;
}
}
class IOServer { class IOServer {
constructor(options = { constructor(options = {
proxyTrustFn: proxyaddr.compile('127.0.0.1') proxyTrustFn: proxyaddr.compile('127.0.0.1')
@ -49,21 +63,16 @@ class IOServer {
// Map proxied sockets to the real IP address via X-Forwarded-For // Map proxied sockets to the real IP address via X-Forwarded-For
// If the resulting address is a known Tor exit, flag it as such // If the resulting address is a known Tor exit, flag it as such
ipProxyMiddleware(socket, next) { ipProxyMiddleware(socket, next) {
if (!socket.context) socket.context = {};
try { try {
socket.handshake.connection = {
remoteAddress: socket.handshake.address
};
socket.context.ipAddress = proxyaddr( socket.context.ipAddress = proxyaddr(
socket.handshake, socket.context.upgradeReq,
this.proxyTrustFn this.proxyTrustFn
); );
if (!socket.context.ipAddress) { if (!socket.context.ipAddress) {
throw new Error( throw new Error(
`Assertion failed: unexpected IP ${socket.context.ipAddress}` 'Could not determine IP address from ' +
socket.context.upgradeReq.connection.remoteAddress
); );
} }
} catch (error) { } catch (error) {
@ -163,7 +172,7 @@ class IOServer {
// Parse cookies // Parse cookies
cookieParsingMiddleware(socket, next) { cookieParsingMiddleware(socket, next) {
const req = socket.handshake; const req = socket.context.upgradeReq;
if (req.headers.cookie) { if (req.headers.cookie) {
cookieParser(req, null, () => next()); cookieParser(req, null, () => next());
} else { } else {
@ -176,7 +185,7 @@ class IOServer {
// Determine session age from ip-session cookie // Determine session age from ip-session cookie
// (Used for restricting chat) // (Used for restricting chat)
ipSessionCookieMiddleware(socket, next) { ipSessionCookieMiddleware(socket, next) {
const cookie = socket.handshake.signedCookies['ip-session']; const cookie = socket.context.upgradeReq.signedCookies['ip-session'];
if (!cookie) { if (!cookie) {
socket.context.ipSessionFirstSeen = new Date(); socket.context.ipSessionFirstSeen = new Date();
next(); next();
@ -197,7 +206,7 @@ class IOServer {
socket.context.aliases = []; socket.context.aliases = [];
const promises = []; const promises = [];
const auth = socket.handshake.signedCookies.auth; const auth = socket.context.upgradeReq.signedCookies.auth;
if (auth) { if (auth) {
promises.push(verifySession(auth).then(user => { promises.push(verifySession(auth).then(user => {
socket.context.user = Object.assign({}, user); socket.context.user = Object.assign({}, user);
@ -245,6 +254,10 @@ class IOServer {
patchTypecheckedFunctions(); patchTypecheckedFunctions();
const io = this.io = sio.instance = sio(); const io = this.io = sio.instance = sio();
io.use((socket, next) => {
socket.context = new SocketIOContext(socket);
next();
});
io.use(this.ipProxyMiddleware.bind(this)); io.use(this.ipProxyMiddleware.bind(this));
io.use(this.ipBanMiddleware.bind(this)); io.use(this.ipBanMiddleware.bind(this));
io.use(this.ipThrottleMiddleware.bind(this)); io.use(this.ipThrottleMiddleware.bind(this));
@ -422,7 +435,9 @@ module.exports = {
ioServer.bindTo(servers); ioServer.bindTo(servers);
}, },
IOServer: IOServer IOServer: IOServer,
SocketIOContext: SocketIOContext
}; };
/* Clean out old rate limiters */ /* Clean out old rate limiters */

View File

@ -1,5 +1,6 @@
const assert = require('assert'); const assert = require('assert');
const IOServer = require('../../lib/io/ioserver').IOServer; const IOServer = require('../../lib/io/ioserver').IOServer;
const SocketIOContext = require('../../lib/io/ioserver').SocketIOContext;
describe('IOServer', () => { describe('IOServer', () => {
let server; let server;
@ -7,9 +8,6 @@ describe('IOServer', () => {
beforeEach(() => { beforeEach(() => {
server = new IOServer(); server = new IOServer();
socket = { socket = {
context: {
ipAddress: '9.9.9.9'
},
handshake: { handshake: {
address: '127.0.0.1', address: '127.0.0.1',
headers: { headers: {
@ -17,6 +15,7 @@ describe('IOServer', () => {
} }
} }
}; };
socket.context = new SocketIOContext(socket);
}); });
describe('#ipProxyMiddleware', () => { describe('#ipProxyMiddleware', () => {
@ -29,7 +28,7 @@ describe('IOServer', () => {
}); });
it('does not proxy from a non-trusted address', done => { it('does not proxy from a non-trusted address', done => {
socket.handshake.address = '5.6.7.8'; socket.context.upgradeReq.connection.remoteAddress = '5.6.7.8';
server.ipProxyMiddleware(socket, error => { server.ipProxyMiddleware(socket, error => {
assert(!error); assert(!error);
assert.strictEqual(socket.context.ipAddress, '5.6.7.8'); assert.strictEqual(socket.context.ipAddress, '5.6.7.8');