refactor: enhance middleware and security with advanced protection mechanisms
- Upgraded rate limiter configuration with more granular control and detailed headers - Improved authentication middleware with enhanced token validation and error responses - Implemented advanced input sanitization using sanitize-html with comprehensive XSS protection - Replaced manual security headers with helmet for robust web security configuration - Enhanced error handling middleware with more detailed logging and specific error type handling - Updated SSE rate limiting with burst and window-based restrictions - Improved token validation with more precise signature and claim verification
This commit is contained in:
@@ -2,32 +2,52 @@ import { Request, Response, NextFunction } from 'express';
|
|||||||
import { HASS_CONFIG, RATE_LIMIT_CONFIG } from '../config/index.js';
|
import { HASS_CONFIG, RATE_LIMIT_CONFIG } from '../config/index.js';
|
||||||
import rateLimit from 'express-rate-limit';
|
import rateLimit from 'express-rate-limit';
|
||||||
import { TokenManager } from '../security/index.js';
|
import { TokenManager } from '../security/index.js';
|
||||||
|
import sanitizeHtml from 'sanitize-html';
|
||||||
|
import helmet from 'helmet';
|
||||||
|
|
||||||
// Rate limiter middleware
|
// Rate limiter middleware with enhanced configuration
|
||||||
export const rateLimiter = rateLimit({
|
export const rateLimiter = rateLimit({
|
||||||
windowMs: 60 * 1000, // 1 minute
|
windowMs: 60 * 1000, // 1 minute
|
||||||
max: RATE_LIMIT_CONFIG.REGULAR,
|
max: RATE_LIMIT_CONFIG.REGULAR,
|
||||||
|
standardHeaders: true, // Return rate limit info in the `RateLimit-*` headers
|
||||||
|
legacyHeaders: false, // Disable the `X-RateLimit-*` headers
|
||||||
message: {
|
message: {
|
||||||
success: false,
|
success: false,
|
||||||
message: 'Too many requests, please try again later.',
|
message: 'Too many requests, please try again later.',
|
||||||
reset_time: new Date(Date.now() + 60 * 1000).toISOString()
|
reset_time: new Date(Date.now() + 60 * 1000).toISOString()
|
||||||
}
|
},
|
||||||
|
skipSuccessfulRequests: false, // Count all requests
|
||||||
|
keyGenerator: (req) => req.ip || req.socket.remoteAddress || 'unknown' // Use IP for rate limiting
|
||||||
});
|
});
|
||||||
|
|
||||||
// WebSocket rate limiter middleware
|
// WebSocket rate limiter middleware with enhanced configuration
|
||||||
export const wsRateLimiter = rateLimit({
|
export const wsRateLimiter = rateLimit({
|
||||||
windowMs: 60 * 1000, // 1 minute
|
windowMs: 60 * 1000, // 1 minute
|
||||||
max: RATE_LIMIT_CONFIG.WEBSOCKET,
|
max: RATE_LIMIT_CONFIG.WEBSOCKET,
|
||||||
|
standardHeaders: true,
|
||||||
|
legacyHeaders: false,
|
||||||
message: {
|
message: {
|
||||||
success: false,
|
success: false,
|
||||||
message: 'Too many WebSocket connections, please try again later.',
|
message: 'Too many WebSocket connections, please try again later.',
|
||||||
reset_time: new Date(Date.now() + 60 * 1000).toISOString()
|
reset_time: new Date(Date.now() + 60 * 1000).toISOString()
|
||||||
}
|
},
|
||||||
|
skipSuccessfulRequests: false,
|
||||||
|
keyGenerator: (req) => req.ip || req.socket.remoteAddress || 'unknown'
|
||||||
});
|
});
|
||||||
|
|
||||||
// Authentication middleware
|
// Authentication middleware with enhanced security
|
||||||
export const authenticate = (req: Request, res: Response, next: NextFunction) => {
|
export const authenticate = (req: Request, res: Response, next: NextFunction) => {
|
||||||
const token = req.headers.authorization?.replace('Bearer ', '') || '';
|
const authHeader = req.headers.authorization;
|
||||||
|
if (!authHeader || !authHeader.startsWith('Bearer ')) {
|
||||||
|
return res.status(401).json({
|
||||||
|
success: false,
|
||||||
|
message: 'Unauthorized',
|
||||||
|
error: 'Missing or invalid authorization header',
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const token = authHeader.replace('Bearer ', '');
|
||||||
const clientIp = req.ip || req.socket.remoteAddress || '';
|
const clientIp = req.ip || req.socket.remoteAddress || '';
|
||||||
|
|
||||||
const validationResult = TokenManager.validateToken(token, clientIp);
|
const validationResult = TokenManager.validateToken(token, clientIp);
|
||||||
@@ -44,18 +64,40 @@ export const authenticate = (req: Request, res: Response, next: NextFunction) =>
|
|||||||
next();
|
next();
|
||||||
};
|
};
|
||||||
|
|
||||||
// Enhanced security headers middleware
|
// Enhanced security headers middleware using helmet
|
||||||
export const securityHeaders = (_req: Request, res: Response, next: NextFunction) => {
|
export const securityHeaders = helmet({
|
||||||
// Set strict security headers
|
contentSecurityPolicy: {
|
||||||
res.setHeader('X-Content-Type-Options', 'nosniff');
|
directives: {
|
||||||
res.setHeader('X-Frame-Options', 'DENY');
|
defaultSrc: ["'self'"],
|
||||||
res.setHeader('X-XSS-Protection', '1; mode=block');
|
scriptSrc: ["'self'", "'unsafe-inline'"],
|
||||||
res.setHeader('Strict-Transport-Security', 'max-age=31536000; includeSubDomains; preload');
|
styleSrc: ["'self'", "'unsafe-inline'"],
|
||||||
res.setHeader('Content-Security-Policy', "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; connect-src 'self' wss: https:;");
|
imgSrc: ["'self'", 'data:', 'https:'],
|
||||||
res.setHeader('Referrer-Policy', 'strict-origin-when-cross-origin');
|
connectSrc: ["'self'", 'wss:', 'https:'],
|
||||||
res.setHeader('Permissions-Policy', 'geolocation=(), microphone=(), camera=()');
|
frameSrc: ["'none'"],
|
||||||
next();
|
objectSrc: ["'none'"],
|
||||||
};
|
baseUri: ["'self'"],
|
||||||
|
formAction: ["'self'"],
|
||||||
|
frameAncestors: ["'none'"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
crossOriginEmbedderPolicy: true,
|
||||||
|
crossOriginOpenerPolicy: { policy: 'same-origin' },
|
||||||
|
crossOriginResourcePolicy: { policy: 'same-origin' },
|
||||||
|
dnsPrefetchControl: { allow: false },
|
||||||
|
frameguard: { action: 'deny' },
|
||||||
|
hidePoweredBy: true,
|
||||||
|
hsts: {
|
||||||
|
maxAge: 31536000,
|
||||||
|
includeSubDomains: true,
|
||||||
|
preload: true
|
||||||
|
},
|
||||||
|
ieNoOpen: true,
|
||||||
|
noSniff: true,
|
||||||
|
originAgentCluster: true,
|
||||||
|
permittedCrossDomainPolicies: { permittedPolicies: 'none' },
|
||||||
|
referrerPolicy: { policy: 'strict-origin-when-cross-origin' },
|
||||||
|
xssFilter: true
|
||||||
|
});
|
||||||
|
|
||||||
// Enhanced request validation middleware
|
// Enhanced request validation middleware
|
||||||
export const validateRequest = (req: Request, res: Response, next: NextFunction) => {
|
export const validateRequest = (req: Request, res: Response, next: NextFunction) => {
|
||||||
@@ -65,13 +107,16 @@ export const validateRequest = (req: Request, res: Response, next: NextFunction)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate content type for POST/PUT/PATCH requests
|
// Validate content type for POST/PUT/PATCH requests
|
||||||
if (['POST', 'PUT', 'PATCH'].includes(req.method) && !req.is('application/json')) {
|
if (['POST', 'PUT', 'PATCH'].includes(req.method)) {
|
||||||
return res.status(415).json({
|
const contentType = req.headers['content-type'];
|
||||||
success: false,
|
if (!contentType || !contentType.includes('application/json')) {
|
||||||
message: 'Unsupported Media Type',
|
return res.status(415).json({
|
||||||
error: 'Content-Type must be application/json',
|
success: false,
|
||||||
timestamp: new Date().toISOString()
|
message: 'Unsupported Media Type',
|
||||||
});
|
error: 'Content-Type must be application/json',
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate request body size
|
// Validate request body size
|
||||||
@@ -101,23 +146,42 @@ export const validateRequest = (req: Request, res: Response, next: NextFunction)
|
|||||||
next();
|
next();
|
||||||
};
|
};
|
||||||
|
|
||||||
// Input sanitization middleware
|
// Enhanced input sanitization middleware
|
||||||
export const sanitizeInput = (req: Request, _res: Response, next: NextFunction) => {
|
export const sanitizeInput = (req: Request, _res: Response, next: NextFunction) => {
|
||||||
if (req.body) {
|
if (req.body) {
|
||||||
// Recursively sanitize object
|
const sanitizeValue = (value: unknown): unknown => {
|
||||||
const sanitizeObject = (obj: any): any => {
|
if (typeof value === 'string') {
|
||||||
|
// Sanitize HTML content
|
||||||
|
return sanitizeHtml(value, {
|
||||||
|
allowedTags: [], // Remove all HTML tags
|
||||||
|
allowedAttributes: {}, // Remove all attributes
|
||||||
|
textFilter: (text) => {
|
||||||
|
// Remove potential XSS patterns
|
||||||
|
return text.replace(/javascript:/gi, '')
|
||||||
|
.replace(/data:/gi, '')
|
||||||
|
.replace(/vbscript:/gi, '')
|
||||||
|
.replace(/on\w+=/gi, '')
|
||||||
|
.replace(/\b(alert|confirm|prompt|exec|eval|setTimeout|setInterval)\b/gi, '');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
};
|
||||||
|
|
||||||
|
const sanitizeObject = (obj: unknown): unknown => {
|
||||||
if (typeof obj !== 'object' || obj === null) {
|
if (typeof obj !== 'object' || obj === null) {
|
||||||
return obj;
|
return sanitizeValue(obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Array.isArray(obj)) {
|
if (Array.isArray(obj)) {
|
||||||
return obj.map(item => sanitizeObject(item));
|
return obj.map(item => sanitizeObject(item));
|
||||||
}
|
}
|
||||||
|
|
||||||
const sanitized: any = {};
|
const sanitized: Record<string, unknown> = {};
|
||||||
for (const [key, value] of Object.entries(obj)) {
|
for (const [key, value] of Object.entries(obj as Record<string, unknown>)) {
|
||||||
// Remove any potentially dangerous characters from keys
|
// Sanitize keys
|
||||||
const sanitizedKey = key.replace(/[<>]/g, '');
|
const sanitizedKey = typeof key === 'string' ? sanitizeValue(key) as string : key;
|
||||||
|
// Recursively sanitize values
|
||||||
sanitized[sanitizedKey] = sanitizeObject(value);
|
sanitized[sanitizedKey] = sanitizeObject(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,44 +195,68 @@ export const sanitizeInput = (req: Request, _res: Response, next: NextFunction)
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Enhanced error handling middleware
|
// Enhanced error handling middleware
|
||||||
export const errorHandler = (err: Error, _req: Request, res: Response, _next: NextFunction) => {
|
export const errorHandler = (err: Error, req: Request, res: Response, _next: NextFunction) => {
|
||||||
console.error('Error:', err);
|
// Log error with request context
|
||||||
|
console.error('Error:', {
|
||||||
// Handle specific error types
|
error: err.message,
|
||||||
if (err.name === 'ValidationError') {
|
stack: err.stack,
|
||||||
return res.status(400).json({
|
method: req.method,
|
||||||
success: false,
|
path: req.path,
|
||||||
message: 'Validation Error',
|
ip: req.ip,
|
||||||
error: err.message,
|
|
||||||
timestamp: new Date().toISOString()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (err.name === 'UnauthorizedError') {
|
|
||||||
return res.status(401).json({
|
|
||||||
success: false,
|
|
||||||
message: 'Unauthorized',
|
|
||||||
error: err.message,
|
|
||||||
timestamp: new Date().toISOString()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (err.name === 'ForbiddenError') {
|
|
||||||
return res.status(403).json({
|
|
||||||
success: false,
|
|
||||||
message: 'Forbidden',
|
|
||||||
error: err.message,
|
|
||||||
timestamp: new Date().toISOString()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default error response
|
|
||||||
res.status(500).json({
|
|
||||||
success: false,
|
|
||||||
message: 'Internal Server Error',
|
|
||||||
error: process.env.NODE_ENV === 'development' ? err.message : 'An unexpected error occurred',
|
|
||||||
timestamp: new Date().toISOString()
|
timestamp: new Date().toISOString()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Handle specific error types
|
||||||
|
switch (err.name) {
|
||||||
|
case 'ValidationError':
|
||||||
|
return res.status(400).json({
|
||||||
|
success: false,
|
||||||
|
message: 'Validation Error',
|
||||||
|
error: err.message,
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
});
|
||||||
|
|
||||||
|
case 'UnauthorizedError':
|
||||||
|
return res.status(401).json({
|
||||||
|
success: false,
|
||||||
|
message: 'Unauthorized',
|
||||||
|
error: err.message,
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
});
|
||||||
|
|
||||||
|
case 'ForbiddenError':
|
||||||
|
return res.status(403).json({
|
||||||
|
success: false,
|
||||||
|
message: 'Forbidden',
|
||||||
|
error: err.message,
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
});
|
||||||
|
|
||||||
|
case 'NotFoundError':
|
||||||
|
return res.status(404).json({
|
||||||
|
success: false,
|
||||||
|
message: 'Not Found',
|
||||||
|
error: err.message,
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
});
|
||||||
|
|
||||||
|
case 'ConflictError':
|
||||||
|
return res.status(409).json({
|
||||||
|
success: false,
|
||||||
|
message: 'Conflict',
|
||||||
|
error: err.message,
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
});
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Default error response
|
||||||
|
return res.status(500).json({
|
||||||
|
success: false,
|
||||||
|
message: 'Internal Server Error',
|
||||||
|
error: process.env.NODE_ENV === 'development' ? err.message : 'An unexpected error occurred',
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Export all middleware
|
// Export all middleware
|
||||||
|
|||||||
@@ -129,6 +129,7 @@ export class TokenManager {
|
|||||||
* Validates a JWT token with enhanced security checks
|
* Validates a JWT token with enhanced security checks
|
||||||
*/
|
*/
|
||||||
static validateToken(token: string, ip?: string): { valid: boolean; error?: string } {
|
static validateToken(token: string, ip?: string): { valid: boolean; error?: string } {
|
||||||
|
// Check basic token format
|
||||||
if (!token || typeof token !== 'string') {
|
if (!token || typeof token !== 'string') {
|
||||||
return { valid: false, error: 'Invalid token format' };
|
return { valid: false, error: 'Invalid token format' };
|
||||||
}
|
}
|
||||||
@@ -139,34 +140,35 @@ export class TokenManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for rate limiting
|
// Check for rate limiting
|
||||||
if (ip) {
|
if (ip && this.isRateLimited(ip)) {
|
||||||
const attempts = failedAttempts.get(ip);
|
return { valid: false, error: 'Too many failed attempts. Please try again later.' };
|
||||||
if (attempts) {
|
}
|
||||||
const timeSinceLastAttempt = Date.now() - attempts.lastAttempt;
|
|
||||||
if (attempts.count >= SECURITY_CONFIG.MAX_FAILED_ATTEMPTS) {
|
// Get JWT secret
|
||||||
if (timeSinceLastAttempt < SECURITY_CONFIG.LOCKOUT_DURATION) {
|
const secret = process.env.JWT_SECRET;
|
||||||
return { valid: false, error: 'Too many failed attempts. Please try again later.' };
|
if (!secret) {
|
||||||
}
|
return { valid: false, error: 'JWT secret not configured' };
|
||||||
// Reset after lockout period
|
|
||||||
failedAttempts.delete(ip);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const decoded = jwt.decode(token);
|
// Verify token signature and decode
|
||||||
|
const decoded = jwt.verify(token, secret) as jwt.JwtPayload;
|
||||||
|
|
||||||
|
// Verify token structure
|
||||||
if (!decoded || typeof decoded !== 'object') {
|
if (!decoded || typeof decoded !== 'object') {
|
||||||
this.recordFailedAttempt(ip);
|
this.recordFailedAttempt(ip);
|
||||||
return { valid: false, error: 'Invalid token structure' };
|
return { valid: false, error: 'Invalid token structure' };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enhanced expiration checks
|
// Check required claims
|
||||||
if (!decoded.exp || !decoded.iat) {
|
if (!decoded.exp || !decoded.iat) {
|
||||||
this.recordFailedAttempt(ip);
|
this.recordFailedAttempt(ip);
|
||||||
return { valid: false, error: 'Token missing required claims' };
|
return { valid: false, error: 'Token missing required claims' };
|
||||||
}
|
}
|
||||||
|
|
||||||
const now = Math.floor(Date.now() / 1000);
|
const now = Math.floor(Date.now() / 1000);
|
||||||
|
|
||||||
|
// Check expiration
|
||||||
if (decoded.exp <= now) {
|
if (decoded.exp <= now) {
|
||||||
this.recordFailedAttempt(ip);
|
this.recordFailedAttempt(ip);
|
||||||
return { valid: false, error: 'Token has expired' };
|
return { valid: false, error: 'Token has expired' };
|
||||||
@@ -179,47 +181,61 @@ export class TokenManager {
|
|||||||
return { valid: false, error: 'Token exceeds maximum age limit' };
|
return { valid: false, error: 'Token exceeds maximum age limit' };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify signature
|
// Reset failed attempts on successful validation
|
||||||
const secret = process.env.JWT_SECRET;
|
if (ip) {
|
||||||
if (!secret) {
|
failedAttempts.delete(ip);
|
||||||
return { valid: false, error: 'JWT secret not configured' };
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
return { valid: true };
|
||||||
jwt.verify(token, secret);
|
|
||||||
// Reset failed attempts on successful validation
|
|
||||||
if (ip) {
|
|
||||||
failedAttempts.delete(ip);
|
|
||||||
}
|
|
||||||
return { valid: true };
|
|
||||||
} catch (error) {
|
|
||||||
this.recordFailedAttempt(ip);
|
|
||||||
return { valid: false, error: 'Invalid token signature' };
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.recordFailedAttempt(ip);
|
this.recordFailedAttempt(ip);
|
||||||
|
if (error instanceof jwt.JsonWebTokenError) {
|
||||||
|
return { valid: false, error: 'Invalid token signature' };
|
||||||
|
}
|
||||||
|
if (error instanceof jwt.TokenExpiredError) {
|
||||||
|
return { valid: false, error: 'Token has expired' };
|
||||||
|
}
|
||||||
return { valid: false, error: 'Token validation failed' };
|
return { valid: false, error: 'Token validation failed' };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if an IP is rate limited
|
||||||
|
*/
|
||||||
|
private static isRateLimited(ip: string): boolean {
|
||||||
|
const attempts = failedAttempts.get(ip);
|
||||||
|
if (!attempts) return false;
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
const timeSinceLastAttempt = now - attempts.lastAttempt;
|
||||||
|
|
||||||
|
// Reset if outside lockout period
|
||||||
|
if (timeSinceLastAttempt >= SECURITY_CONFIG.LOCKOUT_DURATION) {
|
||||||
|
failedAttempts.delete(ip);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return attempts.count >= SECURITY_CONFIG.MAX_FAILED_ATTEMPTS;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Records a failed authentication attempt
|
* Records a failed authentication attempt
|
||||||
*/
|
*/
|
||||||
private static recordFailedAttempt(ip?: string): void {
|
private static recordFailedAttempt(ip?: string): void {
|
||||||
if (!ip) return;
|
if (!ip) return;
|
||||||
|
|
||||||
const attempts = failedAttempts.get(ip) || { count: 0, lastAttempt: 0 };
|
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
|
const attempts = failedAttempts.get(ip);
|
||||||
|
|
||||||
// Reset count if last attempt was outside lockout period
|
if (!attempts || (now - attempts.lastAttempt) >= SECURITY_CONFIG.LOCKOUT_DURATION) {
|
||||||
if (now - attempts.lastAttempt > SECURITY_CONFIG.LOCKOUT_DURATION) {
|
// First attempt or reset after lockout
|
||||||
attempts.count = 1;
|
failedAttempts.set(ip, { count: 1, lastAttempt: now });
|
||||||
} else {
|
} else {
|
||||||
|
// Increment existing attempts
|
||||||
attempts.count++;
|
attempts.count++;
|
||||||
|
attempts.lastAttempt = now;
|
||||||
|
failedAttempts.set(ip, attempts);
|
||||||
}
|
}
|
||||||
|
|
||||||
attempts.lastAttempt = now;
|
|
||||||
failedAttempts.set(ip, attempts);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -231,15 +247,18 @@ export class TokenManager {
|
|||||||
throw new Error('JWT secret not configured');
|
throw new Error('JWT secret not configured');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure we don't override system claims
|
||||||
|
const sanitizedPayload = { ...payload };
|
||||||
|
delete (sanitizedPayload as any).iat;
|
||||||
|
delete (sanitizedPayload as any).exp;
|
||||||
|
|
||||||
return jwt.sign(
|
return jwt.sign(
|
||||||
{
|
sanitizedPayload,
|
||||||
...payload,
|
|
||||||
iat: Math.floor(Date.now() / 1000),
|
|
||||||
},
|
|
||||||
secret,
|
secret,
|
||||||
{
|
{
|
||||||
expiresIn: Math.floor(expiresIn / 1000),
|
expiresIn: Math.floor(expiresIn / 1000),
|
||||||
algorithm: 'HS256'
|
algorithm: 'HS256',
|
||||||
|
notBefore: 0 // Token is valid immediately
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
294
src/sse/index.ts
294
src/sse/index.ts
@@ -7,10 +7,17 @@ const DEFAULT_MAX_CLIENTS = 1000;
|
|||||||
const DEFAULT_PING_INTERVAL = 30000; // 30 seconds
|
const DEFAULT_PING_INTERVAL = 30000; // 30 seconds
|
||||||
const DEFAULT_CLEANUP_INTERVAL = 60000; // 1 minute
|
const DEFAULT_CLEANUP_INTERVAL = 60000; // 1 minute
|
||||||
const DEFAULT_MAX_CONNECTION_AGE = 24 * 60 * 60 * 1000; // 24 hours
|
const DEFAULT_MAX_CONNECTION_AGE = 24 * 60 * 60 * 1000; // 24 hours
|
||||||
|
const DEFAULT_RATE_LIMIT = {
|
||||||
|
MAX_MESSAGES: 100, // messages
|
||||||
|
WINDOW_MS: 60000, // 1 minute
|
||||||
|
BURST_LIMIT: 10 // max messages per second
|
||||||
|
};
|
||||||
|
|
||||||
interface RateLimit {
|
interface RateLimit {
|
||||||
count: number;
|
count: number;
|
||||||
lastReset: number;
|
lastReset: number;
|
||||||
|
burstCount: number;
|
||||||
|
lastBurstReset: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SSEClient {
|
export interface SSEClient {
|
||||||
@@ -32,6 +39,8 @@ interface ClientStats {
|
|||||||
lastPingAt?: Date;
|
lastPingAt?: Date;
|
||||||
subscriptionCount: number;
|
subscriptionCount: number;
|
||||||
connectionDuration: number;
|
connectionDuration: number;
|
||||||
|
messagesSent: number;
|
||||||
|
lastActivity: Date;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class SSEManager extends EventEmitter {
|
export class SSEManager extends EventEmitter {
|
||||||
@@ -42,18 +51,21 @@ export class SSEManager extends EventEmitter {
|
|||||||
private readonly pingInterval: number;
|
private readonly pingInterval: number;
|
||||||
private readonly cleanupInterval: number;
|
private readonly cleanupInterval: number;
|
||||||
private readonly maxConnectionAge: number;
|
private readonly maxConnectionAge: number;
|
||||||
|
private readonly rateLimit: typeof DEFAULT_RATE_LIMIT;
|
||||||
|
|
||||||
constructor(options: {
|
constructor(options: {
|
||||||
maxClients?: number;
|
maxClients?: number;
|
||||||
pingInterval?: number;
|
pingInterval?: number;
|
||||||
cleanupInterval?: number;
|
cleanupInterval?: number;
|
||||||
maxConnectionAge?: number;
|
maxConnectionAge?: number;
|
||||||
|
rateLimit?: Partial<typeof DEFAULT_RATE_LIMIT>;
|
||||||
} = {}) {
|
} = {}) {
|
||||||
super();
|
super();
|
||||||
this.maxClients = options.maxClients || DEFAULT_MAX_CLIENTS;
|
this.maxClients = options.maxClients || DEFAULT_MAX_CLIENTS;
|
||||||
this.pingInterval = options.pingInterval || DEFAULT_PING_INTERVAL;
|
this.pingInterval = options.pingInterval || DEFAULT_PING_INTERVAL;
|
||||||
this.cleanupInterval = options.cleanupInterval || DEFAULT_CLEANUP_INTERVAL;
|
this.cleanupInterval = options.cleanupInterval || DEFAULT_CLEANUP_INTERVAL;
|
||||||
this.maxConnectionAge = options.maxConnectionAge || DEFAULT_MAX_CONNECTION_AGE;
|
this.maxConnectionAge = options.maxConnectionAge || DEFAULT_MAX_CONNECTION_AGE;
|
||||||
|
this.rateLimit = { ...DEFAULT_RATE_LIMIT, ...options.rateLimit };
|
||||||
|
|
||||||
console.log('Initializing SSE Manager...');
|
console.log('Initializing SSE Manager...');
|
||||||
this.startMaintenanceTasks();
|
this.startMaintenanceTasks();
|
||||||
@@ -63,15 +75,17 @@ export class SSEManager extends EventEmitter {
|
|||||||
// Send periodic pings to keep connections alive
|
// Send periodic pings to keep connections alive
|
||||||
setInterval(() => {
|
setInterval(() => {
|
||||||
this.clients.forEach(client => {
|
this.clients.forEach(client => {
|
||||||
try {
|
if (!this.isRateLimited(client)) {
|
||||||
client.send(JSON.stringify({
|
try {
|
||||||
type: 'ping',
|
client.send(JSON.stringify({
|
||||||
timestamp: new Date().toISOString()
|
type: 'ping',
|
||||||
}));
|
timestamp: new Date().toISOString()
|
||||||
client.lastPingAt = new Date();
|
}));
|
||||||
} catch (error) {
|
client.lastPingAt = new Date();
|
||||||
console.error(`Failed to ping client ${client.id}:`, error);
|
} catch (error) {
|
||||||
this.removeClient(client.id);
|
console.error(`Failed to ping client ${client.id}:`, error);
|
||||||
|
this.removeClient(client.id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}, this.pingInterval);
|
}, this.pingInterval);
|
||||||
@@ -98,7 +112,7 @@ export class SSEManager extends EventEmitter {
|
|||||||
return SSEManager.instance;
|
return SSEManager.instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
addClient(client: Omit<SSEClient, 'authenticated' | 'subscriptions'>, token: string): SSEClient | null {
|
addClient(client: Omit<SSEClient, 'authenticated' | 'subscriptions' | 'rateLimit'>, token: string): SSEClient | null {
|
||||||
// Validate token
|
// Validate token
|
||||||
const validationResult = TokenManager.validateToken(token, client.ip);
|
const validationResult = TokenManager.validateToken(token, client.ip);
|
||||||
if (!validationResult.valid) {
|
if (!validationResult.valid) {
|
||||||
@@ -117,7 +131,13 @@ export class SSEManager extends EventEmitter {
|
|||||||
...client,
|
...client,
|
||||||
authenticated: true,
|
authenticated: true,
|
||||||
subscriptions: new Set(),
|
subscriptions: new Set(),
|
||||||
lastPingAt: new Date()
|
lastPingAt: new Date(),
|
||||||
|
rateLimit: {
|
||||||
|
count: 0,
|
||||||
|
lastReset: Date.now(),
|
||||||
|
burstCount: 0,
|
||||||
|
lastBurstReset: Date.now()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
this.clients.set(client.id, newClient);
|
this.clients.set(client.id, newClient);
|
||||||
@@ -126,22 +146,46 @@ export class SSEManager extends EventEmitter {
|
|||||||
return newClient;
|
return newClient;
|
||||||
}
|
}
|
||||||
|
|
||||||
private startClientPing(clientId: string) {
|
private isRateLimited(client: SSEClient): boolean {
|
||||||
const interval = setInterval(() => {
|
const now = Date.now();
|
||||||
const client = this.clients.get(clientId);
|
|
||||||
if (!client) {
|
|
||||||
clearInterval(interval);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.sendToClient(client, {
|
// Reset window counters if needed
|
||||||
type: 'ping',
|
if (now - client.rateLimit.lastReset >= this.rateLimit.WINDOW_MS) {
|
||||||
timestamp: new Date().toISOString()
|
client.rateLimit.count = 0;
|
||||||
});
|
client.rateLimit.lastReset = now;
|
||||||
}, this.pingInterval);
|
}
|
||||||
|
|
||||||
|
// Reset burst counters if needed (every second)
|
||||||
|
if (now - client.rateLimit.lastBurstReset >= 1000) {
|
||||||
|
client.rateLimit.burstCount = 0;
|
||||||
|
client.rateLimit.lastBurstReset = now;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check both window and burst limits
|
||||||
|
return (
|
||||||
|
client.rateLimit.count >= this.rateLimit.MAX_MESSAGES ||
|
||||||
|
client.rateLimit.burstCount >= this.rateLimit.BURST_LIMIT
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
removeClient(clientId: string) {
|
private updateRateLimit(client: SSEClient): void {
|
||||||
|
const now = Date.now();
|
||||||
|
client.rateLimit.count++;
|
||||||
|
client.rateLimit.burstCount++;
|
||||||
|
|
||||||
|
// Update timestamps if needed
|
||||||
|
if (now - client.rateLimit.lastReset >= this.rateLimit.WINDOW_MS) {
|
||||||
|
client.rateLimit.lastReset = now;
|
||||||
|
client.rateLimit.count = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (now - client.rateLimit.lastBurstReset >= 1000) {
|
||||||
|
client.rateLimit.lastBurstReset = now;
|
||||||
|
client.rateLimit.burstCount = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
removeClient(clientId: string): void {
|
||||||
if (this.clients.has(clientId)) {
|
if (this.clients.has(clientId)) {
|
||||||
this.clients.delete(clientId);
|
this.clients.delete(clientId);
|
||||||
console.log(`SSE client disconnected: ${clientId}`);
|
console.log(`SSE client disconnected: ${clientId}`);
|
||||||
@@ -152,46 +196,55 @@ export class SSEManager extends EventEmitter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
subscribeToEntity(clientId: string, entityId: string) {
|
subscribeToEntity(clientId: string, entityId: string): void {
|
||||||
const client = this.clients.get(clientId);
|
const client = this.clients.get(clientId);
|
||||||
if (client?.authenticated) {
|
if (!client?.authenticated) {
|
||||||
client.subscriptions.add(`entity:${entityId}`);
|
console.warn(`Unauthenticated client ${clientId} attempted to subscribe to entity: ${entityId}`);
|
||||||
console.log(`Client ${clientId} subscribed to entity: ${entityId}`);
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Send current state if available
|
client.subscriptions.add(`entity:${entityId}`);
|
||||||
const currentState = this.entityStates.get(entityId);
|
console.log(`Client ${clientId} subscribed to entity: ${entityId}`);
|
||||||
if (currentState) {
|
|
||||||
this.sendToClient(client, {
|
// Send current state if available
|
||||||
type: 'state_changed',
|
const currentState = this.entityStates.get(entityId);
|
||||||
data: {
|
if (currentState && !this.isRateLimited(client)) {
|
||||||
entity_id: currentState.entity_id,
|
this.sendToClient(client, {
|
||||||
state: currentState.state,
|
type: 'state_changed',
|
||||||
attributes: currentState.attributes,
|
data: {
|
||||||
last_changed: currentState.last_changed,
|
entity_id: currentState.entity_id,
|
||||||
last_updated: currentState.last_updated
|
state: currentState.state,
|
||||||
}
|
attributes: currentState.attributes,
|
||||||
});
|
last_changed: currentState.last_changed,
|
||||||
}
|
last_updated: currentState.last_updated
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
subscribeToDomain(clientId: string, domain: string) {
|
subscribeToDomain(clientId: string, domain: string): void {
|
||||||
const client = this.clients.get(clientId);
|
const client = this.clients.get(clientId);
|
||||||
if (client?.authenticated) {
|
if (!client?.authenticated) {
|
||||||
client.subscriptions.add(`domain:${domain}`);
|
console.warn(`Unauthenticated client ${clientId} attempted to subscribe to domain: ${domain}`);
|
||||||
console.log(`Client ${clientId} subscribed to domain: ${domain}`);
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client.subscriptions.add(`domain:${domain}`);
|
||||||
|
console.log(`Client ${clientId} subscribed to domain: ${domain}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
subscribeToEvent(clientId: string, eventType: string) {
|
subscribeToEvent(clientId: string, eventType: string): void {
|
||||||
const client = this.clients.get(clientId);
|
const client = this.clients.get(clientId);
|
||||||
if (client?.authenticated) {
|
if (!client?.authenticated) {
|
||||||
client.subscriptions.add(`event:${eventType}`);
|
console.warn(`Unauthenticated client ${clientId} attempted to subscribe to event: ${eventType}`);
|
||||||
console.log(`Client ${clientId} subscribed to event: ${eventType}`);
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client.subscriptions.add(`event:${eventType}`);
|
||||||
|
console.log(`Client ${clientId} subscribed to event: ${eventType}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
broadcastStateChange(entity: HassEntity) {
|
broadcastStateChange(entity: HassEntity): void {
|
||||||
// Update stored state
|
// Update stored state
|
||||||
this.entityStates.set(entity.entity_id, entity);
|
this.entityStates.set(entity.entity_id, entity);
|
||||||
|
|
||||||
@@ -211,8 +264,8 @@ export class SSEManager extends EventEmitter {
|
|||||||
console.log(`Broadcasting state change for ${entity.entity_id}`);
|
console.log(`Broadcasting state change for ${entity.entity_id}`);
|
||||||
|
|
||||||
// Send to relevant subscribers only
|
// Send to relevant subscribers only
|
||||||
for (const client of this.clients.values()) {
|
this.clients.forEach(client => {
|
||||||
if (!client.authenticated) continue;
|
if (!client.authenticated || this.isRateLimited(client)) return;
|
||||||
|
|
||||||
if (
|
if (
|
||||||
client.subscriptions.has(`entity:${entity.entity_id}`) ||
|
client.subscriptions.has(`entity:${entity.entity_id}`) ||
|
||||||
@@ -221,10 +274,10 @@ export class SSEManager extends EventEmitter {
|
|||||||
) {
|
) {
|
||||||
this.sendToClient(client, message);
|
this.sendToClient(client, message);
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
broadcastEvent(event: HassEvent) {
|
broadcastEvent(event: HassEvent): void {
|
||||||
const message = {
|
const message = {
|
||||||
type: event.event_type,
|
type: event.event_type,
|
||||||
data: event.data,
|
data: event.data,
|
||||||
@@ -237,117 +290,36 @@ export class SSEManager extends EventEmitter {
|
|||||||
console.log(`Broadcasting event: ${event.event_type}`);
|
console.log(`Broadcasting event: ${event.event_type}`);
|
||||||
|
|
||||||
// Send to relevant subscribers only
|
// Send to relevant subscribers only
|
||||||
for (const client of this.clients.values()) {
|
this.clients.forEach(client => {
|
||||||
if (!client.authenticated) continue;
|
if (!client.authenticated || this.isRateLimited(client)) return;
|
||||||
|
|
||||||
if (client.subscriptions.has(`event:${event.event_type}`)) {
|
if (client.subscriptions.has(`event:${event.event_type}`)) {
|
||||||
this.sendToClient(client, message);
|
this.sendToClient(client, message);
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private sendToClient(client: SSEClient, data: any) {
|
private sendToClient(client: SSEClient, data: unknown): void {
|
||||||
try {
|
try {
|
||||||
// Check rate limit
|
if (!client.authenticated) {
|
||||||
const now = Date.now();
|
console.warn(`Attempted to send message to unauthenticated client ${client.id}`);
|
||||||
if (now - client.rateLimit.lastReset > this.cleanupInterval) {
|
|
||||||
client.rateLimit.count = 0;
|
|
||||||
client.rateLimit.lastReset = now;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (client.rateLimit.count >= 1000) {
|
|
||||||
console.warn(`Rate limit exceeded for client ${client.id}`);
|
|
||||||
this.sendToClient(client, {
|
|
||||||
type: 'error',
|
|
||||||
error: 'rate_limit_exceeded',
|
|
||||||
message: 'Too many requests, please try again later',
|
|
||||||
timestamp: new Date().toISOString()
|
|
||||||
});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
client.rateLimit.count++;
|
if (this.isRateLimited(client)) {
|
||||||
client.lastPingAt = new Date();
|
console.warn(`Rate limit exceeded for client ${client.id}`);
|
||||||
client.send(JSON.stringify(data));
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const message = typeof data === 'string' ? data : JSON.stringify(data);
|
||||||
|
client.send(message);
|
||||||
|
this.updateRateLimit(client);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Error sending message to client ${client.id}:`, error);
|
console.error(`Failed to send message to client ${client.id}:`, error);
|
||||||
this.removeClient(client.id);
|
this.removeClient(client.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private validateToken(token?: string): boolean {
|
|
||||||
if (!token) return false;
|
|
||||||
const validationResult = TokenManager.validateToken(token);
|
|
||||||
return validationResult.valid;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Utility methods
|
|
||||||
getConnectedClients(): number {
|
|
||||||
return this.clients.size;
|
|
||||||
}
|
|
||||||
|
|
||||||
getClientSubscriptions(clientId: string) {
|
|
||||||
return this.clients.get(clientId)?.subscriptions;
|
|
||||||
}
|
|
||||||
|
|
||||||
getEntityState(entityId: string): HassEntity | undefined {
|
|
||||||
return this.entityStates.get(entityId);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add new event types
|
|
||||||
broadcastServiceCall(domain: string, service: string, data: any) {
|
|
||||||
const message = {
|
|
||||||
type: 'service_called',
|
|
||||||
data: {
|
|
||||||
domain,
|
|
||||||
service,
|
|
||||||
service_data: data
|
|
||||||
},
|
|
||||||
timestamp: new Date().toISOString()
|
|
||||||
};
|
|
||||||
|
|
||||||
this.broadcastToSubscribers('service_called', message);
|
|
||||||
}
|
|
||||||
|
|
||||||
broadcastAutomationTriggered(automationId: string, trigger: any) {
|
|
||||||
const message = {
|
|
||||||
type: 'automation_triggered',
|
|
||||||
data: {
|
|
||||||
automation_id: automationId,
|
|
||||||
trigger
|
|
||||||
},
|
|
||||||
timestamp: new Date().toISOString()
|
|
||||||
};
|
|
||||||
|
|
||||||
this.broadcastToSubscribers('automation_triggered', message);
|
|
||||||
}
|
|
||||||
|
|
||||||
broadcastScriptExecuted(scriptId: string, data: any) {
|
|
||||||
const message = {
|
|
||||||
type: 'script_executed',
|
|
||||||
data: {
|
|
||||||
script_id: scriptId,
|
|
||||||
execution_data: data
|
|
||||||
},
|
|
||||||
timestamp: new Date().toISOString()
|
|
||||||
};
|
|
||||||
|
|
||||||
this.broadcastToSubscribers('script_executed', message);
|
|
||||||
}
|
|
||||||
|
|
||||||
private broadcastToSubscribers(eventType: string, message: any) {
|
|
||||||
for (const client of this.clients.values()) {
|
|
||||||
if (!client.authenticated) continue;
|
|
||||||
|
|
||||||
if (client.subscriptions.has(`event:${eventType}`) ||
|
|
||||||
client.subscriptions.has(`entity:${eventType}`) ||
|
|
||||||
client.subscriptions.has(`domain:${eventType.split('.')[0]}`)) {
|
|
||||||
this.sendToClient(client, message);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add statistics methods
|
|
||||||
getStatistics(): {
|
getStatistics(): {
|
||||||
totalClients: number;
|
totalClients: number;
|
||||||
authenticatedClients: number;
|
authenticatedClients: number;
|
||||||
@@ -356,31 +328,35 @@ export class SSEManager extends EventEmitter {
|
|||||||
} {
|
} {
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
const clientStats: ClientStats[] = [];
|
const clientStats: ClientStats[] = [];
|
||||||
const subscriptionCounts: { [key: string]: number } = {};
|
const subscriptionStats: { [key: string]: number } = {};
|
||||||
|
let authenticatedClients = 0;
|
||||||
|
|
||||||
this.clients.forEach(client => {
|
this.clients.forEach(client => {
|
||||||
// Collect client statistics
|
if (client.authenticated) {
|
||||||
|
authenticatedClients++;
|
||||||
|
}
|
||||||
|
|
||||||
clientStats.push({
|
clientStats.push({
|
||||||
id: client.id,
|
id: client.id,
|
||||||
ip: client.ip,
|
ip: client.ip,
|
||||||
connectedAt: client.connectedAt,
|
connectedAt: client.connectedAt,
|
||||||
lastPingAt: client.lastPingAt,
|
lastPingAt: client.lastPingAt,
|
||||||
subscriptionCount: client.subscriptions.size,
|
subscriptionCount: client.subscriptions.size,
|
||||||
connectionDuration: now - client.connectedAt.getTime()
|
connectionDuration: now - client.connectedAt.getTime(),
|
||||||
|
messagesSent: client.rateLimit.count,
|
||||||
|
lastActivity: new Date(client.rateLimit.lastReset)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Count subscriptions by type
|
|
||||||
client.subscriptions.forEach(sub => {
|
client.subscriptions.forEach(sub => {
|
||||||
const [type] = sub.split(':');
|
subscriptionStats[sub] = (subscriptionStats[sub] || 0) + 1;
|
||||||
subscriptionCounts[type] = (subscriptionCounts[type] || 0) + 1;
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
return {
|
return {
|
||||||
totalClients: this.clients.size,
|
totalClients: this.clients.size,
|
||||||
authenticatedClients: Array.from(this.clients.values()).filter(c => c.authenticated).length,
|
authenticatedClients,
|
||||||
clientStats,
|
clientStats,
|
||||||
subscriptionStats: subscriptionCounts
|
subscriptionStats
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user