Axe: A Simple Unified Layout Abstraction for Machine Learning Compilers

Axe: A Simple Unified Layout Abstraction for Machine Learning Compilers
Notice: This research summary and analysis were automatically generated using AI technology. For absolute accuracy, please refer to the [Original Paper Viewer] below or the Original ArXiv Source.

Scaling modern deep learning workloads demands coordinated placement of data and compute across device meshes, memory hierarchies, and heterogeneous accelerators. We present Axe Layout, a hardware-aware abstraction that maps logical tensor coordinates to a multi-axis physical space via named axes. Axe unifies tiling, sharding, replication, and offsets across inter-device distribution and on-device layouts, enabling collective primitives to be expressed consistently from device meshes to threads. Building on Axe, we design a multi-granularity, distribution-aware DSL and compiler that composes thread-local control with collective operators in a single kernel. Experiments show that our unified approach can bring performance close to hand-tuned kernels on across latest GPU devices and multi-device environments and accelerator backends.


💡 Research Summary

The paper introduces Axe Layout, a hardware‑aware abstraction that unifies the representation of data placement and compute mapping across the many hierarchical levels encountered in modern deep‑learning workloads. Traditional deep‑learning frameworks and compilers treat distributed sharding, on‑device tiling, thread binding, and replication as separate concerns, leading to fragmented APIs and sub‑optimal optimization pipelines. Axe addresses this fragmentation by extending the classic shape‑stride model with named axes that correspond to physical resources such as GPU lanes, warps, registers, device mesh dimensions, and accelerator memory banks.

An Axe layout is defined as a triple (D, R, O):

  • D (Shard) – an ordered list of iterators, each specified by (extent, stride, axis). This list partitions a logical tensor index across multiple hardware axes, effectively describing multi‑dimensional sharding in a single formalism.
  • R (Replica) – a multiset of replication iterators that enumerate offsets independent of the logical index, enabling broadcasting or data replication across devices or thread groups.
  • O (Offset) – a fixed coordinate vector that adds a constant offset to every result, useful for reserving specific memory banks or register slots.

The mapping is a set‑valued function:
  L(x) = { D(x) + r + O | r ∈ R },
where D(x) is the base coordinate derived from the shard iterators and r enumerates all replica offsets. This formulation naturally captures a wide range of scenarios: intra‑warp register tiling for tensor‑core instructions, sharding across a 2 × 2 GPU mesh, and placement into multidimensional scratchpad memories of AI accelerators such as AWS Trainium or Google TPU.

Building on this abstraction, the authors design a multi‑granularity, distribution‑aware DSL. Tensor declarations carry an Axe layout metadata, and a small set of layout operators—canonicalize, group, tile, slice—allow programmers to transform layouts declaratively. The compiler parses these operators to perform range analysis, axis matching, and schedule selection. Crucially, because D and R are treated uniformly, the compiler can fuse thread‑local control flow (e.g., loop nests, vectorized loads) with collective operations (e.g., all‑reduce, scatter) inside a single kernel. This unifies the low‑level, loop‑binding style of CuTe with the block‑level abstraction of Triton, giving developers the productivity of high‑level DSLs while retaining the performance potential of hand‑tuned kernels.

The implementation targets three back‑ends: NVIDIA’s Blackwell and Hopper GPUs, and AWS Trainium‑1. Benchmarks cover mixture‑of‑experts (MoE) layers, multi‑GPU GEMM + Reduce‑Scatter, and multi‑head attention (MHA) kernels. Compared with strong baselines—FlashInfer, SGLang, cuBLAS + NCCL, Triton‑Distributed, and vendor‑provided Trainium libraries—Axe‑generated kernels achieve speedups of 1.23 × – 1.44 ×, often matching or surpassing hand‑optimized implementations. The authors attribute these gains to reduced data movement, better exploitation of hardware‑specific tiling (e.g., tensor‑core lane layouts), and the ability to automatically replicate or offset data where needed.

The paper’s contributions are fourfold:

  1. Axe layout model that formally encodes sharding, replication, and offsets across inter‑ and intra‑device boundaries.
  2. A suite of layout operators that enable the compiler to reason about and transform layouts during lowering.
  3. A multi‑granularity DSL and compiler pipeline that seamlessly integrates thread‑level and collective primitives.
  4. Empirical validation showing near‑hand‑tuned performance across heterogeneous hardware.

Limitations are acknowledged. The current prototype is tuned for NVIDIA GPUs and Trainium; extending to other vendors (e.g., AMD GPUs, Habana Gaudi) will require defining additional axes and possibly new schedule heuristics. Complex layout transformations can increase compilation time, and the heuristic‑driven scheduler may not always find the globally optimal schedule for every workload. Future work includes dynamic layout re‑adjustment, multi‑stage pipeline support, and cross‑vendor axis standardization to broaden applicability.

In summary, Axe presents a compelling unified abstraction that bridges the gap between expressive, high‑level programming models and low‑level hardware‑specific optimizations. By treating data placement and compute mapping as a single, mathematically grounded construct, it simplifies compiler design, reduces boilerplate, and delivers performance that rivals expert‑crafted kernels across a spectrum of modern AI hardware.


Comments & Academic Discussion

Loading comments...

Leave a Comment