0%

以Attention为例的算子融合

以 Attention 作为例子介绍算子融合的写法

代码

用 Cpp 写 PyTorch 的插件

矩阵乘法+Softmax算子融合

算子融合是将多个计算操作合并为一个计算操作,以减少计算量和内存访问次数,从而提高计算效率。比如,矩阵乘法和 softmax 的融合,可以减少一次内存访问。

Online softmax

softmax 的计算公式为:

\[ \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}} \]

而通常为了数值稳定性(避免溢出),会先计算最大值,再减去最大值,最后再计算 softmax。即:

\[ \text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_{j=1}^n e^{x_j - \max(x)}} \]

在神经网络中,其值通常为一个矩阵(不只是一个维度),所以需要对每一行进行 softmax。示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
import torch

X = torch.rand(4, 4)

def my_softmax(X, dim=1):
X -= torch.max(X, dim=dim, keepdim=True)[0]
return torch.exp(X) / torch.sum(torch.exp(X), dim=dim, keepdim=True)

print(torch.softmax(X, dim=1))
print(my_softmax(X, dim=1))
assert torch.allclose(torch.softmax(X, dim=1), my_softmax(X, dim=1))

但这样跟矩阵乘法进行算子融合是没有优势的,两个函数还是独立计算的。算子融合是需要糅合两种计算操作。所以,需要把 softmax 函数放到矩阵乘法中进行计算。即将计算过程变成一个可迭代的过程,换而言之,随着元素的增加不断更新 softmax 的结果。

Online normalizer calculation for softmax 提供了这么一种方式。简单起见,从一个元素开始介绍。现在有一个列表,里面只有一个元素 [1],其 softmax 计算过程为:

  1. exp: 计算 \(e^{1}\),得到 \([e^{1}]\)
  2. max: 计算 max(1),得到 1
  3. -max: 计算 $e^{1 - 1} $,得到 \([e^{1 - 1} ]\)
  4. sum: 计算 $e^{1- 1} $ 的和,得到 $e^{1- 1} $
  5. softmax: 计算 \([e^{1- 1} ] / (e^{1- 1} )\),得到 \([1]\)

现在增加一个元素,观察有哪些变化。假设增加的元素为 2,现在有一个列表,里面有两个元素 [1, 2],新增的 softmax 计算过程为:

  1. exp: 计算 \(e^{2}\),得到 \(e^{2}\)
  2. max: 与原来的 max(1)=1 比较,得到 max(1, 2),得到 2
  3. -max:所有元素更新一遍,均减去最新的 max(1, 2)=2,得到 \([e^{1- 2} , e^{2- 2} ]\)
  4. sum: 重新计算一遍结果
  5. softmax: 计算 \([e^{1- 2}, e^{2 - 2}] / (e^{1 - 2} + e^{2 - 2})\)

对比发现,第 3 步会导致重新第 4 步计算一遍求和结果,但这个求和结果在第 5 步中作为分母是可以灵活调整的。原来是\(e^{1-1}\),现在更新为\(e^{1-2}+e^{2-2}\)。假设原来的最大值为 old_max,新的最大值为 new_max,原来的元素为 old_v,则原来元素的 exp结果可以更新为

\[ e^{old\_v-old\_max} \rightarrow e^{old\_v-old\_max+old\_max-new\_max} = e^{old\_v-old\_max} \times e^{old\_max-new\_max} \]

而对于新加的元素,则可以在求和部分直接加上\(e^{2-2}\),所以,代码如下

1
2
3
4
5
6
7
8
9
10
import numpy as np

nums = [1, 2]
sum_, max_v = 0., 0.
norm_v = 0.
for i, num in enumerate(nums):
old_max_v = max_v
max_v = max(max_v, num)
norm_v = norm_v*np.exp(old_max_v - max_v) + np.exp(num - max_v)
softmax_nums = [np.exp(num-max_v)/ norm_v for num in nums]

对于两个元素这个公式是适用的,那么对于多个元素呢?答案自然也是同样适用的

