Skip to main content
DatacenterArchadvancedTPUoptical-interconnectdatacenter-architecturecollective-communication3d-toruscircuit-switchingai-infrastructurescalabilitypower-efficiency

TPU Pod Optical Interconnect System

Complete technical deep-dive into Google TPU Pod's optical circuit switching architecture, 3D torus topology, collective communication optimization, and datacenter-scale AI infrastructure

90 min read
Updated 10/1/2025
4 prerequisites

Prerequisites

Make sure you're familiar with these concepts before diving in:

Understanding of neural network training and distributed systems
Basic knowledge of computer networking and interconnects
Familiarity with optical communications fundamentals
Understanding of TPU architecture basics

Learning Objectives

By the end of this topic, you will be able to:

Understand optical circuit switching vs packet switching trade-offs
Master 3D torus topology and routing algorithms for AI workloads
Analyze dimension-ordered all-reduce collective communication
Evaluate power efficiency and scalability of optical interconnects
Compare TPU Pod architecture with alternative datacenter interconnects

Table of Contents

1. Table of Contents

  1. Overview and Motivation
  2. TPU Architecture Primer
  3. Optical Circuit Switching Fundamentals
  4. 3D Torus Network Topology
  5. Physical Layer Technology
  6. Link Layer and Protocol Stack
  7. Collective Communication Optimization
  8. Evolution Across TPU Generations
  9. Performance Analysis
  10. Power and Thermal Considerations
  11. Fault Tolerance and Reliability
  12. Programming Model and Software Stack
  13. Industry Impact and Research Directions

2. 1. Overview and Motivation

2.1 What is a TPU Pod?

A TPU Pod is Google's solution for scaling Tensor Processing Units from single chips to supercomputer-scale systems capable of training the largest machine learning models in production. The Pod's optical interconnect fabric is the critical enabler that allows thousands of specialized AI accelerators to work cohesively as a single, unified computational resource.

2.2 The Scaling Challenge

Why Traditional Interconnects Fall Short:

Modern deep learning training presents unique challenges:

  1. Bulk-Synchronous Parallelism (BSP):

    • All accelerators must synchronize at layer boundaries
    • Requires global all-reduce operations every few milliseconds
    • Stragglers kill performance (tail latency problem)
  2. Bandwidth Requirements:

    • Large language models (LLMs): 100+ billion parameters
    • Parameter updates: Multiple GB per training step
    • Activation passing: Tens of GB for model parallelism
    • Total: 10-100 Tbps aggregate bandwidth needed
  3. Power Constraints:

    • Accelerators: 300-500W each
    • Traditional interconnects: 50-100W per device
    • At 1000 devices: Interconnect power = 50-100 kW
    • Goal: Reduce to <20% of total power budget
  4. Scalability:

    • Need to scale from 256 to 4,096+ accelerators
    • Traditional electrical I/O: bandwidth-per-watt plateaus
    • Multi-tier networks: bandwidth cliffs at boundaries

2.3 Google's Solution: Optical Circuit Switching

Key Insight: AI training workloads are:

  • Predictable: Communication patterns known a priori
  • Regular: Structured collectives (all-reduce, all-gather)
  • Deterministic: Same operations every iteration
  • Bulk: Large data transfers, not tiny packets

Design Choice: Circuit-switched optical network

  • No packet headers (wasted bandwidth)
  • No store-and-forward latency
  • No congestion or contention
  • Deterministic performance
  • Superior power efficiency

3. 2. TPU Architecture Primer

3.1 2.1 Single TPU Chip Architecture

Before understanding the interconnect, we need to understand what we're connecting:

┌─────────────────────────────────────────────────┐
│              TPU v4 Chip                        │
├─────────────────────────────────────────────────┤
│  ┌─────────────┐  ┌─────────────┐              │
│  │   MXU 0     │  │   MXU 1     │  Matrix      │
│  │  128×128    │  │  128×128    │  Multiply    │
│  │  Systolic   │  │  Systolic   │  Units       │
│  └─────────────┘  └─────────────┘              │
│         │                 │                     │
│  ┌──────┴─────────────────┴──────┐             │
│  │     Vector Units (VPU)        │             │
│  │  • FP32/BF16/INT8 operations  │             │
│  │  • 16K elements/cycle         │             │
│  └───────────────────────────────┘             │
│         │                                       │
│  ┌──────┴─────────────────────────┐            │
│  │   High Bandwidth Memory (HBM)  │            │
│  │   • 32 GB HBM2e                │            │
│  │   • 1.2 TB/s bandwidth         │            │
│  └────────────────────────────────┘            │
│         │                                       │
│  ┌──────┴─────────────────────────┐            │
│  │   Interconnect Controllers     │            │
│  │   • 6× Optical Links           │ ◄─────────┐│
│  │   • Router logic               │         I │
│  │   • DMA engines                │         C │
│  └────────────────────────────────┘         I │
└─────────────────────────────────────────────────┘

Key Specifications (TPU v4):

  • Peak compute: 275 TFLOPS (BF16)
  • Memory: 32 GB HBM2e @ 1.2 TB/s
  • Interconnect: 6 bidirectional optical links
  • Power: ~175W compute + ~25W interconnect

3.2 2.2 Interconnect Requirements

Per-chip Bandwidth Math:

Model size: 175B parameters (GPT-3 scale)
4096 TPUs, sharded across all chips
Per-chip: 175B / 4096 = 42.7M parameters
 
All-reduce requirements per iteration:
- Forward pass: Receive activations
- Backward pass: Send gradients
- Parameter sync: All-reduce (~171 MB per chip)
 
Target: 500 iterations/second
Required BW: 171 MB × 500 = 85 GB/s bidirectional
 
Actual: 6 links × 15 GB/s = 90 GB/s ✓

4. 3. Optical Circuit Switching Fundamentals

4.1 3.1 Circuit Switching vs. Packet Switching

Packet Switching (Traditional):

Source → [Header|Payload] → Router (store, forward) → Destination
 
Latency = Serialization + Propagation + Queueing + Switching
        = (Packet_size/BW) + (Distance/c) + (Queue_depth × Packet_time) + Switch_latency
 
Overhead:
- Headers: 5-10% of bandwidth
- Buffering: SRAM, power hungry
- Arbitration: Complex, non-deterministic

Circuit Switching (TPU Pod OCS):

Source → Optical Switch (passive) → Destination
 
Latency = Serialization + Propagation + Circuit_setup (one-time)
        = (Message_size/BW) + (Distance/c) + Setup_latency
 
Advantages:
- No packet overhead (100% efficiency)
- No buffering needed
- Deterministic latency
- Lower power

4.2 3.2 Optical Circuit Switch Technology

