GPU Kernel Optimization + Distributed Training

ML performance is a systems problem

Training a Transformer is not just about the math - it is about whether your memory access pattern fits the GPU cache, whether your operators can be fused to avoid redundant writes, and whether your communication strategy matches your hardware topology. This project addresses both: optimizing GPU operations with Triton, and scaling across nodes with MPI.

Fused Triton Kernel: D = ReLU(A × B + C)

Tile Assignment + Parallelism

Divides the output matrix into tiles. Each GPU thread block computes one tile in parallel. BLOCK_M, BLOCK_N, BLOCK_K tuned via configuration grid search for best performance.

Shared Memory Tiling

Loads sub-blocks of A and B into fast on-chip SRAM before computation. Dramatically reduces expensive global DRAM accesses - the primary GPU performance bottleneck. Register accumulation for partial results.

Operator Fusion ⭐

Fuses MatMul + Add + ReLU into a single kernel pass. Eliminates intermediate result writes to global memory - the key insight for achieving ≥1.25× PyTorch speedup. Performance depends on memory access patterns, not raw compute.

# Triton kernel  -  fused matmul + add + ReLU with shared memory
@triton.jit
def fused_matmul_add_relu(A, B, C, D, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
    for k in range(0, K, BLOCK_K):
        a = tl.load(...)   # load tile of A into SRAM
        b = tl.load(...)   # load tile of B into SRAM
        acc += tl.dot(a, b)
    acc = acc + tl.load(c_ptrs)       # fuse Add
    acc = tl.maximum(acc, 0)          # fuse ReLU
    tl.store(d_ptrs, acc, mask=mask)  # single write to global memory

2D Data × Tensor Parallelism with MPI

Custom AllReduce + AllToAll

Implemented both collective communication primitives from scratch with mpi4py. Benchmarked against native MPI - custom is slower because MPI uses hardware-aware ring and tree algorithms with NVLink awareness.

Data Parallelism

Splits dataset uniformly across nodes via split_data(). Same model weights, different data shards. Requires gradient synchronization via AllReduce after each backward pass. Scales compute linearly with nodes.

Tensor Model Parallelism

Shards Transformer weights: fc_q/k/v split on output dimension, fc_o split on input dimension. Reduces memory per node but increases communication frequency - requires careful DP/MP balance.

Forward + Backward Communication

naive_collect_forward_input/output for activations. naive_collect_backward_output/x for gradients via Reduce-Scatter. Backward requires more communication than forward due to gradient aggregation.

Technologies Used

ComponentTool & Purpose
GPU KernelTriton - tile-based kernel, shared memory, register accumulation, operator fusion
Distributed CommMPI (mpi4py) - AllReduce, AllToAll, Reduce-Scatter from scratch + native
Parallelism2D: Data Parallel (DP) × Tensor Model Parallel (MP)
TuningGrid search over BLOCK_M/N/K, num_warps, num_stages
TensorsNumPy + PyTorch tensors - computation substrate

Results

≥1.25×
speedup over PyTorch baseline kernel
3
collective primitives implemented from scratch
2D
data + tensor parallel training pipeline
  • Operator fusion is the highest-impact GPU optimization - avoiding intermediate memory writes beats compute savings
  • Memory access patterns dominate GPU performance - SRAM tiling matters more than arithmetic throughput
  • Communication is the bottleneck in distributed training - MPI hardware-aware ring algorithms are very hard to match
  • Tensor parallelism reduces memory per node but requires careful sharding of fc_q/k/v vs fc_o dimensions

What I took away

  • GPU performance is 80% memory, 20% compute - reducing global memory accesses via SRAM tiling and operator fusion is the correct mental model.
  • Implementing AllReduce from scratch confirmed why MPI is faster: hardware-aware ring/tree algorithms exploit topology that naive send/recv ignores.
  • Tensor parallelism sharding rules are subtle - fc_q/k/v split on output dim and fc_o on input dim is not obvious until you trace gradient shapes carefully.
  • Backward communication needs more coordination than forward - Reduce-Scatter for gradients is more complex than AllGather for activations.