Episode 3 — NodeJS MongoDB Backend Architecture / 3.15 — Realtime Communication WebSockets

3.15.e — Socket.io Middleware

Socket.io middleware intercepts every incoming connection before it reaches your event handlers, enabling authentication, logging, rate limiting, and data validation at the connection level -- similar to Express middleware but for WebSocket connections.


<< Previous: 3.15.d — Rooms & Namespaces | Next: 3.15 — Exercise Questions >>


1. How Socket.io Middleware Works

Socket.io middleware runs once per connection attempt, before the connection event fires. If middleware calls next(), the connection proceeds. If it calls next(new Error()), the connection is rejected.

Client connects
      │
      ▼
┌─────────────┐   next()    ┌─────────────┐   next()    ┌──────────┐
│ Middleware 1 │ ──────────> │ Middleware 2 │ ──────────> │connection│
│ (logging)    │             │ (auth)       │             │ event    │
└─────────────┘             └─────────────┘             └──────────┘
                                   │
                            next(new Error())
                                   │
                                   ▼
                           Connection rejected
                           Client gets 'connect_error'

Basic Syntax

// Register middleware with io.use()
io.use((socket, next) => {
  // socket — the incoming socket (not yet fully connected)
  // next() — continue to the next middleware or connection handler
  // next(new Error('message')) — reject the connection

  console.log('Middleware executed for:', socket.id);
  next(); // Allow connection
});

2. Authentication Middleware

The most common use case for Socket.io middleware is verifying JWT tokens before allowing a connection.

Server-Side: JWT Authentication Middleware

const jwt = require('jsonwebtoken');
const User = require('./models/User');

// Authentication middleware
io.use(async (socket, next) => {
  try {
    // 1. Extract the token from the handshake
    const token = socket.handshake.auth.token;

    if (!token) {
      return next(new Error('Authentication required'));
    }

    // 2. Verify the JWT
    const decoded = jwt.verify(token, process.env.JWT_SECRET);

    // 3. Find the user in the database
    const user = await User.findById(decoded.id).select('-password');

    if (!user) {
      return next(new Error('User not found'));
    }

    // 4. Attach user data to the socket for use in event handlers
    socket.user = user;
    socket.userId = user._id.toString();

    // 5. Allow the connection
    next();
  } catch (error) {
    if (error.name === 'JsonWebTokenError') {
      return next(new Error('Invalid token'));
    }
    if (error.name === 'TokenExpiredError') {
      return next(new Error('Token expired'));
    }
    next(new Error('Authentication failed'));
  }
});

// Now socket.user is available in all event handlers
io.on('connection', (socket) => {
  console.log(`Authenticated user connected: ${socket.user.name}`);

  // Join user-specific room for targeted notifications
  socket.join(`user:${socket.userId}`);

  socket.on('send-message', (data) => {
    // socket.user is available here!
    io.to(data.room).emit('new-message', {
      text: data.text,
      sender: socket.user.name,
      senderId: socket.userId,
      timestamp: Date.now()
    });
  });
});

Client-Side: Sending the Token

// Connect with authentication token
const socket = io('http://localhost:3000', {
  auth: {
    token: localStorage.getItem('jwt')
  }
});

// Handle authentication errors
socket.on('connect_error', (error) => {
  if (error.message === 'Authentication required') {
    console.log('Please log in first');
    window.location.href = '/login';
  } else if (error.message === 'Token expired') {
    console.log('Session expired, refreshing token...');
    refreshToken().then(newToken => {
      socket.auth.token = newToken;
      socket.connect(); // Retry with new token
    });
  } else {
    console.error('Connection error:', error.message);
  }
});

Alternative: Token via Query String (Less Secure)

// Client — token in query string (visible in logs, not recommended)
const socket = io('http://localhost:3000', {
  query: { token: 'your-jwt-here' }
});

// Server — extract from query
io.use((socket, next) => {
  const token = socket.handshake.query.token;
  // ... verify token ...
});

Best practice: Always use auth property instead of query strings. Query strings can appear in server logs, proxy logs, and browser history.


3. Error Handling in Middleware