\[ \begin{aligned} e^{v1-old\_max}+e^{v2-old\_max} &\rightarrow e^{v1-new\_max}+e^{v2-new\_max} \\ &= e^{v1-old\_max+old\_max-new\_max}+e^{v2-old\_max+old\_max-new\_max} \\ \\ &= e^{v1-old\_max} \times e^{old\_max-new\_max} + e^{v2-old\_max} \times e^{old\_max-new\_max} \\ &= (e^{v1-old\_max} - e^{v2-old\_max}) \times e^{old\_max-new\_max} \end{aligned} \]

所以,将其更新为二维矩阵的形式,就有

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import numpy as np

X = torch.rand(4, 4)

def online_softmax(X):
value = torch.zeros_like(X)
for row in range(X.shape[0]):
row_max = 0.0
normalizer_term = 0.0
for col in range(X.shape[1]):
val = X[row, col]
old_row_max = row_max
row_max = max(old_row_max, val)
normalizer_term = normalizer_term * np.exp(old_row_max - row_max) + np.exp(val - row_max)
value[row, :] = torch.exp(X[row, :] - row_max) / normalizer_term
return value

print(torch.softmax(X, dim=1))
print(online_softmax(X))
assert torch.allclose(torch.softmax(X, dim=1), online_softmax(X))

softmax + 矩阵乘法

一般的融合操作是需要在矩阵乘法后,计算结果的 softmax。即

1
2
3
4
5
import torch

M, N, K = 4, 2, 4
A1, A2 = torch.rand(size=(M, N)), torch.rand(size=(N, K))
output = torch.softmax(A1 @ A2, dim=1)

在结合上面的内容后,将其改造为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import numpy as np

M, N, K = 4, 2, 4
A1, A2 = torch.rand(size=(M, N)), torch.rand(size=(N, K))

def matmul_softmax(A1, A2):
output = torch.zeros(size=(A1.shape[0], A2.shape[1]))
for i in range(A1.shape[0]):
row_max = 0.0
normalizer_term = 0.0
for j in range(A2.shape[1]):
val = output[i, j] = sum(map(lambda x: x[0] * x[1], zip(A1[i], A2[:, j])))

old_row_max = row_max
row_max = max(old_row_max, val)
normalizer_term = normalizer_term * np.exp(old_row_max - row_max) + np.exp(val - row_max)
output[i, :] = torch.exp(output[i, :] - row_max) / normalizer_term
return output

print(torch.softmax(A1 @ A2, dim=1))
print(matmul_softmax(A1, A2))
assert torch.allclose(torch.softmax(A1 @ A2, dim=1), matmul_softmax(A1, A2))

更进一步,结合 Tiled matmul,将其改造为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import numpy as np

M, N, K = 4, 6, 8
A1, A2 = torch.rand(size=(M, N)), torch.rand(size=(N, K))

def block_matmul(sub_A1, sub_A2):
output = torch.zeros(size=(sub_A1.shape[0], sub_A2.shape[1]))
for i in range(sub_A1.shape[0]):
for j in range(sub_A2.shape[1]):
for k in range(sub_A2.shape[0]):
output[i][j] += sub_A1[i][k] * sub_A2[k][j]
return output


def tiled_matmul_softmax(A1, A2):
block_size_M, block_size_N, block_size_K = 2, 3, 4
block_M, block_N, block_K = M // block_size_M, N // block_size_N, K // block_size_K

output = torch.zeros(size=(A1.shape[0], A2.shape[1]))
for i in range(0, A1.shape[0], block_M):
start_i, end_i = i, i + block_M
row_max = torch.tensor([[0. for _ in range(block_N)] for _ in range(block_M)])
old_row_max = torch.tensor([[0. for _ in range(block_N)] for _ in range(block_M)])
normalizer_term = torch.tensor([[0. for _ in range(block_N)] for _ in range(block_M)])

for j in range(0, A2.shape[1], block_N):
start_j, end_j = j, j + block_N
for k in range(0, A2.shape[0], block_K):
start_k, end_k = k, k + block_K
sub_A1 = A1[start_i:end_i, start_k:end_k]
sub_A2 = A2[start_k:end_k, start_j:end_j]
output[start_i:end_i, start_j:end_j] += block_matmul(sub_A1, sub_A2)

