0%

用 Cpp 写 PyTorch 的插件

从零开始,用 Cpp 写 PyTorch 的插件,包括 CPU 和 GPU 的版本。

代码

为什么

一般来说,在原生功能不能满足需求的时候,插件可以作为补充。比如,PyTorch 的 torch.nn.functional 中没有 softmax 函数,但是 torch.nn 中有,所以可以用 torch.nn.functional.softmax 来代替。但是,如果要用 softmax 的导数,就需要用到 softmax 的原始定义,这个时候就需要自己写插件了。

如果是比较简单的需求,则可以直接用 Python 完成。然而,当对性能要求较高时,往往会使用 Cpp 来写插件,最后甚至会优化为 CUDA 代码。

怎么写

从例子出发,一步步来写。假设要实现一个最简单的 Attention 模块,输入为 \(q,k,v \in \mathbb{R}^{M \times N}\),输出为 \(out \in \mathbb{R}^{M \times N}\)(不考虑 Batch Size 以及 Head 数量的情况)。Attention 模块的计算公式为:

\[ out = \text{softmax}(qk^T)v \]

只实现 forward 函数,不实现 backward 函数。

CPU 版本

类似于写 Python 的库,创建一个文件夹,目录结构如下所示:

1
2
3
4
attention
├── attention.cpp
├── setup.py
└── __init__.py

Python 部分

核心代码在 attention.cpp 中 ,首先在 setup.py 中添加如下代码:

1
2
3
4
5
6
7
8
9
10
11
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

setup(
name='attention',
ext_modules=[
CppExtension('attention', ['attention.cpp']),
],
cmdclass={
'build_ext': BuildExtension
})

这样,就可以用 python setup.py install 来安装插件了。或者用 pip install -e attention 以便于快速调试

__init__.py 中导入 attention 模块,以在 Python 中调用 forward 函数,直接计算 attention。

1
from .attention import forward

Cpp 部分

接下来,我们需要在 attention.cpp 中实现 forward 函数,为进行区分,这里使用这个函数名称 attention_forward ,这个函数的输入是 q、k、v 三个 Tensor,输出是 torch::Tensor。而具体的计算步骤可以拆解为三个步骤:

  1. 矩阵的乘法
  2. softmax
  3. 矩阵的乘法

使用 PYBIND11_MODULEattention_forward函数暴露出去,绑定到 forward上,这样就能用 forward函数来调用 attention_forward。整合后的完整代码如下,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <torch/extension.h>
#include <vector>

// 参数:queries(Q),keys(K),values(V)
// 返回:方便起见,返回一个 vector,实际上只有一个元素,便于后续扩展
std::vector<torch::Tensor> attention_forward(
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v) {

torch::Tensor scores = torch::matmul(q, k.transpose(0, 1));
scores = torch::softmax(scores, 1);
torch::Tensor attention = torch::matmul(scores, v);
return {attention};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &attention_forward, "Attention forward (CPU)");
}

测试

用 Python 版本的实现来测试 Cpp 版本的实现是否正确,测试代码如下,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch

import attention
import timeit

torch.manual_seed(42)

def py_attention(q, k, v):
return torch.softmax(q @ k.T, dim=1) @ v

def check_forward(q, k, v):
baseline_values = py_attention(q, k, v)
cpp_values = attention.forward(q, k, v)[-1]

print("base o", baseline_values)
print("cpp o", cpp_values)
print(torch.all(torch.isclose(baseline_values, cpp_values)))

def compare_time(q, k, v, loop=100):
print("py", timeit.timeit(lambda: py_attention(q, k, v), number=loop))
print("cpp", timeit.timeit(lambda: attention.forward(q, k, v), number=loop))

if __name__ == "__main__":
m, n = 2, 4
device = "cuda" if torch.cuda.is_available() else "cpu"
q, k, v = torch.rand(size=(m, n), device=device), torch.rand(size=(m, n), device=device), torch.rand(size=(m, n), device=device)
print("q", q)
print("k", k)
print("v", v)
print("="*20)
check_forward(q, k, v)
compare_time(q, k, v)

测试通过后,可以再使用 compare_time 函数对比一下二者的速度。理论上,二者的速度是相差无几的。因为均用的是 PyTorch 的矩阵乘法和 softmax 函数。

但是,如果需要进行更进一步的优化技巧,那么就需要自己实现矩阵乘法和 softmax 函数了。这里,我们只实现最简单的矩阵乘法和 softmax 函数,然后再对比一下二者的速度。

矩阵乘法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
torch::Tensor my_matmul(const torch::Tensor &a, const torch::Tensor &b) {
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-dimensional");
TORCH_CHECK(a.size(1) == b.size(0), "Dimensions mismatch");

auto m = a.size(0);
auto n = b.size(1);
auto p = a.size(1);

torch::Tensor result = torch::zeros({m, n}, torch::dtype(torch::kFloat32));
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
float sum = 0.0;
for (int k = 0; k < p; k++) {
sum += a[i][k].item<float>() * b[k][j].item<float>();
}
result[i][j] = sum;
}
}
return result;
}

