import torch import torch.nn.functional as F torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) # 检查是否有可用的GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") from flash_attn import flash_attn_func, flash_attn_varlen_func # 设置随机种子,保证结果可复现 torch.manual_seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed(42) # 根据调试信息定义维度参数 bsz = 8 # batch size num_heads = 16 # 注意力头数 tgt_len = 1 # 目标序列长度 squeeze_head_dim = 24 # query/key的单头维度 head_dim = 48 # value的单头维度 # dtype = torch.float16 # 使用 float16 精度就报错: Head dimension of query/key must greater or equal to head dimension in query dtype = torch.float16 # 生成符合维度要求的随机数据并移动到指定设备 # query_states shape: [8, 16, 1, 24] query_states = torch.randn(bsz, num_heads, tgt_len, squeeze_head_dim, device=device, dtype=dtype) # key_states shape: [8, 16, 1, 24] key_states = torch.randn(bsz, num_heads, tgt_len, squeeze_head_dim, device=device, dtype=dtype) # value_states shape: [8, 16, 1, 48] value_states = torch.randn(bsz, num_heads, tgt_len, head_dim, device=device, dtype=dtype) # 其他参数 attention_mask = None # 无注意力掩码 dropout = 0.0 # dropout概率 is_causal = False # 非因果注意力 # 打印输入数据形状信息 print(f"query_states shape: {query_states.shape}, 设备: {query_states.device}") print(f"key_states shape: {key_states.shape}, 设备: {key_states.device}") print(f"value_states shape: {value_states.shape}, 设备: {value_states.device}") try: # 执行缩放点积注意力计算 # attn_output = F.scaled_dot_product_attention( # query_states, # key_states, # value_states, # attn_mask=attention_mask, # dropout_p=dropout, # is_causal=is_causal, # ) attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=None, causal=is_causal ) # 打印输出结果信息 print("\n注意力计算成功!") print(f"输出形状: {attn_output.shape}, 设备: {attn_output.device}") print("输出前两个样本的第一个头的结果:") print(attn_output[:2, 0, :, :].cpu()) # 移回CPU以便打印 except Exception as e: print(f"\n注意力计算出错: {str(e)}") if "CUDA" in str(e) and not torch.cuda.is_available(): print("错误可能是由于没有可用的GPU导致的")