evo-design / evo

Biological foundation modeling from molecular to genome scale
Apache License 2.0
933 stars 113 forks source link

PEFT Prompt Tuning: forward() got an unexpected keyword argument 'inputs_embeds' #62

Open mikeleske opened 4 months ago

mikeleske commented 4 months ago

Has someone successfully applied to Prompt Tuning PEFT to EVO?

With the HF SFT Trainer and the following PEFT config

    peft_config = PromptTuningConfig(
        task_type=TaskType.CAUSAL_LM,
        num_virtual_tokens=128,
        tokenizer_name_or_path='togethercomputer/evo-1-131k-base'
    )
    peft_model = get_peft_model(model, peft_config)
    peft_model.print_trainable_parameters()

I get the following error:

Traceback (most recent call last):
  File "sft.py", line 227, in <module>
    trainer.train()
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/trl/trainer/sft_trainer.py", line 361, in train
    output = super().train(*args, **kwargs)
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/trainer.py", line 3138, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/trainer.py", line 3161, in compute_loss
    outputs = model(**inputs)
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/peft/peft_model.py", line 1177, in forward
    return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'inputs_embeds'

Any pointer on what I am doing wrong would be largely appreciated.