refactor: enhance middleware and security with advanced protection mechanisms

- Upgraded rate limiter configuration with more granular control and detailed headers
- Improved authentication middleware with enhanced token validation and error responses
- Implemented advanced input sanitization using sanitize-html with comprehensive XSS protection
- Replaced manual security headers with helmet for robust web security configuration
- Enhanced error handling middleware with more detailed logging and specific error type handling
- Updated SSE rate limiting with burst and window-based restrictions
- Improved token validation with more precise signature and claim verification
This commit is contained in:
jango-blockchained
2025-02-03 22:29:41 +01:00
parent 89f2278c25
commit 10bf5919e4
3 changed files with 352 additions and 269 deletions

View File

@@ -2,32 +2,52 @@ import { Request, Response, NextFunction } from 'express';
import { HASS_CONFIG, RATE_LIMIT_CONFIG } from '../config/index.js'; import { HASS_CONFIG, RATE_LIMIT_CONFIG } from '../config/index.js';
import rateLimit from 'express-rate-limit'; import rateLimit from 'express-rate-limit';
import { TokenManager } from '../security/index.js'; import { TokenManager } from '../security/index.js';
import sanitizeHtml from 'sanitize-html';
import helmet from 'helmet';
// Rate limiter middleware // Rate limiter middleware with enhanced configuration
export const rateLimiter = rateLimit({ export const rateLimiter = rateLimit({
windowMs: 60 * 1000, // 1 minute windowMs: 60 * 1000, // 1 minute
max: RATE_LIMIT_CONFIG.REGULAR, max: RATE_LIMIT_CONFIG.REGULAR,
standardHeaders: true, // Return rate limit info in the `RateLimit-*` headers
legacyHeaders: false, // Disable the `X-RateLimit-*` headers
message: { message: {
success: false, success: false,
message: 'Too many requests, please try again later.', message: 'Too many requests, please try again later.',
reset_time: new Date(Date.now() + 60 * 1000).toISOString() reset_time: new Date(Date.now() + 60 * 1000).toISOString()
} },
skipSuccessfulRequests: false, // Count all requests
keyGenerator: (req) => req.ip || req.socket.remoteAddress || 'unknown' // Use IP for rate limiting
}); });
// WebSocket rate limiter middleware // WebSocket rate limiter middleware with enhanced configuration
export const wsRateLimiter = rateLimit({ export const wsRateLimiter = rateLimit({
windowMs: 60 * 1000, // 1 minute windowMs: 60 * 1000, // 1 minute
max: RATE_LIMIT_CONFIG.WEBSOCKET, max: RATE_LIMIT_CONFIG.WEBSOCKET,
standardHeaders: true,
legacyHeaders: false,
message: { message: {
success: false, success: false,
message: 'Too many WebSocket connections, please try again later.', message: 'Too many WebSocket connections, please try again later.',
reset_time: new Date(Date.now() + 60 * 1000).toISOString() reset_time: new Date(Date.now() + 60 * 1000).toISOString()
} },
skipSuccessfulRequests: false,
keyGenerator: (req) => req.ip || req.socket.remoteAddress || 'unknown'
}); });
// Authentication middleware // Authentication middleware with enhanced security
export const authenticate = (req: Request, res: Response, next: NextFunction) => { export const authenticate = (req: Request, res: Response, next: NextFunction) => {
const token = req.headers.authorization?.replace('Bearer ', '') || ''; const authHeader = req.headers.authorization;
if (!authHeader || !authHeader.startsWith('Bearer ')) {
return res.status(401).json({
success: false,
message: 'Unauthorized',
error: 'Missing or invalid authorization header',
timestamp: new Date().toISOString()
});
}
const token = authHeader.replace('Bearer ', '');
const clientIp = req.ip || req.socket.remoteAddress || ''; const clientIp = req.ip || req.socket.remoteAddress || '';
const validationResult = TokenManager.validateToken(token, clientIp); const validationResult = TokenManager.validateToken(token, clientIp);
@@ -44,18 +64,40 @@ export const authenticate = (req: Request, res: Response, next: NextFunction) =>
next(); next();
}; };
// Enhanced security headers middleware // Enhanced security headers middleware using helmet
export const securityHeaders = (_req: Request, res: Response, next: NextFunction) => { export const securityHeaders = helmet({
// Set strict security headers contentSecurityPolicy: {
res.setHeader('X-Content-Type-Options', 'nosniff'); directives: {
res.setHeader('X-Frame-Options', 'DENY'); defaultSrc: ["'self'"],
res.setHeader('X-XSS-Protection', '1; mode=block'); scriptSrc: ["'self'", "'unsafe-inline'"],
res.setHeader('Strict-Transport-Security', 'max-age=31536000; includeSubDomains; preload'); styleSrc: ["'self'", "'unsafe-inline'"],
res.setHeader('Content-Security-Policy', "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; connect-src 'self' wss: https:;"); imgSrc: ["'self'", 'data:', 'https:'],
res.setHeader('Referrer-Policy', 'strict-origin-when-cross-origin'); connectSrc: ["'self'", 'wss:', 'https:'],
res.setHeader('Permissions-Policy', 'geolocation=(), microphone=(), camera=()'); frameSrc: ["'none'"],
next(); objectSrc: ["'none'"],
}; baseUri: ["'self'"],
formAction: ["'self'"],
frameAncestors: ["'none'"]
}
},
crossOriginEmbedderPolicy: true,
crossOriginOpenerPolicy: { policy: 'same-origin' },
crossOriginResourcePolicy: { policy: 'same-origin' },
dnsPrefetchControl: { allow: false },
frameguard: { action: 'deny' },
hidePoweredBy: true,
hsts: {
maxAge: 31536000,
includeSubDomains: true,
preload: true
},
ieNoOpen: true,
noSniff: true,
originAgentCluster: true,
permittedCrossDomainPolicies: { permittedPolicies: 'none' },
referrerPolicy: { policy: 'strict-origin-when-cross-origin' },
xssFilter: true
});
// Enhanced request validation middleware // Enhanced request validation middleware
export const validateRequest = (req: Request, res: Response, next: NextFunction) => { export const validateRequest = (req: Request, res: Response, next: NextFunction) => {
@@ -65,13 +107,16 @@ export const validateRequest = (req: Request, res: Response, next: NextFunction)
} }
// Validate content type for POST/PUT/PATCH requests // Validate content type for POST/PUT/PATCH requests
if (['POST', 'PUT', 'PATCH'].includes(req.method) && !req.is('application/json')) { if (['POST', 'PUT', 'PATCH'].includes(req.method)) {
return res.status(415).json({ const contentType = req.headers['content-type'];
success: false, if (!contentType || !contentType.includes('application/json')) {
message: 'Unsupported Media Type', return res.status(415).json({
error: 'Content-Type must be application/json', success: false,
timestamp: new Date().toISOString() message: 'Unsupported Media Type',
}); error: 'Content-Type must be application/json',
timestamp: new Date().toISOString()
});
}
} }
// Validate request body size // Validate request body size
@@ -101,23 +146,42 @@ export const validateRequest = (req: Request, res: Response, next: NextFunction)
next(); next();
}; };
// Input sanitization middleware // Enhanced input sanitization middleware
export const sanitizeInput = (req: Request, _res: Response, next: NextFunction) => { export const sanitizeInput = (req: Request, _res: Response, next: NextFunction) => {
if (req.body) { if (req.body) {
// Recursively sanitize object const sanitizeValue = (value: unknown): unknown => {
const sanitizeObject = (obj: any): any => { if (typeof value === 'string') {
// Sanitize HTML content
return sanitizeHtml(value, {
allowedTags: [], // Remove all HTML tags
allowedAttributes: {}, // Remove all attributes
textFilter: (text) => {
// Remove potential XSS patterns
return text.replace(/javascript:/gi, '')
.replace(/data:/gi, '')
.replace(/vbscript:/gi, '')
.replace(/on\w+=/gi, '')
.replace(/\b(alert|confirm|prompt|exec|eval|setTimeout|setInterval)\b/gi, '');
}
});
}
return value;
};
const sanitizeObject = (obj: unknown): unknown => {
if (typeof obj !== 'object' || obj === null) { if (typeof obj !== 'object' || obj === null) {
return obj; return sanitizeValue(obj);
} }
if (Array.isArray(obj)) { if (Array.isArray(obj)) {
return obj.map(item => sanitizeObject(item)); return obj.map(item => sanitizeObject(item));
} }
const sanitized: any = {}; const sanitized: Record<string, unknown> = {};
for (const [key, value] of Object.entries(obj)) { for (const [key, value] of Object.entries(obj as Record<string, unknown>)) {
// Remove any potentially dangerous characters from keys // Sanitize keys
const sanitizedKey = key.replace(/[<>]/g, ''); const sanitizedKey = typeof key === 'string' ? sanitizeValue(key) as string : key;
// Recursively sanitize values
sanitized[sanitizedKey] = sanitizeObject(value); sanitized[sanitizedKey] = sanitizeObject(value);
} }
@@ -131,44 +195,68 @@ export const sanitizeInput = (req: Request, _res: Response, next: NextFunction)
}; };
// Enhanced error handling middleware // Enhanced error handling middleware
export const errorHandler = (err: Error, _req: Request, res: Response, _next: NextFunction) => { export const errorHandler = (err: Error, req: Request, res: Response, _next: NextFunction) => {
console.error('Error:', err); // Log error with request context
console.error('Error:', {
// Handle specific error types error: err.message,
if (err.name === 'ValidationError') { stack: err.stack,
return res.status(400).json({ method: req.method,
success: false, path: req.path,
message: 'Validation Error', ip: req.ip,
error: err.message,
timestamp: new Date().toISOString()
});
}
if (err.name === 'UnauthorizedError') {
return res.status(401).json({
success: false,
message: 'Unauthorized',
error: err.message,
timestamp: new Date().toISOString()
});
}
if (err.name === 'ForbiddenError') {
return res.status(403).json({
success: false,
message: 'Forbidden',
error: err.message,
timestamp: new Date().toISOString()
});
}
// Default error response
res.status(500).json({
success: false,
message: 'Internal Server Error',
error: process.env.NODE_ENV === 'development' ? err.message : 'An unexpected error occurred',
timestamp: new Date().toISOString() timestamp: new Date().toISOString()
}); });
// Handle specific error types
switch (err.name) {
case 'ValidationError':
return res.status(400).json({
success: false,
message: 'Validation Error',
error: err.message,
timestamp: new Date().toISOString()
});
case 'UnauthorizedError':
return res.status(401).json({
success: false,
message: 'Unauthorized',
error: err.message,
timestamp: new Date().toISOString()
});
case 'ForbiddenError':
return res.status(403).json({
success: false,
message: 'Forbidden',
error: err.message,
timestamp: new Date().toISOString()
});
case 'NotFoundError':
return res.status(404).json({
success: false,
message: 'Not Found',
error: err.message,
timestamp: new Date().toISOString()
});
case 'ConflictError':
return res.status(409).json({
success: false,
message: 'Conflict',
error: err.message,
timestamp: new Date().toISOString()
});
default:
// Default error response
return res.status(500).json({
success: false,
message: 'Internal Server Error',
error: process.env.NODE_ENV === 'development' ? err.message : 'An unexpected error occurred',
timestamp: new Date().toISOString()
});
}
}; };
// Export all middleware // Export all middleware

View File

@@ -129,6 +129,7 @@ export class TokenManager {
* Validates a JWT token with enhanced security checks * Validates a JWT token with enhanced security checks
*/ */
static validateToken(token: string, ip?: string): { valid: boolean; error?: string } { static validateToken(token: string, ip?: string): { valid: boolean; error?: string } {
// Check basic token format
if (!token || typeof token !== 'string') { if (!token || typeof token !== 'string') {
return { valid: false, error: 'Invalid token format' }; return { valid: false, error: 'Invalid token format' };
} }
@@ -139,34 +140,35 @@ export class TokenManager {
} }
// Check for rate limiting // Check for rate limiting
if (ip) { if (ip && this.isRateLimited(ip)) {
const attempts = failedAttempts.get(ip); return { valid: false, error: 'Too many failed attempts. Please try again later.' };
if (attempts) { }
const timeSinceLastAttempt = Date.now() - attempts.lastAttempt;
if (attempts.count >= SECURITY_CONFIG.MAX_FAILED_ATTEMPTS) { // Get JWT secret
if (timeSinceLastAttempt < SECURITY_CONFIG.LOCKOUT_DURATION) { const secret = process.env.JWT_SECRET;
return { valid: false, error: 'Too many failed attempts. Please try again later.' }; if (!secret) {
} return { valid: false, error: 'JWT secret not configured' };
// Reset after lockout period
failedAttempts.delete(ip);
}
}
} }
try { try {
const decoded = jwt.decode(token); // Verify token signature and decode
const decoded = jwt.verify(token, secret) as jwt.JwtPayload;
// Verify token structure
if (!decoded || typeof decoded !== 'object') { if (!decoded || typeof decoded !== 'object') {
this.recordFailedAttempt(ip); this.recordFailedAttempt(ip);
return { valid: false, error: 'Invalid token structure' }; return { valid: false, error: 'Invalid token structure' };
} }
// Enhanced expiration checks // Check required claims
if (!decoded.exp || !decoded.iat) { if (!decoded.exp || !decoded.iat) {
this.recordFailedAttempt(ip); this.recordFailedAttempt(ip);
return { valid: false, error: 'Token missing required claims' }; return { valid: false, error: 'Token missing required claims' };
} }
const now = Math.floor(Date.now() / 1000); const now = Math.floor(Date.now() / 1000);
// Check expiration
if (decoded.exp <= now) { if (decoded.exp <= now) {
this.recordFailedAttempt(ip); this.recordFailedAttempt(ip);
return { valid: false, error: 'Token has expired' }; return { valid: false, error: 'Token has expired' };
@@ -179,47 +181,61 @@ export class TokenManager {
return { valid: false, error: 'Token exceeds maximum age limit' }; return { valid: false, error: 'Token exceeds maximum age limit' };
} }
// Verify signature // Reset failed attempts on successful validation
const secret = process.env.JWT_SECRET; if (ip) {
if (!secret) { failedAttempts.delete(ip);
return { valid: false, error: 'JWT secret not configured' };
} }
try { return { valid: true };
jwt.verify(token, secret);
// Reset failed attempts on successful validation
if (ip) {
failedAttempts.delete(ip);
}
return { valid: true };
} catch (error) {
this.recordFailedAttempt(ip);
return { valid: false, error: 'Invalid token signature' };
}
} catch (error) { } catch (error) {
this.recordFailedAttempt(ip); this.recordFailedAttempt(ip);
if (error instanceof jwt.JsonWebTokenError) {
return { valid: false, error: 'Invalid token signature' };
}
if (error instanceof jwt.TokenExpiredError) {
return { valid: false, error: 'Token has expired' };
}
return { valid: false, error: 'Token validation failed' }; return { valid: false, error: 'Token validation failed' };
} }
} }
/**
* Checks if an IP is rate limited
*/
private static isRateLimited(ip: string): boolean {
const attempts = failedAttempts.get(ip);
if (!attempts) return false;
const now = Date.now();
const timeSinceLastAttempt = now - attempts.lastAttempt;
// Reset if outside lockout period
if (timeSinceLastAttempt >= SECURITY_CONFIG.LOCKOUT_DURATION) {
failedAttempts.delete(ip);
return false;
}
return attempts.count >= SECURITY_CONFIG.MAX_FAILED_ATTEMPTS;
}
/** /**
* Records a failed authentication attempt * Records a failed authentication attempt
*/ */
private static recordFailedAttempt(ip?: string): void { private static recordFailedAttempt(ip?: string): void {
if (!ip) return; if (!ip) return;
const attempts = failedAttempts.get(ip) || { count: 0, lastAttempt: 0 };
const now = Date.now(); const now = Date.now();
const attempts = failedAttempts.get(ip);
// Reset count if last attempt was outside lockout period if (!attempts || (now - attempts.lastAttempt) >= SECURITY_CONFIG.LOCKOUT_DURATION) {
if (now - attempts.lastAttempt > SECURITY_CONFIG.LOCKOUT_DURATION) { // First attempt or reset after lockout
attempts.count = 1; failedAttempts.set(ip, { count: 1, lastAttempt: now });
} else { } else {
// Increment existing attempts
attempts.count++; attempts.count++;
attempts.lastAttempt = now;
failedAttempts.set(ip, attempts);
} }
attempts.lastAttempt = now;
failedAttempts.set(ip, attempts);
} }
/** /**
@@ -231,15 +247,18 @@ export class TokenManager {
throw new Error('JWT secret not configured'); throw new Error('JWT secret not configured');
} }
// Ensure we don't override system claims
const sanitizedPayload = { ...payload };
delete (sanitizedPayload as any).iat;
delete (sanitizedPayload as any).exp;
return jwt.sign( return jwt.sign(
{ sanitizedPayload,
...payload,
iat: Math.floor(Date.now() / 1000),
},
secret, secret,
{ {
expiresIn: Math.floor(expiresIn / 1000), expiresIn: Math.floor(expiresIn / 1000),
algorithm: 'HS256' algorithm: 'HS256',
notBefore: 0 // Token is valid immediately
} }
); );
} }

View File

@@ -7,10 +7,17 @@ const DEFAULT_MAX_CLIENTS = 1000;
const DEFAULT_PING_INTERVAL = 30000; // 30 seconds const DEFAULT_PING_INTERVAL = 30000; // 30 seconds
const DEFAULT_CLEANUP_INTERVAL = 60000; // 1 minute const DEFAULT_CLEANUP_INTERVAL = 60000; // 1 minute
const DEFAULT_MAX_CONNECTION_AGE = 24 * 60 * 60 * 1000; // 24 hours const DEFAULT_MAX_CONNECTION_AGE = 24 * 60 * 60 * 1000; // 24 hours
const DEFAULT_RATE_LIMIT = {
MAX_MESSAGES: 100, // messages
WINDOW_MS: 60000, // 1 minute
BURST_LIMIT: 10 // max messages per second
};
interface RateLimit { interface RateLimit {
count: number; count: number;
lastReset: number; lastReset: number;
burstCount: number;
lastBurstReset: number;
} }
export interface SSEClient { export interface SSEClient {
@@ -32,6 +39,8 @@ interface ClientStats {
lastPingAt?: Date; lastPingAt?: Date;
subscriptionCount: number; subscriptionCount: number;
connectionDuration: number; connectionDuration: number;
messagesSent: number;
lastActivity: Date;
} }
export class SSEManager extends EventEmitter { export class SSEManager extends EventEmitter {
@@ -42,18 +51,21 @@ export class SSEManager extends EventEmitter {
private readonly pingInterval: number; private readonly pingInterval: number;
private readonly cleanupInterval: number; private readonly cleanupInterval: number;
private readonly maxConnectionAge: number; private readonly maxConnectionAge: number;
private readonly rateLimit: typeof DEFAULT_RATE_LIMIT;
constructor(options: { constructor(options: {
maxClients?: number; maxClients?: number;
pingInterval?: number; pingInterval?: number;
cleanupInterval?: number; cleanupInterval?: number;
maxConnectionAge?: number; maxConnectionAge?: number;
rateLimit?: Partial<typeof DEFAULT_RATE_LIMIT>;
} = {}) { } = {}) {
super(); super();
this.maxClients = options.maxClients || DEFAULT_MAX_CLIENTS; this.maxClients = options.maxClients || DEFAULT_MAX_CLIENTS;
this.pingInterval = options.pingInterval || DEFAULT_PING_INTERVAL; this.pingInterval = options.pingInterval || DEFAULT_PING_INTERVAL;
this.cleanupInterval = options.cleanupInterval || DEFAULT_CLEANUP_INTERVAL; this.cleanupInterval = options.cleanupInterval || DEFAULT_CLEANUP_INTERVAL;
this.maxConnectionAge = options.maxConnectionAge || DEFAULT_MAX_CONNECTION_AGE; this.maxConnectionAge = options.maxConnectionAge || DEFAULT_MAX_CONNECTION_AGE;
this.rateLimit = { ...DEFAULT_RATE_LIMIT, ...options.rateLimit };
console.log('Initializing SSE Manager...'); console.log('Initializing SSE Manager...');
this.startMaintenanceTasks(); this.startMaintenanceTasks();
@@ -63,15 +75,17 @@ export class SSEManager extends EventEmitter {
// Send periodic pings to keep connections alive // Send periodic pings to keep connections alive
setInterval(() => { setInterval(() => {
this.clients.forEach(client => { this.clients.forEach(client => {
try { if (!this.isRateLimited(client)) {
client.send(JSON.stringify({ try {
type: 'ping', client.send(JSON.stringify({
timestamp: new Date().toISOString() type: 'ping',
})); timestamp: new Date().toISOString()
client.lastPingAt = new Date(); }));
} catch (error) { client.lastPingAt = new Date();
console.error(`Failed to ping client ${client.id}:`, error); } catch (error) {
this.removeClient(client.id); console.error(`Failed to ping client ${client.id}:`, error);
this.removeClient(client.id);
}
} }
}); });
}, this.pingInterval); }, this.pingInterval);
@@ -98,7 +112,7 @@ export class SSEManager extends EventEmitter {
return SSEManager.instance; return SSEManager.instance;
} }
addClient(client: Omit<SSEClient, 'authenticated' | 'subscriptions'>, token: string): SSEClient | null { addClient(client: Omit<SSEClient, 'authenticated' | 'subscriptions' | 'rateLimit'>, token: string): SSEClient | null {
// Validate token // Validate token
const validationResult = TokenManager.validateToken(token, client.ip); const validationResult = TokenManager.validateToken(token, client.ip);
if (!validationResult.valid) { if (!validationResult.valid) {
@@ -117,7 +131,13 @@ export class SSEManager extends EventEmitter {
...client, ...client,
authenticated: true, authenticated: true,
subscriptions: new Set(), subscriptions: new Set(),
lastPingAt: new Date() lastPingAt: new Date(),
rateLimit: {
count: 0,
lastReset: Date.now(),
burstCount: 0,
lastBurstReset: Date.now()
}
}; };
this.clients.set(client.id, newClient); this.clients.set(client.id, newClient);
@@ -126,22 +146,46 @@ export class SSEManager extends EventEmitter {
return newClient; return newClient;
} }
private startClientPing(clientId: string) { private isRateLimited(client: SSEClient): boolean {
const interval = setInterval(() => { const now = Date.now();
const client = this.clients.get(clientId);
if (!client) {
clearInterval(interval);
return;
}
this.sendToClient(client, { // Reset window counters if needed
type: 'ping', if (now - client.rateLimit.lastReset >= this.rateLimit.WINDOW_MS) {
timestamp: new Date().toISOString() client.rateLimit.count = 0;
}); client.rateLimit.lastReset = now;
}, this.pingInterval); }
// Reset burst counters if needed (every second)
if (now - client.rateLimit.lastBurstReset >= 1000) {
client.rateLimit.burstCount = 0;
client.rateLimit.lastBurstReset = now;
}
// Check both window and burst limits
return (
client.rateLimit.count >= this.rateLimit.MAX_MESSAGES ||
client.rateLimit.burstCount >= this.rateLimit.BURST_LIMIT
);
} }
removeClient(clientId: string) { private updateRateLimit(client: SSEClient): void {
const now = Date.now();
client.rateLimit.count++;
client.rateLimit.burstCount++;
// Update timestamps if needed
if (now - client.rateLimit.lastReset >= this.rateLimit.WINDOW_MS) {
client.rateLimit.lastReset = now;
client.rateLimit.count = 1;
}
if (now - client.rateLimit.lastBurstReset >= 1000) {
client.rateLimit.lastBurstReset = now;
client.rateLimit.burstCount = 1;
}
}
removeClient(clientId: string): void {
if (this.clients.has(clientId)) { if (this.clients.has(clientId)) {
this.clients.delete(clientId); this.clients.delete(clientId);
console.log(`SSE client disconnected: ${clientId}`); console.log(`SSE client disconnected: ${clientId}`);
@@ -152,46 +196,55 @@ export class SSEManager extends EventEmitter {
} }
} }
subscribeToEntity(clientId: string, entityId: string) { subscribeToEntity(clientId: string, entityId: string): void {
const client = this.clients.get(clientId); const client = this.clients.get(clientId);
if (client?.authenticated) { if (!client?.authenticated) {
client.subscriptions.add(`entity:${entityId}`); console.warn(`Unauthenticated client ${clientId} attempted to subscribe to entity: ${entityId}`);
console.log(`Client ${clientId} subscribed to entity: ${entityId}`); return;
}
// Send current state if available client.subscriptions.add(`entity:${entityId}`);
const currentState = this.entityStates.get(entityId); console.log(`Client ${clientId} subscribed to entity: ${entityId}`);
if (currentState) {
this.sendToClient(client, { // Send current state if available
type: 'state_changed', const currentState = this.entityStates.get(entityId);
data: { if (currentState && !this.isRateLimited(client)) {
entity_id: currentState.entity_id, this.sendToClient(client, {
state: currentState.state, type: 'state_changed',
attributes: currentState.attributes, data: {
last_changed: currentState.last_changed, entity_id: currentState.entity_id,
last_updated: currentState.last_updated state: currentState.state,
} attributes: currentState.attributes,
}); last_changed: currentState.last_changed,
} last_updated: currentState.last_updated
}
});
} }
} }
subscribeToDomain(clientId: string, domain: string) { subscribeToDomain(clientId: string, domain: string): void {
const client = this.clients.get(clientId); const client = this.clients.get(clientId);
if (client?.authenticated) { if (!client?.authenticated) {
client.subscriptions.add(`domain:${domain}`); console.warn(`Unauthenticated client ${clientId} attempted to subscribe to domain: ${domain}`);
console.log(`Client ${clientId} subscribed to domain: ${domain}`); return;
} }
client.subscriptions.add(`domain:${domain}`);
console.log(`Client ${clientId} subscribed to domain: ${domain}`);
} }
subscribeToEvent(clientId: string, eventType: string) { subscribeToEvent(clientId: string, eventType: string): void {
const client = this.clients.get(clientId); const client = this.clients.get(clientId);
if (client?.authenticated) { if (!client?.authenticated) {
client.subscriptions.add(`event:${eventType}`); console.warn(`Unauthenticated client ${clientId} attempted to subscribe to event: ${eventType}`);
console.log(`Client ${clientId} subscribed to event: ${eventType}`); return;
} }
client.subscriptions.add(`event:${eventType}`);
console.log(`Client ${clientId} subscribed to event: ${eventType}`);
} }
broadcastStateChange(entity: HassEntity) { broadcastStateChange(entity: HassEntity): void {
// Update stored state // Update stored state
this.entityStates.set(entity.entity_id, entity); this.entityStates.set(entity.entity_id, entity);
@@ -211,8 +264,8 @@ export class SSEManager extends EventEmitter {
console.log(`Broadcasting state change for ${entity.entity_id}`); console.log(`Broadcasting state change for ${entity.entity_id}`);
// Send to relevant subscribers only // Send to relevant subscribers only
for (const client of this.clients.values()) { this.clients.forEach(client => {
if (!client.authenticated) continue; if (!client.authenticated || this.isRateLimited(client)) return;
if ( if (
client.subscriptions.has(`entity:${entity.entity_id}`) || client.subscriptions.has(`entity:${entity.entity_id}`) ||
@@ -221,10 +274,10 @@ export class SSEManager extends EventEmitter {
) { ) {
this.sendToClient(client, message); this.sendToClient(client, message);
} }
} });
} }
broadcastEvent(event: HassEvent) { broadcastEvent(event: HassEvent): void {
const message = { const message = {
type: event.event_type, type: event.event_type,
data: event.data, data: event.data,
@@ -237,117 +290,36 @@ export class SSEManager extends EventEmitter {
console.log(`Broadcasting event: ${event.event_type}`); console.log(`Broadcasting event: ${event.event_type}`);
// Send to relevant subscribers only // Send to relevant subscribers only
for (const client of this.clients.values()) { this.clients.forEach(client => {
if (!client.authenticated) continue; if (!client.authenticated || this.isRateLimited(client)) return;
if (client.subscriptions.has(`event:${event.event_type}`)) { if (client.subscriptions.has(`event:${event.event_type}`)) {
this.sendToClient(client, message); this.sendToClient(client, message);
} }
} });
} }
private sendToClient(client: SSEClient, data: any) { private sendToClient(client: SSEClient, data: unknown): void {
try { try {
// Check rate limit if (!client.authenticated) {
const now = Date.now(); console.warn(`Attempted to send message to unauthenticated client ${client.id}`);
if (now - client.rateLimit.lastReset > this.cleanupInterval) {
client.rateLimit.count = 0;
client.rateLimit.lastReset = now;
}
if (client.rateLimit.count >= 1000) {
console.warn(`Rate limit exceeded for client ${client.id}`);
this.sendToClient(client, {
type: 'error',
error: 'rate_limit_exceeded',
message: 'Too many requests, please try again later',
timestamp: new Date().toISOString()
});
return; return;
} }
client.rateLimit.count++; if (this.isRateLimited(client)) {
client.lastPingAt = new Date(); console.warn(`Rate limit exceeded for client ${client.id}`);
client.send(JSON.stringify(data)); return;
}
const message = typeof data === 'string' ? data : JSON.stringify(data);
client.send(message);
this.updateRateLimit(client);
} catch (error) { } catch (error) {
console.error(`Error sending message to client ${client.id}:`, error); console.error(`Failed to send message to client ${client.id}:`, error);
this.removeClient(client.id); this.removeClient(client.id);
} }
} }
private validateToken(token?: string): boolean {
if (!token) return false;
const validationResult = TokenManager.validateToken(token);
return validationResult.valid;
}
// Utility methods
getConnectedClients(): number {
return this.clients.size;
}
getClientSubscriptions(clientId: string) {
return this.clients.get(clientId)?.subscriptions;
}
getEntityState(entityId: string): HassEntity | undefined {
return this.entityStates.get(entityId);
}
// Add new event types
broadcastServiceCall(domain: string, service: string, data: any) {
const message = {
type: 'service_called',
data: {
domain,
service,
service_data: data
},
timestamp: new Date().toISOString()
};
this.broadcastToSubscribers('service_called', message);
}
broadcastAutomationTriggered(automationId: string, trigger: any) {
const message = {
type: 'automation_triggered',
data: {
automation_id: automationId,
trigger
},
timestamp: new Date().toISOString()
};
this.broadcastToSubscribers('automation_triggered', message);
}
broadcastScriptExecuted(scriptId: string, data: any) {
const message = {
type: 'script_executed',
data: {
script_id: scriptId,
execution_data: data
},
timestamp: new Date().toISOString()
};
this.broadcastToSubscribers('script_executed', message);
}
private broadcastToSubscribers(eventType: string, message: any) {
for (const client of this.clients.values()) {
if (!client.authenticated) continue;
if (client.subscriptions.has(`event:${eventType}`) ||
client.subscriptions.has(`entity:${eventType}`) ||
client.subscriptions.has(`domain:${eventType.split('.')[0]}`)) {
this.sendToClient(client, message);
}
}
}
// Add statistics methods
getStatistics(): { getStatistics(): {
totalClients: number; totalClients: number;
authenticatedClients: number; authenticatedClients: number;
@@ -356,31 +328,35 @@ export class SSEManager extends EventEmitter {
} { } {
const now = Date.now(); const now = Date.now();
const clientStats: ClientStats[] = []; const clientStats: ClientStats[] = [];
const subscriptionCounts: { [key: string]: number } = {}; const subscriptionStats: { [key: string]: number } = {};
let authenticatedClients = 0;
this.clients.forEach(client => { this.clients.forEach(client => {
// Collect client statistics if (client.authenticated) {
authenticatedClients++;
}
clientStats.push({ clientStats.push({
id: client.id, id: client.id,
ip: client.ip, ip: client.ip,
connectedAt: client.connectedAt, connectedAt: client.connectedAt,
lastPingAt: client.lastPingAt, lastPingAt: client.lastPingAt,
subscriptionCount: client.subscriptions.size, subscriptionCount: client.subscriptions.size,
connectionDuration: now - client.connectedAt.getTime() connectionDuration: now - client.connectedAt.getTime(),
messagesSent: client.rateLimit.count,
lastActivity: new Date(client.rateLimit.lastReset)
}); });
// Count subscriptions by type
client.subscriptions.forEach(sub => { client.subscriptions.forEach(sub => {
const [type] = sub.split(':'); subscriptionStats[sub] = (subscriptionStats[sub] || 0) + 1;
subscriptionCounts[type] = (subscriptionCounts[type] || 0) + 1;
}); });
}); });
return { return {
totalClients: this.clients.size, totalClients: this.clients.size,
authenticatedClients: Array.from(this.clients.values()).filter(c => c.authenticated).length, authenticatedClients,
clientStats, clientStats,
subscriptionStats: subscriptionCounts subscriptionStats
}; };
} }
} }