illuin-tech / colpali

The code used to train and run inference with the ColPali architecture.
https://huggingface.co/vidore
MIT License
1.03k stars 93 forks source link

Bfloat16 breaks numpy() in interpretability plots #3

Closed pablohl closed 3 months ago

pablohl commented 3 months ago

https://github.com/illuin-tech/colpali/blob/ad93aeb0d683a1ba621325834817441d9de5a780/colpali_engine/interpretability/gen_interpretability_plots.py#L86

attention_map_normalized is a torch Bfloat16, this breaks in two different places in plot_utils.py

https://github.com/illuin-tech/colpali/blob/ad93aeb0d683a1ba621325834817441d9de5a780/colpali_engine/interpretability/plot_utils.py#L38

https://github.com/illuin-tech/colpali/blob/ad93aeb0d683a1ba621325834817441d9de5a780/colpali_engine/interpretability/plot_utils.py#L98

a solution is to convert to float attention_map_normalized = attention_map_normalized.float()

bilelomrani1 commented 3 months ago

Hi @pablohl good catch, for reference this is related to pytorch#90574. Do you want to contribute your fix?