Flash Attention 是 Transformer 中使用的一种有效机制,用于增强注意力计算,从而使操作更快、占用内存更少。这篇文章将详细介绍在 CUDA 中实现的 Flash Attention 的前向传递。
算法概述
在前向内核中,我们使用矩阵 Q(查询)、K(键)和 V(值)分成块以进行高效计算。基本步骤是:
- 将 Q、K 和 V 分成块。
- 将查询、键和值矩阵加载到共享内存 (SRAM) 中。
- 使用 softmax 计算注意力分数。
- 计算输出。
内存分配和初始化
在执行计算之前,我们为矩阵分配并初始化设备内存:
template <typename T> T* allocateAndInitializeDeviceMemory(size_t size, bool initializeToZero = false, bool initializeToNegativeInfinity = false) { T* device_ptr; cudaMalloc(&device_ptr, size); // Allocate memory on the device if (initializeToZero) { cudaMemset(device_ptr, 0, size); // Initialize to zero } else if (initializeToNegativeInfinity) { float negative_infinity_host = -INFINITY; cudaMemset(device_ptr, *reinterpret_cast<int*>(&negative_infinity_host), size); // Initialize to negative infinity } else { // Generate random numbers if no initialization is specified } return device_ptr; }
CUDA 内核: forward_kernel
主 CUDA 内核forward_kernel
处理计算。
线程和块索引
块中的每个线程计算特定查询和键矩阵的输出的一部分。
用于 Q、K、V 的 SRAM
我们为 Q、K 和 V 的图块声明共享内存:
extern __shared__ float shared_memory[]; float* query_matrix_tile = shared_memory; float* key_matrix_tile = &shared_memory[tile_size]; float* value_matrix_tile = &shared_memory[tile_size * 2];
将矩阵加载到 SRAM 中
我们将 K 和 V 的图块从全局内存加载到共享内存中:
for (int embedding_index = 0; embedding_index < embedding_dimension; embedding_index++) { key_matrix_tile[(thread_index_x * embedding_dimension) + embedding_index] = key_matrix_device_pointer[qkv_offset + ...]; value_matrix_tile[(thread_index_x * embedding_dimension) + embedding_index] = value_matrix_device_pointer[qkv_offset + ...]; } __syncthreads(); // Ensure all threads have completed loading
计算注意力分数
对于每个块,根据查询和键计算分数:
for (int column_index_inner = 0; column_index_inner < block_size_columns; column_index_inner++) { sum += query_matrix_tile[(thread_index_x * embedding_dimension) + embedding_index] * key_matrix_tile[(column_index_inner * embedding_dimension) + embedding_index]; } score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner] = sum * softmax_scale;
Softmax计算
缩放分数并应用 softmax 函数:
float row_sum = 0; for (int column_index_inner = 0; column_index_inner < block_size_columns; column_index_inner++) { score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner] = __expf(score - row_max); row_sum += score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner]; }
计算输出
通过将分数与值矩阵相结合来计算最终输出矩阵:
for (int embedding_index = 0; embedding_index < embedding_dimension; embedding_index++) { float probability_times_value = 0; for (int column_index_inner = 0; column_index_inner < block_size_columns; column_index_inner++) { probability_times_value += score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner] * value_matrix_tile[(column_index_inner * embedding_dimension) + embedding_index]; } output_matrix_device_pointer[qkv_offset + ...] = ... }
同步和最终步骤
所有计算完成后,我们执行同步以确保所有线程都完成其任务:
__syncthreads(); // Ensure all computations are complete before proceeding
实现 CUDA 内核的main
功能
main
函数是 CUDA 程序的核心,我们在这里设置问题、分配内存并启动内核。下面,我们将逐步介绍此功能所涉及的步骤。
1. 问题设置
我们首先定义任务的关键参数:
-
batch_size
、num_heads
、sequence_length
和embedding_dimension
来定义输入维度。 -
block_size_columns
和block_size_rows
用于控制线程在块内的分组方式。 -
softmax_scale
在 softmax 计算期间缩放分数。
int batch_size = 2; int num_heads = 8; int sequence_length = 64; int embedding_dimension = 32; int block_size_columns = 8; int block_size_rows = 8; float softmax_scale = 1.0f / std::sqrt(embedding_dimension);
2. 内存分配和初始化
我们为query
、 key
、 value
矩阵以及 softmax 计算期间使用的sum
和max
矩阵分配和初始化设备内存。
size_t matrix_size = batch_size * num_heads * sequence_length * embedding_dimension * sizeof(float); float* query_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size); float* key_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size); float* value_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size); float* sum_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>( batch_size * num_heads * sequence_length * sizeof(float), true); float* max_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>( batch_size * num_heads * sequence_length * sizeof(float), false, true); float* output_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size, true);
3. 网格和块配置
CUDA 需要定义网格和块配置。在这里,我们使用 2D 网格和 1D 块。
dim3 block(block_size_rows); dim3 grid(sequence_length / block_size_rows, num_heads, batch_size); size_t shared_memory_size = 4 * block_size_columns * embedding_dimension * sizeof(float);
4. 启动内核
启动forward_kernel
时指定了网格和块尺寸。我们传递所需的指针和配置参数。
forward_kernel<<<grid, block, shared_memory_size>>>( query_matrix_device_pointer, key_matrix_device_pointer, value_matrix_device_pointer, sequence_length, embedding_dimension, sequence_length / block_size_columns, sequence_length / block_size_rows, block_size_columns, block_size_rows, softmax_scale, sum_matrix_device_pointer, max_matrix_device_pointer, output_matrix_device_pointer); cudaDeviceSynchronize();
5. 验证和写入结果
将矩阵写入文件: writeMatrixToFile
函数
writeMatrixToFile
函数是验证和分析 CUDA 内核输出的重要实用程序。它可以将设备矩阵保存到文件(例如 CSV)中,以进行调试或进一步评估。本节解释其实现及其在工作流程中的作用。
目的
- 调试:提供一种检查中间或最终计算结果的方法。
- 分析:启用数据可视化或与我实际实现的外部工具(例如Python)中的预期输出进行比较,请检查存储库。
执行
writeMatrixToFile
函数将矩阵数据从设备内存复制到主机内存并将其写入文件。
#include <fstream> #include <iostream> void writeMatrixToFile(float* device_pointer, const std::string& filename, int batch_size, int num_heads, int sequence_length, int embedding_dimension) { size_t size = batch_size * num_heads * sequence_length * embedding_dimension; float* host_pointer = new float[size]; // Copy data from device to host cudaMemcpy(host_pointer, device_pointer, size * sizeof(float), cudaMemcpyDeviceToHost); // Open file for writing std::ofstream file(filename); if (!file.is_open()) { std::cerr << "Failed to open file: " << filename << std::endl; delete[] host_pointer; return; } // Write data to the file for (int batch = 0; batch < batch_size; ++batch) { for (int head = 0; head < num_heads; ++head) { for (int seq = 0; seq < sequence_length; ++seq) { for (int embed = 0; embed < embedding_dimension; ++embed) { size_t index = ((batch * num_heads + head) * sequence_length + seq) * embedding_dimension + embed; file << host_pointer[index]; if (embed < embedding_dimension - 1) { file << ","; } } file << "\n"; } } } // Clean up file.close(); delete[] host_pointer; std::cout << "Matrix written to file: " << filename << std::endl; }
关键步骤
- 复制数据:使用
cudaMemcpy
将矩阵数据从设备传输到主机内存。 - 文件处理:以写入模式打开文件并验证打开是否成功。
- 数据写入:迭代矩阵维度(
batch_size
、num_heads
、sequence_length
、embedding_dimension
)并以 CSV 格式写入值。 - Cleanup :关闭文件并释放动态分配的主机内存。
主函数中的用法
writeMatrixToFile
函数在内核执行后被调用,以保存计算出的输出矩阵:
writeMatrixToFile(output_matrix_device_pointer, "output_matrix.csv", batch_size, num_heads, sequence_length, embedding_dimension);
优点
- 多功能性:支持任意维度的矩阵。
- 易于使用:自动导出 CUDA 输出以进行离线分析的过程。
- 可读性:以适合调试和可视化的结构化格式组织数据。
6. 清理
最后,我们释放分配的设备内存以避免内存泄漏。
cudaFree(query_matrix_device_pointer); cudaFree(key_matrix_device_pointer); cudaFree(value_matrix_device_pointer); cudaFree(sum_matrix_device_pointer); cudaFree(max_matrix_device_pointer); cudaFree(output_matrix_device_pointer);
完整的主要功能
这是完整的main
功能:
int main() { int batch_size = 2, num_heads = 8, sequence_length = 64, embedding_dimension = 32; int block_size_columns = 8, block_size_rows = 8; float softmax_scale = 1.0f / std::sqrt(embedding_dimension); size_t matrix_size = batch_size * num_heads * sequence_length * embedding_dimension * sizeof(float); float* query_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size); float* key_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size); float* value_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size); float* sum_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>( batch_size * num_heads * sequence_length * sizeof(float), true); float* max_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>( batch_size * num_heads * sequence_length * sizeof(float), false, true); float* output_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size, true); dim3 block(block_size_rows); dim3 grid(sequence_length / block_size_rows, num_heads, batch_size); size_t shared_memory_size = 4 * block_size_columns * embedding_dimension * sizeof(float); forward_kernel<<<grid, block, shared_memory_size>>>( query_matrix_device_pointer, key_matrix_device_pointer, value_matrix_device_pointer, sequence_length, embedding_dimension, sequence_length / block_size_columns, sequence_length / block_size_rows, block_size_columns, block_size_rows, softmax_scale, sum_matrix_device_pointer, max_matrix_device_pointer, output_matrix_device_pointer); cudaDeviceSynchronize(); writeMatrixToFile(output_matrix_device_pointer, "output_matrix.csv", batch_size, num_heads, sequence_length, embedding_dimension); cudaFree(query_matrix_device_pointer); cudaFree(key_matrix_device_pointer); cudaFree(value_matrix_device_pointer); cudaFree(sum_matrix_device_pointer); cudaFree(max_matrix_device_pointer); cudaFree(output_matrix_device_pointer); return 0; }
要点
- 内存管理:正确分配和释放设备内存,以确保效率并避免内存泄漏。
- CUDA 配置:了解网格和块配置对于最大化性能至关重要。
- 调试:在开发过程中使用
cudaDeviceSynchronize()
尽早捕获内核错误。 - 后处理:将输出保存到文件中以供验证和进一步分析。
此设置可确保 CUDA 内核在现实场景中结构化且高效地实现。
结论
Flash Attention 的前向传递有效地并行计算注意力分数,利用共享内存进行性能优化。通过通过块构建计算并有效利用 GPU 资源,Flash Attention 极大地提高了 Transformer 的速度。
链接:
原文: https://hamdi.bearblog.dev/understanding-flash-attention-forward-with-cuda/