TPU Pod Optical Interconnect System
Complete technical deep-dive into Google TPU Pod's optical circuit switching architecture, 3D torus topology, collective communication optimization, and datacenter-scale AI infrastructure
Prerequisites
Make sure you're familiar with these concepts before diving in:
Learning Objectives
By the end of this topic, you will be able to:
Table of Contents
1. Table of Contents
- Overview and Motivation
- TPU Architecture Primer
- Optical Circuit Switching Fundamentals
- 3D Torus Network Topology
- Physical Layer Technology
- Link Layer and Protocol Stack
- Collective Communication Optimization
- Evolution Across TPU Generations
- Performance Analysis
- Power and Thermal Considerations
- Fault Tolerance and Reliability
- Programming Model and Software Stack
- Industry Impact and Research Directions
2. 1. Overview and Motivation
2.1 What is a TPU Pod?
A TPU Pod is Google's solution for scaling Tensor Processing Units from single chips to supercomputer-scale systems capable of training the largest machine learning models in production. The Pod's optical interconnect fabric is the critical enabler that allows thousands of specialized AI accelerators to work cohesively as a single, unified computational resource.
2.2 The Scaling Challenge
Why Traditional Interconnects Fall Short:
Modern deep learning training presents unique challenges:
-
Bulk-Synchronous Parallelism (BSP):
- All accelerators must synchronize at layer boundaries
- Requires global all-reduce operations every few milliseconds
- Stragglers kill performance (tail latency problem)
-
Bandwidth Requirements:
- Large language models (LLMs): 100+ billion parameters
- Parameter updates: Multiple GB per training step
- Activation passing: Tens of GB for model parallelism
- Total: 10-100 Tbps aggregate bandwidth needed
-
Power Constraints:
- Accelerators: 300-500W each
- Traditional interconnects: 50-100W per device
- At 1000 devices: Interconnect power = 50-100 kW
- Goal: Reduce to <20% of total power budget
-
Scalability:
- Need to scale from 256 to 4,096+ accelerators
- Traditional electrical I/O: bandwidth-per-watt plateaus
- Multi-tier networks: bandwidth cliffs at boundaries
2.3 Google's Solution: Optical Circuit Switching
Key Insight: AI training workloads are:
- Predictable: Communication patterns known a priori
- Regular: Structured collectives (all-reduce, all-gather)
- Deterministic: Same operations every iteration
- Bulk: Large data transfers, not tiny packets
Design Choice: Circuit-switched optical network
- No packet headers (wasted bandwidth)
- No store-and-forward latency
- No congestion or contention
- Deterministic performance
- Superior power efficiency
3. 2. TPU Architecture Primer
3.1 2.1 Single TPU Chip Architecture
Before understanding the interconnect, we need to understand what we're connecting:
┌─────────────────────────────────────────────────┐
│ TPU v4 Chip │
├─────────────────────────────────────────────────┤
│ ┌─────────────┐ ┌─────────────┐ │
│ │ MXU 0 │ │ MXU 1 │ Matrix │
│ │ 128×128 │ │ 128×128 │ Multiply │
│ │ Systolic │ │ Systolic │ Units │
│ └─────────────┘ └─────────────┘ │
│ │ │ │
│ ┌──────┴─────────────────┴──────┐ │
│ │ Vector Units (VPU) │ │
│ │ • FP32/BF16/INT8 operations │ │
│ │ • 16K elements/cycle │ │
│ └───────────────────────────────┘ │
│ │ │
│ ┌──────┴─────────────────────────┐ │
│ │ High Bandwidth Memory (HBM) │ │
│ │ • 32 GB HBM2e │ │
│ │ • 1.2 TB/s bandwidth │ │
│ └────────────────────────────────┘ │
│ │ │
│ ┌──────┴─────────────────────────┐ │
│ │ Interconnect Controllers │ │
│ │ • 6× Optical Links │ ◄─────────┐│
│ │ • Router logic │ I │
│ │ • DMA engines │ C │
│ └────────────────────────────────┘ I │
└─────────────────────────────────────────────────┘
Key Specifications (TPU v4):
- Peak compute: 275 TFLOPS (BF16)
- Memory: 32 GB HBM2e @ 1.2 TB/s
- Interconnect: 6 bidirectional optical links
- Power: ~175W compute + ~25W interconnect
3.2 2.2 Interconnect Requirements
Per-chip Bandwidth Math:
Model size: 175B parameters (GPT-3 scale)
4096 TPUs, sharded across all chips
Per-chip: 175B / 4096 = 42.7M parameters
All-reduce requirements per iteration:
- Forward pass: Receive activations
- Backward pass: Send gradients
- Parameter sync: All-reduce (~171 MB per chip)
Target: 500 iterations/second
Required BW: 171 MB × 500 = 85 GB/s bidirectional
Actual: 6 links × 15 GB/s = 90 GB/s ✓
4. 3. Optical Circuit Switching Fundamentals
4.1 3.1 Circuit Switching vs. Packet Switching
Packet Switching (Traditional):
Source → [Header|Payload] → Router (store, forward) → Destination
Latency = Serialization + Propagation + Queueing + Switching
= (Packet_size/BW) + (Distance/c) + (Queue_depth × Packet_time) + Switch_latency
Overhead:
- Headers: 5-10% of bandwidth
- Buffering: SRAM, power hungry
- Arbitration: Complex, non-deterministic
Circuit Switching (TPU Pod OCS):
Source → Optical Switch (passive) → Destination
Latency = Serialization + Propagation + Circuit_setup (one-time)
= (Message_size/BW) + (Distance/c) + Setup_latency
Advantages:
- No packet overhead (100% efficiency)
- No buffering needed
- Deterministic latency
- Lower power
4.2 3.2 Optical Circuit Switch Technology
MEMS (Micro-Electro-Mechanical Systems) Switches:
Early TPU Pods likely used MEMS-based OCS:
Input Fibers Mirrors (actuated) Output Fibers
│ ╱ │
──────┼──────────────────────●──────────────────────────┼──────
│ ╱ ╲ │
──────┼────────────────────●───●────────────────────────┼──────
│ ╱ ╲ │
──────┼──────────────────●───────●──────────────────────┼──────
Characteristics:
- Switch time: 10-20 ms (slow, but amortized over long data transfers)
- Port count: 320×320 or larger
- Insertion loss: <2 dB
- Power: ~1-2W per 100 ports (passive after switching)
Silicon Photonics Switches (Newer TPU Pods):
Ring Resonators (electrically tuned)
In ────●───●───●───●──── Out 1
│ │ │ │
●───●───●───●──── Out 2
Characteristics:
- Switch time: <1 μs (thermal tuning) or <10 ns (electro-optic)
- Port count: Limited (32×32), but can cascade
- Insertion loss: 1-3 dB per stage
- Power: ~10-50 mW per active port
4.3 3.3 Reconfigurability
Static vs. Dynamic Configuration:
# Pseudo-code for circuit configuration
# Static configuration (set at Pod initialization)
def configure_static_topology(pod_size):
"""Set up 3D torus"""
for chip_id in range(pod_size):
x, y, z = chip_id_to_coords(chip_id)
neighbors = {
'+X': coords_to_chip_id((x+1) % X_DIM, y, z),
'-X': coords_to_chip_id((x-1) % X_DIM, y, z),
'+Y': coords_to_chip_id(x, (y+1) % Y_DIM, z),
'-Y': coords_to_chip_id(x, (y-1) % Y_DIM, z),
'+Z': coords_to_chip_id(x, y, (z+1) % Z_DIM),
'-Z': coords_to_chip_id(x, y, (z-1) % Z_DIM),
}
configure_optical_circuits(chip_id, neighbors)
# Dynamic reconfiguration (for fault tolerance)
def reconfigure_around_failure(failed_chip):
"""Reroute around failed chip"""
affected_neighbors = get_neighbors(failed_chip)
for neighbor in affected_neighbors:
# Find alternate path (e.g., skip failed chip)
new_route = find_alternate_path(neighbor, failed_chip)
reconfigure_optical_switch(neighbor, new_route)
Use Cases for Reconfigurability:
- Fault tolerance: Route around failed chips/links
- Multi-tenancy: Partition Pod into independent slices
- Adaptive topology: Switch between 3D torus and fat-tree
- Power management: Power down unused links
5. 4. 3D Torus Network Topology
5.1 4.1 Topology Definition
3D Torus Structure:
Z-axis
│
│
(0,0,1)───(1,0,1)
╱│ ╱│
╱ │ ╱ │
(0,1,1)───(1,1,1)
│ │ │ │
│(0,0,0)┼(1,0,0)
│ ╱ │╱
│╱ │
(0,1,0)───(1,1,0)──── Y-axis
╱
╱
X-axis
Wrap-around connections (torus):
- (X_DIM-1, y, z) ↔ (0, y, z)
- (x, Y_DIM-1, z) ↔ (x, 0, z)
- (x, y, Z_DIM-1) ↔ (x, y, 0)
For TPU v4 Pod (4,096 chips):
- Dimensions: 16×16×16
- Links per chip: 6 (±X, ±Y, ±Z)
- Total links: 4,096 × 3 = 12,288 bidirectional links
- Diameter: 8 hops (maximum distance between any two chips)
5.2 4.2 Topology Properties
Degree:
- Every chip has exactly 6 neighbors
- Advantage: Constant per-chip complexity regardless of Pod size
Diameter:
- D = ⌈N^(1/3) / 2⌉ for N chips
- TPU v4: ⌈16/2⌉ = 8 hops maximum
- Advantage: Logarithmic scaling (sort of - cube root)
Bisection Bandwidth:
Bisection: Minimum link capacity if Pod is cut in half
For 16×16×16 torus, cut along one dimension:
Cut through X=8 plane: 16×16 = 256 links
Bandwidth per link: ~150 Gbps
Total bisection BW: 256 × 150 Gbps = 38.4 Tbps
Compare to Fat-tree (2:1 oversubscription):
Would need 256 × 75 Gbps = 19.2 Tbps
Torus: 2× better bisection bandwidth
Path Diversity:
- Multiple paths between any two nodes
- Example: From (0,0,0) to (8,8,8):
- Path 1: +X(8), +Y(8), +Z(8)
- Path 2: +Y(8), +Z(8), +X(8)
- Path 3: +Z(8), +X(8), +Y(8)
- Many more...
- Advantage: Load balancing and fault tolerance
5.3 4.3 Routing Algorithms
Dimension-Ordered Routing (DOR):
def route_packet(source, dest):
"""
Simple dimension-ordered routing for 3D torus
Route along X, then Y, then Z
"""
current = source
path = [current]
# Route along X dimension
while current.x != dest.x:
if (dest.x - current.x) % X_DIM < X_DIM / 2:
current.x = (current.x + 1) % X_DIM # +X direction
else:
current.x = (current.x - 1 + X_DIM) % X_DIM # -X direction
path.append(current)
# Route along Y dimension
while current.y != dest.y:
if (dest.y - current.y) % Y_DIM < Y_DIM / 2:
current.y = (current.y + 1) % Y_DIM
else:
current.y = (current.y - 1 + Y_DIM) % Y_DIM
path.append(current)
# Route along Z dimension
while current.z != dest.z:
if (dest.z - current.z) % Z_DIM < Z_DIM / 2:
current.z = (current.z + 1) % Z_DIM
else:
current.z = (current.z - 1 + Z_DIM) % Z_DIM
path.append(current)
return path
DOR Advantages:
- Deadlock-free: No cyclic dependencies
- Simple: Minimal router logic
- Deterministic: Same path every time (good for AI training)
Adaptive Routing (Advanced):
def adaptive_route(source, dest, link_utilization):
"""
Choose dimension order based on current link utilization
"""
dims_to_route = []
if source.x != dest.x:
dims_to_route.append('X')
if source.y != dest.y:
dims_to_route.append('Y')
if source.z != dest.z:
dims_to_route.append('Z')
# Sort dimensions by current link utilization (least busy first)
dims_to_route.sort(key=lambda d: link_utilization[d])
# Route in order of least congestion
for dim in dims_to_route:
route_along_dimension(dim)
Adaptive Advantages:
- Better load balancing
- Can avoid hot spots
- Higher aggregate throughput
Adaptive Disadvantages:
- More complex router logic
- Non-deterministic (problematic for BSP)
- Deadlock potential (requires virtual channels)
TPU Pod Choice: Likely uses DOR for simplicity and determinism, sufficient for structured AI workloads.
6. 5. Physical Layer Technology
6.1 5.1 Optical Transceiver Technology
Pluggable Optical Modules:
TPU Pods likely use:
- QSFP-DD (Quad Small Form-factor Pluggable - Double Density)
- OSFP (Octal Small Form-factor Pluggable)
Specifications (200 Gbps variant):
Form factor: QSFP-DD
Lanes: 8 × 25 Gbps (PAM4) = 200 Gbps
Wavelength: 850 nm (multi-mode) or 1310 nm (single-mode)
Reach: 100m (MM) or 10 km (SM)
Power: 5-8W per module
Laser Types:
- VCSEL (Vertical-Cavity Surface-Emitting Laser): Multi-mode, short reach
- DFB (Distributed Feedback Laser): Single-mode, long reach
6.2 5.2 Modulation and Encoding
NRZ (Non-Return-to-Zero):
- 1 bit per symbol
- Simple, low power
- Limited to ~25 Gbps per lane (electrical BW limitation)
PAM4 (Pulse Amplitude Modulation 4-level):
Symbol Encoding:
3 ─┐ ┌── (11)
2 ─┤ ┌─────┘ (10)
1 ─┤ │ (01)
0 ─└───┘ (00)
2 bits per symbol
→ 2× spectral efficiency
→ 50 Gbps per lane at 25 GHz bandwidth
PAM4 Challenges:
- Reduced SNR (smaller voltage margins)
- Requires FEC (Forward Error Correction)
- DSP for equalization
Current Generation (TPU v4/v5):
- 8 lanes × 50 Gbps (PAM4) = 400 Gbps per link
- Or 4 lanes × 100 Gbps = 400 Gbps per link
6.3 5.3 Wavelength Division Multiplexing (WDM)
Coarse WDM (CWDM):
Wavelengths: 1270 nm, 1290 nm, 1310 nm, 1330 nm, ...
Spacing: 20 nm
Channels: Typically 4-8
Example:
4 wavelengths × 100 Gbps = 400 Gbps per fiber
Dense WDM (DWDM):
Wavelengths: C-band (1530-1565 nm)
Spacing: 0.8 nm (100 GHz grid) or 0.4 nm (50 GHz)
Channels: 40-80+
Example:
40 wavelengths × 100 Gbps = 4 Tbps per fiber
TPU Pod Usage:
- Likely uses CWDM or simple WDM (4-8 wavelengths)
- Enables higher bandwidth without more fibers
- Trade-off: Cost and complexity vs. bandwidth
6.4 5.4 Forward Error Correction (FEC)
Reed-Solomon FEC:
RS(528, 514):
514 data symbols
14 parity symbols
Overhead: 14/514 = 2.7%
Can correct: 7 symbol errors
Pre-FEC BER: ~1×10^-4
Post-FEC BER: <1×10^-15
KP4-FEC (400G IEEE standard):
Overhead: ~5%
Pre-FEC BER: ~2×10^-4
Post-FEC BER: <1×10^-15
Impact on Latency:
- FEC encoding: ~10-50 ns
- FEC decoding: ~50-200 ns
- Pipelined, so doesn't add to steady-state throughput
7. 6. Link Layer and Protocol Stack
7.1 6.1 Protocol Stack Overview
┌─────────────────────────────────────┐
│ Application Layer (XLA, JAX) │
├─────────────────────────────────────┤
│ Collective Ops (NCCL-like) │ ← All-reduce, All-gather
├─────────────────────────────────────┤
│ Transport Layer (Custom) │ ← Message fragmentation, reassembly
├─────────────────────────────────────┤
│ Link Layer (Google proprietary) │ ← Flow control, error detection
├─────────────────────────────────────┤
│ Physical Layer (Optical) │ ← Transceivers, FEC
└─────────────────────────────────────┘
7.2 6.2 Link Layer Protocol
Frame Format (Hypothetical):
┌──────┬──────┬───────────┬──────────────────┬─────┬─────┐
│ SOF │ Dest │ Source │ Payload │ CRC │ EOF │
│ 8b │ 16b │ 16b │ N bytes │ 32b │ 8b │
└──────┴──────┴───────────┴──────────────────┴─────┴─────┘
SOF: Start of Frame
Dest: Destination chip ID (16 bits → 65536 chips addressable)
Source: Source chip ID
Payload: Data (typically 4 KB - 1 MB)
CRC: Cyclic Redundancy Check
EOF: End of Frame
Minimal overhead compared to Ethernet/InfiniBand:
- Ethernet: 26 bytes (header + CRC)
- InfiniBand: 40+ bytes (headers + ICRC)
- TPU custom: ~10-20 bytes (optimized for known topology)
7.3 6.3 Flow Control
Credit-Based Flow Control:
class FlowControl:
def __init__(self, link_id, buffer_size):
self.link_id = link_id
self.credits = buffer_size # In units of max packet size
self.remote_credits = buffer_size
def can_send(self, packet_size):
"""Check if we have credits to send"""
required_credits = (packet_size + MAX_PACKET - 1) // MAX_PACKET
return self.remote_credits >= required_credits
def send_packet(self, packet):
"""Send packet and decrement credits"""
if self.can_send(len(packet)):
transmit(packet)
self.remote_credits -= 1
return True
return False
def receive_credit_return(self):
"""Remote side consumed a packet, credit returned"""
self.remote_credits += 1
def receive_packet(self, packet):
"""Receive packet and send credit back"""
process(packet)
self.credits += 1
send_credit_return(self.link_id)
Advantages:
- Lossless: No packet drops due to buffer overflow
- Backpressure: Naturally throttles sender
- Simple: Low overhead
7.4 6.4 Error Detection and Recovery
Error Detection:
- CRC-32: Detects transmission errors
- Sequence numbers: Detects dropped/reordered packets
- Timeout: Detects stuck transfers
Error Recovery:
def reliable_send(dest, data):
"""
Reliable send with retransmission
"""
seq_num = get_next_sequence_number(dest)
max_retries = 3
timeout = 100_000 # cycles
for attempt in range(max_retries):
packet = create_packet(dest, data, seq_num)
send_packet(packet)
# Wait for ACK
start_time = get_cycle_count()
while get_cycle_count() - start_time < timeout:
if check_ack_received(dest, seq_num):
return SUCCESS
# Timeout, retry
log_warning(f"Timeout on link to {dest}, retry {attempt+1}")
# Exhausted retries, escalate error
return FAILURE
Proactive Error Handling:
- Periodic link tests: Send test patterns when idle
- Bit error rate monitoring: Track pre-FEC BER
- Predictive failure: Reconfigure before link fails completely
8. 7. Collective Communication Optimization
8.1 7.1 All-Reduce: The Critical Operation
Why All-Reduce Dominates:
In data-parallel training:
- Each TPU computes gradients on local mini-batch
- Must average gradients across all TPUs (all-reduce)
- Update parameters with averaged gradients
- Happens every iteration (~1000 times/second)
Naive Implementation:
def naive_all_reduce(local_gradient):
"""
Terrible: O(N) communication, O(N) latency
"""
if rank == 0:
# Master collects from all workers
total = local_gradient
for worker in range(1, num_workers):
total += receive_from(worker)
# Master broadcasts result
result = total / num_workers
for worker in range(1, num_workers):
send_to(worker, result)
else:
send_to(0, local_gradient)
result = receive_from(0)
return result
Problem: Master is bottleneck, doesn't utilize network bandwidth.
8.2 7.2 Ring All-Reduce
Algorithm:
def ring_all_reduce(local_data, num_chips):
"""
Ring all-reduce: O(N) latency, optimal bandwidth
Phase 1: Reduce-scatter
Phase 2: All-gather
"""
chunk_size = len(local_data) // num_chips
chunks = split_into_chunks(local_data, chunk_size)
# Phase 1: Reduce-scatter
for step in range(num_chips - 1):
send_chunk_idx = (rank - step) % num_chips
recv_chunk_idx = (rank - step - 1) % num_chips
# Send to next neighbor, receive from previous
send_to_neighbor((rank + 1) % num_chips, chunks[send_chunk_idx])
received = receive_from_neighbor((rank - 1 + num_chips) % num_chips)
# Accumulate
chunks[recv_chunk_idx] += received
# Phase 2: All-gather
for step in range(num_chips - 1):
send_chunk_idx = (rank - step + 1) % num_chips
recv_chunk_idx = (rank - step) % num_chips
send_to_neighbor((rank + 1) % num_chips, chunks[send_chunk_idx])
chunks[recv_chunk_idx] = receive_from_neighbor((rank - 1 + num_chips) % num_chips)
return concatenate(chunks)
Complexity:
- Latency: 2(N-1) communication steps
- Bandwidth: Each link carries ~2(N-1)/N of total data ≈ 2 for large N
- Optimal: Achieves ring bandwidth utilization
Limitation: Only uses one dimension of torus, doesn't exploit topology.
8.3 7.3 3D Torus All-Reduce
Recursive Dimension-Ordered Algorithm:
def torus_3d_all_reduce(local_data, coords):
"""
All-reduce on 3D torus: O(log N) latency, better than ring
Reduce-scatter along each dimension, then all-gather in reverse
"""
x, y, z = coords
X_DIM, Y_DIM, Z_DIM = get_torus_dimensions()
# Phase 1: Reduce-scatter along X dimension
data_x = reduce_scatter_dimension(local_data, dimension='X',
my_coord=x, dim_size=X_DIM)
# Phase 2: Reduce-scatter along Y dimension
data_xy = reduce_scatter_dimension(data_x, dimension='Y',
my_coord=y, dim_size=Y_DIM)
# Phase 3: Reduce-scatter along Z dimension
data_xyz = reduce_scatter_dimension(data_xy, dimension='Z',
my_coord=z, dim_size=Z_DIM)
# Now each chip has unique portion of reduced data
# Phase 4: All-gather along Z dimension
data_xy_gathered = all_gather_dimension(data_xyz, dimension='Z',
my_coord=z, dim_size=Z_DIM)
# Phase 5: All-gather along Y dimension
data_x_gathered = all_gather_dimension(data_xy_gathered, dimension='Y',
my_coord=y, dim_size=Y_DIM)
# Phase 6: All-gather along X dimension
final_data = all_gather_dimension(data_x_gathered, dimension='X',
my_coord=x, dim_size=X_DIM)
return final_data
def reduce_scatter_dimension(data, dimension, my_coord, dim_size):
"""
Reduce-scatter along one dimension of torus
Similar to ring all-reduce, but only in one dimension
"""
chunk_size = len(data) // dim_size
chunks = split_into_chunks(data, chunk_size)
for step in range(dim_size - 1):
send_chunk = chunks[(my_coord - step) % dim_size]
recv_chunk_idx = (my_coord - step - 1) % dim_size
send_in_dimension(dimension, +1, send_chunk)
received = receive_in_dimension(dimension, -1)
chunks[recv_chunk_idx] += received
# Return the chunk that belongs to this coordinate
return chunks[my_coord]
def all_gather_dimension(data, dimension, my_coord, dim_size):
"""
All-gather along one dimension
Each chip gets complete data from all chips in this dimension
"""
chunks = [None] * dim_size
chunks[my_coord] = data
for step in range(dim_size - 1):
send_chunk_idx = (my_coord - step + 1) % dim_size
recv_chunk_idx = (my_coord - step) % dim_size
send_in_dimension(dimension, +1, chunks[send_chunk_idx])
chunks[recv_chunk_idx] = receive_in_dimension(dimension, -1)
return concatenate(chunks)
Complexity Analysis:
For N = X_DIM × Y_DIM × Z_DIM chips:
Latency:
Reduce-scatter: (X_DIM-1) + (Y_DIM-1) + (Z_DIM-1) steps
All-gather: (Z_DIM-1) + (Y_DIM-1) + (X_DIM-1) steps
Total: 2 × (X_DIM + Y_DIM + Z_DIM - 3) steps
For 16×16×16 = 4096 chips:
Latency: 2 × (16+16+16-3) = 90 steps
Compare to ring:
Ring: 2 × (4096-1) = 8190 steps
Speed-up: 8190 / 90 ≈ 91× (!!!)
Bandwidth Utilization:
All 12,288 links used simultaneously
Each link carries unique data
→ Near-optimal bandwidth efficiency
8.4 7.4 Bandwidth and Latency Model
Realistic All-Reduce Time:
def model_all_reduce_time(data_size, num_chips, link_bw, link_latency):
"""
Model all-reduce time on 3D torus
"""
X_DIM = Y_DIM = Z_DIM = int(num_chips ** (1/3))
# Data size per chip after scatter
chunk_size = data_size / num_chips
# Time to send chunk over link
transfer_time = chunk_size / link_bw
# Number of steps per dimension
steps_per_dim = X_DIM - 1
# Total time
num_dimensions = 3
phases = 2 # Reduce-scatter + All-gather
total_latency = phases * num_dimensions * steps_per_dim * link_latency
total_transfer = phases * num_dimensions * steps_per_dim * transfer_time
return total_latency + total_transfer
# Example: TPU v4 Pod
data_size = 700e9 # 175B parameters × 4 bytes
num_chips = 4096
link_bw = 150e9 / 8 # 150 Gbps = 18.75 GB/s
link_latency = 1e-6 # 1 μs per hop
time = model_all_reduce_time(data_size, num_chips, link_bw, link_latency)
print(f"All-reduce time: {time*1000:.2f} ms")
# Output: All-reduce time: ~52 ms
# With ring (for comparison):
ring_time = 2 * (num_chips - 1) * (data_size / num_chips / link_bw + link_latency)
print(f"Ring all-reduce time: {ring_time:.2f} s")
# Output: Ring all-reduce time: ~8.2 seconds (!!)
Key Insight: 3D torus with dimension-ordered all-reduce is ~160× faster than ring for 4096 chips.
9. 8. Evolution Across TPU Generations
9.1 8.1 TPU v1 (2015)
Interconnect:
- Custom PCIe-based
- Single-server only (no Pod concept)
- Used for inference, not training
9.2 8.2 TPU v2 (2017)
Interconnect:
- First TPU Pod
- 2D mesh topology
- Custom electrical interconnect
- Up to 256 chips (16×16 mesh)
Specifications:
- Link bandwidth: ~50 Gbps per link
- 4 links per chip (±X, ±Y)
- Total bisection BW: ~800 Gbps
Limitations:
- Electrical I/O limited bandwidth scaling
- 2D mesh has higher diameter than 3D torus
- Power consumption per bit was high
9.3 8.3 TPU v3 (2018)
Major Innovation: Optical Interconnect + 3D Torus
Specifications:
- Up to 2,048 chips (16×16×8 torus, later expanded)
- 6 optical links per chip
- ~100 Gbps per link
- Liquid cooling (compute + interconnect)
Improvements:
- 2× more chips per Pod
- 2× bandwidth per link
- ~3× better power efficiency
- Lower diameter (3D > 2D)
Optical Technology:
- Likely QSFP28 modules (100 Gbps)
- Multi-mode fiber (cheaper, sufficient for intra-datacenter)
- MEMS-based optical switches
9.4 8.4 TPU v4 (2021)
Further Scaling:
Specifications:
- Up to 4,096 chips (16×16×16 torus)
- ~150-200 Gbps per optical link
- Enhanced optical circuit switching
- ~275 TFLOPS per chip (BF16)
- 1.1 exaFLOPS per Pod
Interconnect Improvements:
- Higher bandwidth optics (likely QSFP-DD or early 400G)
- Improved optical switch technology (faster reconfiguration)
- Better power efficiency (~5-7W per 100 Gbps)
- Enhanced fault tolerance and monitoring
Software Stack:
- Better integration with JAX
- Improved collective communication libraries
- SPMD (Single Program Multiple Data) programming model
9.5 8.5 TPU v5e and v5p (2023-2024)
TPU v5e (Efficiency):
- Cost-optimized for inference and smaller training
- Likely uses similar interconnect to v4
- May use lower-cost optical modules
TPU v5p (Performance):
- SparseCores: Optimized for sparse and MoE models
- Enhanced interconnect bandwidth (~200-400 Gbps per link)
- Better support for model parallelism
- 2× performance improvement over v4
Interconnect Evolution:
- Likely uses 400G optics (QSFP-DD or OSFP)
- Possible silicon photonics integration
- Enhanced reconfigurability for dynamic topologies
- Support for heterogeneous Pod configurations
9.6 8.6 Future Trends (TPU v6 and Beyond)
Predicted Innovations:
-
Co-Packaged Optics (CPO):
- Optical transceivers integrated into TPU package
- Eliminates electrical SerDes bottleneck
- 10× bandwidth increase potential (1-10 Tbps per chip)
-
Silicon Photonics:
- On-chip photonic waveguides
- Lower power, higher density
- Wavelength-routed networks
-
3D Stacking:
- Vertical optical interconnects
- True 3D integration (not just topology)
- Ultra-high bandwidth, low latency
-
Hybrid Topologies:
- 3D torus for local, all-to-all for global
- Adaptive topology based on workload
- Support for disaggregated memory
10. 9. Performance Analysis
10.1 9.1 Bandwidth Analysis
Per-Chip Bandwidth:
TPU v4 (estimated):
6 links × 150 Gbps = 900 Gbps total
Bidirectional: 450 Gbps each direction
= 56.25 GB/s per direction
TPU v5p (estimated):
6 links × 300 Gbps = 1.8 Tbps total
= 225 GB/s bidirectional
Compare to GPU (H100):
NVLink: 18 ports × 50 GB/s = 900 GB/s bidirectional
Pod-Level Aggregate Bandwidth:
TPU v4 Pod (4,096 chips):
12,288 links (each chip has 6, shared bidirectionally)
Total: 12,288 × 150 Gbps = 1.8 Pbps
NVIDIA DGX SuperPOD (256 H100s):
Intra-server: 32 servers × 8 GPUs × 900 GB/s ≈ 230 Tbps
Inter-server: 256 GPUs × 400 Gbps (IB) = 102 Tbps
→ Bandwidth cliff at server boundary
Key: TPU Pod has uniform bandwidth, no cliffs
10.2 9.2 Latency Analysis
Component Latencies:
Component | Latency |
---|---|
Optical transceiver (Tx) | 50-100 ns |
Fiber propagation (10m) | 50 ns |
Optical switch (if dynamic) | 1-10 μs (first time), 0 (circuit mode) |
Optical transceiver (Rx) | 50-100 ns |
Router logic | 50-100 ns |
DMA setup | 100-200 ns |
Total per hop | ~500-1000 ns |
Multi-Hop Latency:
Nearest neighbor (1 hop): ~1 μs
4 hops (corner to corner): ~4 μs
8 hops (maximum diameter): ~8 μs
All-reduce (16×16×16):
45 hops total (dimension-ordered)
Latency component: 45 × 1 μs = 45 μs
Bandwidth component: 2 × 700 GB / 18.75 GB/s = 75 ms
Total: ~75 ms (dominated by bandwidth)
10.3 9.3 Efficiency Metrics
Bandwidth Utilization:
def bandwidth_utilization(workload):
"""
Measure effective vs. theoretical bandwidth
"""
theoretical_bw = num_links * link_bandwidth
if workload == "all_reduce":
# All links carry unique data
effective_bw = theoretical_bw * 0.95 # ~95% efficiency
elif workload == "point_to_point":
# Many links idle
effective_bw = theoretical_bw * 0.3 # ~30% efficiency
elif workload == "all_to_all":
# Contention on some links
effective_bw = theoretical_bw * 0.6 # ~60% efficiency
return effective_bw / theoretical_bw
TPU Pod Advantages:
- All-reduce: 90-95% efficiency (near-optimal)
- Predictable, no congestion
- Scales linearly with Pod size
10.4 9.4 Scalability Limits
Current Limits (TPU v4):
Physical:
- Optical reach: ~100m (multi-mode fiber)
- Power density: ~25W interconnect per chip × 4096 = 102 kW
- Cooling capacity: Liquid cooling required
Architectural:
- Diameter: 8 hops (manageable)
- All-reduce time: ~75 ms for 175B parameters
- Addressing: 16-bit chip ID (65K chips addressable)
Scale-out beyond 4096 chips:
- Option 1: Larger 3D torus (e.g., 32×32×32 = 32K chips)
- Diameter: 16 hops (still okay)
- Latency: ~150 ms for all-reduce (acceptable)
- Option 2: Multi-Pod with inter-Pod links
- 8× 4096-chip Pods = 32K chips
- High-bandwidth inter-Pod links (1-10 Tbps)
- Hierarchical all-reduce
11. 10. Power and Thermal Considerations
11.1 10.1 Power Breakdown
Per-Chip Power (TPU v4):
Compute (MXUs, VPUs): ~150W
HBM: ~20W
Interconnect:
- Optical transceivers: ~12W (6 × 2W)
- Router logic: ~5W
- DMA engines: ~3W
Total interconnect: ~20W
-----------------------------------------
Total per chip: ~190W
Pod-Level Power (4,096 chips):
Compute: ~615 kW
Interconnect: ~82 kW
Optical switches: ~10 kW (amortized)
-----------------------------------------
Total: ~707 kW
Interconnect fraction: 82/707 = 11.6%
(vs. 25-30% for electrical interconnect)
11.2 10.2 Power Efficiency
Optical vs. Electrical:
Electrical SerDes (50 Gbps PAM4):
Power: ~5-8W per lane
Efficiency: 0.1-0.16 W/Gbps
Optical (50 Gbps lane):
Laser: ~0.5-1W
Driver: ~0.5W
Receiver: ~0.5W
Total: ~1.5-2W per lane
Efficiency: 0.03-0.04 W/Gbps
Advantage: 3-4× better power efficiency
Scaling Impact:
For 4,096 chips × 6 links = 24,576 links:
Electrical (150 Gbps per link):
24,576 × 150 Gbps × 0.15 W/Gbps = 553 kW
Optical (150 Gbps per link):
24,576 × 150 Gbps × 0.04 W/Gbps = 147 kW
Savings: 406 kW
→ $350K/year at $0.10/kWh
→ Major OPEX reduction
11.3 10.3 Thermal Management
Liquid Cooling:
TPU v4 Pods use liquid cooling for both compute and interconnect:
Cold plate on compute side:
- Dissipates ~170W per chip
- Water temperature: 20-25°C inlet, 30-35°C outlet
Air cooling for transceivers:
- Smaller heatsinks
- Localized airflow
- ~20W per chip, manageable
Benefits:
- Higher density (more chips per rack)
- Quieter operation
- More efficient (PUE ~1.1-1.2 vs. 1.4-1.6 for air)
12. 11. Fault Tolerance and Reliability
12.1 11.1 Failure Modes
Link Failures:
- Optical transceiver failure
- Fiber break or connector issue
- Switch port failure
- Bit error rate degradation
Chip Failures:
- Compute element failure (MXU, VPU)
- Memory failure (HBM)
- Router failure
- Complete chip failure
12.2 11.2 Detection Mechanisms
Link Monitoring:
class LinkMonitor:
def __init__(self, link_id):
self.link_id = link_id
self.error_count = 0
self.ber_history = []
def monitor(self):
"""Continuous monitoring"""
while True:
# Check pre-FEC BER
pre_fec_ber = measure_pre_fec_ber(self.link_id)
self.ber_history.append(pre_fec_ber)
# Check CRC errors
crc_errors = get_crc_error_count(self.link_id)
self.error_count += crc_errors
# Predictive failure detection
if pre_fec_ber > 1e-4:
log_warning(f"Link {self.link_id} BER degrading: {pre_fec_ber}")
if pre_fec_ber > 5e-4:
trigger_link_failover(self.link_id)
# Periodic link test
if idle_time() > 1000: # ms
send_test_pattern(self.link_id)
sleep(10) # ms
12.3 11.3 Fault Tolerance Strategies
1. Link-Level Redundancy:
3D torus provides multiple paths:
From (0,0,0) to (15,15,15):
- Primary: +X(15), +Y(15), +Z(15)
- Alternate 1: +Y(15), +X(15), +Z(15)
- Alternate 2: -X(1), -Y(1), -Z(1) [wraparound]
On link failure:
→ Use alternate path
→ Slightly longer latency, but training continues
2. Chip-Level Redundancy:
def checkpoint_and_recover(failed_chip_id):
"""
Recovery strategy for chip failure
"""
# Option 1: Checkpoint-restart
if training_supports_checkpointing():
save_checkpoint()
blacklist_chip(failed_chip_id)
relaunch_on_healthy_chips()
# Option 2: Spare chips
if spare_chips_available():
spare = allocate_spare_chip()
reconfigure_optical_switch(failed_chip_id, spare)
restore_state_to_spare(spare)
# Option 3: Degraded mode
else:
continue_training_with_fewer_chips()
3. Optical Circuit Reconfiguration:
Dynamic circuit reconfiguration allows:
1. Isolate failed chips
2. Reconfigure topology (e.g., 16×16×16 → 16×16×15)
3. Continue training with slightly degraded performance
Reconfiguration time: 1-10 seconds (MEMS) or <1 second (silicon photonics)
→ Acceptable compared to hours of training time
12.4 11.4 Reliability Metrics
Mean Time Between Failures (MTBF):
Component MTBFs:
TPU chip: ~1,000,000 hours
Optical transceiver: ~500,000 hours
Fiber/connector: ~10,000,000 hours
Pod-level MTBF (4,096 chips):
Chip failures: 4,096 / 1,000,000 hours = 1 failure every 10 days
Optical failures: 24,576 / 500,000 hours = 1 failure every 21 days
Expected: ~1 failure per week
With fault tolerance:
→ Training continues (slight performance degradation)
→ Repair during scheduled maintenance
Availability:
Target: 99.9% availability (8.76 hours downtime/year)
With redundancy and fast failover:
Achieved: >99.95% (4.38 hours/year)
13. 12. Programming Model and Software Stack
13.1 12.1 Software Stack Overview
┌─────────────────────────────────────────┐
│ User Code (Python) │
│ • TensorFlow, JAX, PyTorch (via XLA) │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ XLA (Accelerated Linear Algebra) │
│ • Compiler and optimizer │
│ • Collective ops lowering │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ Collective Communication Library │
│ • All-reduce, all-gather, reduce-scatter│
│ • Topology-aware optimization │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ ICI (Inter-Chip Interconnect) Driver │
│ • Link management │
│ • DMA operations │
│ • Error handling │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ Optical Interconnect Hardware │
└─────────────────────────────────────────┘
13.2 12.2 Programming Model: SPMD
Single Program Multiple Data (SPMD):
import jax
import jax.numpy as jnp
from jax.experimental import maps
# Define mesh for 4096 TPUs arranged as 16×16×16
mesh_shape = (16, 16, 16)
mesh = maps.Mesh(jax.devices(), ('x', 'y', 'z'))
# Partition data across mesh
@maps.xmap(
in_axes=(['x', 'y', 'z'], ...),
out_axes=['x', 'y', 'z'],
axis_resources={'x': 'x', 'y': 'y', 'z': 'z'}
)
def distributed_matmul(A, B):
"""
Matrix multiply distributed across 3D mesh
Collective ops automatically inserted by XLA
"""
return jnp.dot(A, B)
# XLA compiler will:
# 1. Partition A and B across mesh
# 2. Insert all-reduce after local matmuls
# 3. Use topology-aware all-reduce (3D torus algorithm)
# 4. Generate efficient ICI operations
13.3 12.3 Collective Communication API
Example: Explicit All-Reduce:
import jax
from jax import lax
def data_parallel_training_step(params, batch):
"""
One step of data-parallel training
"""
# Each TPU computes gradients on local batch
local_gradients = compute_gradients(params, batch)
# All-reduce to average gradients across all TPUs
# JAX automatically uses topology-aware all-reduce
averaged_gradients = lax.psum(local_gradients, axis_name='batch')
# Update parameters
new_params = params - learning_rate * averaged_gradients
return new_params
# XLA lowers 'psum' to:
# 1. 3D torus all-reduce
# 2. Direct ICI DMA operations
# 3. Optimal scheduling to overlap compute and communication
13.4 12.4 Topology-Aware Optimization
XLA Optimizations:
# Pseudo-code for XLA's collective lowering
def lower_allreduce_to_ici(all_reduce_op, mesh_topology):
"""
XLA compiler optimization for all-reduce
"""
if mesh_topology.type == '3D_TORUS':
# Use dimension-ordered all-reduce
schedule = [
('reduce_scatter', 'X'),
('reduce_scatter', 'Y'),
('reduce_scatter', 'Z'),
('all_gather', 'Z'),
('all_gather', 'Y'),
('all_gather', 'X'),
]
for phase, dimension in schedule:
if phase == 'reduce_scatter':
emit_reduce_scatter_ici(dimension)
else:
emit_all_gather_ici(dimension)
elif mesh_topology.type == 'RING':
emit_ring_allreduce_ici()
else:
# Fallback to general algorithm
emit_tree_allreduce()
def emit_reduce_scatter_ici(dimension):
"""
Emit ICI operations for reduce-scatter along one dimension
"""
neighbor_plus = get_neighbor(dimension, +1)
neighbor_minus = get_neighbor(dimension, -1)
for step in range(dimension_size - 1):
# DMA send to +neighbor
emit_ici_dma_send(neighbor_plus, chunk_offset, chunk_size)
# DMA receive from -neighbor
emit_ici_dma_recv(neighbor_minus, recv_buffer, chunk_size)
# Local reduction
emit_vector_add(local_buffer, recv_buffer)
14. 13. Industry Impact and Research Directions
14.1 13.1 Influence on Industry
Direct Impact:
-
Meta Research SuperCluster (RSC):
- Uses NVIDIA GPUs, but interconnect philosophy influenced by TPU Pods
- Emphasis on uniform, high-bandwidth fabric
- RoCE-based, but considering optical for future generations
-
Microsoft Azure AI Infrastructure:
- Custom interconnects for large-scale training
- Exploring optical for next-generation systems
-
AWS Trainium/Inferentia:
- Custom interconnects for ML accelerators
- EFA (Elastic Fabric Adapter) optimized for collectives
Broader Impact:
-
Optical Interconnect Startups:
- Ayar Labs: TeraPHY optical I/O
- Lightmatter: Photonic interconnect and compute
- Celestial AI: Photonic fabric for AI
- Ranovus: Data center optical switching
-
Industry Consortia:
- Ultra Ethernet Consortium (AI-optimized Ethernet)
- Optical Internetworking Forum (OIF)
- Co-Packaged Optics (CPO) standardization efforts
14.2 13.2 Research Directions
1. Co-Packaged Optics (CPO):
Current: Pluggable optical modules
Chip ← (electrical) → Package ← (electrical) → Module ← (optical) → Fiber
Bottleneck: Electrical chip-to-module link
Future: Co-packaged optics
Chip ← (short electrical) → Optical transceiver (on package) ← (optical) → Fiber
Advantages:
- 10× bandwidth (no electrical reach limit)
- 5× lower power
- Smaller form factor
2. Silicon Photonics Integration:
Goal: Integrate photonic waveguides directly on silicon
Technologies:
- Silicon-on-insulator (SOI) waveguides
- Germanium photodetectors (monolithic with Si)
- III-V laser integration (heterogeneous or bonded)
- Microring modulators (compact, low power)
Benefits:
- 100× higher I/O density
- <1 pJ/bit energy efficiency
- Wavelength-division multiplexing (100+ wavelengths)
3. Wavelength-Routed Networks:
Concept: Use wavelength as routing dimension
Traditional:
All transceivers on same wavelength
Optical switches route by port
Wavelength-routed:
Each transceiver on unique wavelength
Wavelength determines destination
No switching needed (passive routing)
Example:
Chip A (λ1, λ2, λ3) → Fiber → Chip B receives λ1
Chip C receives λ2
Chip D receives λ3
Advantages:
- Ultra-low latency (no switching)
- High bandwidth (100+ wavelengths × 100 Gbps)
- Reconfigurable via wavelength tuning
4. Near-Memory and In-Network Computing:
Trend: Move computation closer to data
Options:
1. Processing-in-Memory (PIM)
- Compute inside HBM stacks
- Reduce data movement
2. In-Network Reduction
- Perform reduce operations in optical switches
- Reduces all-reduce latency
3. Disaggregated Memory
- Memory pools connected via optical fabric
- Flexible capacity allocation
5. Hybrid Electrical-Optical:
Insight: Not everything needs optical
Design:
- Short-reach (<1m): Electrical (lower latency, lower cost)
- Medium-reach (1-10m): Optical (copper becomes impractical)
- Long-reach (>10m): Optical (required)
Example:
Within server: Electrical (NVLink-like)
Rack-to-rack: Optical (TPU Pod-like)
Optimizes for cost, power, and performance
14.3 13.3 Open Research Questions
-
Scaling Beyond 10K Accelerators:
- What topology works best? (4D/5D torus, Dragonfly, Slim Fly?)
- How to manage diameter and bisection bandwidth?
- Multi-level hierarchies vs. flat topologies?
-
Heterogeneous Interconnects:
- Mix of optical (high BW) and electrical (low latency)
- Dynamic selection based on message size
- Protocol translation between domains
-
Energy Proportionality:
- Optical links are "always on" (lasers)
- How to save power during idle periods?
- Fast wake-up from sleep modes?
-
Fault Tolerance at Scale:
- With 10K+ chips, failures are constant
- How to train through failures seamlessly?
- Speculative execution and rollback?
-
Software-Hardware Co-Design:
- How should frameworks (PyTorch, JAX) expose interconnect topology?
- Automatic tuning of collective algorithms?
- Compiler optimizations for communication-compute overlap?
15. Summary: Key Takeaways
15.1 Technical Highlights
- Optical Circuit Switching: Enables scale, power efficiency, and determinism
- 3D Torus Topology: Optimal for structured AI workloads (all-reduce)
- Dimension-Ordered All-Reduce: ~100× faster than naive algorithms
- Power Efficiency: 3-4× better than electrical at datacenter scale
- Scalability: Uniform performance from 256 to 4,096+ chips
15.2 Why It Matters
TPU Pod's optical interconnect represents a fundamental shift in how we scale AI infrastructure:
- Not just faster: Different technology (optical vs. electrical)
- Not just bigger: Better topology (3D torus vs. fat-tree)
- Not just Google: Influencing entire industry direction
15.3 Interview Preparation
Be ready to discuss:
- Trade-offs: optical vs. electrical, circuit vs. packet switching
- Topology comparisons: torus vs. fat-tree, 2D vs. 3D
- Collective algorithms: ring vs. recursive halving vs. dimension-ordered
- Power scaling: why power efficiency matters at scale
- Real-world constraints: cost, reliability, backwards compatibility
Connections to other topics:
- GPU architecture (NVLink comparison)
- HPC interconnects (InfiniBand, Cray networks)
- Distributed training algorithms (data parallel, model parallel, pipeline parallel)
- Optical networking (WDM, silicon photonics, CPO)
- System architecture (memory hierarchy, I/O bandwidth, Amdahl's law)
16. References
-
Google TPU Papers:
- Jouppi et al., "In-Datacenter Performance Analysis of a Tensor Processing Unit" (ISCA 2017)
- Jouppi et al., "A Domain-Specific Supercomputer for Training Deep Neural Networks" (CACM 2020)
- Jouppi et al., "Ten Lessons From Three Generations Spent Designing TPUs" (IEEE Micro 2021)
-
Optical Interconnects:
- Kachris et al., "Optical Interconnects for Data Centers" (IEEE Press, 2013)
- Sun et al., "LIONS: An RDMA-Oriented Design for Low-Latency Optical Switches" (SIGCOMM 2020)
- Mellette et al., "RotorNet: A Scalable, Low-complexity, Optical Datacenter Network" (SIGCOMM 2017)
-
Network Topologies:
- Dally & Towles, "Principles and Practices of Interconnection Networks" (Morgan Kaufmann, 2004)
- Kim et al., "Technology-Driven, Highly-Scalable Dragonfly Topology" (ISCA 2008)
-
Collective Communication:
- Thakur et al., "Optimization of Collective Communication Operations in MPICH" (IJHPCA 2005)
- Patarasuk & Yuan, "Bandwidth Optimal All-reduce Algorithms for Clusters of Workstations" (JPDC 2009)
-
Silicon Photonics:
- Soref, "The Past, Present, and Future of Silicon Photonics" (IEEE JSAC 2006)
- Sun et al., "A 45 nm CMOS-SOI Monolithic Photonics Platform" (JLT 2018)
-
Industry Reports:
- Google Cloud TPU Documentation: https://cloud.google.com/tpu
- Omdia: "AI Infrastructure Market Analysis" (Annual reports)
- Yole Développement: "Silicon Photonics Market Report"