benchmarks.SaturnNPU.kernel_library.smolvla_fused_attention
Source: benchmarks/SaturnNPU/kernel_library/smolvla_fused_attention.py
benchmarks.SaturnNPU.kernel_library.smolvla_fused_attention
Fused attention (SDPA) kernel — matches MLIR variant_0_12_1024_64_bf16.
MLIR source: SaturnNPU/kernels/iree_linalg_ext.attention/variant_0_12_1024_64_bf16.mlir
Q: [12, 1024, 64] bf16 indexing (d0, d1, d3) → [batch, q_seq, head_dim] K: [12, 1024, 64] bf16 indexing (d0, d4, d3) → [batch, k_seq, head_dim] V: [12, 64, 1024] bf16 indexing (d0, d2, d4) → [batch, head_dim, k_seq] scale: scalar bf16 output: [12, 1024, 64] bf16
This demo tiles to one Q-block (32 rows) and two K tiles (k_seq=64 total, 32 per tile), with head_dim=64. Production wraps this body in loops over all 1024 q-rows and 12 heads.
Flash attention
Softmax is global over the full key sequence — naive per-tile softmax gives wrong results. We use online softmax (flash attention), maintaining running stats across K tiles:
m = -inf, l = 0, O = 0
for each K tile:
S = Q @ K^T * scale
m' = max(m, rowmax(S))
α = exp(m − m')
O = α * O + exp(S − m') @ V
l = α * l + rowsum(exp(S − m'))
m = m'
output = O / l
DRAM layouts (all bf16, column-blocked so each vload gets a contiguous [32,16] chunk):
Q_DRAM: [32, 64] bf16 stored as [128, 16] (4 × [32,16] blocks, cols 0:16/16:32/32:48/48:64) KT_DRAM: K tile [32,64] pre-transposed → K^T[64,32] stored as [128,16] (top/bot × left/right) VT_DRAM: V_mlir tile [64,32] pre-transposed → V_std[32,64] stored as [128,16] OUT_DRAM: [32, 64] bf16 stored as [128, 16]
Q @ K^T (head_dim=64, two MXU passes accumulating over head_dim=32 per pass): Q_lo [32,32] fp8 @ K^T_top [32,32] fp8 → acc[0] (fresh) Q_hi [32,32] fp8 @ K^T_bot [32,32] fp8 → acc[0] (accumulate) vmatpop.bf16 writes acc[0][32,32] into registers v16 (left [32,16]) and v17 (right [32,16])
exp_s @ V (output is [32,64], two independent MXU passes): exp_s [32,32] fp8 @ V_left [32,32] fp8 → acc[0] (O left half) exp_s [32,32] fp8 @ V_right [32,32] fp8 → acc[1] (O right half)
bf16 → fp8 quantization uses the acc roundtrip (vmatpush.acc.bf16 + vmatpop.fp8). vmatpush.acc.bf16(vd=slot, vs1=v) reads v and v+1 as a [32,32] BF16 tile → acc[slot] vmatpop.fp8(vd=dst, vs1=slot) converts acc[slot][32,32] → fp8 → dst register [32,32]
-inf initialization: m (running row-max) is initialized to -100.0 via vli.all imm=-100. This is encoded as an integer immediate (vli.all calls torch.full(shape, -100, dtype=bf16) = -100.0_bf16). True bf16 -inf (0xFF80) cannot be encoded directly as a vli.all integer immediate. -100.0_bf16 is sufficient: all real attention logits after scaling (SCALE≈0.125) will be much larger than -100, so the first tile always overwrites the initial m value.
SmolVLAFusedAttentionProgram
Bases: Program
Flash attention: output[b,q,h] = softmax(Q[b,q,:] @ K[b,:,:]^T * scale) @ V[b,:,h]
One Q-block (32 rows), two K tiles (k_seq=64), head_dim=64. Inputs are bf16 (matching MLIR). Quantized to fp8 on-chip via acc roundtrip.
MRF register map (simulator: each fp8 reg = [32,32], bf16 reg = [32,16]) ──────────────── Persistent (survive across K tiles): v0 Q_lo fp8 [32,32] Q[:, 0:32] after bf16→fp8 roundtrip v1 Q_hi fp8 [32,32] Q[:, 32:64] after bf16→fp8 roundtrip v2 m_prev bf16 [32,16] running row-max; init = -100.0 (see -inf note in module docstring) v3 l_prev bf16 [32,16] running row-sum, init = 0 v4 O_col0 bf16 [32,16] O[:, 0:16] v5 O_col1 bf16 [32,16] O[:, 16:32] v6 O_col2 bf16 [32,16] O[:, 32:48] v7 O_col3 bf16 [32,16] O[:, 48:64] v8 scale bf16 [32,16]
Per K-tile (after load+quantize): v9 KT_top fp8 [32,32] K^T[0:32, :] (vmatpush.acc.bf16 reads v9+v10 as pair) v11 KT_bot fp8 [32,32] K^T[32:64, :] (vmatpush.acc.bf16 reads v11+v12 as pair) v12 VT_left fp8 [32,32] V_std[:, 0:32] (vmatpush.acc.bf16 reads v12+v13 as pair) v14 VT_right fp8 [32,32] V_std[:, 32:64] (vmatpush.acc.bf16 reads v14+v15 as pair)
Temporaries (reused each iteration): v16,v17 scores sl,sr bf16 [32,16] (written as pair by vmatpop.bf16 acc[0]) v18,v19 scaled sl,sr bf16 [32,16] v20 tile_max / exp_diff bf16 [32,16] v21 m_new bf16 [32,16] v22,v23 exp_s left,right bf16 [32,16] v24 exp_s_fp8 fp8 [32,32] (acc[1] roundtrip; vmatpush reads v22+v23 as pair) v25,v26 vc_left col0,col1 bf16 [32,16] (written as pair by vmatpop.bf16 acc[0]) v27,v28 vc_right col0,col1 bf16 [32,16] (written as pair by vmatpop.bf16 acc[1])
Scalar register map ─────────────────── VMEM: x1=Q x2=KT0 x3=KT1 x4=VT0 x5=VT1 x6=SCALE x7=OUT DRAM: x9=KT0 x10=KT1 x11=VT0 x12=VT1 x13=SCALE x14=OUT Sizes: x15=4096 (tile) x16=1024 (scale)
fused_attention_reference(Q, K, V_mlir, scale)
Exact SDPA matching the MLIR affine maps. Returns [q_rows, head_dim] float.