OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.72k stars 161 forks source link

this repo's hack of rope embedding accepts different input than transformers #252

Closed babu111 closed 3 months ago

babu111 commented 3 months ago
image

this is the implementation of the transformer rope, notice that it has a position_ids input

your hack doesn't have this input, so position_ids is fed into seq_len

image
babu111 commented 3 months ago

also, using flash_attn==2.4.2 will result in the following error ImportError: /share/miniconda3/envs/openrlhf/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE

hijkzzz commented 3 months ago

we have removed the hack of RoPE due to hf has fixed the RoPE accuracy bugs for flash attention, please see: https://github.com/OpenLLMAI/OpenRLHF/issues/251

babu111 commented 3 months ago

thank you for your reply!