# 这里算完了每个block的结果,所以需要将其拆分成每个block,然后再计算softmax
for ii, row in enumerate(range(start_i, end_i)):
for jj, col in enumerate(range(start_j, end_j)):
val = output[row][col]
old_row_max[ii][jj] = row_max[ii][jj]
row_max[ii][jj] = max(old_row_max[ii][jj], val)
normalizer_term[ii][jj] = normalizer_term[ii][jj] * np.exp(old_row_max[ii][jj] - row_max[ii][jj]) + np.exp(val - row_max[ii][jj])

for ii, row in enumerate(range(start_i, end_i)):
row_max_v, _ = torch.max(row_max, dim=1)
# 重算 sum, 代入公式 old_v*exp(old_max - new_max)
sum_ = torch.sum(normalizer_term[ii] * torch.exp(row_max[ii] - row_max_v[ii]))
output[row, :] = torch.exp(output[row, :] - row_max_v[ii]) / sum_
return output


print(torch.softmax(A1 @ A2, dim=1))
print(tiled_matmul_softmax(A1, A2))
assert torch.allclose(torch.softmax(A1 @ A2, dim=1), tiled_matmul_softmax(A1, A2))

由于这里是每次计算出来是一个 block ,所以要把 block 拆分出每个元素,计算 block 中每行每列的最大值 row_max 以及分母 normalizer_term。

最后在计算 softmax 时,要计算所有 block 的最大值 torch.max(row_max, dim=1),还需要计算分母,这里需要考虑所有的 block,可以类比于合并计算 [1, 2], [3, 4]两个列表的normalizer_term 。已知对应的 normalizer_term \(e^{1-2}+e^{2-2}\)\(e^{3-4}+e^{3-4}\),合并后的结果应当是 \(e^{1-4}+e^{2-4}+e^{3-4}+e^{3-4}\)。将其公式化写作:

\[ \begin{aligned} \sum_{i=1}^{m} e^{x_i-max(x)} + \sum_{i=1}^{n} e^{y_i-max(y)} & \rightarrow \sum_{i=1}^{m} e^{x_i-max(x,y)} + \sum_{i=1}^{n} e^{y_i-max(x,y)} \\ &= \sum_{i=1}^{m} e^{x_i-max(x)}e^{max(x)-max(x,y)} + \sum_{i=1}^{n} e^{y_i-max(x,y)}e^{max(y)-max(x,y)} \end{aligned} \]

假设 max(x,y)=max(x),那么对应项将乘 1,这并不会对结果有任何影响。对应的代码为:

1
sum_ = torch.sum(normalizer_term[ii] * torch.exp(row_max[ii] - row_max_v[ii]))

至此,完成了矩阵乘法和 softmax 的融合。接下来,会实现 cuda 版本以对比性能。

CUDA 实现

首先,接着之前矩阵乘法的cuda实现,

1
2
3
4
5
6
7
if (row < input1.size(0) && col < input2.size(1)) {
scalar_t value = 0.0;
for (int k = 0; k < input1.size(1); ++k) {
value += input1[row][k] * input2[k][col];
}
output[row][col] = value
}

在循环之后,已经计算完了输出矩阵中的一项。需要继续算每行的 row_maxnormalizer_term。但由于这里是 cuda 中的某个 block,所以需要借助共享内存来通信每行的结果。

1
2
3
4
5
// 使用共享内存,计算每个 row 的最大值
__shared__ scalar_t row_max[16][16];
__shared__ scalar_t normalizer_term[16][16];
row_max[threadIdx.y][threadIdx.x] = value; // 先把计算结果放到 row_max 中,以便于比较大小
__syncthreads(); // 这行代码是为了保证每个线程都已经计算完了,才能进行下一步的操作

计算过程分为三个步骤:1. 找到每行的最大值

1
2
3
4
5
6
for (int i = blockDim.x / 2; i > 0; i /= 2) {
if (threadIdx.x < i) {
row_max[threadIdx.y][threadIdx.x] = max(row_max[threadIdx.y][threadIdx.x], row_max[threadIdx.y][threadIdx.x + i]);
}
__syncthreads();
}
  1. 计算每行的 softmax 的分母每项组成
