import torch import torch.nn.functional as F from flash_attn import flash_attn_func # 检查是否有可用的GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 设置随机种子,保证结果可复现 torch.manual_seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed(42) # 根据调试信息定义维度参数 bsz = 8 # batch size num_heads = 16 # 注意力头数 tgt_len = 1 # 目标序列长度 squeeze_head_dim = 24 # query/key的单头维度 head_dim = 48 # value的单头维度 dtype = torch.float16 # 使用 float16 精度就报错: Head dimension of query/key must greater or equal to head dimension in query # dtype = torch.float32 # 生成符合维度要求的随机数据并移动到指定设备 # query_states shape: [8, 16, 1, 24] query_states = torch.randn(bsz, num_heads, tgt_len, squeeze_head_dim, device=device, dtype=dtype) # key_states shape: [8, 16, 1, 24] key_states = torch.randn(bsz, num_heads, tgt_len, squeeze_head_dim, device=device, dtype=dtype) # value_states shape: [8, 16, 1, 48] value_states = torch.randn(bsz, num_heads, tgt_len, head_dim, device=device, dtype=dtype) # 其他参数 attention_mask = None # 无注意力掩码 dropout = 0.0 # dropout概率 is_causal = False # 非因果注意力 # 打印输入数据形状信息 print(f"query_states shape: {query_states.shape}, 设备: {query_states.device}") print(f"key_states shape: {key_states.shape}, 设备: {key_states.device}") print(f"value_states shape: {value_states.shape}, 设备: {value_states.device}") with torch.backends.cuda.sdp_kernel( enable_flash=False, # 关闭 FlashAttention enable_mem_efficient=False, # 关闭 memory-efficient enable_math=True, # 强制使用朴素实现 ): attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=None, causal=is_causal ) # 打印输出结果信息 print("\n注意力计算成功!") print(f"输出形状: {attn_output.shape}, 设备: {attn_output.device}") print("输出前两个样本的第一个头的结果:") print(attn_output[:2, 0, :, :].cpu()) # 移回CPU以便打印