// Middleware that rejects connections
io.use((socket, next) => {
  const token = socket.handshake.auth.token;

  if (!token) {
    // Create a proper Error object with a message
    const err = new Error('No token provided');
    err.data = { type: 'AUTH_REQUIRED' }; // Additional data for client
    return next(err);
  }

  try {
    const decoded = jwt.verify(token, process.env.JWT_SECRET);
    socket.user = decoded;
    next();
  } catch (error) {
    const err = new Error('Invalid token');
    err.data = { type: 'INVALID_TOKEN' };
    return next(err);
  }
});

// Client-side: handling middleware rejection
socket.on('connect_error', (err) => {
  console.log('Error message:', err.message);   // "No token provided"
  console.log('Error data:', err.data);          // { type: 'AUTH_REQUIRED' }

  // Custom handling based on error type
  switch (err.data?.type) {
    case 'AUTH_REQUIRED':
      redirectToLogin();
      break;
    case 'INVALID_TOKEN':
      clearStoredToken();
      redirectToLogin();
      break;
    default:
      showErrorMessage('Connection failed. Retrying...');
  }
});

4. Chaining Multiple Middleware

Middleware executes in the order it is registered:

// Middleware 1: Logging
io.use((socket, next) => {
  console.log(`[${new Date().toISOString()}] Connection attempt from ${socket.handshake.address}`);
  console.log('User-Agent:', socket.handshake.headers['user-agent']);
  socket.connectionTime = Date.now();
  next();
});

// Middleware 2: Authentication
io.use(async (socket, next) => {
  const token = socket.handshake.auth.token;
  if (!token) return next(new Error('Authentication required'));

  try {
    const decoded = jwt.verify(token, process.env.JWT_SECRET);
    socket.user = await User.findById(decoded.id);
    if (!socket.user) return next(new Error('User not found'));
    next();
  } catch (err) {
    next(new Error('Invalid token'));
  }
});

// Middleware 3: Authorization (role check)
io.use((socket, next) => {
  if (socket.user.isBanned) {
    return next(new Error('Account suspended'));
  }
  next();
});

// Middleware 4: Rate limiting check
io.use((socket, next) => {
  const clientIP = socket.handshake.address;
  const connections = getActiveConnectionCount(clientIP);

  if (connections >= 5) {
    return next(new Error('Too many connections from this IP'));
  }
  next();
});

// Only reaches here if ALL middleware pass
io.on('connection', (socket) => {
  console.log(`Verified user ${socket.user.name} connected`);
});

5. Per-Namespace Middleware

Each namespace can have its own middleware, independent of the main namespace:

// Main namespace middleware — applies to io (/) connections
io.use(loggingMiddleware);

// /chat namespace — needs user authentication
const chatNs = io.of('/chat');
chatNs.use(authMiddleware); // Only runs for /chat connections
chatNs.on('connection', (socket) => {
  // socket.user available from authMiddleware
  console.log(`[chat] ${socket.user.name} connected`);
});

// /admin namespace — needs admin role
const adminNs = io.of('/admin');
adminNs.use(authMiddleware);  // First: verify they are a user
adminNs.use((socket, next) => { // Second: verify they are an admin
  if (socket.user.role !== 'admin') {
    return next(new Error('Admin access required'));
  }
  next();
});
adminNs.on('connection', (socket) => {
  console.log(`[admin] Admin ${socket.user.name} connected`);
});

// /public namespace — no auth needed
const publicNs = io.of('/public');
publicNs.on('connection', (socket) => {
  console.log('[public] Anonymous user connected');
});

6. Logging Middleware

function createLoggingMiddleware(namespace = 'default') {
  return (socket, next) => {
    const info = {
      namespace,
      socketId: socket.id,
      ip: socket.handshake.address,
      userAgent: socket.handshake.headers['user-agent'],
      timestamp: new Date().toISOString(),
      auth: socket.handshake.auth ? 'provided' : 'none'
    };

    console.log(`[WS:${namespace}] Connection attempt:`, JSON.stringify(info));

    // Track connection duration on disconnect
    socket.on('disconnect', (reason) => {
      const duration = Date.now() - socket.connectionTime;
      console.log(`[WS:${namespace}] Disconnected: ${socket.id}, ` +
        `reason: ${reason}, duration: ${duration}ms`);
    });

    socket.connectionTime = Date.now();
    next();
  };
}

