OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.73k stars 164 forks source link

question about support matrix #193

Closed paulcx closed 4 months ago

paulcx commented 5 months ago

Just for clarification. What does 34B Full Tuning with 4 A100 mean in that table? support for PPO or DPO or both? Have you tested for train 34B llama DPO on 8*A100?

hijkzzz commented 5 months ago

Support Matrix currently refers to RLHF I think 34B llama DPO on 8*A100 with zero3 + reference policy offload is OK.

paulcx commented 5 months ago

Support Matrix currently refers to RLHF

I think 34B llama DPO on 8*A100 with zero3 + reference policy offload is OK.

is there a example script for reference model offload implementing?

hijkzzz commented 5 months ago

Support Matrix currently refers to RLHF I think 34B llama DPO on 8*A100 with zero3 + reference policy offload is OK.

is there a example script for reference model offload implementing?

see here https://github.com/OpenLLMAI/OpenRLHF/blob/76d9b65a2b56266991119cb7c197fe16345a40aa/examples/train_dpo.py#L155

paulcx commented 5 months ago

Support Matrix currently refers to RLHF I think 34B llama DPO on 8*A100 with zero3 + reference policy offload is OK.

is there a example script for reference model offload implementing?

see here

https://github.com/OpenLLMAI/OpenRLHF/blob/76d9b65a2b56266991119cb7c197fe16345a40aa/examples/train_dpo.py#L155

thanks!

paulcx commented 4 months ago

After several preliminary attempts (34B llama DPO on 8*A100 80G with zero3 + reference policy offload ), each ending in oom, was it my parameters below that were misaligned?

../train_dpo.py \
     --save_path ./output \
     --save_steps -1 \
     --logging_steps 1 \
     --eval_steps -1 \
     --train_batch_size 8 \
     --micro_train_batch_size 1 \
     --pretrain xxx \
     --bf16 \
     --max_epochs 1 \
     --max_len 2048 \
     --zero_stage 3 \
     --beta 0.1 \
     --learning_rate 5e-7 \
     --dataset xxx \
     --dataset_probs 0.72,0.08,0.12,0.08 \
     --flash_attn \
     --gradient_checkpointing \
     --adam_offload
     --ref_offload
hijkzzz commented 4 months ago

What is you CPU memory size?

Could you try to modify the line 96 in trian_dpo.py to

    # strategy prepare
    (ref_model, (model, optim, scheduler)) = strategy.prepare(ref_model, (model, optim, scheduler))
paulcx commented 4 months ago

What is you CPU memory size?

Could you try to modify the line 96 in trian_dpo.py to

    # strategy prepare
    (ref_model, (model, optim, scheduler)) = strategy.prepare(ref_model, (model, optim, scheduler))

After modifying train_dpo.py, the result is still the same. I got 1T CPU memory.

hijkzzz commented 4 months ago

It works well on my side with the script

set -x 

read -r -d '' training_commands <<EOF
../train_dpo.py \
     --save_path ./ckpt/13b_llama_dpo \
     --save_steps -1 \
     --logging_steps 1 \
     --eval_steps -1 \
     --train_batch_size 128 \
     --micro_train_batch_size 1 \
     --pretrain codellama/CodeLlama-34b-Instruct-hf \
     --bf16 \
     --max_samples 1024 \
     --max_epochs 1 \
     --max_len 2048 \
     --zero_stage 3 \
     --beta 0.1 \
     --learning_rate 5e-7 \
     --dataset Anthropic/hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward,lmsys/chatbot_arena_conversations,openai/webgpt_comparisons \
     --dataset_probs 0.72,0.08,0.12,0.08 \
     --flash_attn \
     --gradient_checkpointing \
     --adam_offload \
     --ref_offload
EOF
     # --wandb [WANDB_TOKENS]
     # --ipo [for IPO]
     # --label_smoothing 0.1 [for cDPO]

if [[ ${1} != "slurm" ]]; then
    export PATH=$HOME/.local/bin/:$PATH
    deepspeed $training_commands
fi
Train epoch:   0%|                                                                                                                                                              | 0/1 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.                                                                                      | 0/128 [00:00<?, ?it/s]

Train step of epoch 0:   2%|▊                                                   | 2/128 [00:56<58:22, 27.80s/it, preference_loss=0.693, chosen_reward=0, reject_reward=0, acc_mean=0, loss_mean=0.132]
Train step of epoch 0:   2%|█▏                                                  | 3/128 [01:21<55:08, 26.47s/it, preference_loss=0.693, chosen_reward=0, reject_reward=0, acc_mean=0, loss_mean=0.188]

GPU status


|   7  NVIDIA H100 PCIe               On  | 00000000:E1:00.0 Off |                    0 |
| N/A   44C    P0             112W / 350W |  40417MiB / 81559MiB |    100%      Default |
|                                         |                      |             Disabled
``
paulcx commented 4 months ago

