Skip to main content
GPUintermediate

GPU Matrix Multiplication Optimization

GEMM optimization on GPUs: tiling, memory hierarchy, coalescing, and tensor cores for peak performance.

20 min read
Updated 9/7/2024
3 prerequisites

Prerequisites

Make sure you're familiar with these concepts before diving in:

GPU Architecture Basics
CUDA Programming
Memory Hierarchy

Learning Objectives

By the end of this topic, you will be able to:

Design efficient GEMM kernels using tiling and shared memory
Optimize memory access patterns for coalescing and bank conflict avoidance
Leverage tensor cores and mixed precision for maximum throughput
Apply roofline analysis to identify compute vs memory bottlenecks
Implement advanced techniques like double buffering and async copies

Table of Contents

GPU Matrix Multiplication Optimization

Matrix multiplication (GEMM) is the backbone of modern AI and the ultimate test of GPU optimization skills. Let's master the art of making silicon sing with perfectly orchestrated data movement and compute! 🎯

1. GEMM Fundamentals

1.1 The Operation

Compute C = αAB + βC where:

  • A: [M×K] matrix
  • B: [K×N] matrix
  • C: [M×N] matrix
  • FLOPs: ~2MNK (multiply-add for each output element)

1.2 Arithmetic Intensity (AI)

The holy grail of optimization:

AI = FLOPs / Bytes_moved = 2MNK / (sizeof(A) + sizeof(B) + sizeof(C))

Goal: Increase AI through data reuse (tiling) to become compute-bound rather than memory-bound.

2. GPU Mapping Strategy

2.1 Thread Hierarchy Mapping

The key insight: Map the problem hierarchy to the hardware hierarchy

Rendering diagram...

2.2 Basic Tiling Strategy

// Thread block computes a tile of C
__global__ void gemm_tiled(float* A, float* B, float* C, int M, int N, int K) {
    const int TILE_SIZE = 16;
    
    // Thread block tile coordinates
    int bx = blockIdx.x, by = blockIdx.y;
    int tx = threadIdx.x, ty = threadIdx.y;
    
    // Shared memory for tiles
    __shared__ float As[TILE_SIZE][TILE_SIZE];
    __shared__ float Bs[TILE_SIZE][TILE_SIZE];
    
    float sum = 0.0f;
    
    // Loop over K dimension in tiles
    for (int k = 0; k < K; k += TILE_SIZE) {
        // Load tiles into shared memory
        As[ty][tx] = A[(by * TILE_SIZE + ty) * K + (k + tx)];
        Bs[ty][tx] = B[(k + ty) * N + (bx * TILE_SIZE + tx)];
        
        __syncthreads();
        
        // Compute partial dot product
        for (int i = 0; i < TILE_SIZE; i++) {
            sum += As[ty][i] * Bs[i][tx];
        }
        
        __syncthreads();
    }
    
    // Write result
    C[(by * TILE_SIZE + ty) * N + (bx * TILE_SIZE + tx)] = sum;
}

3. Memory Optimization Techniques

3.1 Shared Memory & Register Tiling

The performance secret: Exploit the memory hierarchy aggressively

template<int TILE_M, int TILE_N, int TILE_K>
__global__ void gemm_optimized() {
    // Shared memory staging
    __shared__ float As[TILE_M][TILE_K + 1];  // +1 to avoid bank conflicts
    __shared__ float Bs[TILE_K][TILE_N + 1];
    
    // Register tiles for accumulation
    float C_reg[REG_TILE_M][REG_TILE_N] = {0};
    float A_reg[REG_TILE_M];
    float B_reg[REG_TILE_N];
    
    // K-loop with register blocking
    for (int k = 0; k < K; k += TILE_K) {
        // Load tiles to shared memory (coalesced)
        load_tile_A(As, A, k);
        load_tile_B(Bs, B, k);
        __syncthreads();
        
        // Inner loop: shared memory to registers
        for (int kk = 0; kk < TILE_K; kk++) {
            // Load from shared to registers
            load_A_registers(A_reg, As, kk);
            load_B_registers(B_reg, Bs, kk);
            
            // Compute: registers only
            compute_tile(C_reg, A_reg, B_reg);
        }
        __syncthreads();
    }
    
    // Store results
    store_C_tile(C, C_reg);
}

3.2 Bank Conflict Avoidance

Shared memory is banked - avoid conflicts for peak bandwidth:

// Bad: All threads access same bank
__shared__ float shared[32][32];
float val = shared[threadIdx.x][0];  // Bank 0 conflict!
 
// Good: Padding breaks the pattern
__shared__ float shared[32][33];     // Extra column breaks stride
float val = shared[threadIdx.x][threadIdx.y];
 
// Advanced: Swizzling for complex patterns
int swizzled_col = (col + row) % 33;
float val = shared[row][swizzled_col];

3.3 Coalescing Optimization

The golden rule: Consecutive threads should access consecutive memory locations.

// Excellent: Vectorized coalesced loads
__global__ void gemm_vectorized() {
    // Load 4 floats per thread in one instruction
    float4 A_vec = reinterpret_cast<float4*>(A)[global_idx];
    
    // Unpack for computation
    float A_vals[4] = {A_vec.x, A_vec.y, A_vec.z, A_vec.w};
}
 
// Memory transaction analysis:
// 32 threads × 4 floats = 128 floats = 512 bytes = 4 cache lines
// Optimal: 4 transactions for 32 threads

4. Tensor Cores & Mixed Precision

4.1 Tensor Core Programming

Modern GPUs have specialized matrix units for AI workloads:

#include <mma.h>
using namespace nvcuda;
 
__global__ void gemm_tensor_core() {
    // Tensor Core fragments
    wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
    wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;
    
    // Initialize accumulator
    wmma::fill_fragment(c_frag, 0.0f);
    
    for (int k = 0; k < K; k += 16) {
        // Load 16×16 tiles
        wmma::load_matrix_sync(a_frag, A + k, K);
        wmma::load_matrix_sync(b_frag, B + k * N, N);
        
        // Matrix multiply-accumulate
        wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
    }
    
    // Store result
    wmma::store_matrix_sync(C, c_frag, N, wmma::mem_row_major);
}

4.2 Mixed Precision Strategy

Transformer Engine pattern for dynamic precision:

// Input: FP16/BF16 for memory efficiency
// Compute: FP32 accumulation for numerical stability  
// Output: FP16/BF16 for downstream efficiency
 
__global__ void gemm_mixed_precision() {
    // Load in reduced precision
    half A_half = A[idx];
    half B_half = B[idx];
    
    // Promote to FP32 for accumulation
    float A_float = __half2float(A_half);
    float B_float = __half2float(B_half);
    
    // Accumulate in FP32
    float sum = A_float * B_float + prev_sum;
    
    // Store in reduced precision
    C[idx] = __float2half(sum);
}

5. Advanced Optimization Techniques

5.1 Double Buffering & Async Copies

Overlap memory and compute for peak utilization:

__global__ void gemm_double_buffered() {
    // Double-buffered shared memory
    __shared__ float As[2][TILE_M][TILE_K];
    __shared__ float Bs[2][TILE_K][TILE_N];
    
    int read_stage = 0, write_stage = 1;
    
    // Prefetch first tile
    load_tile_async(As[write_stage], A, 0);
    load_tile_async(Bs[write_stage], B, 0);
    cp_async_commit_group();
    
    for (int k = TILE_K; k < K; k += TILE_K) {
        // Swap buffers
        read_stage ^= 1;
        write_stage ^= 1;
        
        // Start loading next tile (async)
        load_tile_async(As[write_stage], A, k);
        load_tile_async(Bs[write_stage], B, k);
        cp_async_commit_group();
        
        // Wait for current tile
        cp_async_wait_group(1);
        __syncthreads();
        
        // Compute on current tile while next loads
        compute_tile(As[read_stage], Bs[read_stage], C_accum);
    }
}

5.2 Stream-K Decomposition

Load balancing for irregular problem sizes:

// Distribute work across thread blocks more evenly
// Instead of fixed 2D grid, use 1D work distribution
__global__ void gemm_stream_k(int total_tiles) {
    int tile_id = blockIdx.x;
    
    while (tile_id < total_tiles) {
        // Convert 1D tile_id to 2D coordinates
        int tile_m = tile_id / tiles_n;
        int tile_n = tile_id % tiles_n;
        
        // Process tile
        compute_tile(tile_m, tile_n);
        
        // Get next tile
        tile_id += gridDim.x;
    }
}

6. Roofline Analysis for GEMM

6.1 Identifying Bottlenecks

