auto m = a.size(0); auto n = b.size(1); auto p = a.size(1);
torch::Tensor result = torch::zeros({m, n}, torch::dtype(torch::kFloat32)); for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { float sum = 0.0; for (int k = 0; k < p; k++) { sum += a[i][k].item<float>() * b[k][j].item<float>(); } result[i][j] = sum; } } return result; }
auto m = a.size(0); auto n = b.size(1); auto p = a.size(1);
torch::Tensor result = torch::zeros({m, n}, torch::dtype(torch::kFloat32)); for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { float sum = 0.0; for (int k = 0; k < p; k++) { sum += a[i][k].item<float>() * b[k][j].item<float>(); } result[i][j] = sum; } } return result; }
// 参数:queries(Q),keys(K),values(V) std::vector<torch::Tensor> attention_forward( torch::Tensor &q, torch::Tensor &k, torch::Tensor &v){ if (!(q.device().type() == k.device().type() && q.device().type() == v.device().type())) { throw std::runtime_error("Input tensors q, k, and v must be on the same device"); }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &attention_forward, "Attention forward (CUDA)"); }
Cu 部分
在同级目录下,创建 attention_kernel.cu文件。首先实现
attention_cuda_forward函数的主要逻辑。其中,主要对矩阵乘法进行了优化,使用了
CUDA 的并行计算。然后,使用
AT_DISPATCH_FLOATING_TYPES宏来实现对不同类型的支持,这样就可以支持
float和 double类型了。
对于
matrix_multiply函数,与一般写法不同的是,需要提前创建好输出的
Tensor,然后再传入到 CUDA 的函数中。并且需要创建好
blocks和 threads,然后再调用 CUDA
的函数。这里指定了每个 CUDA 是有 16 x 16
线程的块,而这些块是可以并行计算的,所以能够加速计算。可参考 An Even
Easier Introduction to CUDA
而在传递参数的时候,需要使用 packed_accessor。这里的
packed_accessor的第一个参数是
Tensor的类型,第二个参数是
Tensor的维度,第三个参数是
Tensor的类型,第四个参数是
Tensor的维度。这里的
packed_accessor的第三个参数和第四个参数,是为了支持 CUDA
的。
if (row < input1.size(0) && col < input2.size(1)) { scalar_t value = 0.0; for (int k = 0; k < input1.size(1); ++k) { value += input1[row][k] * input2[k][col]; } output[row][col] = value; } }
torch::Tensor matrix_multiply(torch::Tensor input1, torch::Tensor input2){ int rows1 = input1.size(0); int cols1 = input1.size(1); int cols2 = input2.size(1);
auto options = torch::TensorOptions().device(input1.device()); torch::Tensor output = torch::zeros({rows1, cols2}, options);
block_size = 2# 方便起见,这里设置为 N,M,K 的公约数。同时也拆成了 2x2 的 block block_M, block_N, block_K = M // block_size, N // block_size, K // block_size
defmatmul(sub_A1, sub_A2): output = torch.zeros(size=(sub_A1.shape[0], sub_A2.shape[1])) for i inrange(sub_A1.shape[0]): for j inrange(sub_A2.shape[1]): for k inrange(sub_A2.shape[0]): output[i][j] += sub_A1[i][k] * sub_A2[k][j] return output
block_size_M, block_size_N, block_size_K = 2, 3, 4 block_M, block_N, block_K = M // block_size_M, N // block_size_N, K // block_size_K
defblock_matmul(sub_A1, sub_A2): output = torch.zeros(size=(sub_A1.shape[0], sub_A2.shape[1])) for i inrange(sub_A1.shape[0]): for j inrange(sub_A2.shape[1]): for k inrange(sub_A2.shape[0]): output[i][j] += sub_A1[i][k] * sub_A2[k][j] return output
defmatmul(A1, A2): output = torch.zeros(size=(A1.shape[0], A2.shape[1])) for i inrange(0, A1.shape[0], block_M): start_i, end_i = i, i + block_M for j inrange(0, A2.shape[1], block_N): start_j, end_j = j, j + block_N for k inrange(0, A2.shape[0], block_K): start_k, end_k = k, k + block_K # 计算每个 block 的矩阵乘法 sub_A1 = A1[start_i:end_i, start_k:end_k] sub_A2 = A2[start_k:end_k, start_j:end_j] # 把每个 block 的结果放到对应的位置 output[start_i:end_i, start_j:end_j] += block_matmul(sub_A1, sub_A2) return output print(matmul(A1, A2)) print(A1 @ A2) assert torch.allclose(matmul(A1, A2), A1 @ A2)
From https://penny-xu.github.io/blog/tiled-matrix-multiplication
With or without tiling, the same number of accesses into global
memory occur. The difference is that, without tiling, each thread must
sequentially (one after the other) access global memory 8 times.
With tiling, we can parallelize the access to global memory so that
each thread only sequentially accesses global memory 4 times.
To summarize, the point is not to reduce the number of
multiplications or even the total number of global memory accesses, but
rather to reduce the number of sequential global memory accesses per
thread. In other words, we better share the heavy load of memory access
across threads.
M, N, K = 4, 6, 8 A1, A2 = torch.rand(size=(M, N)), torch.rand(size=(N, K))
defblock_matmul(sub_A1, sub_A2): output = torch.zeros(size=(sub_A1.shape[0], sub_A2.shape[1])) for i inrange(sub_A1.shape[0]): for j inrange(sub_A2.shape[1]): for k inrange(sub_A2.shape[0]): output[i][j] += sub_A1[i][k] * sub_A2[k][j] return output
deftiled_matmul_softmax(A1, A2): block_size_M, block_size_N, block_size_K = 2, 3, 4 block_M, block_N, block_K = M // block_size_M, N // block_size_N, K // block_size_K
output = torch.zeros(size=(A1.shape[0], A2.shape[1])) for i inrange(0, A1.shape[0], block_M): start_i, end_i = i, i + block_M row_max = torch.tensor([[0.for _ inrange(block_N)] for _ inrange(block_M)]) old_row_max = torch.tensor([[0.for _ inrange(block_N)] for _ inrange(block_M)]) normalizer_term = torch.tensor([[0.for _ inrange(block_N)] for _ inrange(block_M)])
for j inrange(0, A2.shape[1], block_N): start_j, end_j = j, j + block_N for k inrange(0, A2.shape[0], block_K): start_k, end_k = k, k + block_K sub_A1 = A1[start_i:end_i, start_k:end_k] sub_A2 = A2[start_k:end_k, start_j:end_j] output[start_i:end_i, start_j:end_j] += block_matmul(sub_A1, sub_A2)
# 这里算完了每个block的结果,所以需要将其拆分成每个block,然后再计算softmax for ii, row inenumerate(range(start_i, end_i)): for jj, col inenumerate(range(start_j, end_j)): val = output[row][col] old_row_max[ii][jj] = row_max[ii][jj] row_max[ii][jj] = max(old_row_max[ii][jj], val) normalizer_term[ii][jj] = normalizer_term[ii][jj] * np.exp(old_row_max[ii][jj] - row_max[ii][jj]) + np.exp(val - row_max[ii][jj])
if (row < input1.size(0) && col < input2.size(1)) { scalar_t value = 0.0; for (int k = 0; k < input1.size(1); ++k) { value += input1[row][k] * input2[k][col]; } output[row][col] = value }
在循环之后,已经计算完了输出矩阵中的一项。需要继续算每行的
row_max和 normalizer_term。但由于这里是 cuda
中的某个 block,所以需要借助共享内存来通信每行的结果。
if (row < input1.size(0) && col < input2.size(1)) { scalar_t value = 0.0; for (int k = 0; k < input1.size(1); ++k) { value += input1[row][k] * input2[k][col]; }
template <typenamescalar_t> __global__ voidmatrix_multiply_vector_kernel(const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> input1, const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> input2, torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> output, constint M, constint N, constint K ){ int index = blockIdx.x * blockDim.x + threadIdx.x; int row = index / N; int col = index % N;
if (row < M && col < N) { float value = 0.0; for (int k = 0; k < K; ++k) { value += input1[row * K + k] * input2[k * N + col]; } output[row * N + col] = value; if (col == N - 1) { float row_max = 0.0; float normalizer_term = 0.0; float old_row_max = 0.0; for (int i = 0; i < N; ++i) { old_row_max = row_max; row_max = max(row_max, output[row * N + i]); normalizer_term = normalizer_term * exp(old_row_max - row_max) + exp(output[row * N + i] - row_max); } for (int i = 0; i < N; ++i) { output[row * N + i] = exp(output[row * N + i] - row_max) / normalizer_term; } } } }
std::vector<torch::Tensor> matmul_vector(torch::Tensor input1, torch::Tensor input2){ int M = input1.size(0); int K = input1.size(1); int N = input2.size(1);
auto options = torch::TensorOptions().device(input1.device()); const dim3 threads(BLOCK_SIZE); const dim3 blocks((M * N + threads.x - 1) / threads.x);
// Reshape input tensors to vectors auto input1_vector = input1.reshape({-1}); auto input2_vector = input2.reshape({-1}); torch::Tensor output_vector = torch::zeros({M * N}, options);
AT_DISPATCH_FLOATING_TYPES(input1_vector.scalar_type(), "matrix_multiply_vector_kernel", ([&] { matrix_multiply_vector_kernel<<<blocks, threads>>>( input1_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), input2_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), output_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), M, N, K ); })); return {output_vector.reshape({M, N}), output_vector.reshape({M, N})}; }
std::vector<torch::Tensor> attention_cuda_forward( torch::Tensor q, torch::Tensor k, torch::Tensor v){
template <typenamescalar_t> __global__ voidsoftmax_kernel(torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> output, constint M, constint N){ int index = blockIdx.x * blockDim.x + threadIdx.x; int row = index / N;
if (row < M) { float row_max = 0.0; float normalizer_term = 0.0; float old_row_max = 0.0;
for (int i = 0; i < N; ++i) { old_row_max = row_max; row_max = max(row_max, output[row * N + i]); normalizer_term = normalizer_term * exp(old_row_max - row_max) + exp(output[row * N + i] - row_max); }
for (int i = 0; i < N; ++i) { output[row * N + i] = exp(output[row * N + i] - row_max) / normalizer_term; } } }
torch::Tensor matrix_softmax_vector_softmax(torch::Tensor input1, torch::Tensor input2){ int M = input1.size(0); int K = input1.size(1); int N = input2.size(1);
auto options = torch::TensorOptions().device(input1.device()); const dim3 threads(BLOCK_SIZE_VECTOR); const dim3 blocks((M * N + threads.x - 1) / threads.x);
// Reshape input tensors to vectors auto input1_vector = input1.reshape({-1}); auto input2_vector = input2.reshape({-1}); torch::Tensor output_vector = torch::zeros({M * N}, options);
// 普通一维的矩阵乘法 AT_DISPATCH_FLOATING_TYPES(input1_vector.scalar_type(), "matrix_multiply_vector_softmax_kernel", ([&] { matrix_multiply_vector_softmax_kernel<<<blocks, threads>>>( input1_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), input2_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), output_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), M, N, K ); }));
cudaDeviceSynchronize();
AT_DISPATCH_FLOATING_TYPES(output_vector.scalar_type(), "softmax_kernel", ([&] { softmax_kernel<<<blocks, threads>>>( output_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), M, N ); })); return output_vector.reshape({M, N});