HipKittens: Fast and Furious AMD Kernels
HipKittens is a C++ embedded domain-specific language that provides tile-based programming primitives for high-performance AI kernel development on AMD GPUs. The framework introduces novel scheduling patterns (8-wave ping-pong and 4-wave interleave), explicit register management, and chiplet-aware cache optimization to achieve performance competitive with or exceeding hand-optimized assembly kernels across diverse AI workloads.
HipKittens: Fast and Furious AMD Kernels
1. Introduction and Problem Statement
The AI hardware landscape faces a critical challenge: while AMD GPUs offer state-of-the-art compute and memory bandwidth comparable to NVIDIA's latest offerings, they suffer from a "CUDA moat" problem. Peak-performance AMD kernels are written in raw assembly by a handful of experts, making it difficult to scale across the breadth of AI workloads.
The Hardware Lottery Problem
| Specification | NVIDIA B200 SXM5 | AMD MI355X OAM |
|---|---|---|
| BF16 compute | 2.2 PFLOPs | 2.5 PFLOPs |
| MXFP4 compute | 9.0 PFLOPs | 10.1 PFLOPs |
| Memory capacity | 180 GB | 288 GB |
| Memory bandwidth | 8.0 TB/s | 8.0 TB/s |
Despite AMD's competitive hardware specs, software support lags significantly:
- AMD's AITER library achieves only 30% of peak performance on some workloads
- PyTorch Llama GQA backwards reaches just 24% of peak performance
- Compilers like Triton sacrifice performance for simplicity
Key Insight: The research asks whether entirely new programming primitives are needed for AMD, or whether existing tile-based abstractions can be adapted.
2. Technical Approach: HipKittens Framework
HipKittens (HK) provides a minimal collection of C++ embedded programming primitives for AMD GPUs, building on the tile-based philosophy of ThunderKittens but reimagining the implementation for AMD's architecture.
2.1 Optimized Programmable Memory Access
2.1.1 Developer-Controlled Register Scheduling
Challenge: AMD hardware splits 512 registers per SIMD into:
- 256 VGPRs (Vector General-Purpose Registers)
- 256 AGPRs (Accumulator Registers)
However, the HIPCC compiler prevents using AGPRs as inputs to matrix instructions, forcing redundant data movement.
HK Solution: Bypass the compiler entirely by letting developers pin registers explicitly.
// Define explicit register ranges
using Q_ranges = split_many_t<type_list<range<24, 39>>, 4>;
// Create register tile with pinned registers
rt<bf16, 16, 128, row_l, rt_16x32_s, Q_ranges> Q_i;Impact: This enables state-of-the-art backwards attention kernels (Table 1):
| Method | Seq. Length | TFLOPS |
|---|---|---|
| HK (standard) | 4096 | 855 |
| HK (pinned registers) | 4096 | 1024 |
| AMD Assembly (AITER) | 4096 | 1018 |
2.1.2 Heterogeneous Matrix Core Layouts
Challenge: Unlike NVIDIA's compositional core matrix structure, AMD matrix instructions use entirely different layouts for each shape, creating an explosion of tile layouts.
HK Solution: Implement optimized swizzle patterns for commonly co-occurring layouts:
- Register tiles: Default to smallest MFMA instruction (16×16×32) for maximal scheduling control
- Shared memory tiles: Bank-conflict-free swizzles for common access patterns
- Global memory: Swizzle HBM addresses (not shared memory) for async loads
2.2 Overlapping Compute and Memory
2.2.1 Why Wave Specialization Fails on AMD
The dominant NVIDIA pattern—wave specialization (producer-consumer)—underperforms on AMD due to:
- Static register allocation: Producers consume registers without contributing to computation
- Limited output tile size: Reduces arithmetic intensity
- Smaller shared memory: 40% less SRAM per processor than NVIDIA B200
Experimental Evidence:
| Configuration | MFMA Shape | Output Tile | TFLOPS |
|---|---|---|---|
| 4 producers / 8 consumers | 16×16×32 | 128×256 | 893 |
| 4 producers / 12 consumers | 16×16×32 | 192×256 | 1278 |
| 0 producers / 8 consumers | 16×16×32 | 256×256 | 1610 |
| NVIDIA TK (B200) | 256×256×16 | 256×256 | 1538 |
2.2.2 HK's Scheduling Patterns
HK identifies two high-performance patterns that generalize across AI workloads:
1. 8-Wave Ping-Pong (Balanced Workloads)
- 8 waves per thread block (2 per SIMD)
- Waves alternate between compute and memory roles
- Uses large tile primitives (similar to wave specialization)
- Best for: Balanced compute/memory workloads
2. 4-Wave Interleave (Imbalanced Workloads)
- 1 wave per SIMD (4 total)
- Fine-grained instruction interleaving
- Uses small base tile primitives
- Best for: Compute-heavy or memory-heavy workloads
Performance vs. Programmability Tradeoff:
| Kernel Pattern | LoC | TFLOPS | Use Case |
|---|---|---|---|
| FP8 GEMM (8-wave) | 48 | 3222 | Simpler code |
| FP8 GEMM (4-wave) | 183 | 3327 | Peak performance |
| MHA backwards (8-wave) | 331 | 894 | Good balance |
| MHA backwards (4-wave) | 989 | 1091 | Maximum speed |
Key Finding: The simple 8-wave pattern is sufficient to match AMD's hand-optimized assembly kernels across BF16 GEMM, FP8 GEMM, and attention forward.
2.3 Chiplet-Aware Cache Optimization
Modern AMD GPUs use chiplet architectures (MI355X has 8 chiplets), creating a hierarchical cache structure:
- Each chiplet: 32 CUs with private L2 cache (4MB)
- All chiplets: Shared LLC (last-level cache)
Cache Reuse Algorithm
Cost Model:
Bandwidth = LLC_Bandwidth × LLC_Hit% + L2_Bandwidth × L2_Hit%Two Optimization Principles:
- L2 Reuse: Thread blocks on same chiplet should cover rectangular output regions
- LLC Reuse: Coordinate across chiplets to overlap input matrix access
Algorithm 1: XCD Swizzle for Cache Reuse
def xcd_swizzle(block_x, block_y, grid_x, grid_y, num_xcds, W, C):
"""
W: window height (optimizes L2)
C: chunk size (optimizes LLC)
"""
# Step 1: Flatten and group by XCD
xy = block_x + grid_x * block_y
xcd = xy % num_xcds
local = xy // num_xcds
# Step 2: Apply windowed traversal
tid_per_group = W * num_cols
group_id = xy // tid_per_group
first_row = group_id * W
# Map to 2D coordinates
row = first_row + (xy % W)
col = xy // W
return (row, col)Performance Impact:
| Block Order | L2 Hit% | LLC Hit% | Mem. BW | TFLOPS | Speedup |
|---|---|---|---|---|---|
| Row-major (naive) | 55% | 95% | 15.1 TB/s | 1113 | 1.0× |
| L2-optimized only | 79% | 24% | 14.9 TB/s | 991 | 0.89× |
| HK (L2+LLC) | 75% | 93% | 18.3 TB/s | 1145 | 1.03× |
For problematic shapes (e.g., 57 tiles across 8 XCDs):
| Shape | Strategy | L2 Hit% | LLC Hit% | TFLOPS | Speedup |
|---|---|---|---|---|---|
| 14592³ | Row-major | 36% | 76% | 900 | 1.0× |
| 14592³ | HK optimized | 78% | 55% | 1068 | 1.19× |
3. Key Results
3.1 GEMM Performance
BF16 GEMM (MI355X):
- Matches AMD assembly (AITER) and HipBLASLT
- 1.3-3.0× faster than Triton compiler
- Single 8-wave kernel generalizes across problem sizes
FP8 GEMM (MI355X):
- 8-wave: 3222 TFLOPS (48 lines of code)
- 4-wave: 3327 TFLOPS (183 lines of code)
- Competitive with hand-optimized assembly
3.2 Attention Kernels
Attention Forward:
- 1.0-2.1× faster than AITER (assembly)
- 1.3-4.5× faster than PyTorch SDPA
- 1.2-4.5× faster than Triton
Attention Backward:
- GQA non-causal: 1.8-2.5× faster than all baselines
- Uses multiple MFMA shapes (16×16×32, 32×32×16)
- Leverages explicit register pinning
3.3 Memory-Bound Kernels
Fused Dropout-Residual-LayerNorm:
- 1.1-2.2× faster than AITER and PyTorch compiled
Rotary Positional Encoding:
- 1.2-1.8× faster than baselines
3.4 Cross-Platform Validation
| Workload | MI325X (CDNA3) | MI355X (CDNA4) | Status |
|---|---|---|---|
| BF16 GEMM | ✓ Competitive | ✓ Competitive | Matches assembly |
| FP8 GEMM | ✓ Competitive | ✓ Competitive | Matches assembly |
| MHA/GQA Forward | ✓ Best | ✓ Best | Beats all baselines |
| MHA/GQA Backward | ✓ Best | ✓ Best | 1.8-2.5× speedup |
| Memory-bound | ✓ Best | ✓ Best | 1.1-2.2× speedup |
4. Practical Implications
4.1 Democratizing AMD Kernel Development
Before HipKittens:
- Peak performance required raw assembly expertise
- Limited to handful of AMD engineers
- Slow to scale across AI workloads
- Example: GQA backwards unsupported in assembly libraries
With HipKittens:
- Tile-based abstractions familiar to AI researchers
- PyTorch-inspired API (mma, exp, add operations)
- Reusable scheduling patterns (8-wave, 4-wave)
- Comprehensive kernel suite released
4.2 Real-World Validation
Model Training:
- Pretrained Llama 1B (10B tokens on Slim Pajama)
- Pretrained BERT 110M
- Result: Matched perplexity of PyTorch/AITER baselines
4.3 Code Simplicity
Example: 8-wave ping-pong attention forward kernel structure:
// HipKittens tile-based code (simplified)
rt<bf16, 16, 128, row_l> Q_tile, K_tile, V_tile;
// Load tiles from global memory
load(Q_tile, Q_global);
load(K_tile, K_global);
// Compute attention with bulk operations
mma(QK_tile, Q_tile, K_tile); // Matrix multiply
exp(QK_tile, QK_tile); // Softmax exp
mma(O_tile, QK_tile, V_tile); // Final output
// Store result
store(O_global, O_tile);Compare to raw assembly: 331 lines (8-wave) vs 989 lines (4-wave hand-optimized)
5. Related Work and Context
5.1 Comparison with Existing Approaches
| Framework | Target | Abstraction Level | AMD Support | Peak Performance |
|---|---|---|---|---|
| Raw Assembly | AMD/NVIDIA | Lowest | ✓ Full | ✓ Best (but not scalable) |
| CUTLASS | NVIDIA | C++ Templates | ✗ None | ✓ Excellent |
| ThunderKittens | NVIDIA | Tile-based C++ | ✗ None | ✓ Excellent |
| Triton | Both | Python DSL | ✓ Partial | ✗ 1.3-3× slower |
| Mojo | Both | Python-like | ✓ Partial | ✗ 2× slower |
| TileLang | Both | Python DSL | ✓ Limited | ? Incomplete eval |
| HipKittens | AMD | Tile-based C++ | ✓ Full | ✓ Matches/beats assembly |
5.2 Key Architectural Differences
NVIDIA Advantages:
- Larger shared memory (40% more per processor)
- Dedicated memory hardware (TMA)
- Asynchronous matrix multiply (wgmma)
- Register reallocation
- Hardware synchronization (mbarriers)
- Compositional core matrix structure
AMD Advantages:
- 2× larger register file
- Competitive peak compute/bandwidth
- Lower hardware cost
- More memory capacity (288GB vs 180GB)
HipKittens' Contribution: Shows that AMD's architectural differences don't require entirely new abstractions—tile-based programming works, but the implementation must be rethought.
6. Technical Innovations Summary
6.1 Three-Level Optimization Strategy
6.2 Unified Programming Model Vision
Long-term Goal: A single, tile-based software layer for high-performance AI kernels that translates across GPU vendors.
Evidence of Generality:
- Tile abstractions work on both NVIDIA and AMD
- PyTorch-inspired API familiar to researchers
- Scheduling patterns differ but remain composable
- Same front-end interface, different back-end instantiation
7. Conclusion
HipKittens demonstrates that:
- Tile-based abstractions generalize across GPU vendors (NVIDIA → AMD)
- Compiler bypass is sometimes necessary for peak performance (register pinning)
- Simple scheduling patterns suffice (8-wave ping-pong matches assembly)
- Chiplet architectures require new optimizations (L2/LLC joint optimization)
- A unified programming model is achievable (same interface, different implementation)
Impact: Opens the hardware landscape for AI by providing the first systematic, high-performance, and accessible AMD kernel framework—moving beyond the "CUDA moat" toward true multi-vendor portability.