After a few attempts, especially uninstalling transformers==4.38.2 and reinstalling 4.37.2, there seems to be some progress. Now it is stuck at a new error of "TypeError: LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids'"

I'm wondering if this is related to new fix of

# https://github.com/OpenLLMAI/OpenRLHF/issues/217
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids)
log_probs = log_probs_from_logits(output["logits"][:, :-1, :], sequences[:, 1:])
Train epoch:   0%|                                                                                                                                                 | 0/1 [00:00<?, ?it/s`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...                                                                        | 0/128 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/OpenRLHF/examples/scripts/../train_dpo.py", line 188, in <module>
    train(args)
  File "/home/OpenRLHF/examples/scripts/../train_dpo.py", line 121, in train
    trainer.fit(args)
  File "/home/OpenRLHF/openrlhf/trainer/dpo_trainer.py", line 117, in fit
    chosen_logps, rejected_logps, aux_loss = self.concatenated_forward(
  File "/home/OpenRLHF/openrlhf/trainer/dpo_trainer.py", line 232, in concatenated_forward
    output = model(input_ids, attention_mask=att_masks, return_output=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/OpenRLHF/openrlhf/models/actor.py", line 181, in forward
    output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 1852, in forward
    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1060, in forward
    layer_outputs = self._gradient_checkpointing_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 488, in checkpoint
    ret = function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 798, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 508, in forward
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
TypeError: LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids'

version: transformers==4.37.2 (oom @ 4.38.1 or 4.38.2) deepspeed==0.13.4 flash-attn==2.4.2

hijkzzz commented 4 months ago

I used 4.38.1, please use the ngc container

paulcx commented 4 months ago

I used 4.38.1, please use the ngc container

I did and image is based nvcr.io/nvidia/pytorch:23.12-py3. Transformer 4.38.1 will trigger the oom at very beginning of the pipline.

[2024-03-03 01:31:51,663] [INFO] [partition_parameters.py:343:__exit__] finished initializing model - num_params = 542, num_elems = 33.93B
Traceback (most recent call last):
  File "/home/OpenRLHF/examples/scripts/../train_dpo.py", line 188, in <module>
    train(args)
  File "/home/OpenRLHF/examples/scripts/../train_dpo.py", line 23, in train
    model = Actor(
  File "/home/OpenRLHF/openrlhf/models/actor.py", line 71, in __init__
    self.model = AutoModelForCausalLM.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 561, in from_pretrained
    return model_class.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py", line 3375, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/partition_parameters.py", line 503, in wrapper
    f(module, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1095, in __init__
    self.model = LlamaModel(config)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/partition_parameters.py", line 503, in wrapper
    f(module, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 922, in __init__
Traceback (most recent call last):
    causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/partition_parameters.py", line 238, in wrapped_fn
  File "/home/OpenRLHF/examples/scripts/../train_dpo.py", line 188, in <module>
    tensor: Tensor = fn(*args, **kwargs)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 298.02 GiB. GPU 0 has a total capacity of 79.15 GiB of which 64.17 GiB is free. Process 2365995 has 14.98 GiB memory in use. Of the allocated memory 13.70 GiB is allocated by PyTorch, and 637.91 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
    train(args)
  File "/home/OpenRLHF/examples/scripts/../train_dpo.py", line 23, in train
    model = Actor(
  File "/home/OpenRLHF/openrlhf/models/actor.py", line 71, in __init__
    self.model = AutoModelForCausalLM.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 561, in from_pretrained
    return model_class.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py", line 3375, in from_pretrained
hijkzzz commented 4 months ago

I have uploaded dockerfiles, please follow the new readme.md to install. there is no OOM issue on my side.

Or you could disable our rope hack (just search the function replace_rope... and annotate it) with v4.37

paulcx commented 4 months ago

Finally got it to work after disabling replace_rope. I'm wondering if disabling has any other side effect?

update: I tried new dockerfile and it does not work (OOM) as same as before. The transformers version within the container is 4.38.2 anyway.

hijkzzz commented 4 months ago

Finally got it to work after disabling replace_rope. I'm wondering if disabling has any other side effect?

update: I tried new dockerfile and it does not work (OOM) as same as before. The transformers version within the container is 4.38.2 anyway.

just see the issue: https://github.com/OpenLLMAI/OpenRLHF/issues/191 Very strange why there is OOM on your machine, I can not reproduce OOM with v4.38.2

paulcx commented 4 months ago

Finally got it to work after disabling replace_rope. I'm wondering if disabling has any other side effect?

update: I tried new dockerfile and it does not work (OOM) as same as before. The transformers version within the container is 4.38.2 anyway.

just see the issue: https://github.com/OpenLLMAI/OpenRLHF/issues/191

Very strange why there is OOM on your machine, I can not reproduce OOM with v4.38.2

it's weird. I'm going to fix the rope replace issue next week. Any idea for 'LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids? I guess we need to overwrite atten function as well?

hijkzzz commented 4 months ago

Finally got it to work after disabling replace_rope. I'm wondering if disabling has any other side effect?

update: I tried new dockerfile and it does not work (OOM) as same as before. The transformers version within the container is 4.38.2 anyway.

just see the issue: #191 Very strange why there is OOM on your machine, I can not reproduce OOM with v4.38.2

it's weird. I'm going to fix the rope replace issue next week. Any idea for 'LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids? I guess we need to overwrite atten function as well?

I have removed this patch as it was fixed in transformers v4.38.2 could you try deepspeed=0.13.2 with transformers=v4.38.2

see https://github.com/OpenLLMAI/OpenRLHF/commit/177f04203013b51e26b62236c8f4017b88d1dfde

paulcx commented 4 months ago

Finally got it to work after disabling replace_rope. I'm wondering if disabling has any other side effect?

update: I tried new dockerfile and it does not work (OOM) as same as before. The transformers version within the container is 4.38.2 anyway.

just see the issue: #191 Very strange why there is OOM on your machine, I can not reproduce OOM with v4.38.2

it's weird. I'm going to fix the rope replace issue next week. Any idea for 'LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids? I guess we need to overwrite atten function as well?

I have removed this patch as it was fixed in transformers v4.38.2 could you try deepspeed=0.13.2 with transformers=v4.38.2

see 177f042

Not working based on deepspeed==0.13.2, transformers==4.38.2 (only 4.37.2 woking in my env)

image

@hijkzzz What is your pytorch version?

hijkzzz commented 4 months ago

Finally got it to work after disabling replace_rope. I'm wondering if disabling has any other side effect?

update: I tried new dockerfile and it does not work (OOM) as same as before. The transformers version within the container is 4.38.2 anyway.

just see the issue: #191 Very strange why there is OOM on your machine, I can not reproduce OOM with v4.38.2

it's weird. I'm going to fix the rope replace issue next week. Any idea for 'LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids? I guess we need to overwrite atten function as well?

I have removed this patch as it was fixed in transformers v4.38.2 could you try deepspeed=0.13.2 with transformers=v4.38.2 see 177f042

Not working based on deepspeed==0.13.2, transformers==4.38.2 (only 4.37.2 woking in my env)

image

@hijkzzz What is your pytorch version?

I just use the dockerfiles from OpenRLHF

paulcx commented 4 months ago

According to https://github.com/OpenLLMAI/OpenRLHF/blob/d5915d8f0c5830e0d7baf9900ff3ea5914b42dbe/dockerfile/Dockerfile#L18

some pip installed lib version: vllm==0.3.2 torch==2.1.2+cu121 transformers==4.38.2 deepspeed (not installed)

hijkzzz commented 4 months ago

According to

https://github.com/OpenLLMAI/OpenRLHF/blob/d5915d8f0c5830e0d7baf9900ff3ea5914b42dbe/dockerfile/Dockerfile#L18

some pip installed lib version: vllm==0.3.2 torch==2.1.2+cu121 transformers==4.38.2 deepspeed (not installed)

also ./build_openrlhf.sh

paulcx commented 4 months ago

After a few attempts, especially uninstalling transformers==4.38.2 and reinstalling 4.37.2, there seems to be some progress. Now it is stuck at a new error of "TypeError: LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids'"

I'm wondering if this is related to new fix of

# https://github.com/OpenLLMAI/OpenRLHF/issues/217
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids)
log_probs = log_probs_from_logits(output["logits"][:, :-1, :], sequences[:, 1:])
Train epoch:   0%|                                                                                                                                                 | 0/1 [00:00<?, ?it/s`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...                                                                        | 0/128 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/OpenRLHF/examples/scripts/../train_dpo.py", line 188, in <module>
    train(args)
  File "/home/OpenRLHF/examples/scripts/../train_dpo.py", line 121, in train
    trainer.fit(args)
  File "/home/OpenRLHF/openrlhf/trainer/dpo_trainer.py", line 117, in fit
    chosen_logps, rejected_logps, aux_loss = self.concatenated_forward(
  File "/home/OpenRLHF/openrlhf/trainer/dpo_trainer.py", line 232, in concatenated_forward
    output = model(input_ids, attention_mask=att_masks, return_output=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/OpenRLHF/openrlhf/models/actor.py", line 181, in forward
    output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 1852, in forward
    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1060, in forward
    layer_outputs = self._gradient_checkpointing_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 488, in checkpoint
    ret = function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 798, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 508, in forward
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1560, in _call_impl
    result = forward_call(*args, **kwargs)
TypeError: LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids'

version: transformers==4.37.2 (oom @ 4.38.1 or 4.38.2) deepspeed==0.13.4 flash-attn==2.4.2

Finally, I fixed that issue with hacking the transformers code. Now, the replace_rope_embedding patch from #191 works with transformers==4.37.2

The reason for doing this is that after multiple experiments, only version 4.37.2 in my case does not encounter OOM situations.