ML performance is a systems problem
Training a Transformer is not just about the math - it is about whether your memory access pattern fits the GPU cache, whether your operators can be fused to avoid redundant writes, and whether your communication strategy matches your hardware topology. This project addresses both: optimizing GPU operations with Triton, and scaling across nodes with MPI.
Fused Triton Kernel: D = ReLU(A × B + C)
Tile Assignment + Parallelism
Divides the output matrix into tiles. Each GPU thread block computes one tile in parallel. BLOCK_M, BLOCK_N, BLOCK_K tuned via configuration grid search for best performance.
Shared Memory Tiling
Loads sub-blocks of A and B into fast on-chip SRAM before computation. Dramatically reduces expensive global DRAM accesses - the primary GPU performance bottleneck. Register accumulation for partial results.
Operator Fusion ⭐
Fuses MatMul + Add + ReLU into a single kernel pass. Eliminates intermediate result writes to global memory - the key insight for achieving ≥1.25× PyTorch speedup. Performance depends on memory access patterns, not raw compute.
# Triton kernel - fused matmul + add + ReLU with shared memory @triton.jit def fused_matmul_add_relu(A, B, C, D, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): pid_m = tl.program_id(0) pid_n = tl.program_id(1) acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) for k in range(0, K, BLOCK_K): a = tl.load(...) # load tile of A into SRAM b = tl.load(...) # load tile of B into SRAM acc += tl.dot(a, b) acc = acc + tl.load(c_ptrs) # fuse Add acc = tl.maximum(acc, 0) # fuse ReLU tl.store(d_ptrs, acc, mask=mask) # single write to global memory
2D Data × Tensor Parallelism with MPI
Custom AllReduce + AllToAll
Implemented both collective communication primitives from scratch with mpi4py. Benchmarked against native MPI - custom is slower because MPI uses hardware-aware ring and tree algorithms with NVLink awareness.
Data Parallelism
Splits dataset uniformly across nodes via split_data(). Same model weights, different data shards. Requires gradient synchronization via AllReduce after each backward pass. Scales compute linearly with nodes.
Tensor Model Parallelism
Shards Transformer weights: fc_q/k/v split on output dimension, fc_o split on input dimension. Reduces memory per node but increases communication frequency - requires careful DP/MP balance.
Forward + Backward Communication
naive_collect_forward_input/output for activations. naive_collect_backward_output/x for gradients via Reduce-Scatter. Backward requires more communication than forward due to gradient aggregation.
Technologies Used
| Component | Tool & Purpose |
|---|---|
| GPU Kernel | Triton - tile-based kernel, shared memory, register accumulation, operator fusion |
| Distributed Comm | MPI (mpi4py) - AllReduce, AllToAll, Reduce-Scatter from scratch + native |
| Parallelism | 2D: Data Parallel (DP) × Tensor Model Parallel (MP) |
| Tuning | Grid search over BLOCK_M/N/K, num_warps, num_stages |
| Tensors | NumPy + PyTorch tensors - computation substrate |
Results
- Operator fusion is the highest-impact GPU optimization - avoiding intermediate memory writes beats compute savings
- Memory access patterns dominate GPU performance - SRAM tiling matters more than arithmetic throughput
- Communication is the bottleneck in distributed training - MPI hardware-aware ring algorithms are very hard to match
- Tensor parallelism reduces memory per node but requires careful sharding of fc_q/k/v vs fc_o dimensions
What I took away
- GPU performance is 80% memory, 20% compute - reducing global memory accesses via SRAM tiling and operator fusion is the correct mental model.
- Implementing AllReduce from scratch confirmed why MPI is faster: hardware-aware ring/tree algorithms exploit topology that naive send/recv ignores.
- Tensor parallelism sharding rules are subtle - fc_q/k/v split on output dim and fc_o on input dim is not obvious until you trace gradient shapes carefully.
- Backward communication needs more coordination than forward - Reduce-Scatter for gradients is more complex than AllGather for activations.