1
2
normalizer_term[threadIdx.y][threadIdx.x] = exp(value - row_max[threadIdx.y][0]);
__syncthreads();
  1. 计算每行的 softmax 的每项之后
1
2
3
4
5
6
7
8
for (int i = blockDim.x / 2; i > 0; i /= 2) {
if (threadIdx.x < i) {
normalizer_term[threadIdx.y][threadIdx.x] += normalizer_term[threadIdx.y][threadIdx.x + i];
}
__syncthreads();
}
// 最后将其更新到输出矩阵中
output[row][col] = exp(value - row_max[threadIdx.y][0]) / normalizer_term[threadIdx.y][0];

完整代码如下(这里没有实现完整的attention,仅仅是一个矩阵乘法 + softmax 计算):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

// Matrix multiply kernel
template <typename scalar_t>
__global__ void matrix_multiply_kernel(const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input1,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input2,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;

if (row < input1.size(0) && col < input2.size(1)) {
scalar_t value = 0.0;
for (int k = 0; k < input1.size(1); ++k) {
value += input1[row][k] * input2[k][col];
}

// 使用共享内存,计算每个 row 的最大值
__shared__ scalar_t row_max[16][16];
__shared__ scalar_t normalizer_term[16][16];
row_max[threadIdx.y][threadIdx.x] = value;
__syncthreads();

for (int i = blockDim.x / 2; i > 0; i /= 2) {
if (threadIdx.x < i) {
row_max[threadIdx.y][threadIdx.x] = max(row_max[threadIdx.y][threadIdx.x], row_max[threadIdx.y][threadIdx.x + i]);
}
__syncthreads();
}
// 计算每个 row 的 softmax 的分母
normalizer_term[threadIdx.y][threadIdx.x] = exp(value - row_max[threadIdx.y][0]);

__syncthreads();
// 计算每个 row normalizer_term之和
for (int i = blockDim.x / 2; i > 0; i /= 2) {
if (threadIdx.x < i) {
normalizer_term[threadIdx.y][threadIdx.x] += normalizer_term[threadIdx.y][threadIdx.x + i];
}
__syncthreads();
}

// 计算每个 row 的 softmax
output[row][col] = exp(value - row_max[threadIdx.y][0]) / normalizer_term[threadIdx.y][0];
}
}

