/*
* 测试用例:演示沐曦GPU __ballot_sync bug及workaround验证
*
* 问题:在warp 1中,线程59-63(lane_id 27-31)执行__ballot_sync时,
* 即使predicate为真,返回0x00000000而不是正确的bitmask
*
*/
#include <mc_runtime.h>
#include <stdio.h>
#define USE_METAX // 启用workaround
// 原始的__ballot_sync实现(NVIDIA GPU正常,沐曦GPU有bug)
__device__ uint32_t ballot_original(bool pred) {
uint32_t FULL_WARP_ACTIVE_BMSK = 0xFFFFFFFF;
return __ballot_sync(FULL_WARP_ACTIVE_BMSK, pred);
}
// Workaround:使用共享内存实现ballot(沐曦GPU专用)
__device__ uint32_t ballot_workaround(bool pred) {
uint32_t lane_id = threadIdx.x % 32;
uint32_t thrdIdxInBlk = threadIdx.x + threadIdx.y * blockDim.x;
uint32_t warp_id = thrdIdxInBlk / 32;
__shared__ uint32_t s_warp_pred[512];
uint32_t idx = warp_id * 32 + lane_id;
s_warp_pred[idx] = pred ? 1 : 0;
__syncthreads();
uint32_t ballotResult = 0;
for (int i = 0; i < 32; i++) {
if (s_warp_pred[warp_id * 32 + i]) {
ballotResult |= (1U << i);
}
}
__syncthreads();
return ballotResult;
}
// 测试内核:重现生产环境中的bug场景
__global__ void test_ballot_bug(
uint32_t* d_original_results,
uint32_t* d_workaround_results,
uint32_t nInfoBits
) {
uint32_t thrdIdxInBlk = threadIdx.x + threadIdx.y * blockDim.x;
uint32_t lane_id = threadIdx.x % 32;
uint32_t warp_id = thrdIdxInBlk / 32;
// 生产环境中的精确predicate逻辑
uint32_t N_MAX_INFO_BITS = 128;
int16_t interleaverTblStartIdx = N_MAX_INFO_BITS - nInfoBits;
bool pred = ((thrdIdxInBlk < N_MAX_INFO_BITS) && (thrdIdxInBlk >= interleaverTblStartIdx)) ? true : false;
// 测试原始实现
uint32_t original_ballot = ballot_original(pred);
// 测试workaround
uint32_t workaround_ballot = ballot_workaround(pred);
// 保存结果(每个warp只保存第一个线程的结果)
if (lane_id == 0 && warp_id < 2) { // 只保存warp 0和warp 1
d_original_results[warp_id] = original_ballot;
d_workaround_results[warp_id] = workaround_ballot;
}
// 打印调试信息
if (pred) {
printf("threadIdx.x=%d, threadIdx.y=%d, thrdIdxInBlk=%d, lane_id=%d, warp_id=%d, pred=%d\n",
threadIdx.x, threadIdx.y, thrdIdxInBlk, lane_id, warp_id, pred);
printf(" original_ballot=0x%08x, workaround_ballot=0x%08x\n",
original_ballot, workaround_ballot);
}
}
// 验证结果的辅助函数
void verify_results(uint32_t* original, uint32_t* workaround, int nWarps, uint32_t expected_mask) {
printf("\n========== 验证结果 ==========\n");
printf("预期bitmask (所有pred为true的线程): 0x%08x\n", expected_mask);
printf("\n原始 __ballot_sync 结果:\n");
for (int i = 0; i < nWarps; i++) {
printf(" Warp %d: 0x%08x %s\n", i, original[i],
(original[i] == expected_mask) ? "✓ 正确" : "✗ 错误");
}
printf("\nWorkaround 结果:\n");
for (int i = 0; i < nWarps; i++) {
printf(" Warp %d: 0x%08x %s\n", i, workaround[i],
(workaround[i] == expected_mask) ? "✓ 正确" : "✗ 错误");
}
}
int main() {
// 测试参数:重现生产环境中的bug场景
uint32_t nInfoBits = 69; // 导致线程59-63的predicate为true
uint32_t N_MAX_INFO_BITS = 128;
printf("========== 测试配置 ==========\n");
printf("nInfoBits = %d, N_MAX_INFO_BITS = %d\n", nInfoBits, N_MAX_INFO_BITS);
printf("Block dim: 32x16x1 (512 threads)\n");
printf("Grid dim: 1x1x1\n\n");
// 计算预期的bitmask
// 线程59-63 (lane 27-31 in warp 1) 应该有 pred=true
// 预期bitmask = 0xF8000000 (lanes 27-31 set)
uint32_t expected_mask = 0xF8000000;
// 分配设备内存
uint32_t *d_original, *d_workaround;
uint32_t *h_original, *h_workaround;
mcMalloc(&d_original, 2 * sizeof(uint32_t));
mcMalloc(&d_workaround, 2 * sizeof(uint32_t));
h_original = (uint32_t*)malloc(2 * sizeof(uint32_t));
h_workaround = (uint32_t*)malloc(2 * sizeof(uint32_t));
// 初始化
mcMemset(d_original, 0, 2 * sizeof(uint32_t));
mcMemset(d_workaround, 0, 2 * sizeof(uint32_t));
// 启动kernel
printf("========== 启动Kernel ==========\n");
dim3 blockDim(32, 16, 1); // 32x16x1 = 512 threads
dim3 gridDim(1, 1, 1);
test_ballot_bug<<<gridDim, blockDim>>>(d_original, d_workaround, nInfoBits);
// 同步并检查错误
mcError_t err = mcDeviceSynchronize();
if (err != mcSuccess) {
printf("CUDA Error: %s\n", mcGetErrorString(err));
return -1;
}
// 复制结果回主机
mcMemcpy(h_original, d_original, 2 * sizeof(uint32_t), mcMemcpyDeviceToHost);
mcMemcpy(h_workaround, d_workaround, 2 * sizeof(uint32_t), mcMemcpyDeviceToHost);
// 验证结果
verify_results(h_original, h_workaround, 2, expected_mask);
// 清理
mcFree(d_original);
mcFree(d_workaround);
free(h_original);
free(h_workaround);
printf("\n========== 测试完成 ==========\n");
return 0;
}
Warp 0, 2, 3的__ballot_sync工作正常
只有Warp 1的特定lane范围(27-31)受影响
Warp 1的lanes 27-31 (线程59-63)在predicate=true时,__ballot_sync返回0x00000000