客户在N卡上也做过测试,是没有出现这个报错的。
下面是去除try except 的回显。附件中为运行的代码和报错截图。测试容器是sglang
python test4.py
使用设备: cuda
query_states shape: torch.Size([8, 16, 1, 24]), 设备: cuda:0
key_states shape: torch.Size([8, 16, 1, 24]), 设备: cuda:0
value_states shape: torch.Size([8, 16, 1, 48]), 设备: cuda:0
/opt/conda/lib/python3.10/contextlib.py:103: FutureWarning: torch.backends.cuda.sdp_kernel()
is deprecated. In the future, this context manager will be removed. Please see torch.nn.attention.sdpa_kernel()
for the new context manager, with updated signature.
self.gen = func(args, kwds)
Traceback (most recent call last):
File "/data/lhz/BD/test1.py", line 51, in <module>
attn_output = flash_attn_func(
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 1054, in flash_attn_func
return FlashAttnFunc.apply(
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(args, **kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 704, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state, attn_mask = _flash_attn_forward(
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 110, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state, attn_mask = flash_attn_cuda.fwd(
RuntimeError: Head dimension of query/key must greater or equal to head dimension in query