<- Back to all posts

FlashMHF: Turning the FFN into “Multi-Head + Flash” — Notes on Flash Multi-Head FFN

2026-02-23

FFN + SRAM fused kernel: improves perplexity/downstream accuracy while cutting peak memory by 3–5x and reaching up to 1.08x inference speedup.


TL;DR (4 things to remember)

  • Motivation: An FFN (especially SwiGLU) is structurally very close to “single-head attention”, except softmax is replaced by element-wise nonlinearity/gating — so FFNs should also benefit from the multi-head principle.
  • Why naïve MH-FFN breaks: (1) activation memory scales linearly with head count (H copies of large intermediates); (2) as models scale, the ratio d_ff / d_h explodes, causing a scaling imbalance and degraded performance.
  • What FlashMHF changes: use parallel FFN sub-networks + gating to keep a balanced effective expansion ratio per head; and use an I/O-aware fused kernel (SRAMFFN) to compute SwiGLU online without materializing huge intermediates in HBM.
  • Result: for 128M–1.3B pretraining, compared to SwiGLU FFN: better perplexity/downstream accuracy, 3–5x lower peak memory, and up to 1.08x faster inference (avg ~1.05x).

FlashMHF Slide 01


1) Background: why touch the FFN?

Inside a Transformer block, attention gets most of the spotlight — but the FFN/MLP is a large chunk of parameters and compute. FlashMHF targets a pragmatic goal: improve FFN expressivity like multi-head attention does, without blowing up memory.

FlashMHF Slide 02


2) “FFN ≈ Attention” structural symmetry (the key viewpoint)

The paper rewrites SwiGLU as a kind of “generalized attention”: replace softmax(QK^T) with an element-wise nonlinearity ϕ, then multiply by V.

In a more direct form (ignoring batch/head dims):

A = SiLU(Q K_g^T) ⊙ (Q K_u^T)
O = A V

This explains the direction: if multi-head is a strong inductive bias for attention, maybe FFN can also be decomposed into heads to gain expressivity.

FlashMHF Slide 03


3) Two fatal issues of naïve MH-FFN

3.1 Memory pressure: intermediate activations × H

A naïve multi-head FFN produces one large intermediate activation per head (roughly L × d_ff). More heads → closer to “multiply peak activation memory by H”.

Figure 2: Memory limitation of MH-FFN

3.2 Scaling imbalance: d_ff / d_h blows up with model size

MH-FFN often inherits the MHA habit of keeping d_h fixed (e.g. 128). As the model scales, d_ff grows, so d_ff / d_h becomes too large and drifts away from the empirically good range.

The paper gives a simple sanity check (naïve MH-FFN with d_h = 128):

  • 128M: d_ff / d_h = 2048 / 128 = 16
  • 370M: d_ff / d_h = 2688 / 128 = 21
  • 1.3B: d_ff / d_h = 5760 / 128 = 45

4) FlashMHF: fix scaling + fix memory

4.1 Parallel FFN sub-networks (dense-MoE-like, but every token aggregates)

High-level idea:

  • Each head no longer has a single monolithic FFN path; instead it has E parallel sub-networks.
  • For each token, compute gating weights within the head and aggregate the E sub-network outputs.

This keeps each sub-network at a reasonable internal width/expansion ratio, which directly combats the d_ff / d_h imbalance. Since aggregation is dense, there is no routing / all-to-all overhead.

FlashMHF Slide 04

4.2 I/O-aware flash algorithm (SRAMFFN: blockwise online compute)

The goal is simple: avoid writing

A = SiLU(QK^T) ⊙ (QU^T)   # huge

to HBM and reading it back.

FlashMHF’s fused kernel blocks K/U/V along d_ff, performs blockwise compute on SRAM, and accumulates directly into O (“finish a block → add to O → discard the block activation”). Peak memory drops from “store giant intermediates” to “store an output accumulator”.

Figure 3: Parallel sub-networks & SRAMFFN


5) Results (only the key numbers)

5.1 Language modeling validation loss (PG19 val, lower is better)

From Table 1 (370M / 1.3B):

  • 370M Baseline (SwiGLU): 3.030
  • 370M FlashMHF: d_h=64 3.046, d_h=128 3.014, d_h=256 3.029
  • 1.3B Baseline (SwiGLU): 2.843
  • 1.3B FlashMHF: d_h=64 2.849, d_h=128 2.793, d_h=256 2.799

Empirically, d_h=128 tends to be a sweet spot: too small bottlenecks each head; too large reduces the number of heads and loses subspace diversity.

5.2 Efficiency: 3–5x memory, up to 1.08x speed

On Hopper benchmarks (Section 4.3):

  • Peak memory: 3–5x lower than a standard SwiGLU FFN
  • Inference latency: up to 1.08x speedup (avg ~1.05x)

Figure 8: Memory and latency comparison