Closed babu111 closed 3 months ago
also, using flash_attn==2.4.2 will result in the following error ImportError: /share/miniconda3/envs/openrlhf/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE
we have removed the hack of RoPE due to hf has fixed the RoPE accuracy bugs for flash attention, please see: https://github.com/OpenLLMAI/OpenRLHF/issues/251
thank you for your reply!
this is the implementation of the transformer rope, notice that it has a position_ids input
your hack doesn't have this input, so position_ids is fed into seq_len