AI News HubLIVE
In-site rewrite2 min read

Modern GPU Programming for MLSys

This article introduces a comprehensive guide to modern GPU programming for machine learning systems, covering GPU hardware architecture, programming models, and step-by-step optimization of key kernels like GEMM and FlashAttention using the TIRx DSL. It is based on a CMU course and targets the Blackwell GPU generation.

SourceHacker News AIAuthor: sonabinu

Modern GPU Programming For MLSys

Contents

Modern GPU Programming For MLSys#

Machine learning systems sit at the heart of modern AI workloads. In these systems, performance often comes down to the quality of a small number of GPU kernels. Attention kernels, LLM prefill and decode kernels, low-precision block-scaled GEMMs, fused MoE layers, and other large fused kernels all directly shape end-to-end speed in both training and serving.

To make these kernels fast, however, we need more than a list of optimization tricks. Modern GPUs are no longer simple variations of the same old design. Recent architectures introduce richer memory spaces, new access patterns, and increasingly specialized execution units. To program them well, we need both a clear mental model of the hardware and a practical understanding of how high-performance kernels are built. This book is about developing both.

The book follows a simple progression: first understand the GPU hardware, then learn the programming model we will use, and finally build state-of-the-art kernels step by step. Our main target is the Blackwell generation, and our main running examples are fast matrix multiplication (GEMM) and FlashAttention. Along the way, we will also study the core ingredients behind GPU optimization: data layout, asynchronous data movement, and asynchronous coordination.

The material grows out of the Machine Learning Systems course series at Carnegie Mellon University. To make the ideas easier to study and easier to run, this book uses the TIRx Python DSL to build real GPU kernel examples step by step. TIRx stays close to the hardware, which lets us reason about low-level control while still learning through runnable code.

How This Book Is Organized#

Part I, Understanding the GPU. This part introduces the overall organization of the GPU, general recipes for writing fast kernels, and key concepts such as data layout, asynchronous memory operations, and coordination. It builds the hardware intuition that the rest of the book relies on.

Part II, TIRx Overview. This part introduces the key elements of TIRx, which serve as the foundation for the code examples throughout the book.

Part III, GEMM: Tiled to SOTA. A complete guide to optimizing a tiled GEMM, built up through TMA pipelining, persistent scheduling, warp specialization, and 2-CTA clusters.

Part IV, Flash Attention 4. A complete attention kernel built from the Part III techniques: two MMAs with softmax between them, online-softmax rescaling, causal masking, and GQA.

Reference. TIRx language reference and compiler internals.

Part I, Understanding the GPU

GPU Execution Model

What Makes a Kernel Fast

Data Layout and Its Notation

Tensor Core Operand Layouts Across GPU Generations

Async Data Movement: TMA

Tensor Cores: tcgen05

Special Memory: TMEM

Async Coordination: mbarriers

Advanced: Cluster Launch Control

Part II, TIRx Overview

Introduction to TIRx

TIRx Layout API

Part III, GEMM: Tiled to SOTA

Building a Tiled GEMM

GEMM

Optimization Path

Step 1: Sequential Single-Tile GEMM

Step 2: K-Loop Accumulation

Step 3: Spatial Tiling (Multi-CTA)

Exercises

Pipelining GEMM with TMA

Step 4: TMA Async Load

Step 5: Software Pipeline (PIPE_DEPTH=2)

Step 6: Persistent Kernel + Tile Scheduler

Exercises

Scaling GEMM with Warp Specialization and Clusters

Step 7: Warp Specialization + Pipeline

Step 8: 2-CTA Cluster

Step 9: Multi-Consumer Warp Specialization

End-to-End Result

Exercises

Part IV, Flash Attention 4

Flash Attention 4

Algorithm Shape

Tile-Primitive Graph

Warp Roles and Scopes

Reading the Fragments

The Two MMA Phases

TMEM Layout and Reuse

How Barriers Connect the Roles

Pipelining Structure

Rescaling and Writeback

Causal Masking

GQA Support

Tile Scheduling

Compile and Verify

Differences from GEMM

Exercises

Reference

Reference

Debugging Warp-Specialized Kernels

Compiler Internals

TIRx Language Reference