MEMS (Micro-Electro-Mechanical Systems) Switches:

Early TPU Pods likely used MEMS-based OCS:

      Input Fibers          Mirrors (actuated)        Output Fibers
         │                       ╱                         │
   ──────┼──────────────────────●──────────────────────────┼──────
         │                     ╱ ╲                         │
   ──────┼────────────────────●───●────────────────────────┼──────
         │                   ╱     ╲                       │
   ──────┼──────────────────●───────●──────────────────────┼──────

Characteristics:

  • Switch time: 10-20 ms (slow, but amortized over long data transfers)
  • Port count: 320×320 or larger
  • Insertion loss: <2 dB
  • Power: ~1-2W per 100 ports (passive after switching)

Silicon Photonics Switches (Newer TPU Pods):

           Ring Resonators (electrically tuned)
    In ────●───●───●───●──── Out 1
           │   │   │   │
           ●───●───●───●──── Out 2

Characteristics:

  • Switch time: <1 μs (thermal tuning) or <10 ns (electro-optic)
  • Port count: Limited (32×32), but can cascade
  • Insertion loss: 1-3 dB per stage
  • Power: ~10-50 mW per active port

4.3 3.3 Reconfigurability

Static vs. Dynamic Configuration:

# Pseudo-code for circuit configuration
 
# Static configuration (set at Pod initialization)
def configure_static_topology(pod_size):
    """Set up 3D torus"""
    for chip_id in range(pod_size):
        x, y, z = chip_id_to_coords(chip_id)
        
        neighbors = {
            '+X': coords_to_chip_id((x+1) % X_DIM, y, z),
            '-X': coords_to_chip_id((x-1) % X_DIM, y, z),
            '+Y': coords_to_chip_id(x, (y+1) % Y_DIM, z),
            '-Y': coords_to_chip_id(x, (y-1) % Y_DIM, z),
            '+Z': coords_to_chip_id(x, y, (z+1) % Z_DIM),
            '-Z': coords_to_chip_id(x, y, (z-1) % Z_DIM),
        }
        
        configure_optical_circuits(chip_id, neighbors)
 
# Dynamic reconfiguration (for fault tolerance)
def reconfigure_around_failure(failed_chip):
    """Reroute around failed chip"""
    affected_neighbors = get_neighbors(failed_chip)
    
    for neighbor in affected_neighbors:
        # Find alternate path (e.g., skip failed chip)
        new_route = find_alternate_path(neighbor, failed_chip)
        reconfigure_optical_switch(neighbor, new_route)

Use Cases for Reconfigurability:

  1. Fault tolerance: Route around failed chips/links
  2. Multi-tenancy: Partition Pod into independent slices
  3. Adaptive topology: Switch between 3D torus and fat-tree
  4. Power management: Power down unused links

5. 4. 3D Torus Network Topology

5.1 4.1 Topology Definition

3D Torus Structure:

           Z-axis


         (0,0,1)───(1,0,1)
           ╱│      ╱│
          ╱ │     ╱ │
      (0,1,1)───(1,1,1)
         │  │    │  │
         │(0,0,0)┼(1,0,0)
         │ ╱     │╱
         │╱      │
      (0,1,0)───(1,1,0)──── Y-axis


    X-axis
 
Wrap-around connections (torus):
- (X_DIM-1, y, z) ↔ (0, y, z)
- (x, Y_DIM-1, z) ↔ (x, 0, z)
- (x, y, Z_DIM-1) ↔ (x, y, 0)

For TPU v4 Pod (4,096 chips):

  • Dimensions: 16×16×16
  • Links per chip: 6 (±X, ±Y, ±Z)
  • Total links: 4,096 × 3 = 12,288 bidirectional links
  • Diameter: 8 hops (maximum distance between any two chips)

5.2 4.2 Topology Properties

Degree:

  • Every chip has exactly 6 neighbors
  • Advantage: Constant per-chip complexity regardless of Pod size

Diameter:

  • D = ⌈N^(1/3) / 2⌉ for N chips
  • TPU v4: ⌈16/2⌉ = 8 hops maximum
  • Advantage: Logarithmic scaling (sort of - cube root)

Bisection Bandwidth:

Bisection: Minimum link capacity if Pod is cut in half
 
For 16×16×16 torus, cut along one dimension:
Cut through X=8 plane: 16×16 = 256 links
Bandwidth per link: ~150 Gbps
Total bisection BW: 256 × 150 Gbps = 38.4 Tbps
 
Compare to Fat-tree (2:1 oversubscription):
Would need 256 × 75 Gbps = 19.2 Tbps
Torus: 2× better bisection bandwidth

Path Diversity:

  • Multiple paths between any two nodes
  • Example: From (0,0,0) to (8,8,8):
    • Path 1: +X(8), +Y(8), +Z(8)
    • Path 2: +Y(8), +Z(8), +X(8)
    • Path 3: +Z(8), +X(8), +Y(8)
    • Many more...
  • Advantage: Load balancing and fault tolerance

5.3 4.3 Routing Algorithms

Dimension-Ordered Routing (DOR):

def route_packet(source, dest):
    """
    Simple dimension-ordered routing for 3D torus
    Route along X, then Y, then Z
    """
    current = source
    path = [current]
    
    # Route along X dimension
    while current.x != dest.x:
        if (dest.x - current.x) % X_DIM < X_DIM / 2:
            current.x = (current.x + 1) % X_DIM  # +X direction
        else:
            current.x = (current.x - 1 + X_DIM) % X_DIM  # -X direction
        path.append(current)
    
    # Route along Y dimension
    while current.y != dest.y:
        if (dest.y - current.y) % Y_DIM < Y_DIM / 2:
            current.y = (current.y + 1) % Y_DIM
        else:
            current.y = (current.y - 1 + Y_DIM) % Y_DIM
        path.append(current)
    
    # Route along Z dimension
    while current.z != dest.z:
        if (dest.z - current.z) % Z_DIM < Z_DIM / 2:
            current.z = (current.z + 1) % Z_DIM
        else:
            current.z = (current.z - 1 + Z_DIM) % Z_DIM
        path.append(current)
    
    return path

DOR Advantages:

  • Deadlock-free: No cyclic dependencies
  • Simple: Minimal router logic
  • Deterministic: Same path every time (good for AI training)

Adaptive Routing (Advanced):

def adaptive_route(source, dest, link_utilization):
    """
    Choose dimension order based on current link utilization
    """
    dims_to_route = []
    if source.x != dest.x:
        dims_to_route.append('X')
    if source.y != dest.y:
        dims_to_route.append('Y')
    if source.z != dest.z:
        dims_to_route.append('Z')
    
    # Sort dimensions by current link utilization (least busy first)
    dims_to_route.sort(key=lambda d: link_utilization[d])
    
    # Route in order of least congestion
    for dim in dims_to_route:
        route_along_dimension(dim)