// Theoretical analysis
float peak_flops = 312e12;        // A100 FP16 tensor FLOPS
float peak_bandwidth = 1555e9;    // A100 HBM bandwidth
 
// Problem characteristics
float flops = 2.0f * M * N * K;
float bytes = sizeof(half) * (M*K + K*N + M*N);  // Assuming FP16
float arithmetic_intensity = flops / bytes;
 
// Performance bounds
float compute_bound_perf = peak_flops;
float memory_bound_perf = arithmetic_intensity * peak_bandwidth;
float theoretical_perf = min(compute_bound_perf, memory_bound_perf);
Rendering diagram...

6.2 Optimization Strategy Selection

if (arithmetic_intensity < 10) {
    // Memory bound - focus on data movement
    optimize_coalescing();
    use_shared_memory_efficiently();
    minimize_global_memory_traffic();
} else {
    // Compute bound - focus on math throughput  
    use_tensor_cores();
    optimize_mixed_precision();
    maximize_occupancy();
}

7. Structured Sparsity in GEMM

7.1 2:4 Sparse GEMM

Hardware acceleration for structured sparse patterns:

// Sparse GEMM with 2:4 pattern
#include <cusparseLt.h>
 
// Compressed sparse matrix format
struct SparseMatrix {
    void* compressed_values;    // 50% of original values
    void* metadata;            // 2 bits per 4 elements
    int compressed_size;
};
 
// Sparse GEMM kernel
cusparseLtMatmul(handle, &plan, &alpha,
                 sparse_A.compressed_values, sparse_A.metadata,
                 dense_B, &beta, dense_C,
                 workspace, streams, num_streams);

Performance characteristics:

  • Theoretical: 2× speedup (half the math)
  • Practical: 1.3-1.6× (overhead from metadata and scheduling)
  • Memory: ~50% reduction in weight storage

8. Verification & Benchmarking

8.1 Correctness Validation

bool validate_gemm(const Matrix& A, const Matrix& B, 
                   const Matrix& C_gpu, const Matrix& C_ref) {
    float max_error = 0.0f;
    float relative_error = 0.0f;
    
    for (int i = 0; i < M * N; i++) {
        float error = abs(C_gpu.data[i] - C_ref.data[i]);
        float rel_err = error / max(abs(C_ref.data[i]), 1e-7f);
        
        max_error = max(max_error, error);
        relative_error = max(relative_error, rel_err);
    }
    
    return max_error < 1e-3f && relative_error < 1e-2f;
}

8.2 Performance Metrics

struct GEMMMetrics {
    float achieved_tflops;
    float memory_bandwidth_gbps;
    float efficiency_vs_peak;
    float speedup_vs_cublas;
    
    void compute(float runtime_ms, int M, int N, int K) {
        float flops = 2.0f * M * N * K;
        achieved_tflops = flops / (runtime_ms * 1e-3f) / 1e12f;
        
        float bytes = sizeof(float) * (M*K + K*N + M*N);
        memory_bandwidth_gbps = bytes / (runtime_ms * 1e-3f) / 1e9f;
    }
};

9. Interview Practice Questions

10. Hands-On Exercises

10.1 Exercise 1: Roofline Analysis

Given M=N=K=4096, FP16 precision on A100 GPU:

  • Calculate arithmetic intensity
  • Determine if compute or memory bound
  • Estimate theoretical peak performance
  • Compare with cuBLAS baseline

10.2 Exercise 2: Tile Size Optimization

Design shared memory tile sizes that:

  • Keep register usage under 128 per thread
  • Fit in 48KB shared memory per SM
  • Avoid bank conflicts for both A and B tiles
  • Maximize data reuse

10.3 Exercise 3: Tensor Core Integration

Convert a basic GEMM kernel to use tensor cores:

  • Show fragment declarations and loading
  • Handle mixed precision correctly
  • Optimize for warp-level cooperation
  • Measure speedup vs CUDA core version

11. Key Takeaways

  1. Tiling is everything: Transform memory-bound to compute-bound through data reuse
  2. Memory hierarchy mastery: Shared memory staging with coalescing and bank conflict avoidance
  3. Tensor cores are the future: Learn WMMA API and mixed precision techniques
  4. Overlap is critical: Use async copies and double buffering for peak utilization
  5. Roofline guides optimization: Identify bottlenecks before optimizing

GEMM optimization is both an art and a science. Master these techniques, and you'll have the foundation for optimizing any compute-intensive GPU kernel! 🚀