aehrc / cxrmate

CXRMate: Longitudinal Data and a Semantic Similarity Reward for Chest X-Ray Report Generation
https://huggingface.co/aehrc/cxrmate
Apache License 2.0
14 stars 3 forks source link

error when running cxrmate.ipynb #3

Open AceMcAwesome77 opened 10 months ago

AceMcAwesome77 commented 10 months ago

Hi, I am trying to run through the example code in cxrmate.ipynb. When I get to this line:

outputs = encoder_decoder.generate(
    pixel_values=images.to(device),
    decoder_input_ids=prompt['input_ids'],
    special_token_ids=[
        tokenizer.additional_special_tokens_ids[
            tokenizer.additional_special_tokens.index('[PMT-SEP]')
        ],
        tokenizer.bos_token_id,
        tokenizer.sep_token_id,
    ],  
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    mask_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    use_cache=True,
    max_length=256 + prompt['input_ids'].shape[1],
    num_beams=4,
)

I get the following error:

Traceback (most recent call last):

  Cell In[11], line 1
    outputs = encoder_decoder.generate(

  File ~\AppData\Local\anaconda3\lib\site-packages\torch\utils\_contextlib.py:115 in decorate_context
    return func(*args, **kwargs)

  File ~\AppData\Local\anaconda3\lib\site-packages\transformers\generation\utils.py:1593 in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(

  File ~\AppData\Local\anaconda3\lib\site-packages\transformers\generation\utils.py:742 in _prepare_encoder_decoder_kwargs_for_generation
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)

  File ~\AppData\Local\anaconda3\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl
    return forward_call(*args, **kwargs)

  File ~\.cache\huggingface\modules\transformers_modules\aehrc\cxrmate\1f014633b98564f21316b32e167b5796381690d8\modelling_longitudinal.py:91 in forward
    return ModelOutputWithProjectionEmbedding(

  File ~\AppData\Local\anaconda3\lib\site-packages\transformers\utils\generic.py:325 in __init__
    raise TypeError(

TypeError: transformers_modules.aehrc.cxrmate.1f014633b98564f21316b32e167b5796381690d8.modelling_longitudinal.ModelOutputWithProjectionEmbedding is not a dataclasss. This is a subclass of ModelOutput and so must use the @dataclass decorator.
anicolson commented 10 months ago

Hi AceMcAwesome77, this @dataclass decorator issue has been fixed.