Adaptive Advantages:

  • Better load balancing
  • Can avoid hot spots
  • Higher aggregate throughput

Adaptive Disadvantages:

  • More complex router logic
  • Non-deterministic (problematic for BSP)
  • Deadlock potential (requires virtual channels)

TPU Pod Choice: Likely uses DOR for simplicity and determinism, sufficient for structured AI workloads.


6. 5. Physical Layer Technology

6.1 5.1 Optical Transceiver Technology

Pluggable Optical Modules:

TPU Pods likely use:

  • QSFP-DD (Quad Small Form-factor Pluggable - Double Density)
  • OSFP (Octal Small Form-factor Pluggable)

Specifications (200 Gbps variant):

Form factor: QSFP-DD
Lanes: 8 × 25 Gbps (PAM4) = 200 Gbps
Wavelength: 850 nm (multi-mode) or 1310 nm (single-mode)
Reach: 100m (MM) or 10 km (SM)
Power: 5-8W per module

Laser Types:

  • VCSEL (Vertical-Cavity Surface-Emitting Laser): Multi-mode, short reach
  • DFB (Distributed Feedback Laser): Single-mode, long reach

6.2 5.2 Modulation and Encoding

NRZ (Non-Return-to-Zero):

  • 1 bit per symbol
  • Simple, low power
  • Limited to ~25 Gbps per lane (electrical BW limitation)

PAM4 (Pulse Amplitude Modulation 4-level):

Symbol Encoding:
  3 ─┐         ┌──  (11)
  2 ─┤   ┌─────┘    (10)
  1 ─┤   │          (01)
  0 ─└───┘          (00)
 
2 bits per symbol
→ 2× spectral efficiency
→ 50 Gbps per lane at 25 GHz bandwidth

PAM4 Challenges:

  • Reduced SNR (smaller voltage margins)
  • Requires FEC (Forward Error Correction)
  • DSP for equalization

Current Generation (TPU v4/v5):

  • 8 lanes × 50 Gbps (PAM4) = 400 Gbps per link
  • Or 4 lanes × 100 Gbps = 400 Gbps per link

6.3 5.3 Wavelength Division Multiplexing (WDM)

Coarse WDM (CWDM):

Wavelengths: 1270 nm, 1290 nm, 1310 nm, 1330 nm, ...
Spacing: 20 nm
Channels: Typically 4-8
 
Example:
  4 wavelengths × 100 Gbps = 400 Gbps per fiber

Dense WDM (DWDM):

Wavelengths: C-band (1530-1565 nm)
Spacing: 0.8 nm (100 GHz grid) or 0.4 nm (50 GHz)
Channels: 40-80+
 
Example:
  40 wavelengths × 100 Gbps = 4 Tbps per fiber

TPU Pod Usage:

  • Likely uses CWDM or simple WDM (4-8 wavelengths)
  • Enables higher bandwidth without more fibers
  • Trade-off: Cost and complexity vs. bandwidth

6.4 5.4 Forward Error Correction (FEC)

Reed-Solomon FEC:

RS(528, 514):
  514 data symbols
  14 parity symbols
  Overhead: 14/514 = 2.7%
  
Can correct: 7 symbol errors
Pre-FEC BER: ~1×10^-4
Post-FEC BER: &lt;1×10^-15

KP4-FEC (400G IEEE standard):

Overhead: ~5%
Pre-FEC BER: ~2×10^-4
Post-FEC BER: &lt;1×10^-15

Impact on Latency:

  • FEC encoding: ~10-50 ns
  • FEC decoding: ~50-200 ns
  • Pipelined, so doesn't add to steady-state throughput

7.1 6.1 Protocol Stack Overview

┌─────────────────────────────────────┐
│  Application Layer (XLA, JAX)      │
├─────────────────────────────────────┤
│  Collective Ops (NCCL-like)        │  ← All-reduce, All-gather
├─────────────────────────────────────┤
│  Transport Layer (Custom)           │  ← Message fragmentation, reassembly
├─────────────────────────────────────┤
│  Link Layer (Google proprietary)   │  ← Flow control, error detection
├─────────────────────────────────────┤
│  Physical Layer (Optical)           │  ← Transceivers, FEC
└─────────────────────────────────────┘

Frame Format (Hypothetical):

┌──────┬──────┬───────────┬──────────────────┬─────┬─────┐
│ SOF  │ Dest │  Source   │     Payload      │ CRC │ EOF │
│ 8b   │ 16b  │   16b     │   N bytes        │ 32b │ 8b  │
└──────┴──────┴───────────┴──────────────────┴─────┴─────┘
 
SOF: Start of Frame
Dest: Destination chip ID (16 bits → 65536 chips addressable)
Source: Source chip ID
Payload: Data (typically 4 KB - 1 MB)
CRC: Cyclic Redundancy Check
EOF: End of Frame

Minimal overhead compared to Ethernet/InfiniBand:

  • Ethernet: 26 bytes (header + CRC)
  • InfiniBand: 40+ bytes (headers + ICRC)
  • TPU custom: ~10-20 bytes (optimized for known topology)

7.3 6.3 Flow Control

Credit-Based Flow Control:

class FlowControl:
    def __init__(self, link_id, buffer_size):
        self.link_id = link_id
        self.credits = buffer_size  # In units of max packet size
        self.remote_credits = buffer_size
    
    def can_send(self, packet_size):
        """Check if we have credits to send"""
        required_credits = (packet_size + MAX_PACKET - 1) // MAX_PACKET
        return self.remote_credits >= required_credits
    
    def send_packet(self, packet):
        """Send packet and decrement credits"""
        if self.can_send(len(packet)):
            transmit(packet)
            self.remote_credits -= 1
            return True
        return False
    
    def receive_credit_return(self):
        """Remote side consumed a packet, credit returned"""
        self.remote_credits += 1
    
    def receive_packet(self, packet):
        """Receive packet and send credit back"""
        process(packet)
        self.credits += 1
        send_credit_return(self.link_id)

Advantages:

  • Lossless: No packet drops due to buffer overflow
  • Backpressure: Naturally throttles sender
  • Simple: Low overhead

7.4 6.4 Error Detection and Recovery

Error Detection:

  • CRC-32: Detects transmission errors
  • Sequence numbers: Detects dropped/reordered packets
  • Timeout: Detects stuck transfers

Error Recovery:

def reliable_send(dest, data):
    """
    Reliable send with retransmission
    """
    seq_num = get_next_sequence_number(dest)
    max_retries = 3
    timeout = 100_000  # cycles
    
    for attempt in range(max_retries):
        packet = create_packet(dest, data, seq_num)
        send_packet(packet)
        
        # Wait for ACK
        start_time = get_cycle_count()
        while get_cycle_count() - start_time < timeout:
            if check_ack_received(dest, seq_num):
                return SUCCESS
        
        # Timeout, retry
        log_warning(f"Timeout on link to {dest}, retry {attempt+1}")
    
    # Exhausted retries, escalate error
    return FAILURE

Proactive Error Handling:

  • Periodic link tests: Send test patterns when idle
  • Bit error rate monitoring: Track pre-FEC BER
  • Predictive failure: Reconfigure before link fails completely

8. 7. Collective Communication Optimization

8.1 7.1 All-Reduce: The Critical Operation

Why All-Reduce Dominates:

In data-parallel training:

  1. Each TPU computes gradients on local mini-batch
  2. Must average gradients across all TPUs (all-reduce)
  3. Update parameters with averaged gradients
  4. Happens every iteration (~1000 times/second)

Naive Implementation:

def naive_all_reduce(local_gradient):
    """
    Terrible: O(N) communication, O(N) latency
    """
    if rank == 0:
        # Master collects from all workers
        total = local_gradient
        for worker in range(1, num_workers):
            total += receive_from(worker)
        
        # Master broadcasts result
        result = total / num_workers
        for worker in range(1, num_workers):
            send_to(worker, result)
    else:
        send_to(0, local_gradient)
        result = receive_from(0)
    
    return result

Problem: Master is bottleneck, doesn't utilize network bandwidth.

8.2 7.2 Ring All-Reduce

Algorithm:

def ring_all_reduce(local_data, num_chips):
    """
    Ring all-reduce: O(N) latency, optimal bandwidth
    
    Phase 1: Reduce-scatter
    Phase 2: All-gather
    """
    chunk_size = len(local_data) // num_chips
    chunks = split_into_chunks(local_data, chunk_size)
    
    # Phase 1: Reduce-scatter
    for step in range(num_chips - 1):
        send_chunk_idx = (rank - step) % num_chips
        recv_chunk_idx = (rank - step - 1) % num_chips
        
        # Send to next neighbor, receive from previous
        send_to_neighbor((rank + 1) % num_chips, chunks[send_chunk_idx])
        received = receive_from_neighbor((rank - 1 + num_chips) % num_chips)
        
        # Accumulate
        chunks[recv_chunk_idx] += received
    
    # Phase 2: All-gather
    for step in range(num_chips - 1):
        send_chunk_idx = (rank - step + 1) % num_chips
        recv_chunk_idx = (rank - step) % num_chips
        
        send_to_neighbor((rank + 1) % num_chips, chunks[send_chunk_idx])
        chunks[recv_chunk_idx] = receive_from_neighbor((rank - 1 + num_chips) % num_chips)
    
    return concatenate(chunks)

Complexity:

  • Latency: 2(N-1) communication steps
  • Bandwidth: Each link carries ~2(N-1)/N of total data ≈ 2 for large N
  • Optimal: Achieves ring bandwidth utilization

Limitation: Only uses one dimension of torus, doesn't exploit topology.

8.3 7.3 3D Torus All-Reduce

Recursive Dimension-Ordered Algorithm:

def torus_3d_all_reduce(local_data, coords):
    """
    All-reduce on 3D torus: O(log N) latency, better than ring
    
    Reduce-scatter along each dimension, then all-gather in reverse
    """
    x, y, z = coords
    X_DIM, Y_DIM, Z_DIM = get_torus_dimensions()
    
    # Phase 1: Reduce-scatter along X dimension
    data_x = reduce_scatter_dimension(local_data, dimension='X', 
                                       my_coord=x, dim_size=X_DIM)
    
    # Phase 2: Reduce-scatter along Y dimension
    data_xy = reduce_scatter_dimension(data_x, dimension='Y',
                                        my_coord=y, dim_size=Y_DIM)
    
    # Phase 3: Reduce-scatter along Z dimension
    data_xyz = reduce_scatter_dimension(data_xy, dimension='Z',
                                         my_coord=z, dim_size=Z_DIM)
    
    # Now each chip has unique portion of reduced data
    
    # Phase 4: All-gather along Z dimension
    data_xy_gathered = all_gather_dimension(data_xyz, dimension='Z',
                                             my_coord=z, dim_size=Z_DIM)
    
    # Phase 5: All-gather along Y dimension
    data_x_gathered = all_gather_dimension(data_xy_gathered, dimension='Y',
                                            my_coord=y, dim_size=Y_DIM)
    
    # Phase 6: All-gather along X dimension
    final_data = all_gather_dimension(data_x_gathered, dimension='X',
                                       my_coord=x, dim_size=X_DIM)
    
    return final_data
 
def reduce_scatter_dimension(data, dimension, my_coord, dim_size):
    """
    Reduce-scatter along one dimension of torus
    Similar to ring all-reduce, but only in one dimension
    """
    chunk_size = len(data) // dim_size
    chunks = split_into_chunks(data, chunk_size)
    
    for step in range(dim_size - 1):
        send_chunk = chunks[(my_coord - step) % dim_size]
        recv_chunk_idx = (my_coord - step - 1) % dim_size
        
        send_in_dimension(dimension, +1, send_chunk)
        received = receive_in_dimension(dimension, -1)
        
        chunks[recv_chunk_idx] += received
    
    # Return the chunk that belongs to this coordinate
    return chunks[my_coord]
 
def all_gather_dimension(data, dimension, my_coord, dim_size):
    """
    All-gather along one dimension
    Each chip gets complete data from all chips in this dimension
    """
    chunks = [None] * dim_size
    chunks[my_coord] = data
    
    for step in range(dim_size - 1):
        send_chunk_idx = (my_coord - step + 1) % dim_size
        recv_chunk_idx = (my_coord - step) % dim_size
        
        send_in_dimension(dimension, +1, chunks[send_chunk_idx])
        chunks[recv_chunk_idx] = receive_in_dimension(dimension, -1)
    
    return concatenate(chunks)

Complexity Analysis:

For N = X_DIM × Y_DIM × Z_DIM chips:

Latency:
  Reduce-scatter: (X_DIM-1) + (Y_DIM-1) + (Z_DIM-1) steps
  All-gather:     (Z_DIM-1) + (Y_DIM-1) + (X_DIM-1) steps
  Total:          2 × (X_DIM + Y_DIM + Z_DIM - 3) steps
  
For 16×16×16 = 4096 chips:
  Latency: 2 × (16+16+16-3) = 90 steps
  