io.use(createLoggingMiddleware('main'));
io.of('/chat').use(createLoggingMiddleware('chat'));

7. Rate Limiting Concepts for WebSocket

Unlike HTTP where each request is independent, WebSocket rate limiting operates on two levels:

Connection-Level Rate Limiting

// Limit connections per IP
const connectionCounts = new Map();

io.use((socket, next) => {
  const ip = socket.handshake.address;
  const count = connectionCounts.get(ip) || 0;

  if (count >= 10) {
    return next(new Error('Too many connections'));
  }

  connectionCounts.set(ip, count + 1);

  socket.on('disconnect', () => {
    const current = connectionCounts.get(ip) || 1;
    if (current <= 1) {
      connectionCounts.delete(ip);
    } else {
      connectionCounts.set(ip, current - 1);
    }
  });

  next();
});

Event-Level Rate Limiting

// Rate limit specific events per socket
function createEventRateLimiter(maxEvents, windowMs) {
  return (socket, eventName, handler) => {
    const events = [];

    socket.on(eventName, (data, callback) => {
      const now = Date.now();

      // Remove events outside the window
      while (events.length && events[0] < now - windowMs) {
        events.shift();
      }

      if (events.length >= maxEvents) {
        const retryAfter = events[0] + windowMs - now;
        return callback?.({
          error: 'Rate limited',
          retryAfter: Math.ceil(retryAfter / 1000)
        });
      }

      events.push(now);
      handler(data, callback);
    });
  };
}

// Usage
io.on('connection', (socket) => {
  const rateLimited = createEventRateLimiter(20, 60000); // 20 events per minute

  rateLimited(socket, 'send-message', (data, callback) => {
    // Process message normally
    io.to(data.room).emit('new-message', {
      text: data.text,
      user: socket.user.name
    });
    callback?.({ status: 'ok' });
  });
});

8. Disconnection Handling and Cleanup

io.on('connection', (socket) => {
  const userId = socket.userId;

  // Track user's online status
  updateUserStatus(userId, 'online');
  socket.join(`user:${userId}`);

  // Handle graceful disconnect
  socket.on('disconnect', async (reason) => {
    console.log(`User ${userId} disconnected: ${reason}`);

    // Wait briefly to see if they reconnect (tab refresh, network blip)
    setTimeout(async () => {
      const sockets = await io.in(`user:${userId}`).fetchSockets();

      // Only mark offline if no other connections exist for this user
      if (sockets.length === 0) {
        updateUserStatus(userId, 'offline');

        // Notify friends/contacts
        const friends = await getFriends(userId);
        friends.forEach(friendId => {
          io.to(`user:${friendId}`).emit('friend-offline', {
            userId,
            lastSeen: new Date()
          });
        });
      }
    }, 5000); // 5 second grace period
  });

  // Handle disconnecting event (fires before disconnect)
  socket.on('disconnecting', () => {
    // socket.rooms is still available here (not in 'disconnect')
    const rooms = Array.from(socket.rooms).filter(r => r !== socket.id);
    console.log(`User leaving rooms:`, rooms);

    rooms.forEach(room => {
      socket.to(room).emit('user-leaving', {
        userId,
        room
      });
    });
  });
});

9. Real Example: Authenticated Real-Time Notifications

// notifications.js — Complete notification system
const express = require('express');
const http = require('http');
const { Server } = require('socket.io');
const jwt = require('jsonwebtoken');

const app = express();
const server = http.createServer(app);
const io = new Server(server, { cors: { origin: '*' } });

app.use(express.json());

// ---------- MIDDLEWARE STACK ----------

// Middleware 1: Logging
io.use((socket, next) => {
  console.log(`[${new Date().toISOString()}] WS connection attempt`);
  next();
});

// Middleware 2: Authentication
io.use(async (socket, next) => {
  const token = socket.handshake.auth.token;

  if (!token) {
    return next(new Error('Authentication required'));
  }

  try {
    const decoded = jwt.verify(token, process.env.JWT_SECRET || 'secret');
    // In production, query the database:
    // socket.user = await User.findById(decoded.id);
    socket.user = { id: decoded.id, name: decoded.name, role: decoded.role };
    socket.userId = decoded.id;
    next();
  } catch (err) {
    next(new Error('Invalid or expired token'));
  }
});

