flash attension
# Example: N = d = d_v = 4, scale = 1/sqrt(d) = 1/2 Q = [ [1, 0, 2, 0], [0, 1, 1, 0], [1, 1, 0, 1], [0, 2, 1, 1] ] K = [ [2, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 1], [0, 0, 2, 1] ] V = [ [1, 0], [0, 1], [1, 1], [2, 1] ] ======================== 1) NORMAL ATTENTION ======================== Scores: S = (Q K^T) / 2 → S ∈ R^{4×4} S = [ [1.0, 1.0, 1.5, 2.0], [0.5, 1.0, 0.5, 1.0], [1.5, 0.5, 1.0, 0.5], [1.0, 1.5, 1.0, 1.5] ] Row-wise softmax P = softmax_row(S) (approximate decimals) P ≈ [ [0.1570, 0.1570, 0.2588, 0.4269], [0.1887, 0.3113, 0.1887, 0.3113], [0.4271, 0.1571, 0.2589, 0.1571], [0.1888, 0.3112, 0.1888, 0.3112] ] Output: O = P V (approximate) O ≈ [ [1.2696, 0.8427], [1.0000, 0.8113], [1.0000, 0.5731], [1.0000, 0.8112] ] ======================================== 2) FLASHATTENTION TILING (same numbers) =======================================...