Compare to ring:
  Ring: 2 × (4096-1) = 8190 steps
  
Speed-up: 8190 / 90 ≈ 91×  (!!!)
 
Bandwidth Utilization:
  All 12,288 links used simultaneously
  Each link carries unique data
  → Near-optimal bandwidth efficiency

8.4 7.4 Bandwidth and Latency Model

Realistic All-Reduce Time:

def model_all_reduce_time(data_size, num_chips, link_bw, link_latency):
    """
    Model all-reduce time on 3D torus
    """
    X_DIM = Y_DIM = Z_DIM = int(num_chips ** (1/3))
    
    # Data size per chip after scatter
    chunk_size = data_size / num_chips
    
    # Time to send chunk over link
    transfer_time = chunk_size / link_bw
    
    # Number of steps per dimension
    steps_per_dim = X_DIM - 1
    
    # Total time
    num_dimensions = 3
    phases = 2  # Reduce-scatter + All-gather
    
    total_latency = phases * num_dimensions * steps_per_dim * link_latency
    total_transfer = phases * num_dimensions * steps_per_dim * transfer_time
    
    return total_latency + total_transfer
 
# Example: TPU v4 Pod
data_size = 700e9  # 175B parameters × 4 bytes
num_chips = 4096
link_bw = 150e9 / 8  # 150 Gbps = 18.75 GB/s
link_latency = 1e-6  # 1 μs per hop
 
time = model_all_reduce_time(data_size, num_chips, link_bw, link_latency)
print(f"All-reduce time: {time*1000:.2f} ms")
# Output: All-reduce time: ~52 ms
 
# With ring (for comparison):
ring_time = 2 * (num_chips - 1) * (data_size / num_chips / link_bw + link_latency)
print(f"Ring all-reduce time: {ring_time:.2f} s")
# Output: Ring all-reduce time: ~8.2 seconds (!!)

Key Insight: 3D torus with dimension-ordered all-reduce is ~160× faster than ring for 4096 chips.


9. 8. Evolution Across TPU Generations

9.1 8.1 TPU v1 (2015)

Interconnect:

  • Custom PCIe-based
  • Single-server only (no Pod concept)
  • Used for inference, not training

9.2 8.2 TPU v2 (2017)

Interconnect:

  • First TPU Pod
  • 2D mesh topology
  • Custom electrical interconnect
  • Up to 256 chips (16×16 mesh)

Specifications:

  • Link bandwidth: ~50 Gbps per link
  • 4 links per chip (±X, ±Y)
  • Total bisection BW: ~800 Gbps

Limitations:

  • Electrical I/O limited bandwidth scaling
  • 2D mesh has higher diameter than 3D torus
  • Power consumption per bit was high

9.3 8.3 TPU v3 (2018)

Major Innovation: Optical Interconnect + 3D Torus

Specifications:

  • Up to 2,048 chips (16×16×8 torus, later expanded)
  • 6 optical links per chip
  • ~100 Gbps per link
  • Liquid cooling (compute + interconnect)

Improvements:

  • 2× more chips per Pod
  • 2× bandwidth per link
  • ~3× better power efficiency
  • Lower diameter (3D > 2D)

Optical Technology:

  • Likely QSFP28 modules (100 Gbps)
  • Multi-mode fiber (cheaper, sufficient for intra-datacenter)
  • MEMS-based optical switches

9.4 8.4 TPU v4 (2021)

Further Scaling:

Specifications:

  • Up to 4,096 chips (16×16×16 torus)
  • ~150-200 Gbps per optical link
  • Enhanced optical circuit switching
  • ~275 TFLOPS per chip (BF16)
  • 1.1 exaFLOPS per Pod

Interconnect Improvements:

  • Higher bandwidth optics (likely QSFP-DD or early 400G)
  • Improved optical switch technology (faster reconfiguration)
  • Better power efficiency (~5-7W per 100 Gbps)
  • Enhanced fault tolerance and monitoring

Software Stack:

  • Better integration with JAX
  • Improved collective communication libraries
  • SPMD (Single Program Multiple Data) programming model

9.5 8.5 TPU v5e and v5p (2023-2024)

TPU v5e (Efficiency):

  • Cost-optimized for inference and smaller training
  • Likely uses similar interconnect to v4
  • May use lower-cost optical modules

TPU v5p (Performance):

  • SparseCores: Optimized for sparse and MoE models
  • Enhanced interconnect bandwidth (~200-400 Gbps per link)
  • Better support for model parallelism
  • 2× performance improvement over v4

Interconnect Evolution:

  • Likely uses 400G optics (QSFP-DD or OSFP)
  • Possible silicon photonics integration
  • Enhanced reconfigurability for dynamic topologies
  • Support for heterogeneous Pod configurations

Predicted Innovations:

  1. Co-Packaged Optics (CPO):

    • Optical transceivers integrated into TPU package
    • Eliminates electrical SerDes bottleneck
    • 10× bandwidth increase potential (1-10 Tbps per chip)
  2. Silicon Photonics:

    • On-chip photonic waveguides
    • Lower power, higher density
    • Wavelength-routed networks
  3. 3D Stacking:

    • Vertical optical interconnects
    • True 3D integration (not just topology)
    • Ultra-high bandwidth, low latency
  4. Hybrid Topologies:

    • 3D torus for local, all-to-all for global
    • Adaptive topology based on workload
    • Support for disaggregated memory

10. 9. Performance Analysis

10.1 9.1 Bandwidth Analysis

Per-Chip Bandwidth:

TPU v4 (estimated):
  6 links × 150 Gbps = 900 Gbps total
  Bidirectional: 450 Gbps each direction
  = 56.25 GB/s per direction
 
TPU v5p (estimated):
  6 links × 300 Gbps = 1.8 Tbps total
  = 225 GB/s bidirectional
 
Compare to GPU (H100):
  NVLink: 18 ports × 50 GB/s = 900 GB/s bidirectional

Pod-Level Aggregate Bandwidth:

TPU v4 Pod (4,096 chips):
  12,288 links (each chip has 6, shared bidirectionally)
  Total: 12,288 × 150 Gbps = 1.8 Pbps
 
NVIDIA DGX SuperPOD (256 H100s):
  Intra-server: 32 servers × 8 GPUs × 900 GB/s ≈ 230 Tbps
  Inter-server: 256 GPUs × 400 Gbps (IB) = 102 Tbps
  → Bandwidth cliff at server boundary
 
Key: TPU Pod has uniform bandwidth, no cliffs

10.2 9.2 Latency Analysis

Component Latencies:

ComponentLatency
Optical transceiver (Tx)50-100 ns
Fiber propagation (10m)50 ns
Optical switch (if dynamic)1-10 μs (first time), 0 (circuit mode)
Optical transceiver (Rx)50-100 ns
Router logic50-100 ns
DMA setup100-200 ns
Total per hop~500-1000 ns

Multi-Hop Latency:

Nearest neighbor (1 hop):    ~1 μs
4 hops (corner to corner):   ~4 μs
8 hops (maximum diameter):   ~8 μs
 
All-reduce (16×16×16):
  45 hops total (dimension-ordered)
  Latency component: 45 × 1 μs = 45 μs
  Bandwidth component: 2 × 700 GB / 18.75 GB/s = 75 ms
  Total: ~75 ms (dominated by bandwidth)

10.3 9.3 Efficiency Metrics

Bandwidth Utilization:

def bandwidth_utilization(workload):
    """
    Measure effective vs. theoretical bandwidth
    """
    theoretical_bw = num_links * link_bandwidth
    
    if workload == "all_reduce":
        # All links carry unique data
        effective_bw = theoretical_bw * 0.95  # ~95% efficiency
    elif workload == "point_to_point":
        # Many links idle
        effective_bw = theoretical_bw * 0.3  # ~30% efficiency
    elif workload == "all_to_all":
        # Contention on some links
        effective_bw = theoretical_bw * 0.6  # ~60% efficiency
    
    return effective_bw / theoretical_bw

TPU Pod Advantages:

  • All-reduce: 90-95% efficiency (near-optimal)
  • Predictable, no congestion
  • Scales linearly with Pod size

10.4 9.4 Scalability Limits

Current Limits (TPU v4):

Physical:
  - Optical reach: ~100m (multi-mode fiber)
  - Power density: ~25W interconnect per chip × 4096 = 102 kW
  - Cooling capacity: Liquid cooling required
  
Architectural:
  - Diameter: 8 hops (manageable)
  - All-reduce time: ~75 ms for 175B parameters
  - Addressing: 16-bit chip ID (65K chips addressable)
 
Scale-out beyond 4096 chips:
  - Option 1: Larger 3D torus (e.g., 32×32×32 = 32K chips)
    - Diameter: 16 hops (still okay)
    - Latency: ~150 ms for all-reduce (acceptable)
  
  - Option 2: Multi-Pod with inter-Pod links
    - 8× 4096-chip Pods = 32K chips
    - High-bandwidth inter-Pod links (1-10 Tbps)
    - Hierarchical all-reduce

11. 10. Power and Thermal Considerations

11.1 10.1 Power Breakdown

Per-Chip Power (TPU v4):

Compute (MXUs, VPUs):        ~150W
HBM:                         ~20W
Interconnect:
  - Optical transceivers:    ~12W (6 × 2W)
  - Router logic:            ~5W
  - DMA engines:             ~3W
Total interconnect:          ~20W
-----------------------------------------
Total per chip:              ~190W

Pod-Level Power (4,096 chips):

Compute:                   ~615 kW
Interconnect:              ~82 kW
Optical switches:          ~10 kW (amortized)
-----------------------------------------
Total:                     ~707 kW
 
Interconnect fraction: 82/707 = 11.6%
(vs. 25-30% for electrical interconnect)

11.2 10.2 Power Efficiency

Optical vs. Electrical:

Electrical SerDes (50 Gbps PAM4):
  Power: ~5-8W per lane
  Efficiency: 0.1-0.16 W/Gbps
 
Optical (50 Gbps lane):
  Laser: ~0.5-1W
  Driver: ~0.5W
  Receiver: ~0.5W
  Total: ~1.5-2W per lane
  Efficiency: 0.03-0.04 W/Gbps
 
Advantage: 3-4× better power efficiency

Scaling Impact:

For 4,096 chips × 6 links = 24,576 links:
 
Electrical (150 Gbps per link):
  24,576 × 150 Gbps × 0.15 W/Gbps = 553 kW
  
Optical (150 Gbps per link):
  24,576 × 150 Gbps × 0.04 W/Gbps = 147 kW
 
Savings: 406 kW
  → $350K/year at $0.10/kWh
  → Major OPEX reduction

11.3 10.3 Thermal Management

Liquid Cooling:

TPU v4 Pods use liquid cooling for both compute and interconnect:

Cold plate on compute side:
  - Dissipates ~170W per chip
  - Water temperature: 20-25°C inlet, 30-35°C outlet
 
Air cooling for transceivers:
  - Smaller heatsinks
  - Localized airflow
  - ~20W per chip, manageable
 
Benefits:
  - Higher density (more chips per rack)
  - Quieter operation
  - More efficient (PUE ~1.1-1.2 vs. 1.4-1.6 for air)

12. 11. Fault Tolerance and Reliability

12.1 11.1 Failure Modes

Link Failures:

  • Optical transceiver failure
  • Fiber break or connector issue
  • Switch port failure
  • Bit error rate degradation

Chip Failures:

  • Compute element failure (MXU, VPU)
  • Memory failure (HBM)
  • Router failure
  • Complete chip failure

12.2 11.2 Detection Mechanisms

Link Monitoring:

class LinkMonitor:
    def __init__(self, link_id):
        self.link_id = link_id
        self.error_count = 0
        self.ber_history = []
    
    def monitor(self):
        """Continuous monitoring"""
        while True:
            # Check pre-FEC BER
            pre_fec_ber = measure_pre_fec_ber(self.link_id)
            self.ber_history.append(pre_fec_ber)
            
            # Check CRC errors
            crc_errors = get_crc_error_count(self.link_id)
            self.error_count += crc_errors
            
            # Predictive failure detection
            if pre_fec_ber > 1e-4:
                log_warning(f"Link {self.link_id} BER degrading: {pre_fec_ber}")
                if pre_fec_ber > 5e-4:
                    trigger_link_failover(self.link_id)
            
            # Periodic link test
            if idle_time() > 1000:  # ms
                send_test_pattern(self.link_id)
            
            sleep(10)  # ms

12.3 11.3 Fault Tolerance Strategies

1. Link-Level Redundancy:

3D torus provides multiple paths:
  From (0,0,0) to (15,15,15):
    - Primary: +X(15), +Y(15), +Z(15)
    - Alternate 1: +Y(15), +X(15), +Z(15)
    - Alternate 2: -X(1), -Y(1), -Z(1)  [wraparound]
 
On link failure:
  → Use alternate path
  → Slightly longer latency, but training continues

2. Chip-Level Redundancy:

def checkpoint_and_recover(failed_chip_id):
    """
    Recovery strategy for chip failure
    """
    # Option 1: Checkpoint-restart
    if training_supports_checkpointing():
        save_checkpoint()
        blacklist_chip(failed_chip_id)
        relaunch_on_healthy_chips()
    
    # Option 2: Spare chips
    if spare_chips_available():
        spare = allocate_spare_chip()
        reconfigure_optical_switch(failed_chip_id, spare)
        restore_state_to_spare(spare)
    
    # Option 3: Degraded mode
    else:
        continue_training_with_fewer_chips()

3. Optical Circuit Reconfiguration:

Dynamic circuit reconfiguration allows:
  1. Isolate failed chips
  2. Reconfigure topology (e.g., 16×16×16 → 16×16×15)
  3. Continue training with slightly degraded performance
 
Reconfiguration time: 1-10 seconds (MEMS) or &lt;1 second (silicon photonics)
  → Acceptable compared to hours of training time

12.4 11.4 Reliability Metrics

Mean Time Between Failures (MTBF):

Component MTBFs:
  TPU chip: ~1,000,000 hours
  Optical transceiver: ~500,000 hours
  Fiber/connector: ~10,000,000 hours
 
Pod-level MTBF (4,096 chips):
  Chip failures: 4,096 / 1,000,000 hours = 1 failure every 10 days
  Optical failures: 24,576 / 500,000 hours = 1 failure every 21 days
  
Expected: ~1 failure per week
 
With fault tolerance:
  → Training continues (slight performance degradation)
  → Repair during scheduled maintenance

Availability:

Target: 99.9% availability (8.76 hours downtime/year)
 
With redundancy and fast failover:
  Achieved: &gt;99.95% (4.38 hours/year)

13. 12. Programming Model and Software Stack

13.1 12.1 Software Stack Overview

┌─────────────────────────────────────────┐
│  User Code (Python)                     │
│  • TensorFlow, JAX, PyTorch (via XLA)   │
└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐
│  XLA (Accelerated Linear Algebra)       │
│  • Compiler and optimizer               │
│  • Collective ops lowering              │
└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐
│  Collective Communication Library       │
│  • All-reduce, all-gather, reduce-scatter│
│  • Topology-aware optimization          │
└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐
│  ICI (Inter-Chip Interconnect) Driver   │
│  • Link management                      │
│  • DMA operations                       │
│  • Error handling                       │
└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐
│  Optical Interconnect Hardware          │
└─────────────────────────────────────────┘

13.2 12.2 Programming Model: SPMD

Single Program Multiple Data (SPMD):

import jax
import jax.numpy as jnp
from jax.experimental import maps
 
# Define mesh for 4096 TPUs arranged as 16×16×16
mesh_shape = (16, 16, 16)
mesh = maps.Mesh(jax.devices(), ('x', 'y', 'z'))
 
# Partition data across mesh
@maps.xmap(
    in_axes=(['x', 'y', 'z'], ...),
    out_axes=['x', 'y', 'z'],
    axis_resources={'x': 'x', 'y': 'y', 'z': 'z'}
)
def distributed_matmul(A, B):
    """
    Matrix multiply distributed across 3D mesh
    Collective ops automatically inserted by XLA
    """
    return jnp.dot(A, B)
 
# XLA compiler will:
# 1. Partition A and B across mesh
# 2. Insert all-reduce after local matmuls
# 3. Use topology-aware all-reduce (3D torus algorithm)
# 4. Generate efficient ICI operations

13.3 12.3 Collective Communication API

Example: Explicit All-Reduce:

import jax
from jax import lax
 
def data_parallel_training_step(params, batch):
    """
    One step of data-parallel training
    """
    # Each TPU computes gradients on local batch
    local_gradients = compute_gradients(params, batch)
    
    # All-reduce to average gradients across all TPUs
    # JAX automatically uses topology-aware all-reduce
    averaged_gradients = lax.psum(local_gradients, axis_name='batch')
    
    # Update parameters
    new_params = params - learning_rate * averaged_gradients
    
    return new_params
 
# XLA lowers 'psum' to:
# 1. 3D torus all-reduce
# 2. Direct ICI DMA operations
# 3. Optimal scheduling to overlap compute and communication

13.4 12.4 Topology-Aware Optimization

XLA Optimizations:

# Pseudo-code for XLA's collective lowering
 
def lower_allreduce_to_ici(all_reduce_op, mesh_topology):
    """
    XLA compiler optimization for all-reduce
    """
    if mesh_topology.type == '3D_TORUS':
        # Use dimension-ordered all-reduce
        schedule = [
            ('reduce_scatter', 'X'),
            ('reduce_scatter', 'Y'),
            ('reduce_scatter', 'Z'),
            ('all_gather', 'Z'),
            ('all_gather', 'Y'),
            ('all_gather', 'X'),
        ]
        
        for phase, dimension in schedule:
            if phase == 'reduce_scatter':
                emit_reduce_scatter_ici(dimension)
            else:
                emit_all_gather_ici(dimension)
    
    elif mesh_topology.type == 'RING':
        emit_ring_allreduce_ici()
    
    else:
        # Fallback to general algorithm
        emit_tree_allreduce()
 
def emit_reduce_scatter_ici(dimension):
    """
    Emit ICI operations for reduce-scatter along one dimension
    """
    neighbor_plus = get_neighbor(dimension, +1)
    neighbor_minus = get_neighbor(dimension, -1)
    
    for step in range(dimension_size - 1):
        # DMA send to +neighbor
        emit_ici_dma_send(neighbor_plus, chunk_offset, chunk_size)
        
        # DMA receive from -neighbor
        emit_ici_dma_recv(neighbor_minus, recv_buffer, chunk_size)
        
        # Local reduction
        emit_vector_add(local_buffer, recv_buffer)

14. 13. Industry Impact and Research Directions

14.1 13.1 Influence on Industry

Direct Impact:

  1. Meta Research SuperCluster (RSC):

    • Uses NVIDIA GPUs, but interconnect philosophy influenced by TPU Pods
    • Emphasis on uniform, high-bandwidth fabric
    • RoCE-based, but considering optical for future generations
  2. Microsoft Azure AI Infrastructure:

    • Custom interconnects for large-scale training
    • Exploring optical for next-generation systems
  3. AWS Trainium/Inferentia:

    • Custom interconnects for ML accelerators
    • EFA (Elastic Fabric Adapter) optimized for collectives

Broader Impact:

  • Optical Interconnect Startups:

    • Ayar Labs: TeraPHY optical I/O
    • Lightmatter: Photonic interconnect and compute
    • Celestial AI: Photonic fabric for AI
    • Ranovus: Data center optical switching
  • Industry Consortia:

    • Ultra Ethernet Consortium (AI-optimized Ethernet)
    • Optical Internetworking Forum (OIF)
    • Co-Packaged Optics (CPO) standardization efforts

14.2 13.2 Research Directions

