flash attension

# Example: N = d = d_v = 4, scale = 1/sqrt(d) = 1/2

Q =
[ [1, 0, 2, 0],
  [0, 1, 1, 0],
  [1, 1, 0, 1],
  [0, 2, 1, 1] ]

K =
[ [2, 1, 0, 0],
  [0, 1, 1, 0],
  [1, 0, 1, 1],
  [0, 0, 2, 1] ]

V =
[ [1, 0],
  [0, 1],
  [1, 1],
  [2, 1] ]


========================
1) NORMAL ATTENTION
========================

Scores: S = (Q K^T) / 2  →  S ∈ R^{4×4}
S =
[ [1.0, 1.0, 1.5, 2.0],
  [0.5, 1.0, 0.5, 1.0],
  [1.5, 0.5, 1.0, 0.5],
  [1.0, 1.5, 1.0, 1.5] ]

Row-wise softmax P = softmax_row(S)  (approximate decimals)
P ≈
[ [0.1570, 0.1570, 0.2588, 0.4269],
  [0.1887, 0.3113, 0.1887, 0.3113],
  [0.4271, 0.1571, 0.2589, 0.1571],
  [0.1888, 0.3112, 0.1888, 0.3112] ]

Output: O = P V  (approximate)
O ≈
[ [1.2696, 0.8427],
  [1.0000, 0.8113],
  [1.0000, 0.5731],
  [1.0000, 0.8112] ]


========================================
2) FLASHATTENTION TILING (same numbers)
========================================

Choose tile sizes: b_q = 2 (query rows per tile), b_k = 2 (key/value rows per tile)

Query tiles:
Q_{A1} = Q[0:2,:] =
[ [1, 0, 2, 0],
  [0, 1, 1, 0] ]

Q_{A2} = Q[2:4,:] =
[ [1, 1, 0, 1],
  [0, 2, 1, 1] ]

Key/Value tiles (paired rows):
K_{B1} = K[0:2,:] =
[ [2, 1, 0, 0],
  [0, 1, 1, 0] ]
V_{B1} = V[0:2,:] =
[ [1, 0],
  [0, 1] ]

K_{B2} = K[2:4,:] =
[ [1, 0, 1, 1],
  [0, 0, 2, 1] ]
V_{B2} = V[2:4,:] =
[ [1, 1],
  [2, 1] ]

Local score blocks S_{A,B} = (Q_A K_B^T) / 2:

S_{A1,B1} =
[ [1.0, 1.0],
  [0.5, 1.0] ]

S_{A1,B2} =
[ [1.5, 2.0],
  [0.5, 1.0] ]

S_{A2,B1} =
[ [1.5, 0.5],
  [1.0, 1.5] ]

S_{A2,B2} =
[ [1.0, 0.5],
  [1.0, 1.5] ]

(Placing these 2×2 tiles into the 4×4 grid reconstructs the full S above.
FlashAttention streams these tiles and uses an online softmax; no full 4×4 S or P is materialized.)

Comments

Popular posts from this blog

Eigen decomposition, Singular Value Decomposition (SVD) and Principal Component Analysis