Closed pablohl closed 3 months ago
https://github.com/illuin-tech/colpali/blob/ad93aeb0d683a1ba621325834817441d9de5a780/colpali_engine/interpretability/gen_interpretability_plots.py#L86
attention_map_normalized is a torch Bfloat16, this breaks in two different places in plot_utils.py
attention_map_normalized
Bfloat16
plot_utils.py
https://github.com/illuin-tech/colpali/blob/ad93aeb0d683a1ba621325834817441d9de5a780/colpali_engine/interpretability/plot_utils.py#L38
https://github.com/illuin-tech/colpali/blob/ad93aeb0d683a1ba621325834817441d9de5a780/colpali_engine/interpretability/plot_utils.py#L98
a solution is to convert to float attention_map_normalized = attention_map_normalized.float()
attention_map_normalized = attention_map_normalized.float()
Hi @pablohl good catch, for reference this is related to pytorch#90574. Do you want to contribute your fix?
https://github.com/illuin-tech/colpali/blob/ad93aeb0d683a1ba621325834817441d9de5a780/colpali_engine/interpretability/gen_interpretability_plots.py#L86
attention_map_normalized
is a torchBfloat16
, this breaks in two different places inplot_utils.py
https://github.com/illuin-tech/colpali/blob/ad93aeb0d683a1ba621325834817441d9de5a780/colpali_engine/interpretability/plot_utils.py#L38
https://github.com/illuin-tech/colpali/blob/ad93aeb0d683a1ba621325834817441d9de5a780/colpali_engine/interpretability/plot_utils.py#L98
a solution is to convert to float
attention_map_normalized = attention_map_normalized.float()