GPU Matrix Multiplication Optimization
GEMM optimization on GPUs: tiling, memory hierarchy, coalescing, and tensor cores for peak performance.
Prerequisites
Make sure you're familiar with these concepts before diving in:
Learning Objectives
By the end of this topic, you will be able to:
Table of Contents
GPU Matrix Multiplication Optimization
Matrix multiplication (GEMM) is the backbone of modern AI and the ultimate test of GPU optimization skills. Let's master the art of making silicon sing with perfectly orchestrated data movement and compute! 🎯
1. GEMM Fundamentals
1.1 The Operation
Compute C = αAB + βC
where:
- A: [M×K] matrix
- B: [K×N] matrix
- C: [M×N] matrix
- FLOPs: ~2MNK (multiply-add for each output element)
1.2 Arithmetic Intensity (AI)
The holy grail of optimization:
AI = FLOPs / Bytes_moved = 2MNK / (sizeof(A) + sizeof(B) + sizeof(C))
Goal: Increase AI through data reuse (tiling) to become compute-bound rather than memory-bound.
2. GPU Mapping Strategy
2.1 Thread Hierarchy Mapping
The key insight: Map the problem hierarchy to the hardware hierarchy
2.2 Basic Tiling Strategy
// Thread block computes a tile of C
__global__ void gemm_tiled(float* A, float* B, float* C, int M, int N, int K) {
const int TILE_SIZE = 16;
// Thread block tile coordinates
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
// Shared memory for tiles
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
float sum = 0.0f;
// Loop over K dimension in tiles
for (int k = 0; k < K; k += TILE_SIZE) {
// Load tiles into shared memory
As[ty][tx] = A[(by * TILE_SIZE + ty) * K + (k + tx)];
Bs[ty][tx] = B[(k + ty) * N + (bx * TILE_SIZE + tx)];
__syncthreads();
// Compute partial dot product
for (int i = 0; i < TILE_SIZE; i++) {
sum += As[ty][i] * Bs[i][tx];
}
__syncthreads();
}
// Write result
C[(by * TILE_SIZE + ty) * N + (bx * TILE_SIZE + tx)] = sum;
}
3. Memory Optimization Techniques
3.1 Shared Memory & Register Tiling
The performance secret: Exploit the memory hierarchy aggressively
template<int TILE_M, int TILE_N, int TILE_K>
__global__ void gemm_optimized() {
// Shared memory staging
__shared__ float As[TILE_M][TILE_K + 1]; // +1 to avoid bank conflicts
__shared__ float Bs[TILE_K][TILE_N + 1];
// Register tiles for accumulation
float C_reg[REG_TILE_M][REG_TILE_N] = {0};
float A_reg[REG_TILE_M];
float B_reg[REG_TILE_N];
// K-loop with register blocking
for (int k = 0; k < K; k += TILE_K) {
// Load tiles to shared memory (coalesced)
load_tile_A(As, A, k);
load_tile_B(Bs, B, k);
__syncthreads();
// Inner loop: shared memory to registers
for (int kk = 0; kk < TILE_K; kk++) {
// Load from shared to registers
load_A_registers(A_reg, As, kk);
load_B_registers(B_reg, Bs, kk);
// Compute: registers only
compute_tile(C_reg, A_reg, B_reg);
}
__syncthreads();
}
// Store results
store_C_tile(C, C_reg);
}
3.2 Bank Conflict Avoidance
Shared memory is banked - avoid conflicts for peak bandwidth:
// Bad: All threads access same bank
__shared__ float shared[32][32];
float val = shared[threadIdx.x][0]; // Bank 0 conflict!
// Good: Padding breaks the pattern
__shared__ float shared[32][33]; // Extra column breaks stride
float val = shared[threadIdx.x][threadIdx.y];
// Advanced: Swizzling for complex patterns
int swizzled_col = (col + row) % 33;
float val = shared[row][swizzled_col];
3.3 Coalescing Optimization
The golden rule: Consecutive threads should access consecutive memory locations.
// Excellent: Vectorized coalesced loads
__global__ void gemm_vectorized() {
// Load 4 floats per thread in one instruction
float4 A_vec = reinterpret_cast<float4*>(A)[global_idx];
// Unpack for computation
float A_vals[4] = {A_vec.x, A_vec.y, A_vec.z, A_vec.w};
}
// Memory transaction analysis:
// 32 threads × 4 floats = 128 floats = 512 bytes = 4 cache lines
// Optimal: 4 transactions for 32 threads
4. Tensor Cores & Mixed Precision
4.1 Tensor Core Programming
Modern GPUs have specialized matrix units for AI workloads:
#include <mma.h>
using namespace nvcuda;
__global__ void gemm_tensor_core() {
// Tensor Core fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;
// Initialize accumulator
wmma::fill_fragment(c_frag, 0.0f);
for (int k = 0; k < K; k += 16) {
// Load 16×16 tiles
wmma::load_matrix_sync(a_frag, A + k, K);
wmma::load_matrix_sync(b_frag, B + k * N, N);
// Matrix multiply-accumulate
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
// Store result
wmma::store_matrix_sync(C, c_frag, N, wmma::mem_row_major);
}
4.2 Mixed Precision Strategy
Transformer Engine pattern for dynamic precision:
// Input: FP16/BF16 for memory efficiency
// Compute: FP32 accumulation for numerical stability
// Output: FP16/BF16 for downstream efficiency
__global__ void gemm_mixed_precision() {
// Load in reduced precision
half A_half = A[idx];
half B_half = B[idx];
// Promote to FP32 for accumulation
float A_float = __half2float(A_half);
float B_float = __half2float(B_half);
// Accumulate in FP32
float sum = A_float * B_float + prev_sum;
// Store in reduced precision
C[idx] = __float2half(sum);
}
5. Advanced Optimization Techniques
5.1 Double Buffering & Async Copies
Overlap memory and compute for peak utilization:
__global__ void gemm_double_buffered() {
// Double-buffered shared memory
__shared__ float As[2][TILE_M][TILE_K];
__shared__ float Bs[2][TILE_K][TILE_N];
int read_stage = 0, write_stage = 1;
// Prefetch first tile
load_tile_async(As[write_stage], A, 0);
load_tile_async(Bs[write_stage], B, 0);
cp_async_commit_group();
for (int k = TILE_K; k < K; k += TILE_K) {
// Swap buffers
read_stage ^= 1;
write_stage ^= 1;
// Start loading next tile (async)
load_tile_async(As[write_stage], A, k);
load_tile_async(Bs[write_stage], B, k);
cp_async_commit_group();
// Wait for current tile
cp_async_wait_group(1);
__syncthreads();
// Compute on current tile while next loads
compute_tile(As[read_stage], Bs[read_stage], C_accum);
}
}
5.2 Stream-K Decomposition
Load balancing for irregular problem sizes:
// Distribute work across thread blocks more evenly
// Instead of fixed 2D grid, use 1D work distribution
__global__ void gemm_stream_k(int total_tiles) {
int tile_id = blockIdx.x;
while (tile_id < total_tiles) {
// Convert 1D tile_id to 2D coordinates
int tile_m = tile_id / tiles_n;
int tile_n = tile_id % tiles_n;
// Process tile
compute_tile(tile_m, tile_n);
// Get next tile
tile_id += gridDim.x;
}
}
6. Roofline Analysis for GEMM
6.1 Identifying Bottlenecks
// Theoretical analysis
float peak_flops = 312e12; // A100 FP16 tensor FLOPS
float peak_bandwidth = 1555e9; // A100 HBM bandwidth
// Problem characteristics
float flops = 2.0f * M * N * K;
float bytes = sizeof(half) * (M*K + K*N + M*N); // Assuming FP16
float arithmetic_intensity = flops / bytes;
// Performance bounds
float compute_bound_perf = peak_flops;
float memory_bound_perf = arithmetic_intensity * peak_bandwidth;
float theoretical_perf = min(compute_bound_perf, memory_bound_perf);
6.2 Optimization Strategy Selection
if (arithmetic_intensity < 10) {
// Memory bound - focus on data movement
optimize_coalescing();
use_shared_memory_efficiently();
minimize_global_memory_traffic();
} else {
// Compute bound - focus on math throughput
use_tensor_cores();
optimize_mixed_precision();
maximize_occupancy();
}
7. Structured Sparsity in GEMM
7.1 2:4 Sparse GEMM
Hardware acceleration for structured sparse patterns:
// Sparse GEMM with 2:4 pattern
#include <cusparseLt.h>
// Compressed sparse matrix format
struct SparseMatrix {
void* compressed_values; // 50% of original values
void* metadata; // 2 bits per 4 elements
int compressed_size;
};
// Sparse GEMM kernel
cusparseLtMatmul(handle, &plan, &alpha,
sparse_A.compressed_values, sparse_A.metadata,
dense_B, &beta, dense_C,
workspace, streams, num_streams);
Performance characteristics:
- Theoretical: 2× speedup (half the math)
- Practical: 1.3-1.6× (overhead from metadata and scheduling)
- Memory: ~50% reduction in weight storage
8. Verification & Benchmarking
8.1 Correctness Validation
bool validate_gemm(const Matrix& A, const Matrix& B,
const Matrix& C_gpu, const Matrix& C_ref) {
float max_error = 0.0f;
float relative_error = 0.0f;
for (int i = 0; i < M * N; i++) {
float error = abs(C_gpu.data[i] - C_ref.data[i]);
float rel_err = error / max(abs(C_ref.data[i]), 1e-7f);
max_error = max(max_error, error);
relative_error = max(relative_error, rel_err);
}
return max_error < 1e-3f && relative_error < 1e-2f;
}
8.2 Performance Metrics
struct GEMMMetrics {
float achieved_tflops;
float memory_bandwidth_gbps;
float efficiency_vs_peak;
float speedup_vs_cublas;
void compute(float runtime_ms, int M, int N, int K) {
float flops = 2.0f * M * N * K;
achieved_tflops = flops / (runtime_ms * 1e-3f) / 1e12f;
float bytes = sizeof(float) * (M*K + K*N + M*N);
memory_bandwidth_gbps = bytes / (runtime_ms * 1e-3f) / 1e9f;
}
};
9. Interview Practice Questions
10. Hands-On Exercises
10.1 Exercise 1: Roofline Analysis
Given M=N=K=4096, FP16 precision on A100 GPU:
- Calculate arithmetic intensity
- Determine if compute or memory bound
- Estimate theoretical peak performance
- Compare with cuBLAS baseline
10.2 Exercise 2: Tile Size Optimization
Design shared memory tile sizes that:
- Keep register usage under 128 per thread
- Fit in 48KB shared memory per SM
- Avoid bank conflicts for both A and B tiles
- Maximize data reuse
10.3 Exercise 3: Tensor Core Integration
Convert a basic GEMM kernel to use tensor cores:
- Show fragment declarations and loading
- Handle mixed precision correctly
- Optimize for warp-level cooperation
- Measure speedup vs CUDA core version
11. Key Takeaways
- Tiling is everything: Transform memory-bound to compute-bound through data reuse
- Memory hierarchy mastery: Shared memory staging with coalescing and bank conflict avoidance
- Tensor cores are the future: Learn WMMA API and mixed precision techniques
- Overlap is critical: Use async copies and double buffering for peak utilization
- Roofline guides optimization: Identify bottlenecks before optimizing
GEMM optimization is both an art and a science. Master these techniques, and you'll have the foundation for optimizing any compute-intensive GPU kernel! 🚀