#!/usr/bin/env python3 """ lb_proxy.py — minimal streaming reverse proxy / load balancer for the two vLLM 122B replicas. Why not nginx: this needs to stream Server-Sent Events (token-by-token chat completions) without buffering, route only to replicas that are actually up, and have zero install footprint. A ~120-line asyncio proxy does exactly that and runs inside the existing vLLM container (which has Python + the stdlib; we also use httpx, which ships in the vLLM image). Routing: least-busy of the healthy replicas. We track in-flight requests per backend and send each new request to the one with the fewest. This beats blind round-robin when one user fires a huge 128k prompt — the other replica keeps serving short requests. Health: a backend that fails to connect is marked down and skipped; we re-probe it on the next request after a short cooldown. Env vars: PUBLIC_PORT public port to listen on (default 7080) BACKENDS comma-separated backend base URLs (default http://127.0.0.1:7081,http://127.0.0.1:7082) """ import asyncio import os import time import httpx from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response, StreamingResponse, JSONResponse from starlette.routing import Route import uvicorn BACKENDS = [ b.strip() for b in os.environ.get( "BACKENDS", "http://127.0.0.1:7091,http://127.0.0.1:7092" ).split(",") if b.strip() ] PUBLIC_PORT = int(os.environ.get("PUBLIC_PORT", "7080")) COOLDOWN_S = 10.0 # how long to keep a backend marked down before retry # Per-backend state: in-flight count + "down until" timestamp. inflight = {b: 0 for b in BACKENDS} down_until = {b: 0.0 for b in BACKENDS} # Long timeout: 128k-context generations can run for minutes. client = httpx.AsyncClient(timeout=httpx.Timeout(None, connect=10.0)) def pick_backend() -> str | None: now = time.monotonic() healthy = [b for b in BACKENDS if down_until[b] <= now] if not healthy: # All cooling down — try the soonest-available one anyway. healthy = BACKENDS return min(healthy, key=lambda b: inflight[b]) async def proxy(request: Request) -> Response: backend = pick_backend() if backend is None: return JSONResponse({"error": "no backend available"}, status_code=503) url = backend + request.url.path if request.url.query: url += "?" + request.url.query # Drop hop-by-hop / host headers; pass the rest (incl. Authorization). headers = { k: v for k, v in request.headers.items() if k.lower() not in ("host", "content-length", "connection") } body = await request.body() inflight[backend] += 1 try: req = client.build_request( request.method, url, headers=headers, content=body ) upstream = await client.send(req, stream=True) except (httpx.ConnectError, httpx.ConnectTimeout): down_until[backend] = time.monotonic() + COOLDOWN_S inflight[backend] -= 1 # Retry once on the other backend. other = pick_backend() if other and other != backend: return await proxy(request) return JSONResponse({"error": "backend unreachable"}, status_code=502) resp_headers = { k: v for k, v in upstream.headers.items() if k.lower() not in ("content-length", "transfer-encoding", "connection") } async def stream(): try: async for chunk in upstream.aiter_raw(): yield chunk finally: await upstream.aclose() inflight[backend] -= 1 return StreamingResponse( stream(), status_code=upstream.status_code, headers=resp_headers, media_type=upstream.headers.get("content-type"), ) async def health(request: Request) -> Response: now = time.monotonic() return JSONResponse( { "backends": [ { "url": b, "inflight": inflight[b], "up": down_until[b] <= now, } for b in BACKENDS ] } ) app = Starlette( routes=[ Route("/_lb_health", health, methods=["GET"]), Route("/{path:path}", proxy, methods=["GET", "POST", "PUT", "DELETE"]), ] ) if __name__ == "__main__": print(f"Load balancer on :{PUBLIC_PORT} -> {BACKENDS}", flush=True) uvicorn.run(app, host="0.0.0.0", port=PUBLIC_PORT, log_level="warning")