M, N, K = 4, 6, 8 A1, A2 = torch.rand(size=(M, N)), torch.rand(size=(N, K))
defblock_matmul(sub_A1, sub_A2): output = torch.zeros(size=(sub_A1.shape[0], sub_A2.shape[1])) for i inrange(sub_A1.shape[0]): for j inrange(sub_A2.shape[1]): for k inrange(sub_A2.shape[0]): output[i][j] += sub_A1[i][k] * sub_A2[k][j] return output
deftiled_matmul_softmax(A1, A2): block_size_M, block_size_N, block_size_K = 2, 3, 4 block_M, block_N, block_K = M // block_size_M, N // block_size_N, K // block_size_K
output = torch.zeros(size=(A1.shape[0], A2.shape[1])) for i inrange(0, A1.shape[0], block_M): start_i, end_i = i, i + block_M row_max = torch.tensor([[0.for _ inrange(block_N)] for _ inrange(block_M)]) old_row_max = torch.tensor([[0.for _ inrange(block_N)] for _ inrange(block_M)]) normalizer_term = torch.tensor([[0.for _ inrange(block_N)] for _ inrange(block_M)])
for j inrange(0, A2.shape[1], block_N): start_j, end_j = j, j + block_N for k inrange(0, A2.shape[0], block_K): start_k, end_k = k, k + block_K sub_A1 = A1[start_i:end_i, start_k:end_k] sub_A2 = A2[start_k:end_k, start_j:end_j] output[start_i:end_i, start_j:end_j] += block_matmul(sub_A1, sub_A2)
# 这里算完了每个block的结果,所以需要将其拆分成每个block,然后再计算softmax for ii, row inenumerate(range(start_i, end_i)): for jj, col inenumerate(range(start_j, end_j)): val = output[row][col] old_row_max[ii][jj] = row_max[ii][jj] row_max[ii][jj] = max(old_row_max[ii][jj], val) normalizer_term[ii][jj] = normalizer_term[ii][jj] * np.exp(old_row_max[ii][jj] - row_max[ii][jj]) + np.exp(val - row_max[ii][jj])
if (row < input1.size(0) && col < input2.size(1)) { scalar_t value = 0.0; for (int k = 0; k < input1.size(1); ++k) { value += input1[row][k] * input2[k][col]; } output[row][col] = value }
在循环之后,已经计算完了输出矩阵中的一项。需要继续算每行的 row_max和 normalizer_term。但由于这里是 cuda 中的某个 block,所以需要借助共享内存来通信每行的结果。
if (row < input1.size(0) && col < input2.size(1)) { scalar_t value = 0.0; for (int k = 0; k < input1.size(1); ++k) { value += input1[row][k] * input2[k][col]; }
template <typenamescalar_t> __global__ voidmatrix_multiply_vector_kernel(const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> input1, const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> input2, torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> output, constint M, constint N, constint K ){ int index = blockIdx.x * blockDim.x + threadIdx.x; int row = index / N; int col = index % N;
if (row < M && col < N) { float value = 0.0; for (int k = 0; k < K; ++k) { value += input1[row * K + k] * input2[k * N + col]; } output[row * N + col] = value; if (col == N - 1) { float row_max = 0.0; float normalizer_term = 0.0; float old_row_max = 0.0; for (int i = 0; i < N; ++i) { old_row_max = row_max; row_max = max(row_max, output[row * N + i]); normalizer_term = normalizer_term * exp(old_row_max - row_max) + exp(output[row * N + i] - row_max); } for (int i = 0; i < N; ++i) { output[row * N + i] = exp(output[row * N + i] - row_max) / normalizer_term; } } } }
std::vector<torch::Tensor> matmul_vector(torch::Tensor input1, torch::Tensor input2){ int M = input1.size(0); int K = input1.size(1); int N = input2.size(1);
auto options = torch::TensorOptions().device(input1.device()); const dim3 threads(BLOCK_SIZE); const dim3 blocks((M * N + threads.x - 1) / threads.x);
// Reshape input tensors to vectors auto input1_vector = input1.reshape({-1}); auto input2_vector = input2.reshape({-1}); torch::Tensor output_vector = torch::zeros({M * N}, options);
AT_DISPATCH_FLOATING_TYPES(input1_vector.scalar_type(), "matrix_multiply_vector_kernel", ([&] { matrix_multiply_vector_kernel<<<blocks, threads>>>( input1_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), input2_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), output_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), M, N, K ); })); return {output_vector.reshape({M, N}), output_vector.reshape({M, N})}; }
std::vector<torch::Tensor> attention_cuda_forward( torch::Tensor q, torch::Tensor k, torch::Tensor v){
template <typenamescalar_t> __global__ voidsoftmax_kernel(torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> output, constint M, constint N){ int index = blockIdx.x * blockDim.x + threadIdx.x; int row = index / N;
if (row < M) { float row_max = 0.0; float normalizer_term = 0.0; float old_row_max = 0.0;
for (int i = 0; i < N; ++i) { old_row_max = row_max; row_max = max(row_max, output[row * N + i]); normalizer_term = normalizer_term * exp(old_row_max - row_max) + exp(output[row * N + i] - row_max); }
for (int i = 0; i < N; ++i) { output[row * N + i] = exp(output[row * N + i] - row_max) / normalizer_term; } } }
torch::Tensor matrix_softmax_vector_softmax(torch::Tensor input1, torch::Tensor input2){ int M = input1.size(0); int K = input1.size(1); int N = input2.size(1);
auto options = torch::TensorOptions().device(input1.device()); const dim3 threads(BLOCK_SIZE_VECTOR); const dim3 blocks((M * N + threads.x - 1) / threads.x);
// Reshape input tensors to vectors auto input1_vector = input1.reshape({-1}); auto input2_vector = input2.reshape({-1}); torch::Tensor output_vector = torch::zeros({M * N}, options);
// 普通一维的矩阵乘法 AT_DISPATCH_FLOATING_TYPES(input1_vector.scalar_type(), "matrix_multiply_vector_softmax_kernel", ([&] { matrix_multiply_vector_softmax_kernel<<<blocks, threads>>>( input1_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), input2_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), output_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), M, N, K ); }));
cudaDeviceSynchronize();
AT_DISPATCH_FLOATING_TYPES(output_vector.scalar_type(), "softmax_kernel", ([&] { softmax_kernel<<<blocks, threads>>>( output_vector.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(), M, N ); })); return output_vector.reshape({M, N});