torch::Tensor matrix_multiply(torch::Tensor input1, torch::Tensor input2) {
int rows1 = input1.size(0);
int cols1 = input1.size(1);
int cols2 = input2.size(1);

auto options = torch::TensorOptions().device(input1.device());
torch::Tensor output = torch::zeros({rows1, cols2}, options);

const dim3 threads(16, 16);
const dim3 blocks((cols2 + threads.x - 1) / threads.x,
(rows1 + threads.y - 1) / threads.y);

AT_DISPATCH_FLOATING_TYPES(input1.scalar_type(), "matrix_multiply_kernel", ([&] {
matrix_multiply_kernel<<<blocks, threads>>>(
input1.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input2.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));

return output;
}


// 这里偷懒没有换名字
std::vector<torch::Tensor> attention_cuda_forward(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v) {

torch::Tensor scores = matrix_multiply(q, k);
return {scores};
}

测试代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch

import mulsoftmax
import timeit

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

def my_py_softmax(x, dim):
e = torch.exp(x)
s = torch.sum(e, dim=dim, keepdim=True)
return e / s

def py_mulsoft(q, k, v):
# print(q@k.T)
return torch.softmax(q @ k.T, dim=1)

def check_forward(q, k, v):
baseline_values = py_mulsoft(q, k, v)
cpp_values = mulsoftmax.forward(q, k, v)[0]

print("base o", baseline_values)
print("cpp o", cpp_values)
print(torch.all(torch.isclose(baseline_values, cpp_values)))

def compare_time(loop=100):
q, k, v = torch.rand(size=(m, n), device=device), torch.rand(size=(m, n), device=device), torch.rand(size=(m, n), device=device)
print("py", timeit.timeit(lambda: py_mulsoft(q, k, v), number=loop))
print("cpp", timeit.timeit(lambda: mulsoftmax.forward(q, k, v)[0], number=loop))

if __name__ == "__main__":
m, n = 16, 40
device = "cuda"
q, k, v = torch.rand(size=(m, n), device=device), torch.rand(size=(m, n), device=device), torch.rand(size=(m, n), device=device)
# 先检查结果是否正确
check_forward(q, k, v)
q, k, v = torch.rand(size=(m, n)), torch.rand(size=(m, n)), torch.rand(size=(m, n))
# 循环1w次,对比性能差距
compare_time(10000)

输出结果如下:

py cuda
0 0.5136 0.2909
1 0.6143 0.3322
2 0.7300 0.3608

其他实现

二维矩阵转一维矩阵,实现矩阵乘法

在某些情况下,为了降低空间复杂度,会把二维矩阵展开为一维矩阵,再进行矩阵乘法。这里在这个基础上,再实现了 softmax 的计算。其本质上是把二维矩阵转为一维矩阵,再进行矩阵乘法。原来的input1[row][k] -> input1[row * K + k]input2[k][col] -> input2[k * N + col]。但这里暂时不好实现 softmax 融合,因为 softmax 需要计算每行的最大值,这里的一维矩阵无法直接计算每行的最大值。所以这里偷懒先使用判断的方法,如果是最后一列,则计算 softmax。因此,导致算的速度会比较慢。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

#DEFINE BLOCK_SIZE 256;

template <typename scalar_t>
__global__ void matrix_multiply_vector_kernel(const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> input1,
const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> input2,
torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> output,
const int M, const int N, const int K
) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int row = index / N;
int col = index % N;

if (row < M && col < N) {
float value = 0.0;
for (int k = 0; k < K; ++k) {
value += input1[row * K + k] * input2[k * N + col];
}
output[row * N + col] = value;
if (col == N - 1) {
float row_max = 0.0;
float normalizer_term = 0.0;
float old_row_max = 0.0;
for (int i = 0; i < N; ++i) {
old_row_max = row_max;
row_max = max(row_max, output[row * N + i]);
normalizer_term = normalizer_term * exp(old_row_max - row_max) + exp(output[row * N + i] - row_max);
}
for (int i = 0; i < N; ++i) {
output[row * N + i] = exp(output[row * N + i] - row_max) / normalizer_term;
}
}
}
}

std::vector<torch::Tensor> matmul_vector(torch::Tensor input1, torch::Tensor input2) {
int M = input1.size(0);
int K = input1.size(1);
int N = input2.size(1);

auto options = torch::TensorOptions().device(input1.device());

const dim3 threads(BLOCK_SIZE);
const dim3 blocks((M * N + threads.x - 1) / threads.x);

// Reshape input tensors to vectors
auto input1_vector = input1.reshape({-1});
auto input2_vector = input2.reshape({-1});
torch::Tensor output_vector = torch::zeros({M * N}, options);

AT_DISPATCH_FLOATING_TYPES(input1_vector.scalar_type(), "matrix_multiply_vector_kernel", ([&] {
matrix_multiply_vector_kernel<<<blocks, threads>>>(
input1_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
input2_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
output_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
M, N, K
);
}));
return {output_vector.reshape({M, N}), output_vector.reshape({M, N})};
}

std::vector<torch::Tensor> attention_cuda_forward(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v) {

return matmul_vector(q, k);
}

输出结果如下:

py cuda
0 0.4239 0.4907
1 0.5069 0.4388
2 0.5462 0.5799

尽管用了一种比较笨的方法,但是速度实际上与原生的 pytorch 相差无几。

单独实现 softmax

相较于非算子融合的写法,这里实现了一维矩阵的 softmax。相当于在计算矩阵乘法后再算 softmax。比较朴素的写法。。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
template <typename scalar_t>
__global__ void softmax_kernel(torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> output,
const int M, const int N) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int row = index / N;

if (row < M) {
float row_max = 0.0;
float normalizer_term = 0.0;
float old_row_max = 0.0;

for (int i = 0; i < N; ++i) {
old_row_max = row_max;
row_max = max(row_max, output[row * N + i]);
normalizer_term = normalizer_term * exp(old_row_max - row_max) + exp(output[row * N + i] - row_max);
}

for (int i = 0; i < N; ++i) {
output[row * N + i] = exp(output[row * N + i] - row_max) / normalizer_term;
}
}
}

