huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.01k stars 1.27k forks source link

Failed wrapping Bert2Bert to AutoModelForSeq2SeqLMWithValueHead #1250

Closed henry-tujia closed 8 months ago

henry-tujia commented 9 months ago

I observed that this function checks for the existence of the lm_head using keywords. However, when dealing with an Encoder-Decoder model, it fails to detect the linear layer (BertOnlyMLMHead), which serves as the final layer.

Error Message

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File "/root/.vscode-server/extensions/ms-python.python-2023.22.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/root/.vscode-server/extensions/ms-python.python-2023.22.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/root/.vscode-server/extensions/ms-python.python-2023.22.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/root/.vscode-server/extensions/ms-python.python-2023.22.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.vscode-server/extensions/ms-python.python-2023.22.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/root/.vscode-server/extensions/ms-python.python-2023.22.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/nfs/haotan/RL/seq2seq_ppo.py", line 138, in <module>
    model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/trl/models/modeling_base.py", line 274, in from_pretrained
    model = cls(pretrained_model, **multi_adapter_args, **trl_model_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/trl/models/modeling_value_head.py", line 293, in __init__
    raise ValueError("The model does not have a language model head, please use a model that has one.")
ValueError: The model does not have a language model head, please use a model that has one.

Bert2Bert Model Code


vocabsize = 64000
max_length = 1024
batch_size = 32

encoder_config = BertConfig(
    vocab_size=vocabsize,
    max_position_embeddings=max_length + 64,  # this should be some large value
    num_attention_heads=8,
    num_hidden_layers=6,
    hidden_size=512,
    type_vocab_size=1,
)

encoder = BertModel(config=encoder_config)

decoder_config = BertConfig(
    vocab_size=vocabsize,
    max_position_embeddings=max_length + 64,  # this should be some large value
    num_attention_heads=8,
    num_hidden_layers=6,
    hidden_size=512,
    type_vocab_size=1,
    is_decoder=True,
    add_cross_attention=True,
)  # Very Important

decoder = BertLMHeadModel(config=decoder_config)

model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.sep_token_id = tokenizer.sep_token_id
model.config.max_length = max_length

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model)

Bert2Bert Model Structure

EncoderDecoderModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(64000, 512, padding_idx=0)
      (position_embeddings): Embedding(1088, 512)
      (token_type_embeddings): Embedding(1, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=512, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=512, bias=True)
            (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=512, out_features=512, bias=True)
      (activation): Tanh()
    )
  )
  (decoder): BertLMHeadModel(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(64000, 512, padding_idx=0)
        (position_embeddings): Embedding(1088, 512)
        (token_type_embeddings): Embedding(1, 512)
        (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-5): 6 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=512, out_features=512, bias=True)
                (key): Linear(in_features=512, out_features=512, bias=True)
                (value): Linear(in_features=512, out_features=512, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=512, out_features=512, bias=True)
                (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (crossattention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=512, out_features=512, bias=True)
                (key): Linear(in_features=512, out_features=512, bias=True)
                (value): Linear(in_features=512, out_features=512, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=512, out_features=512, bias=True)
                (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): BertIntermediate(
              (dense): Linear(in_features=512, out_features=3072, bias=True)
              (intermediate_act_fn): GELUActivation()
            )
            (output): BertOutput(
              (dense): Linear(in_features=3072, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
    )
    (cls): BertOnlyMLMHead(
      (predictions): BertLMPredictionHead(
        (transform): BertPredictionHeadTransform(
          (dense): Linear(in_features=512, out_features=512, bias=True)
          (transform_act_fn): GELUActivation()
          (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
        )
        (decoder): Linear(in_features=512, out_features=64000, bias=True)
      )
    )
  )
)
lvwerra commented 9 months ago

Looks like this should work if you add the name to the lm_head_namings attribute of the AutoModelForSeq2SeqLMWithValueHead.

henry-tujia commented 9 months ago

I'll give it a try and share the outcome. Thanks for the response!

github-actions[bot] commented 8 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.