softmax

由于 softmax函数比较特殊,后续会结合算子融合一起优化,所以简单的对其进行展开。用 torch::exptorch::sum实现了一遍,为了方便也可以直接使用 torch::softmax(scores, 1)

1
2
3
4
5
torch::Tensor my_softmax(const torch::Tensor& scores) {
torch::Tensor exponents = torch::exp(scores);
torch::Tensor sum = torch::sum(exponents, 1, true);
return exponents / sum;
}

但把这两个函数替换到 attention_forward函数中后,再次运行 compare_time函数,发现手写的 Cpp 版本的实现要比 Python 版本的实现慢。为什么?因为,当前只是简单的实现了矩阵乘法和 softmax,而 PyTorch 中的矩阵乘法和 softmax 都是经过优化的,所以速度会更快。

另外,使用原生的矩阵乘法和 softmax 函数,可以在 GPU 上运行,而手写的矩阵乘法和 softmax 函数,只能在 CPU 上运行。因此,接下来将其改造为 GPU 版本,然后再进行优化。

GPU 版本

Python 部分

setup.py 中更改为如下代码,把 CppExtension 改为 CUDAExtension

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
name='attention',
ext_modules=[
CUDAExtension('attention', [
'attention.cpp',
'attention_kernel.cu',
])
],
cmdclass={
'build_ext': BuildExtension
})

Cpp 部分

为了兼容之前的代码,这里将之前的 attention_forward更新为 attention_cpu_forward,同时加了一个类型判断,如果输入的 Tensor不在同一个设备上,则抛出异常。而对于 attention_cuda_forward的实现需要在 attention_kernel.cu中实现。注意:这里需要提前定义好 attention_cuda_forward函数,否则会报错。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <torch/extension.h>

#include <vector>

std::vector<torch::Tensor> attention_cuda_forward(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v);

torch::Tensor my_matmul(const torch::Tensor &a, const torch::Tensor &b) {
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-dimensional");
TORCH_CHECK(a.size(1) == b.size(0), "Dimensions mismatch");

auto m = a.size(0);
auto n = b.size(1);
auto p = a.size(1);

torch::Tensor result = torch::zeros({m, n}, torch::dtype(torch::kFloat32));
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
float sum = 0.0;
for (int k = 0; k < p; k++) {
sum += a[i][k].item<float>() * b[k][j].item<float>();
}
result[i][j] = sum;
}
}
return result;
}

std::vector<torch::Tensor> attention_cpu_forward(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v) {
torch::Tensor scores = my_matmul(q, k);
torch::Tensor attention = my_matmul(torch::softmax(scores, 1), v);
return {scores, attention};
}

// 参数:queries(Q),keys(K),values(V)
std::vector<torch::Tensor> attention_forward(
torch::Tensor &q,
torch::Tensor &k,
torch::Tensor &v) {
if (!(q.device().type() == k.device().type() && q.device().type() == v.device().type())) {
throw std::runtime_error("Input tensors q, k, and v must be on the same device");
}

if (q.is_cuda()) {
return attention_cuda_forward(q, k.transpose(0, 1), v);
} else {
return attention_cpu_forward(q, k.transpose(0, 1), v);
}
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &attention_forward, "Attention forward (CUDA)");
}

Cu 部分

在同级目录下,创建 attention_kernel.cu文件。首先实现 attention_cuda_forward函数的主要逻辑。其中,主要对矩阵乘法进行了优化,使用了 CUDA 的并行计算。然后,使用 AT_DISPATCH_FLOATING_TYPES宏来实现对不同类型的支持,这样就可以支持 floatdouble类型了。

对于 matrix_multiply函数,与一般写法不同的是,需要提前创建好输出的 Tensor,然后再传入到 CUDA 的函数中。并且需要创建好 blocksthreads,然后再调用 CUDA 的函数。这里指定了每个 CUDA 是有 16 x 16 线程的块,而这些块是可以并行计算的,所以能够加速计算。可参考 An Even Easier Introduction to CUDA

而在传递参数的时候,需要使用 packed_accessor。这里的 packed_accessor的第一个参数是 Tensor的类型,第二个参数是 Tensor的维度,第三个参数是 Tensor的类型,第四个参数是 Tensor的维度。这里的 packed_accessor的第三个参数和第四个参数,是为了支持 CUDA 的。

接下来就是实现 matrix_multiply_kernel。矩阵的乘法中,如果要计算输出矩阵的第一个值,则需要用到输入矩阵的第一行和第一列。因此,这里需要根据 blockthread的索引,来计算出对应的行和列。然后,就是普通的矩阵乘法的实现了。原来的矩阵乘法的实现是:

1
2
3
4
5
out = [[0 for _ in range(n)] for _ in range(m)]
for i in range(m):
for j in range(n):
for k in range(p):
out[i][j] += input1[i][k] * input2[k][j]

相当于把外面两个循环分别交给了 CUDA 的 blockthread来计算。这样,就可以实现并行计算了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

// Matrix multiply kernel
template <typename scalar_t>
__global__ void matrix_multiply_kernel(const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input1,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input2,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;

if (row < input1.size(0) && col < input2.size(1)) {
scalar_t value = 0.0;
for (int k = 0; k < input1.size(1); ++k) {
value += input1[row][k] * input2[k][col];
}
output[row][col] = value;
}
}

torch::Tensor matrix_multiply(torch::Tensor input1, torch::Tensor input2) {
int rows1 = input1.size(0);
int cols1 = input1.size(1);
int cols2 = input2.size(1);

auto options = torch::TensorOptions().device(input1.device());
torch::Tensor output = torch::zeros({rows1, cols2}, options);

const dim3 threads(16, 16);
const dim3 blocks((cols2 + threads.x - 1) / threads.x,
(rows1 + threads.y - 1) / threads.y);

AT_DISPATCH_FLOATING_TYPES(input1.scalar_type(), "matrix_multiply_kernel", ([&] {
matrix_multiply_kernel<<<blocks, threads>>>(
input1.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input2.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));

return output;
}


std::vector<torch::Tensor> attention_cuda_forward(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v) {

torch::Tensor scores = matrix_multiply(q, k);
torch::Tensor attention = matrix_multiply(torch::softmax(scores, 1), v);
return {scores, attention};
}

测试代码不变,依旧可以用上面的来进行检验。至此,最简单的实验就完成了。接下来,就是对其进行优化了。

矩阵乘法的优化

Matmul

重新回到矩阵乘法上,假设有两个矩阵 A1 和 A2,形状分别为 \(M \times N\)\(N \times K\),则矩阵乘法的计算公式为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

M, N, K = 4, 2, 4

A1 = torch.rand(size=(M, N))
A2 = torch.rand(size=(N, K))

output = torch.zeros(size=(M, K))
for i in range(M):
for j in range(K):
sum_ = 0
for k in range(N):
sum_ += A1[i][k] * A2[k][j]
output[i][j] = sum_

一种朴素的优化手段是把最后一个循环并行计算。

1
2
3
for i in range(M):
for j in range(K):
output[i][j] = sum(map(lambda x: x[0] * x[1], zip(A1[i], A2[:, j])))

利用多线程/进程(下文统称为job)进行并行计算可以提高程序的计算速度,但这样需要每个job都能访问到 A1 和 A2 的数据,所以这就引入了全局内存和共享内存的概念。

  • 全局内存(Global Memory):全局内存是一种在计算机程序中可被所有线程或进程访问的内存空间。它通常用于存储全局变量、静态变量以及动态分配的内存等。全局内存的特点是可以在整个程序执行过程中进行读写操作,但它的访问速度相对较慢。
  • 共享内存(Shared Memory):共享内存是一种特殊的内存区域,被多个线程或进程同时访问和共享。通过将数据存储在共享内存中,不同的线程或进程可以直接读取和写入这些数据,而无需使用其他的通信机制。共享内存的特点是高效的数据共享和访问速度,因为不需要进行复制或传输数据。

CUDA 中依旧存在类似的概念

  • 全局内存(Global Memory):在 CUDA 中,全局内存是一个设备(GPU)上可见的主机(CPU)内存空间。它可以由所有的线程块和线程访问,用于存储全局变量和动态分配的内存等。全局内存的读写操作相对较慢,因为涉及主机与设备之间的数据传输。
  • 共享内存(Shared Memory):在 CUDA 中,共享内存是位于每个线程块中的一块高速缓存内存。它被同一个线程块内的线程共享,并且比全局内存具有更快的读写速度。共享内存通常用于优化算法的性能,通过在线程块内部共享数据来减少全局内存的访问。

简而言之,全局内存可以很方便的存各种东西,但是速度慢;共享内存是一个好东西,速度块,但通常大小受限。所以,分别有两种优化:

  • 全局内存中,考虑如何加速访问
  • 共享内存中,考虑如何减小占用空间

这就引入了 Tiled matmul 算法。

Tiled Matmul

  • 对于加速访问,在无法控制硬件的前提下,只能通过并行的方式同时读取数据。
  • 对于减小占用空间,可以通过拆分矩阵,把大矩阵拆分成若干小矩阵,然后再进行计算。

重新思考矩阵 output的计算过程,每个元素的计算其实是独立的,其本质可以拆成若干独立的小块,如下图所示:

From https://penny-xu.github.io/blog/tiled-matrix-multiplication

由于矩阵 output每个元素是完全独立的,可以将其拆成若干个小矩阵来计算。如上图所示,把矩阵 output 拆成了 4 个小矩阵。对应的代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch

M, N, K = 4, 2, 4

A1 = torch.rand(size=(M, N))
A2 = torch.rand(size=(N, K))

block_size = 2 # 方便起见,这里设置为 N,M,K 的公约数。同时也拆成了 2x2 的 block
block_M, block_N, block_K = M // block_size, N // block_size, K // block_size

def matmul(sub_A1, sub_A2):
output = torch.zeros(size=(sub_A1.shape[0], sub_A2.shape[1]))
for i in range(sub_A1.shape[0]):
for j in range(sub_A2.shape[1]):
for k in range(sub_A2.shape[0]):
output[i][j] += sub_A1[i][k] * sub_A2[k][j]
return output

output11 = matmul(A1[:block_M, :], A2[:, :block_K])
output12 = matmul(A1[:block_M, :], A2[:, block_K:])
output21 = matmul(A1[block_M:, :], A2[:, :block_K])
output22 = matmul(A1[block_M:, :], A2[:, block_K:])
output = torch.cat([torch.cat([output11, output12], dim=1), torch.cat([output21, output22], dim=1)], dim=0)
print(output)
print(A1 @ A2)
assert torch.allclose(output, A1 @ A2)

对于,左上角矩阵 output11,实际上是由 block_size个矩阵乘法,再求和得到的 output11 = matmul(A1[:block_M, :block_N], A2[:block_N, :block_K]) + matmul(A1[:block_M, block_N:], A2[block_N:, :block_K])。再把它扩展的灵活一点

  1. 不局限于只能扩展为 2 \(\times\) 2 矩阵
  2. block_size 可以针对 M, N, K 进行调整
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch

M, N, K = 4, 6, 8

A1 = torch.rand(size=(M, N))
A2 = torch.rand(size=(N, K))

block_size_M, block_size_N, block_size_K = 2, 3, 4
block_M, block_N, block_K = M // block_size_M, N // block_size_N, K // block_size_K

def block_matmul(sub_A1, sub_A2):
output = torch.zeros(size=(sub_A1.shape[0], sub_A2.shape[1]))
for i in range(sub_A1.shape[0]):
for j in range(sub_A2.shape[1]):
for k in range(sub_A2.shape[0]):
output[i][j] += sub_A1[i][k] * sub_A2[k][j]
return output

def matmul(A1, A2):
output = torch.zeros(size=(A1.shape[0], A2.shape[1]))
for i in range(0, A1.shape[0], block_M):
start_i, end_i = i, i + block_M
for j in range(0, A2.shape[1], block_N):
start_j, end_j = j, j + block_N
for k in range(0, A2.shape[0], block_K):
start_k, end_k = k, k + block_K
# 计算每个 block 的矩阵乘法
sub_A1 = A1[start_i:end_i, start_k:end_k]
sub_A2 = A2[start_k:end_k, start_j:end_j]
# 把每个 block 的结果放到对应的位置
output[start_i:end_i, start_j:end_j] += block_matmul(sub_A1, sub_A2)
return output
print(matmul(A1, A2))
print(A1 @ A2)
assert torch.allclose(matmul(A1, A2), A1 @ A2)

From https://penny-xu.github.io/blog/tiled-matrix-multiplication

再次回顾这样做的目的,原作者是这么说的:

  • With or without tiling, the same number of accesses into global memory occur. The difference is that, without tiling, each thread must sequentially (one after the other) access global memory 8 times.
  • With tiling, we can parallelize the access to global memory so that each thread only sequentially accesses global memory 4 times.
  • To summarize, the point is not to reduce the number of multiplications or even the total number of global memory accesses, but rather to reduce the number of sequential global memory accesses per thread. In other words, we better share the heavy load of memory access across threads.

简而言之,作者的观点是通过拆成 block 的形式,能够并行的读取全局内存数据。个人观点,在某些情况下,拆分后的 block 可以恰好把数据放到共享内存中,以加速计算?

最后,对于 block size大小的选择:

  • 越大的 block size表示拆分后的矩阵个数变少,这样访问全局内存的次数会更多。但是,每个矩阵比较大,这样线程的并行度会更高。即,IO变小,计算变快。
  • 越小的 block size表示拆分后的矩阵个数变多,这样访问全局内存的次数会更少。但是,每个矩阵比较小,这样线程的并行度会更低。即,IO变大,计算变慢。

下一篇:以Attention为例的算子融合

参考

[1] https://pytorch.org/tutorials/advanced/cpp_extension.html

[2] Tiled matmul