torch::Tensor matrix_softmax_vector_softmax(torch::Tensor input1, torch::Tensor input2) {
int M = input1.size(0);
int K = input1.size(1);
int N = input2.size(1);

auto options = torch::TensorOptions().device(input1.device());

const dim3 threads(BLOCK_SIZE_VECTOR);
const dim3 blocks((M * N + threads.x - 1) / threads.x);

// Reshape input tensors to vectors
auto input1_vector = input1.reshape({-1});
auto input2_vector = input2.reshape({-1});
torch::Tensor output_vector = torch::zeros({M * N}, options);

// 普通一维的矩阵乘法
AT_DISPATCH_FLOATING_TYPES(input1_vector.scalar_type(), "matrix_multiply_vector_softmax_kernel", ([&] {
matrix_multiply_vector_softmax_kernel<<<blocks, threads>>>(
input1_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
input2_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
output_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
M, N, K
);
}));

cudaDeviceSynchronize();

AT_DISPATCH_FLOATING_TYPES(output_vector.scalar_type(), "softmax_kernel", ([&] {
softmax_kernel<<<blocks, threads>>>(
output_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
M, N
);
}));
return output_vector.reshape({M, N});

}

输出结果如下:

py cuda
0 0.4239 1.1874
1 0.5069 1.1101
2 0.5462 1.3238

对比发现,这里的cuda还是会比pytorch的慢一些,这是因为没有用自带的softmax,且没有使用算子融合。所以就算速度会更慢。

Attention 的算子融合写法

Attention 的公式如下所示,在上面已经是实现了一次矩阵乘法 + softmax 计算的融合,接下来需要再将其与新的矩阵乘法融合。

1
2
def py_attention(q, k, v):
return torch.softmax(q @ k.T, dim=1) @ v

相较于之前的先计算矩阵乘法再计算 softmax,这里是先计算 softmax 再计算矩阵乘法。重新回到开始,假设q @ v.T=out矩阵乘法结果为

\[ [[1, 2, 3, 4], [1, 2, 3, 4]] \]

在计算过程中,往往是先计算得到前半部分,再得到后半部分,即

\[ [[1, 2], [1, 2]] \]\[ [[3, 4], [3, 4]] \]

第一次计算

对于out,每行的计算过程是独立的,所以可以只模拟一行一列的情况来表示所有行的计算过程。假设k的第一列全为1。第一次计算为

\[ [1, 2] \]

  1. \(v_1 = \max(1, 2)\)
  2. \(\text{denominator1} = \exp(1-v_1) + \exp(2-v_1)\)
  3. \([\exp(1-v_1) / \text{denominator1}, \exp(2-v_1) / \text{denominator1}]\)
  4. 此时,再乘上k的第一列前两项(因为这里只有两项)

\[ [[1], [1]] \]

  1. \(\exp(1-v_1) / \text{denominator1} * 1 + \exp(2-v_1) / \text{denominator1} * 1\)

第二次计算

\[ [3, 4] \]

  1. \(v_2 = \max(3, 4, v_1)\)
  2. \(\text{denominator2} = \text{denominator1}*(exp(v_1-v_2))+d_t\),其中,\(d_t=\exp(3-v_2) + \exp(4-v_2)\), 见上文的 Online Softmax
  3. \([\exp(3-v_2) / \text{denominator2}, \exp(4-v_2) / \text{denominator2}]\)
  4. 此时,再乘上k的第一列后两项

\[ [[1], [1]] \]

  1. \(\exp(3-v2) / \text{denominator2} * 1 + \exp(4-v2) / \text{denominator2} * 1\)

所以现在计算的结果是

