← gum

2025·09·05 ·5 min ·note

A mildly cursed 3.5× triton tl.reduce optimization


TL;DR: For small, compile-time K, manually unrolling a 3D→2D bitwise-OR reduction can beat tl.reduce by ~3.5×.

I’ll take you on a small adventure of some weird Triton compiler behaviour.

We’ll look at a reduction used inside an attention variant that reduces a 3D tensor along the last axis with bitwise OR to produce a 2D tensor and the weird stuff I encountered while doing that.

Concretely given an integer tensor

$$ A \in \mathbb{Z}^{M \times N \times K} $$

we want a 2D output

$$ \begin{aligned} Y &\in \mathbb{Z}^{M \times N} \ Y[m,n] &= \operatorname{OR}_{k=0}^{K-1} A[m,n,k] \end{aligned} $$

Equivalently

$$ \begin{aligned} Y[m,n] &= A[m,n,0] \mid A[m,n,1] \mid \cdots \ &\quad \mid A[m,n,K-1]. \end{aligned} $$

In row-major memory, the linear index is

$$ \operatorname{idx}(m,n,k) = m \cdot (N \cdot K) + n \cdot K + k. $$

I guess if you’re reading this you already know what Triton is, but here’s a short recap for context.

Triton is a Python DSL for writing GPU kernels optimized for tile-oriented on-chip SRAM/register tiles. You write kernels in Python. Triton lowers them through internal IR stages, applies optimizations along the way and produces some nice CUDA binaries that run on your GPU.

The tl.reduce kernel

@triton.jit
def u64_or(a, b):
    return a | b

@triton.jit
def reduce_or_kernel(input_ptr, output_ptr,
                     M, N, K,
                     BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
    pid_m = tl.program_id(0); pid_n = tl.program_id(1)
    m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)      # [BLOCK_M]
    n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)      # [BLOCK_N]
    k = tl.arange(0, BLOCK_K)                        # [BLOCK_K], constexpr

    # Linearized pointer arithmetic for a [BLOCK_M, BLOCK_N, BLOCK_K] tile
    offs3d = m[:, None, None] * (N * K) + n[None, :, None] * K + k[None, None, :]

    mask3d = ((m[:, None, None] < M) &
              (n[None, :, None] < N) &
              (k[None, None, :] < K))

    tile = tl.load(input_ptr + offs3d, mask=mask3d, other=0)  # [BM, BN, BK]
    out  = tl.reduce(tile, axis=2, combine_fn=u64_or)         # [BM, BN]

    out_offs = m[:, None] * N + n[None, :]
    out_mask = (m[:, None] < M) & (n[None, :] < N)
    tl.store(output_ptr + out_offs, out, mask=out_mask)

We see ~2.5× slower performance after introducing tl.reduce. That’s odd because the last-axis size is known at compile time (tl.constexpr), so you’d expect the compiler to get a big hint to unroll the loop, but it apparently doesn’t.

We can verify this by allocating an empty tensor with the same shape to SRAM and just introducing tl.reduce also caused this kinda slowdown.

So let’s generate unrolled code manually and compare against tl.reduce. I wrote a tiny codegen helper to do this.

code = "res = load_k(0)"
for i in range(1, K):
    code += f"\nres = res | load_k({i})"

The kernel sketch then becomes:

# Base pointer for this (m, n) tile along K
input_base = m[:, None] * (N * K) + n[None, :] * K
mask       = (m[:, None] < M) & (n[None, :] < N)

# Stream K values from global, OR into a single accumulator
res = tl.load(input_ptr + input_base + 0, mask=mask, other=0)
for i in 1..K-1:
    val = tl.load(input_ptr + input_base + i, mask=mask, other=0)
    res = res | val

# Store 2D tile
out_offs = m[:, None] * N + n[None, :]
tl.store(output_ptr + out_offs, res, mask=mask)

For K=16 e.g. the generated body looks like:

res = load_k(0)
res = res | load_k(1)
res = res | load_k(2)
res = res | load_k(3)
res = res | load_k(4)
res = res | load_k(5)
res = res | load_k(6)
res = res | load_k(7)
res = res | load_k(8)
res = res | load_k(9)
res = res | load_k(10)
res = res | load_k(11)
res = res | load_k(12)
res = res | load_k(13)
res = res | load_k(14)
res = res | load_k(15)
import triton
import triton.language as tl


@triton.jit
def manual_unroll_k16_kernel(
    input_ptr, output_ptr,
    M, N, K,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    m2 = m[:, None]
    n2 = n[None, :]
    mask = (m2 < M) & (n2 < N)

    input_base = m2 * N * K + n2 * K

    # Manual unroll (global loads)
    res = tl.load(input_ptr + input_base + 0, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 1, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 2, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 3, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 4, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 5, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 6, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 7, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 8, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 9, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 10, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 11, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 12, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 13, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 14, mask=mask, other=0)
    res = res | tl.load(input_ptr + input_base + 15, mask=mask, other=0)

    out_offs = m2 * N + n2
    tl.store(output_ptr + out_offs, res, mask=mask)

It’s quite ugly but it works.

Benchmark: tl.reduce vs manual unroll

   M |    N |  K | tl.reduce (ms) | manual (ms) | speedup
-----+------+----+-----------------+-------------+---------
  64 |  128 |  4 |          0.014 |       0.012 |   1.19x
 128 |  256 |  8 |          0.056 |       0.016 |   3.58x
 256 |  512 | 16 |          0.088 |       0.031 |   2.84x
 512 | 1024 |  4 |          0.013 |       0.013 |   1.01x
1024 | 2048 |  8 |          1.038 |       0.414 |   2.50x
2048 | 4096 | 16 |          8.438 |       3.275 |   2.58x

For many shapes this gives a significant speedup. You can test whether you benefit on a similar problem with the code in this gist. Use this only as a quick fix if you don’t want to touch the Triton compiler itself. It works pretty well though.

Versioning: Triton 3.4.0.