turboLJY / Few-Shot-KG2Text

Source for the ACL 2021 Findings paper "Few-shot Knowledge Graph-to-Text Generation with Pretrained Language Models"
18 stars 10 forks source link

Run model error with test , can not use inputs_embeds #10

Open Shj451148969 opened 3 years ago

Shj451148969 commented 3 years ago

transformers version 4.12.2

nodes = nodes.to(device)
student_embeddings = student(nodes, edges, types)

node_masks = node_masks.to(device)
generated_ids = plm.generate(input_ids=None, inputs_embeds=student_embeddings, attention_mask=node_masks, num_beams=4, max_length=config["max_seq_length"], early_stopping=True)

ERROR IS

Traceback (most recent call last):
  File "/home/amax/shj/tmp/pycharm_project_21/train.py", line 152, in <module>
    run_eval(read_configuration("./config.yaml"), 31)
  File "/home/amax/shj/tmp/pycharm_project_21/train.py", line 140, in run_eval
    generated, reference = run_eval_batch(config, batch, student, plm, device, tokenizer)
  File "/home/amax/shj/tmp/pycharm_project_21/run.py", line 81, in run_eval_batch
    generated_ids = plm.generate(input_ids=None, inputs_embeds=student_embeddings, attention_mask=node_masks, num_beams=4, max_length=config["max_seq_length"], early_stopping=True)
  File "/home/amax/anaconda3/envs/shj_dev/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/amax/anaconda3/envs/shj_dev/lib/python3.9/site-packages/transformers/generation_utils.py", line 913, in generate
    input_ids = self._prepare_decoder_input_ids_for_generation(
  File "/home/amax/anaconda3/envs/shj_dev/lib/python3.9/site-packages/transformers/generation_utils.py", line 424, in _prepare_decoder_input_ids_for_generation
    torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) * decoder_start_token_id
AttributeError: 'NoneType' object has no attribute 'shape'
Shj451148969 commented 3 years ago

I check the code in generation_utils.py. It seems use input_ids not inputs_embeds, if you set input_ids=None, it will throw error.