// ---------- CONNECTION HANDLER ----------

const userSockets = new Map(); // userId -> Set of socketIds

io.on('connection', (socket) => {
  const { userId, user } = socket;
  console.log(`${user.name} connected (${socket.id})`);

  // Track this user's sockets (they may have multiple tabs open)
  if (!userSockets.has(userId)) {
    userSockets.set(userId, new Set());
  }
  userSockets.get(userId).add(socket.id);

  // Join personal notification room
  socket.join(`user:${userId}`);

  // Send unread notification count on connect
  socket.emit('unread-count', {
    count: getUnreadCount(userId)
  });

  // Mark notifications as read
  socket.on('mark-read', async (notificationIds) => {
    await markNotificationsRead(userId, notificationIds);
    // Update count across all user's devices
    io.to(`user:${userId}`).emit('unread-count', {
      count: getUnreadCount(userId)
    });
  });

  // Cleanup on disconnect
  socket.on('disconnect', () => {
    userSockets.get(userId)?.delete(socket.id);
    if (userSockets.get(userId)?.size === 0) {
      userSockets.delete(userId);
    }
    console.log(`${user.name} disconnected`);
  });
});

// ---------- NOTIFICATION API (used by other services) ----------

// This function can be called from any route handler or service
function sendNotification(userId, notification) {
  const notif = {
    id: Date.now(),
    ...notification,
    read: false,
    createdAt: new Date()
  };

  // Store in database (simplified)
  storeNotification(userId, notif);

  // Push to user if they are online
  io.to(`user:${userId}`).emit('notification', notif);
  io.to(`user:${userId}`).emit('unread-count', {
    count: getUnreadCount(userId)
  });
}

// Export for use in route handlers
module.exports = { app, server, io, sendNotification };

// ---------- EXAMPLE: REST endpoint triggers notification ----------

app.post('/api/posts/:id/comment', async (req, res) => {
  // ... save comment to database ...

  const post = { authorId: 'user123', title: 'My Post' }; // simplified
  const commenter = { name: 'Alice' }; // simplified

  // Send real-time notification to post author
  sendNotification(post.authorId, {
    type: 'comment',
    title: 'New comment on your post',
    body: `${commenter.name} commented on "${post.title}"`,
    link: `/posts/${req.params.id}`
  });

  res.json({ success: true });
});

// ---------- HELPERS (simplified, use database in production) ----------
const notifications = new Map();

function storeNotification(userId, notif) {
  if (!notifications.has(userId)) notifications.set(userId, []);
  notifications.get(userId).push(notif);
}

function getUnreadCount(userId) {
  return (notifications.get(userId) || []).filter(n => !n.read).length;
}

function markNotificationsRead(userId, ids) {
  (notifications.get(userId) || []).forEach(n => {
    if (ids.includes(n.id)) n.read = true;
  });
}

server.listen(3000, () => console.log('Notification server on :3000'));

Key Takeaways

  1. Socket.io middleware runs once per connection using io.use((socket, next) => {...})
  2. JWT authentication is the most common middleware pattern -- extract from socket.handshake.auth.token
  3. Attach user data to socket.user so all event handlers can access it
  4. Call next(new Error('message')) to reject connections; clients receive connect_error
  5. Namespace middleware allows different auth requirements per feature area
  6. Rate limiting for WebSocket operates at both connection and event levels
  7. Handle the disconnecting event (not disconnect) if you need access to socket.rooms
  8. Always implement a grace period before marking users offline to handle reconnection

Explain-It Challenge

Scenario: Your application uses Socket.io for real-time features. A security audit reveals that some users are able to emit events to rooms they have not joined, and there is no validation on the data being sent through socket events.

Design a comprehensive middleware and event-validation strategy that: (1) authenticates users on connection, (2) checks room membership before allowing room-targeted events, and (3) validates the shape and size of event payloads. Explain where each check should live (middleware vs. event handler) and why.


<< Previous: 3.15.d — Rooms & Namespaces | Next: 3.15 — Exercise Questions >>