Python WebSockets: Real-Time Communication with FastAPI
WebSockets enable full-duplex communication between browser and server over a persistent connection. Unlike HTTP polling, a WebSocket connection stays open, allowing the server to push data to clients the instant it's available — essential for chat applications, live dashboards, collaborative tools, and real-time notifications. FastAPI has native WebSocket support built on Starlette's ASGI foundation, making it straightforward to build production-grade real-time services.
Table of Contents
WebSocket Basics in FastAPI
FastAPI exposes WebSocket endpoints with the @app.websocket() decorator. The WebSocket object provides accept(), send_text(), send_json(), send_bytes(), and receive_text()/receive_json() methods. The endpoint runs until the connection closes — either the client disconnects or you call close(). Always handle WebSocketDisconnect to clean up resources.
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import json
app = FastAPI()
@app.websocket("/ws/echo")
async def echo_endpoint(websocket: WebSocket):
"""Basic echo server — sends back whatever the client sends."""
await websocket.accept()
try:
while True:
# Receive text, JSON, or bytes
data = await websocket.receive_text()
await websocket.send_text(f"Echo: {data}")
except WebSocketDisconnect:
print("Client disconnected")
@app.websocket("/ws/ping")
async def ping_endpoint(websocket: WebSocket):
"""JSON ping/pong with structured messages."""
await websocket.accept()
await websocket.send_json({"type": "connected", "message": "WebSocket ready"})
try:
while True:
data = await websocket.receive_json()
if data.get("type") == "ping":
await websocket.send_json({
"type": "pong",
"timestamp": data.get("timestamp"),
"server_time": __import__('time').time()
})
elif data.get("type") == "close":
await websocket.close()
break
except WebSocketDisconnect as e:
print(f"Disconnected: code={e.code}")
@app.get("/")
async def get_test_page():
return HTMLResponse("""
Check console for WebSocket messages
""")
Connection Manager: Broadcasting
A connection manager tracks all active WebSocket connections and provides broadcast methods. This is the core data structure for any multi-user real-time application. Thread safety is not an issue since FastAPI uses asyncio (single-threaded event loop), but you must handle WebSocketDisconnect promptly to remove stale connections.
import asyncio
from typing import Optional
class ConnectionManager:
"""Manages active WebSocket connections with broadcast support."""
def __init__(self):
self.active_connections: list[WebSocket] = []
self._lock = asyncio.Lock()
async def connect(self, websocket: WebSocket):
await websocket.accept()
async with self._lock:
self.active_connections.append(websocket)
async def disconnect(self, websocket: WebSocket):
async with self._lock:
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def send_personal(self, message: dict, websocket: WebSocket):
try:
await websocket.send_json(message)
except Exception:
await self.disconnect(websocket)
async def broadcast(self, message: dict, exclude: Optional[WebSocket] = None):
"""Send message to all connected clients except `exclude`."""
dead = []
async with self._lock:
targets = list(self.active_connections)
for connection in targets:
if connection is exclude:
continue
try:
await connection.send_json(message)
except Exception:
dead.append(connection)
for d in dead:
await self.disconnect(d)
@property
def connection_count(self) -> int:
return len(self.active_connections)
manager = ConnectionManager()
@app.websocket("/ws/chat")
async def chat_endpoint(websocket: WebSocket, username: str = "Anonymous"):
await manager.connect(websocket)
await manager.broadcast({
"type": "system",
"message": f"{username} joined the chat",
"users_online": manager.connection_count,
})
try:
while True:
data = await websocket.receive_json()
await manager.broadcast({
"type": "message",
"from": username,
"content": data.get("content", ""),
"timestamp": __import__('time').time(),
}, exclude=websocket)
except WebSocketDisconnect:
await manager.disconnect(websocket)
await manager.broadcast({
"type": "system",
"message": f"{username} left the chat",
"users_online": manager.connection_count,
})
Rooms and Channels
Rooms (also called channels) group connections so broadcasts only go to members of that room. This pattern enables chat rooms, collaborative document editing where each document is a room, and dashboards where each tenant sees only their data.
from collections import defaultdict
class RoomManager:
"""WebSocket room management with per-room broadcasting."""
def __init__(self):
self.rooms: dict[str, set[WebSocket]] = defaultdict(set)
self.connection_rooms: dict[WebSocket, str] = {}
async def join(self, room_id: str, websocket: WebSocket):
await websocket.accept()
self.rooms[room_id].add(websocket)
self.connection_rooms[websocket] = room_id
count = len(self.rooms[room_id])
await self.broadcast_to_room(room_id, {
"type": "user_joined",
"room": room_id,
"members": count,
}, exclude=websocket)
async def leave(self, websocket: WebSocket):
room_id = self.connection_rooms.pop(websocket, None)
if room_id:
self.rooms[room_id].discard(websocket)
if not self.rooms[room_id]:
del self.rooms[room_id]
else:
await self.broadcast_to_room(room_id, {
"type": "user_left",
"room": room_id,
"members": len(self.rooms[room_id]),
})
async def broadcast_to_room(self, room_id: str, message: dict, exclude: WebSocket = None):
dead = []
for ws in list(self.rooms.get(room_id, set())):
if ws is exclude:
continue
try:
await ws.send_json(message)
except Exception:
dead.append(ws)
for ws in dead:
await self.leave(ws)
def get_rooms(self) -> dict[str, int]:
return {room: len(members) for room, members in self.rooms.items()}
room_manager = RoomManager()
@app.websocket("/ws/room/{room_id}")
async def room_endpoint(websocket: WebSocket, room_id: str):
await room_manager.join(room_id, websocket)
try:
while True:
data = await websocket.receive_json()
await room_manager.broadcast_to_room(room_id, {
"type": "message",
"room": room_id,
"content": data.get("content", ""),
}, exclude=websocket)
except WebSocketDisconnect:
await room_manager.leave(websocket)
WebSocket Authentication
Browsers don't support custom headers for WebSocket connections, so JWT tokens must be passed as query parameters or as the first message after connection. The query parameter approach is simpler; the first-message approach avoids tokens appearing in server logs.
import jwt
from fastapi import WebSocket, WebSocketDisconnect, Query, status
async def get_ws_user(token: str) -> dict:
"""Validate JWT token for WebSocket authentication."""
try:
payload = jwt.decode(token, "secret", algorithms=["HS256"])
return {"user_id": payload["sub"], "name": payload.get("name", "User")}
except jwt.InvalidTokenError:
return None
# Method 1: Token as query parameter
@app.websocket("/ws/secure")
async def secure_ws(websocket: WebSocket, token: str = Query(None)):
user = await get_ws_user(token) if token else None
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await websocket.accept()
await websocket.send_json({"type": "auth_success", "user": user["name"]})
try:
while True:
data = await websocket.receive_json()
await websocket.send_json({
"type": "echo",
"from": user["name"],
"content": data.get("content"),
})
except WebSocketDisconnect:
pass
# Method 2: Authenticate via first message (token not in URL/logs)
@app.websocket("/ws/auth-first")
async def auth_first_ws(websocket: WebSocket):
await websocket.accept() # Accept but don't process yet
await websocket.send_json({"type": "auth_required"})
# Wait for auth message
try:
auth_data = await asyncio.wait_for(websocket.receive_json(), timeout=10.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
if auth_data.get("type") != "auth" or not auth_data.get("token"):
await websocket.send_json({"type": "error", "message": "Auth required"})
await websocket.close(code=1008)
return
user = await get_ws_user(auth_data["token"])
if not user:
await websocket.send_json({"type": "error", "message": "Invalid token"})
await websocket.close(code=1008)
return
await websocket.send_json({"type": "auth_success"})
# Proceed with authenticated session
try:
while True:
data = await websocket.receive_json()
await websocket.send_json({"type": "ack", "user": user["name"]})
except WebSocketDisconnect:
pass
Redis Pub/Sub for Multi-Server
When running multiple server instances behind a load balancer, each server has its own in-memory connection set. A message sent to a connection on server A never reaches clients on server B. Redis Pub/Sub solves this: servers publish to Redis channels, and each server subscribes and forwards messages to its local connections.
import asyncio
import json
import redis.asyncio as aioredis
class RedisBackedManager:
"""WebSocket manager backed by Redis Pub/Sub for multi-instance support."""
def __init__(self, redis_url: str):
self.redis = aioredis.from_url(redis_url)
self.local_connections: dict[str, set[WebSocket]] = defaultdict(set)
async def connect(self, channel: str, websocket: WebSocket):
await websocket.accept()
self.local_connections[channel].add(websocket)
async def disconnect(self, channel: str, websocket: WebSocket):
self.local_connections[channel].discard(websocket)
async def publish(self, channel: str, message: dict):
"""Publish to Redis — all server instances will receive this."""
await self.redis.publish(channel, json.dumps(message))
async def subscribe_and_forward(self, channel: str):
"""Subscribe to Redis channel and forward messages to local WebSockets."""
pubsub = self.redis.pubsub()
await pubsub.subscribe(channel)
async for redis_message in pubsub.listen():
if redis_message["type"] == "message":
data = json.loads(redis_message["data"])
dead = []
for ws in list(self.local_connections.get(channel, set())):
try:
await ws.send_json(data)
except Exception:
dead.append(ws)
for ws in dead:
await self.disconnect(channel, ws)
redis_manager = RedisBackedManager("redis://localhost:6379")
@app.websocket("/ws/global/{channel}")
async def global_channel(websocket: WebSocket, channel: str):
await redis_manager.connect(channel, websocket)
# Start Redis subscriber for this channel (in background)
sub_task = asyncio.create_task(redis_manager.subscribe_and_forward(channel))
try:
while True:
data = await websocket.receive_json()
# Publish through Redis so ALL server instances get it
await redis_manager.publish(channel, {
"content": data.get("content"),
"channel": channel,
})
except WebSocketDisconnect:
await redis_manager.disconnect(channel, websocket)
sub_task.cancel()
Browser Client and Reconnection
Production WebSocket clients must handle disconnections gracefully with automatic reconnection and exponential backoff. Network interruptions, server restarts, and load balancer timeouts all cause WebSocket disconnections. A robust client reconnects automatically without user intervention.
// Robust WebSocket client with exponential backoff reconnection
class ReconnectingWebSocket {
constructor(url, options = {}) {
this.url = url;
this.maxRetries = options.maxRetries || 10;
this.baseDelay = options.baseDelay || 1000;
this.maxDelay = options.maxDelay || 30000;
this.retries = 0;
this.ws = null;
this.messageQueue = [];
this.connect();
}
connect() {
this.ws = new WebSocket(this.url);
this.ws.onopen = () => {
console.log("WebSocket connected");
this.retries = 0;
// Flush queued messages
while (this.messageQueue.length) {
this.ws.send(this.messageQueue.shift());
}
this.onopen?.();
};
this.ws.onmessage = (event) => {
const data = JSON.parse(event.data);
this.onmessage?.(data);
};
this.ws.onclose = (event) => {
if (!event.wasClean && this.retries < this.maxRetries) {
const delay = Math.min(
this.baseDelay * Math.pow(2, this.retries),
this.maxDelay
);
this.retries++;
console.log(`Reconnecting in ${delay}ms (attempt ${this.retries})`);
setTimeout(() => this.connect(), delay);
}
};
this.ws.onerror = (error) => console.error("WebSocket error:", error);
}
send(data) {
const message = JSON.stringify(data);
if (this.ws.readyState === WebSocket.OPEN) {
this.ws.send(message);
} else {
this.messageQueue.push(message); // Queue for when reconnected
}
}
close() { this.maxRetries = 0; this.ws.close(); }
}
// Usage
const ws = new ReconnectingWebSocket("wss://api.techoral.com/ws/chat");
ws.onmessage = (data) => console.log("Received:", data);
ws.send({ type: "message", content: "Hello!" });
Production Patterns
Production WebSocket services need heartbeat pings to detect dead connections, message rate limiting to prevent abuse, proper error handling for all disconnect codes, and Nginx/load balancer configuration for WebSocket proxying. FastAPI's lifespan events are the right place to initialize shared resources like Redis connections.
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
# Heartbeat to detect dead connections
async def heartbeat(websocket: WebSocket, interval: int = 30):
"""Send ping every `interval` seconds, disconnect if no pong received."""
while True:
await asyncio.sleep(interval)
try:
await websocket.send_json({"type": "ping"})
except Exception:
break
@app.websocket("/ws/with-heartbeat")
async def heartbeat_ws(websocket: WebSocket):
await websocket.accept()
ping_task = asyncio.create_task(heartbeat(websocket))
try:
while True:
data = await websocket.receive_json()
if data.get("type") == "pong":
continue # heartbeat acknowledged
await websocket.send_json({"echo": data})
except WebSocketDisconnect:
pass
finally:
ping_task.cancel()
# Nginx config for WebSocket proxying:
"""
location /ws/ {
proxy_pass http://localhost:8000;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_read_timeout 3600s; # Keep alive for 1 hour
proxy_send_timeout 3600s;
}
"""