1. Co-Packaged Optics (CPO):

Current: Pluggable optical modules
  Chip ← (electrical) → Package ← (electrical) → Module ← (optical) → Fiber
  Bottleneck: Electrical chip-to-module link
 
Future: Co-packaged optics
  Chip ← (short electrical) → Optical transceiver (on package) ← (optical) → Fiber
  Advantages:
    - 10× bandwidth (no electrical reach limit)
    - 5× lower power
    - Smaller form factor

2. Silicon Photonics Integration:

Goal: Integrate photonic waveguides directly on silicon
 
Technologies:
  - Silicon-on-insulator (SOI) waveguides
  - Germanium photodetectors (monolithic with Si)
  - III-V laser integration (heterogeneous or bonded)
  - Microring modulators (compact, low power)
 
Benefits:
  - 100× higher I/O density
  - &lt;1 pJ/bit energy efficiency
  - Wavelength-division multiplexing (100+ wavelengths)

3. Wavelength-Routed Networks:

Concept: Use wavelength as routing dimension
 
Traditional:
  All transceivers on same wavelength
  Optical switches route by port
 
Wavelength-routed:
  Each transceiver on unique wavelength
  Wavelength determines destination
  No switching needed (passive routing)
 
Example:
  Chip A (λ1, λ2, λ3) → Fiber → Chip B receives λ1
                                Chip C receives λ2
                                Chip D receives λ3
 
Advantages:
  - Ultra-low latency (no switching)
  - High bandwidth (100+ wavelengths × 100 Gbps)
  - Reconfigurable via wavelength tuning

4. Near-Memory and In-Network Computing:

Trend: Move computation closer to data
 
Options:
  1. Processing-in-Memory (PIM)
     - Compute inside HBM stacks
     - Reduce data movement
  
  2. In-Network Reduction
     - Perform reduce operations in optical switches
     - Reduces all-reduce latency
  
  3. Disaggregated Memory
     - Memory pools connected via optical fabric
     - Flexible capacity allocation

5. Hybrid Electrical-Optical:

Insight: Not everything needs optical
 
Design:
  - Short-reach (&lt;1m): Electrical (lower latency, lower cost)
  - Medium-reach (1-10m): Optical (copper becomes impractical)
  - Long-reach (&gt;10m): Optical (required)
 
Example:
  Within server: Electrical (NVLink-like)
  Rack-to-rack: Optical (TPU Pod-like)
  
Optimizes for cost, power, and performance

14.3 13.3 Open Research Questions

  1. Scaling Beyond 10K Accelerators:

    • What topology works best? (4D/5D torus, Dragonfly, Slim Fly?)
    • How to manage diameter and bisection bandwidth?
    • Multi-level hierarchies vs. flat topologies?
  2. Heterogeneous Interconnects:

    • Mix of optical (high BW) and electrical (low latency)
    • Dynamic selection based on message size
    • Protocol translation between domains
  3. Energy Proportionality:

    • Optical links are "always on" (lasers)
    • How to save power during idle periods?
    • Fast wake-up from sleep modes?
  4. Fault Tolerance at Scale:

    • With 10K+ chips, failures are constant
    • How to train through failures seamlessly?
    • Speculative execution and rollback?
  5. Software-Hardware Co-Design:

    • How should frameworks (PyTorch, JAX) expose interconnect topology?
    • Automatic tuning of collective algorithms?
    • Compiler optimizations for communication-compute overlap?

15. Summary: Key Takeaways

15.1 Technical Highlights

  1. Optical Circuit Switching: Enables scale, power efficiency, and determinism
  2. 3D Torus Topology: Optimal for structured AI workloads (all-reduce)
  3. Dimension-Ordered All-Reduce: ~100× faster than naive algorithms
  4. Power Efficiency: 3-4× better than electrical at datacenter scale
  5. Scalability: Uniform performance from 256 to 4,096+ chips

15.2 Why It Matters

TPU Pod's optical interconnect represents a fundamental shift in how we scale AI infrastructure:

  • Not just faster: Different technology (optical vs. electrical)
  • Not just bigger: Better topology (3D torus vs. fat-tree)
  • Not just Google: Influencing entire industry direction

15.3 Interview Preparation

Be ready to discuss:

  • Trade-offs: optical vs. electrical, circuit vs. packet switching
  • Topology comparisons: torus vs. fat-tree, 2D vs. 3D
  • Collective algorithms: ring vs. recursive halving vs. dimension-ordered
  • Power scaling: why power efficiency matters at scale
  • Real-world constraints: cost, reliability, backwards compatibility

Connections to other topics:

  • GPU architecture (NVLink comparison)
  • HPC interconnects (InfiniBand, Cray networks)
  • Distributed training algorithms (data parallel, model parallel, pipeline parallel)
  • Optical networking (WDM, silicon photonics, CPO)
  • System architecture (memory hierarchy, I/O bandwidth, Amdahl's law)

16. References

  1. Google TPU Papers:

    • Jouppi et al., "In-Datacenter Performance Analysis of a Tensor Processing Unit" (ISCA 2017)
    • Jouppi et al., "A Domain-Specific Supercomputer for Training Deep Neural Networks" (CACM 2020)
    • Jouppi et al., "Ten Lessons From Three Generations Spent Designing TPUs" (IEEE Micro 2021)
  2. Optical Interconnects:

    • Kachris et al., "Optical Interconnects for Data Centers" (IEEE Press, 2013)
    • Sun et al., "LIONS: An RDMA-Oriented Design for Low-Latency Optical Switches" (SIGCOMM 2020)
    • Mellette et al., "RotorNet: A Scalable, Low-complexity, Optical Datacenter Network" (SIGCOMM 2017)
  3. Network Topologies:

    • Dally & Towles, "Principles and Practices of Interconnection Networks" (Morgan Kaufmann, 2004)
    • Kim et al., "Technology-Driven, Highly-Scalable Dragonfly Topology" (ISCA 2008)
  4. Collective Communication:

    • Thakur et al., "Optimization of Collective Communication Operations in MPICH" (IJHPCA 2005)
    • Patarasuk & Yuan, "Bandwidth Optimal All-reduce Algorithms for Clusters of Workstations" (JPDC 2009)
  5. Silicon Photonics:

    • Soref, "The Past, Present, and Future of Silicon Photonics" (IEEE JSAC 2006)
    • Sun et al., "A 45 nm CMOS-SOI Monolithic Photonics Platform" (JLT 2018)
  6. Industry Reports:

    • Google Cloud TPU Documentation: https://cloud.google.com/tpu
    • Omdia: "AI Infrastructure Market Analysis" (Annual reports)
    • Yole Développement: "Silicon Photonics Market Report"