aertslab / CREsted

Other
26 stars 1 forks source link

Error in calculate_contribution_scores_regions with pytorch #45

Open lldelisle opened 1 week ago

lldelisle commented 1 week ago

Report

Hi, I got the following error when trying to use calculate_contribution_scores_regions with pytorch:

2024-10-16T15:33:14.181525+0200 INFO Calculating contribution scores for 4 class(es) and 2 region(s).
Region:   0%|                                                                                                                         | 0/2 [00:00<?, ?it/s]/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_explainer_torch.py:255: FutureWarning: The input object of type 'Tensor' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Tensor', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.
  return np.array(x_shuffle)
/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_explainer_torch.py:255: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  return np.array(x_shuffle)
Region:   0%|                                                                                                                         | 0/2 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/ldelisle/scripts/scitas_sbatchhistory/2024/20241002_tryCREest/20241016_plots_with_model.py", line 121, in <module>
    scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_regions(
  File "/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_crested.py", line 994, in calculate_contribution_scores_regions
    return self.calculate_contribution_scores_sequence(
  File "/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_crested.py", line 1085, in calculate_contribution_scores_sequence
    scores[:, i, :, :] = explainer.expected_integrated_grad(
  File "/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_explainer_torch.py", line 71, in expected_integrated_grad
    baselines = torch.tensor(
TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.

If you need more info, tell me I can try to make a minimal example.

Version information

I don't have session_info but I can give you the output of pip list:

Package                           Version                                                                                                                   
--------------------------------- ------------                                                                                                              
absl-py                           1.2.0                                                                                                                     
anndata                           0.10.5.post1                                                                                                              
appdirs                           1.4.4                                                                                                                     
array_api_compat                  1.9                                                                                                                       
astunparse                        1.6.3                                                                                                                     
backports.entry-points-selectable 1.1.1                                                                                                                     
certifi                           2021.10.8                                                                                                                 
charset-normalizer                2.0.12                                                                                                                    
click                             8.1.7                                                                                                                     
crested                           1.1.0                                                                                                                     
cycler                            0.11.0                                                                                                                    
Cython                            0.29.30                                                                                                                   
distlib                           0.3.4  
docker-pycreds                    0.4.0                                                                                                                     
exceptiongroup                    1.2.2                                                                                                                     
filelock                          3.5.0                                                                                                                     
fonttools                         4.54.1                                                                                                                    
fsspec                            2024.9.0                                                                                                                  
gast                              0.5.3                                                                                                                     
gitdb                             4.0.11                                                                                                                    
GitPython                         3.1.43                                                                                                                    
google-pasta                      0.2.0                                                                                                                     
h5py                              3.12.1                                                                                                                    
idna                              3.3                                                                                                                       
Jinja2                            3.1.4                                                                                                                     
joblib                            1.4.2                                                                                                                     
keras                             3.6.0                                                                                                                     
Keras-Preprocessing               1.1.2                                                                                                                     
kiwisolver                        1.3.2                                                                                                                     
logomaker                         0.8                                                                                                                       
loguru                            0.7.2                                                                                                                     
markdown-it-py                    3.0.0                                                                                                                     
MarkupSafe                        3.0.1                                                                                                                     
matplotlib                        3.5.2                                                                                                                     
mdurl                             0.1.2                                                                                                                     
ml_dtypes                         0.5.0                                                                                                                     
mpmath                            1.2.1                                                                                                                     
namex                             0.0.8                                                                                                                     
natsort                           8.4.0                                                                                                                     
networkx                          3.3                                                                                                                       
numpy                             1.22.4  
nvidia-cuda-cupti-cu12            12.1.105                                                                                                        
nvidia-cuda-nvrtc-cu12            12.1.105                                                                                                                  
nvidia-cuda-runtime-cu12          12.1.105                                                                                                                  
nvidia-cudnn-cu12                 9.1.0.70                                                                                                                  
nvidia-cufft-cu12                 11.0.2.54                                                                                                                 
nvidia-curand-cu12                10.3.2.106                                                                                                                
nvidia-cusolver-cu12              11.4.5.107                                                                                                                
nvidia-cusparse-cu12              12.1.0.106                                                                                                                
nvidia-nccl-cu12                  2.20.5                                                                                                                    
nvidia-nvjitlink-cu12             12.6.77                                                                                                                   
nvidia-nvtx-cu12                  12.1.105                                                                                                                  
opt-einsum                        3.3.0                                                                                                                     
optree                            0.13.0                                                                                                                    
packaging                         21.3                                                                                                                      
pandas                            1.4.2                                                                                                                     
Pillow                            9.0.0                                                                                                                     
pip                               24.2                                                                                                                      
platformdirs                      2.4.0                                                                                                                     
ply                               3.11                                                                                                                      
pooch                             1.6.0                                                                                                                     
protobuf                          3.20.0                                                                                                                    
psutil                            6.0.0                                                                                                                     
pybigtools                        0.2.1                                                                                                                     
pyBigWig                          0.3.23                                                                                                                    
Pygments                          2.18.0  
pyparsing                         3.0.6
pysam                             0.22.1
python-dateutil                   2.8.2
pytz                              2021.3
PyYAML                            6.0.2
requests                          2.26.0
rich                              13.9.2
scikit-learn                      1.5.2
scipy                             1.8.1
seaborn                           0.13.2
semver                            2.8.1
sentry-sdk                        1.9.0
setproctitle                      1.3.3
setuptools                        58.3.0
six                               1.16.0
smmap                             5.0.1
sphire                            1.4.1
sympy                             1.8
termcolor                         1.1.0
threadpoolctl                     3.5.0
torch                             2.4.1
tqdm                              4.66.5
triton                            3.0.0
typing_extensions                 4.12.2
urllib3                           1.26.6
virtualenv                        20.10.0
wandb                             0.18.3
wheel                             0.44.0
wrapt                             1.13.3
xarray                            2022.3.0
lldelisle commented 1 week ago

(I don't think it is the origin of the bug but just to let you know, I trained the model with version 1.0.0 and upgraded to 1.1.0 to plot)

LukasMahieu commented 1 week ago

Haven't seen this before, we normally have unit tests for this. I'll try to reproduce asap.

lldelisle commented 1 week ago

I have numpy 1.22.4, I think this is the issue...

LukasMahieu commented 1 week ago

Actually I have seen this before. I remember that upgrading to numpy 2.+ indeed fixed this. How did you create your environment? When I create a standard crested environment with python 3.10 then numpy 2.1.2 gets installed and I don't run into any issues. For python 3.10 numpy also recommend version 1.23+

lldelisle commented 1 week ago

I am a newbie in GPU so I made the virtual environment on my HPC using module load to be sure pytorch will recognize the GPU and then tried to install other dependencies without overwritting existing installation. Here are the command lines:

module purge
module load gcc/11.3.0 python/3.10.4 openmpi/4.1.3-cuda
# With openmpi comes cuda/11.8.0
# Only the first time:
virtualenv --system-site-packages venv9
# Activate
source venv9/bin/activate
pip install --no-cache-dir torch
pip install --no-cache-dir keras
pip install --no-cache-dir crested urllib3==1.26.6 numpy=='1.22.4' platformdirs=='2.4.0' sentry_sdk==1.9.0

I've just created a new virtualenv with:

# Try pytorch with pip
module purge
module load gcc/11.3.0 python/3.10.4
# Only the first time:
virtualenv venv7
# Activate
source venv7/bin/activate

pip install --no-cache-dir torch
pip install --no-cache-dir crested

And it solved the issue (but I got another bug, I will write a different issue).

Then I would say that you should set numpy>2 in dependencies no?

LukasMahieu commented 1 week ago

Then I would say that you should set numpy>2 in dependencies no?

Yes, I'll do just this. Thanks for bringing this up.

lldelisle commented 1 week ago

You are welcome. I like when people report bugs to my project so I prevent other people to face the same bug but being too shy to report it and not use the package... I hope you are in the same spirit because it seems that I faced other bugs...

LukasMahieu commented 1 week ago

Absolutely, feedback is greatly appreciated! I just realized the solution is not as simple here after all, since torch would require >2 but tensorflow only works with <2 🙃 . I'll see if I can fix this in the code itself early next week.