CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.51k stars 471 forks source link

MPT is not working #585

Open ouhenio opened 11 months ago

ouhenio commented 11 months ago

🐛 Describe the bug

When running the following code:

import trlx

trainer = trlx.train(
    "mosaicml/mpt-7b",
    samples=[
        ['Question: 1 + 2 Answer:', '3'],
        ['Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:', '(pi ** 2)/ 6']
    ]
)

A ValueError is raised:

Traceback (most recent call last):
  File "--/rl-llm/train.py", line 14, in <module>
    trainer = trlx.train(
  File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/trlx.py", line 92, in train
    trainer = get_trainer(config.train.trainer)(
  File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/trainer/accelerate_sft_trainer.py", line 32, in __init__
    super().__init__(config, **kwargs)
  File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/trainer/accelerate_base_trainer.py", line 66, in __init__
    self.model = self.setup_model()
  File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/trainer/accelerate_base_trainer.py", line 161, in setup_model
    freeze_bottom_causal_layers(model.base_model, self.config.model.num_layers_unfrozen)
  File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/utils/modeling.py", line 24, in freeze_bottom_causal_layers
    hidden_layers = hf_get_decoder_blocks(model)
  File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/utils/modeling.py", line 148, in hf_get_decoder_blocks
    return findattr(model, hidden_layers_attrs)
  File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/utils/modeling.py", line 96, in findattr
    raise ValueError(f"Could not find an attribute from `{attrs}` in `{obj}`")
ValueError: Could not find an attribute from `('h', 'layers', 'model.layers', 'decoder.layers', 'transformer.h', 'transformer.blocks', 'model.decoder.layers', 'gpt_neox.layers', 'decoder.block')` in `MptModel(
  (wte): Embedding(50432, 4096)
  (blocks): ModuleList(
    (0-31): 32 x MptBlock(
      (norm_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      (attn): MptAttention(
        (Wqkv): Linear(in_features=4096, out_features=12288, bias=False)
        (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (norm_2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      (ffn): MptMLP(
        (up_proj): Linear(in_features=4096, out_features=16384, bias=False)
        (act): GELU(approximate='none')
        (down_proj): Linear(in_features=16384, out_features=4096, bias=False)
      )
      (resid_attn_dropout): Dropout(p=0, inplace=False)
    )
  )
  (norm_f): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)

I'm not sure what is going on, since #546 supposedly fixed it.

I installed trlx with

pip install -U git+https://github.com/CarperAI/trlx.git

and

git clone https://github.com/CarperAI/trlx.git

cd trlx

pip install torch --extra-index-url https://download.pytorch.org/whl/cu118

pip install -e .

it fails with both.

Which trlX version are you using?

0.7.0

Additional system and package information

linux