Skip to content

Commit 4526e64

Browse files
June24-WuJune
andauthored
[New Question] Attention with Linear Biases (Medium) (#78)
Co-authored-by: June <[email protected]>
1 parent 596424f commit 4526e64

File tree

6 files changed

+262
-0
lines changed

6 files changed

+262
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
<p>
2+
Implement Attention with Linear Biases (ALiBi), following the method described in
3+
<a href="https://arxiv.org/pdf/2108.12409" target="_blank">
4+
"Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
5+
</a>, for a given set of matrices.
6+
Given the query matrix <code>Q</code> of size <code>M×d</code>, key matrix <code>K</code> of size <code>N×d</code>, and value matrix
7+
<code>V</code> of size <code>N×d</code>, your program should compute the output matrix using the formula:
8+
</p>
9+
10+
<p>
11+
$$
12+
\text{Attention}_{ALiBi}(Q, K, V) = \text{softmax}\Bigl( \frac{QK^T}{\sqrt{d}} + \alpha \cdot \Delta \Bigr)V
13+
$$
14+
</p>
15+
16+
<p>
17+
where &alpha; is a slope controlling the linear bias and <code>&Delta; = i - j</code> represents the relative position between query <code>i</code> and key <code>j</code>.
18+
The softmax function is applied row-wise. <code>Q</code>, <code>K</code>, <code>V</code>, <code>output</code>, and <code>&alpha;</code> are all of data type <code>float32</code>;
19+
<code>M</code>, <code>N</code>, <code>d</code> are of data type <code>int32</code>.
20+
</p>
21+
22+
<h2>Implementation Requirements</h2>
23+
<ul>
24+
<li>Use only native features (external libraries are not permitted)</li>
25+
<li>The
26+
<code>solve</code> function signature must remain unchanged
27+
</li>
28+
<li>The final result must be stored in the output matrix
29+
<code>output</code>
30+
</li>
31+
</ul>
32+
<h2>Example 1:</h2>
33+
<p>
34+
<strong>Input:</strong><br>
35+
<code>Q</code> (2×4):
36+
\[
37+
\begin{bmatrix}
38+
1.0 & 0.0 & 0.0 & 0.0 \\
39+
0.0 & 1.0 & 0.0 & 0.0
40+
\end{bmatrix}
41+
\]
42+
<code>K</code> (3×4):
43+
\[
44+
\begin{bmatrix}
45+
1.0 & 0.0 & 0.0 & 0.0 \\
46+
0.0 & 1.0 & 0.0 & 0.0 \\
47+
0.0 & 0.0 & 1.0 & 0.0
48+
\end{bmatrix}
49+
\]
50+
<code>V</code> (3×4):
51+
\[
52+
\begin{bmatrix}
53+
1.0 & 2.0 & 3.0 & 4.0 \\
54+
5.0 & 6.0 & 7.0 & 8.0 \\
55+
9.0 & 10.0 & 11.0 & 12.0
56+
\end{bmatrix}
57+
\]
58+
\(\alpha = 0.5\)
59+
</p>
60+
61+
<p>
62+
<strong>Output:</strong><br>
63+
<code>output</code> (2×4):
64+
\[
65+
\begin{bmatrix}
66+
3.05 & 4.05 & 6.05 & 7.05 \\
67+
3.93 & 4.93 & 5.93 & 6.93
68+
\end{bmatrix}
69+
\]
70+
</p>
71+
72+
<h2>Example 2:</h2>
73+
<p>
74+
<strong>Input:</strong><br>
75+
<code>Q</code> (1×2):
76+
\[
77+
\begin{bmatrix}
78+
1.0 & 2.0
79+
\end{bmatrix}
80+
\]
81+
<code>K</code> (2×2):
82+
\[
83+
\begin{bmatrix}
84+
1.0 & 0.0 \\
85+
0.0 & 1.0
86+
\end{bmatrix}
87+
\]
88+
<code>V</code> (2×2):
89+
\[
90+
\begin{bmatrix}
91+
3.0 & 4.0 \\
92+
5.0 & 6.0
93+
\end{bmatrix}
94+
\]
95+
<code>α</code> = 0.8
96+
</p>
97+
98+
<p>
99+
<strong>Output:</strong><br>
100+
<code>output</code> (1×2):
101+
\[
102+
\begin{bmatrix}
103+
3.95 & 4.95
104+
\end{bmatrix}
105+
\]
106+
</p>
107+
108+
<h2>Constraints</h2>
109+
<ul>
110+
<li>Matrix <code>Q</code> is of size <code>M×d</code> and matrices <code>K</code> and <code>V</code> are of size
111+
<code>N×d</code></li>
112+
<li>1 &le; <code>M</code>, <code>N</code> &le; 2048</li>
113+
<li>1 &le; <code>d</code> &le; 1024</li>
114+
<li>-1.0 &le; <code>&alpha;</code> &le; 1.0</li>
115+
</ul>
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import ctypes
2+
from typing import Any, List, Dict
3+
import torch
4+
from core.challenge_base import ChallengeBase
5+
6+
class Challenge(ChallengeBase):
7+
def __init__(self):
8+
super().__init__(
9+
name="Attention with Linear Biases",
10+
atol=1e-04,
11+
rtol=1e-04,
12+
num_gpus=1,
13+
access_tier="free"
14+
)
15+
16+
def reference_impl(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, output: torch.Tensor, M: int, N: int, d: int, alpha: float):
17+
assert Q.shape == (M,d)
18+
assert K.shape == (N,d)
19+
assert V.shape == (N,d)
20+
assert output.shape == (M,d)
21+
22+
scale = d ** 0.5
23+
attn = torch.matmul(Q, K.t()) / scale
24+
25+
pos_bias = alpha * (torch.arange(M, device=Q.device).unsqueeze(1) - torch.arange(N, device=K.device).unsqueeze(0))
26+
attn = attn + pos_bias
27+
28+
attn = torch.softmax(attn, dim=1) # M , N
29+
torch.matmul(attn, V, out=output)
30+
31+
def get_solve_signature(self) -> Dict[str, Any]:
32+
return {
33+
"Q": ctypes.POINTER(ctypes.c_float),
34+
"K": ctypes.POINTER(ctypes.c_float),
35+
"V": ctypes.POINTER(ctypes.c_float),
36+
"output": ctypes.POINTER(ctypes.c_float),
37+
"M": ctypes.c_int,
38+
"N": ctypes.c_int,
39+
"d": ctypes.c_int,
40+
"alpha": ctypes.c_float,
41+
}
42+
43+
def generate_example_test(self) -> Dict[str, Any]:
44+
dtype = torch.float32
45+
Q = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device="cuda", dtype=dtype)
46+
K = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], device="cuda", dtype=dtype)
47+
V = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], device="cuda", dtype=dtype)
48+
output = torch.empty(2, 4, device="cuda", dtype=dtype)
49+
return {"Q": Q, "K": K, "V": V, "output": output, "M": 2, "N": 3, "d": 4, "alpha": 0.5}
50+
51+
def generate_functional_test(self) -> List[Dict[str, Any]]:
52+
dtype = torch.float32
53+
tests = []
54+
55+
# basic_example 1
56+
tests.append({
57+
"Q": torch.tensor([[1.0, 2.0]], device="cuda", dtype=dtype),
58+
"K": torch.tensor([[1.0, 0.0],[0.0, 1.0]], device="cuda", dtype=dtype),
59+
"V": torch.tensor([[3.0, 4.0], [5.0, 6.0]], device="cuda", dtype=dtype),
60+
"output": torch.empty(1, 2, device="cuda", dtype=dtype),
61+
"M": 1, "N": 2, "d": 2, "alpha": 0.8
62+
})
63+
64+
# basic_example 2
65+
tests.append({
66+
"Q": torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device="cuda", dtype=dtype),
67+
"K": torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], device="cuda", dtype=dtype),
68+
"V": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], device="cuda", dtype=dtype),
69+
"output": torch.empty(2, 4, device="cuda", dtype=dtype),
70+
"M": 2, "N": 3, "d": 4, "alpha": 0.5
71+
})
72+
73+
# zero_matrices
74+
tests.append({
75+
"Q": torch.zeros((3, 5), device="cuda", dtype=dtype),
76+
"K": torch.zeros((3, 5), device="cuda", dtype=dtype),
77+
"V": torch.zeros((3, 5), device="cuda", dtype=dtype),
78+
"output": torch.empty(3, 5, device="cuda", dtype=dtype),
79+
"M": 3, "N": 3, "d": 5, "alpha": 0.5
80+
})
81+
82+
# mixed_values
83+
tests.append({
84+
"Q": torch.tensor([[-1.0, 2.0, -3.0], [4.0, -5.0, 6.0], [-7.0, 8.0, -9.0], [10.0, -11.0, 12.0]], device="cuda", dtype=dtype),
85+
"K": torch.tensor([[2.0, -1.0, 3.0], [-4.0, 5.0, -6.0], [7.0, -8.0, 9.0], [-10.0, 11.0, -12.0]], device="cuda", dtype=dtype),
86+
"V": torch.tensor([[1.0, 0.5, -0.5], [-1.0, 2.0, 3.0], [4.0, -2.0, 1.0], [0.0, 1.0, -1.0]], device="cuda", dtype=dtype),
87+
"output": torch.empty(4, 3, device="cuda", dtype=dtype),
88+
"M": 4, "N": 4, "d": 3, "alpha": 1.0
89+
})
90+
91+
# large_matrices
92+
tests.append({
93+
"Q": torch.empty((64, 32), device="cuda", dtype=dtype).uniform_(-0.1, 0.1),
94+
"K": torch.empty((128, 32), device="cuda", dtype=dtype).uniform_(-0.1, 0.1),
95+
"V": torch.empty((128, 32), device="cuda", dtype=dtype).uniform_(-0.1, 0.1),
96+
"output": torch.empty(64, 32, device="cuda", dtype=dtype),
97+
"M": 64, "N": 128, "d": 32, "alpha": -0.76
98+
})
99+
100+
# different alpha
101+
tests.append({
102+
"Q": torch.empty((64, 32), device="cuda", dtype=dtype).uniform_(-1, 1),
103+
"K": torch.empty((128, 32), device="cuda", dtype=dtype).uniform_(-1, 1),
104+
"V": torch.empty((128, 32), device="cuda", dtype=dtype).uniform_(-1, 1),
105+
"output": torch.empty(64, 32, device="cuda", dtype=dtype),
106+
"M": 64, "N": 128, "d": 32, "alpha": -0.3
107+
})
108+
109+
return tests
110+
111+
def generate_performance_test(self) -> Dict[str, Any]:
112+
dtype = torch.float32
113+
M, N, d = 2048, 2048, 1024
114+
Q = torch.empty((M, d), device="cuda", dtype=dtype).uniform_(-0.1, 0.1)
115+
K = torch.empty((N, d), device="cuda", dtype=dtype).uniform_(-0.1, 0.1)
116+
V = torch.empty((N,d), device="cuda", dtype=dtype).uniform_(-0.1, 0.1)
117+
output = torch.empty(M, d, device="cuda", dtype=dtype)
118+
return {"Q": Q, "K": K, "V": V, "output": output, "M": M, "N": N, "d": d, "alpha": 0.5}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#include <cuda_runtime.h>
2+
3+
// Q, K, V, output are device pointers
4+
extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int M, int N, int d, float alpha) {
5+
6+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from gpu.host import DeviceContext
2+
from gpu.id import block_dim, block_idx, thread_idx
3+
from memory import UnsafePointer
4+
from math import ceildiv
5+
6+
# Q, K, V, output are device pointers (i.e. pointers to memory on the GPU)
7+
@export
8+
def solve(Q: UnsafePointer[Float32], K: UnsafePointer[Float32], V: UnsafePointer[Float32],
9+
output: UnsafePointer[Float32], M: Int32, N: Int32, d: Int32, alpha: Float32):
10+
pass
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import torch
2+
3+
# Q, K, V, output are tensors on the GPU
4+
def solve(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, output: torch.Tensor,
5+
M: int, N: int, d: int, alpha: float):
6+
pass
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
# Q, K, V, output are tensors on the GPU
6+
def solve(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, output: torch.Tensor, M: int, N: int, d: int, alpha: float):
7+
pass

0 commit comments

Comments
 (0)