flash attension
# Example: N = d = d_v = 4, scale = 1/sqrt(d) = 1/2
Q =
[ [1, 0, 2, 0],
[0, 1, 1, 0],
[1, 1, 0, 1],
[0, 2, 1, 1] ]
K =
[ [2, 1, 0, 0],
[0, 1, 1, 0],
[1, 0, 1, 1],
[0, 0, 2, 1] ]
V =
[ [1, 0],
[0, 1],
[1, 1],
[2, 1] ]
========================
1) NORMAL ATTENTION
========================
Scores: S = (Q K^T) / 2 → S ∈ R^{4×4}
S =
[ [1.0, 1.0, 1.5, 2.0],
[0.5, 1.0, 0.5, 1.0],
[1.5, 0.5, 1.0, 0.5],
[1.0, 1.5, 1.0, 1.5] ]
Row-wise softmax P = softmax_row(S) (approximate decimals)
P ≈
[ [0.1570, 0.1570, 0.2588, 0.4269],
[0.1887, 0.3113, 0.1887, 0.3113],
[0.4271, 0.1571, 0.2589, 0.1571],
[0.1888, 0.3112, 0.1888, 0.3112] ]
Output: O = P V (approximate)
O ≈
[ [1.2696, 0.8427],
[1.0000, 0.8113],
[1.0000, 0.5731],
[1.0000, 0.8112] ]
========================================
2) FLASHATTENTION TILING (same numbers)
========================================
Choose tile sizes: b_q = 2 (query rows per tile), b_k = 2 (key/value rows per tile)
Query tiles:
Q_{A1} = Q[0:2,:] =
[ [1, 0, 2, 0],
[0, 1, 1, 0] ]
Q_{A2} = Q[2:4,:] =
[ [1, 1, 0, 1],
[0, 2, 1, 1] ]
Key/Value tiles (paired rows):
K_{B1} = K[0:2,:] =
[ [2, 1, 0, 0],
[0, 1, 1, 0] ]
V_{B1} = V[0:2,:] =
[ [1, 0],
[0, 1] ]
K_{B2} = K[2:4,:] =
[ [1, 0, 1, 1],
[0, 0, 2, 1] ]
V_{B2} = V[2:4,:] =
[ [1, 1],
[2, 1] ]
Local score blocks S_{A,B} = (Q_A K_B^T) / 2:
S_{A1,B1} =
[ [1.0, 1.0],
[0.5, 1.0] ]
S_{A1,B2} =
[ [1.5, 2.0],
[0.5, 1.0] ]
S_{A2,B1} =
[ [1.5, 0.5],
[1.0, 1.5] ]
S_{A2,B2} =
[ [1.0, 0.5],
[1.0, 1.5] ]
(Placing these 2×2 tiles into the 4×4 grid reconstructs the full S above.
FlashAttention streams these tiles and uses an online softmax; no full 4×4 S or P is materialized.)
Q =
[ [1, 0, 2, 0],
[0, 1, 1, 0],
[1, 1, 0, 1],
[0, 2, 1, 1] ]
K =
[ [2, 1, 0, 0],
[0, 1, 1, 0],
[1, 0, 1, 1],
[0, 0, 2, 1] ]
V =
[ [1, 0],
[0, 1],
[1, 1],
[2, 1] ]
========================
1) NORMAL ATTENTION
========================
Scores: S = (Q K^T) / 2 → S ∈ R^{4×4}
S =
[ [1.0, 1.0, 1.5, 2.0],
[0.5, 1.0, 0.5, 1.0],
[1.5, 0.5, 1.0, 0.5],
[1.0, 1.5, 1.0, 1.5] ]
Row-wise softmax P = softmax_row(S) (approximate decimals)
P ≈
[ [0.1570, 0.1570, 0.2588, 0.4269],
[0.1887, 0.3113, 0.1887, 0.3113],
[0.4271, 0.1571, 0.2589, 0.1571],
[0.1888, 0.3112, 0.1888, 0.3112] ]
Output: O = P V (approximate)
O ≈
[ [1.2696, 0.8427],
[1.0000, 0.8113],
[1.0000, 0.5731],
[1.0000, 0.8112] ]
========================================
2) FLASHATTENTION TILING (same numbers)
========================================
Choose tile sizes: b_q = 2 (query rows per tile), b_k = 2 (key/value rows per tile)
Query tiles:
Q_{A1} = Q[0:2,:] =
[ [1, 0, 2, 0],
[0, 1, 1, 0] ]
Q_{A2} = Q[2:4,:] =
[ [1, 1, 0, 1],
[0, 2, 1, 1] ]
Key/Value tiles (paired rows):
K_{B1} = K[0:2,:] =
[ [2, 1, 0, 0],
[0, 1, 1, 0] ]
V_{B1} = V[0:2,:] =
[ [1, 0],
[0, 1] ]
K_{B2} = K[2:4,:] =
[ [1, 0, 1, 1],
[0, 0, 2, 1] ]
V_{B2} = V[2:4,:] =
[ [1, 1],
[2, 1] ]
Local score blocks S_{A,B} = (Q_A K_B^T) / 2:
S_{A1,B1} =
[ [1.0, 1.0],
[0.5, 1.0] ]
S_{A1,B2} =
[ [1.5, 2.0],
[0.5, 1.0] ]
S_{A2,B1} =
[ [1.5, 0.5],
[1.0, 1.5] ]
S_{A2,B2} =
[ [1.0, 0.5],
[1.0, 1.5] ]
(Placing these 2×2 tiles into the 4×4 grid reconstructs the full S above.
FlashAttention streams these tiles and uses an online softmax; no full 4×4 S or P is materialized.)
Comments
Post a Comment