你好,我在使用c500时测试了几个fla版本(都是沐曦的版本)都遇到了如图所示的相关错误
mx-smi回显也如图,错误的具体traceback如下:
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/openvla/vla-scripts/distill_train_stage2.py", line 1453, in <module>
[rank0]: distill_train()
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/draccus/argparsing.py", line 228, in wrapper_inner
[rank0]: response = fn(cfg, args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/openvla/vla-scripts/distill_train_stage2.py", line 1235, in distill_train
[rank0]: student_output = student_vla(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]: ret_val = func(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2030, in forward
[rank0]: loss = self.module(inputs, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]: return inner()
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]: result = forward_call(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/openvla/prismatic/extern/hf/modeling_prismatic.py", line 404, in forward
[rank0]: language_model_output = self.language_model(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
[rank0]: outputs = self.model(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 988, in forward
[rank0]: layer_outputs = self._gradient_checkpointing_func(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank0]: return disable_fn(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]: return fn(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 496, in checkpoint
[rank0]: ret = function(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/openvla/prismatic/models/backbones/llm/llama_gated_delta.py", line 220, in forward
[rank0]: attn_output, present_key_value = self.self_attn(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/openvla/prismatic/models/backbones/llm/gated_delta/gated_delta_net.py", line 289, in forward
[rank0]: o, new_recurrent_state = chunk_gated_delta_rule(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]: return fn(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/fla/ops/gated_delta_rule/chunk.py", line 313, in chunk_gated_delta_rule
[rank0]: o, final_state = ChunkGatedDeltaRuleFunction.apply(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]: return super().apply(args, *kwargs) # type: ignore[misc]
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/fla/utils.py", line 164, in wrapper
[rank0]: return fn(contiguous_args, *contiguous_kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 503, in decorate_fwd
[rank0]: return fwd(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/fla/ops/gated_delta_rule/chunk.py", line 174, in forward
[rank0]: g, o, A, final_state = chunk_gated_delta_rule_fwd(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/fla/ops/gated_delta_rule/chunk.py", line 31, in chunk_gated_delta_rule_fwd
[rank0]: A = chunk_scaled_dot_kkt_fwd(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/fla/ops/common/chunk_scaled_dot_kkt.py", line 114, in chunk_scaled_dot_kkt_fwd
[rank0]: chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in <lambda>
[rank0]: return lambda args, kwargs: self.run(grid=grid, warmup=False, args, **kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 396, in run
[rank0]: return self.fn.run(args, *kwargs)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 212, in run
[rank0]: timings = {config: self._bench(args, config=config, *kwargs) for config in pruned_configs}
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 212, in <dictcomp>
[rank0]: timings = {config: self._bench(args, config=config, *kwargs) for config in pruned_configs}
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 137, in _bench
[rank0]: return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/testing.py", line 152, in do_bench
[rank0]: fn()
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 118, in kernel_call
[rank0]: self.fn.run(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/runtime/jit.py", line 662, in run
[rank0]: kernel = self.compile(
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/compiler/compiler.py", line 283, in compile
[rank0]: next_module = compile_ir(module, metadata)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/backends/metax/compiler.py", line 419, in <lambda>
[rank0]: stages["mlir"] = lambda src, metadata: self.make_mlir(src, metadata, options, self.capability)
[rank0]: File "/mnt/afs/lixiaoou/intern/wanyang/envs/openvla-triton-test-fla040/lib/python3.10/site-packages/triton/backends/metax/compiler.py", line 328, in make_mlir
[rank0]: pm.run(mod)
[rank0]: RuntimeError: PassManager::run failed