\[ \begin{aligned} \exp(1-v_1) / \text{denominator1} * 1 + \exp(2-v_1) / \text{denominator1} * 1 \\ + \exp(3-v_2) / \text{denominator2} * 1 + \exp(4-v_2) / \text{denominator2} * 1 \end{aligned} \]

然而,正确的结果应当是

\[ \begin{aligned} & \exp(1-v_2) / \text{denominator2} * 1 + \exp(2-v_2) / \text{denominator2} * 1 \\ & + \exp(3-v_2) / \text{denominator2} * 1 + \exp(4-v_2) / \text{denominator2} * 1 \end{aligned} \]

校准之前的结果

对比发现,只有之前的结果是有偏差的,所以只需要对之前的结果进行校准,即乘上一个系数。这个系数怎么计算呢?以第一项为例,用实际的结果除以之前的结果,有

\[ \begin{aligned} & \frac{\exp(1-v_1)}{\text{denominator1}}\times \frac{\text{denominator1}}{\text{denominator2}} \times \exp(v_1-v_2) = \frac{\exp(1-v_2)}{\text{denominator2}} \\ \end{aligned} \]

并且,对于第二项,也是一样的,所以这个系数为

\[ \frac{\text{denominator1}*\exp(v_1-v_2)}{\text{denominator2}} \]

因此,对于每次校准结果,只需要先乘上这个系数(当前的denominator除以上一轮的denominator),最后再计算第五步。

在计算过程中,需要存denominator这个变量,同时还需要额外存一次之前的最大值,以进行校准。

Python 代码实现

为了便于理解,约定变量名为

  • old_row_max:之前的最大值
  • row_max:最新的最大值
  • denominator:之前的分母
  • mod_denominator:之前的分母进行校准后的结果
  • new_molecule:最新的分子
  • cur_denominator:分母

所以,校准公式为

\[ \begin{aligned} \frac{\text{denominator1}*\exp(v_1-v_2)}{\text{denominator2}} &= \frac{\text{denominator} * \exp(\text{row\_max}-\text{old\_row\_max})}{\text{cur\_denominator}} \\ &= \frac{\text{mod\_denominator}}{\text{cur\_denominator}} \end{aligned} \]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch

M = 4
N = 6

q = torch.rand((M, N))
k = torch.rand((M, N))
v = torch.rand((M, N))

def flash_attention(q, k, v):
output = torch.zeros(q.shape)

block_m = 2
block_n = 2
block_head = N
for i in range(0, M, block_m):
start_i, end_i = i, i + block_m

old_row_max = torch.zeros([block_m]) - float("inf")
denominator = torch.zeros([block_m]) # 用于存储分母
acc = torch.zeros([block_m, block_head])
q_sub = q[start_i:end_i, :]
for j in range(0, M, block_n):
start_j, end_j = j, j + block_n
k_sub = k[start_j:end_j, :]
v_sub = v[start_j:end_j, :]
qk = q_sub @ k_sub.T
# online softmax
row_max = torch.max(
torch.stack((torch.max(qk, dim=1).values, old_row_max), dim=0), dim=0
).values
mod_denominator = denominator * torch.exp(old_row_max - row_max) # 对之前的分母进行校准
new_molecule = torch.exp(qk - row_max.reshape(-1, 1)) # 最新的分子
cur_denominator = torch.sum(new_molecule, -1) + mod_denominator # 对分母进行更新

# 分子 / 分母,当前block的softmax
new_softmax = new_molecule / torch.unsqueeze(cur_denominator, dim=1)
# 校准系数
acc *= torch.unsqueeze(mod_denominator / cur_denominator, dim=1)
acc += new_softmax @ v_sub
# 更新需要存的值
old_row_max = row_max
denominator = cur_denominator

output[start_i:end_i, :] = acc

return output


def naive_attention(q, k, v):
return torch.softmax(q @ k.T, dim=1) @ v


if __name__ == "__main__":
desired = naive_attention(q, k, v)
actual = flash_attention(q, k, v)
print(desired)
print(actual)
assert torch.allclose(desired, actual)

CUDA 实现

参考

[1] Online softmax

[2] Online normalizer calculation for softmax

[3] Flash Attention on INTEL GPU