DearCaat / RRT-MIL

[CVPR 2024] Feature Re-Embedding: Towards Foundation Model-Level Performance in Computational Pathology
58 stars 3 forks source link

RuntimeError: Expected canUse32BitIndexMath(input) && canUse32BitIndexMath(output) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) #10

Closed shubhaminnani closed 1 week ago

shubhaminnani commented 1 month ago

While running the code facing the issue as in title, can you please help?

Traceback (most recent call last):
  File "/mnt/e/ssl/RRT-MIL/main.py", line 750, in <module>
    main(args=args)
  File "/mnt/e/ssl/RRT-MIL/main.py", line 67, in main
    ckc_metric = one_fold(args,k,ckc_metric,train_p, train_l, test_p, test_l,val_p,val_l)
  File "/mnt/e/ssl/RRT-MIL/main.py", line 280, in one_fold
    train_loss,start,end = train_loop(args,model,train_loader,optimizer,device,amp_autocast,criterion,loss_scaler,scheduler,k,epoch)
  File "/mnt/e/ssl/RRT-MIL/main.py", line 449, in train_loop
    train_logits = model(bag)
  File "/home/sinnani/miniconda3/envs/mambamil/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e/ssl/RRT-MIL/modules/rrt.py", line 232, in forward
    x = self.online_encoder(x)
  File "/home/sinnani/miniconda3/envs/mambamil/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e/ssl/RRT-MIL/modules/rrt.py", line 188, in forward
    x = layer(x)
  File "/home/sinnani/miniconda3/envs/mambamil/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e/ssl/RRT-MIL/modules/rrt.py", line 110, in forward
    x,attn = self.forward_trans(x,need_attn=need_attn)
  File "/mnt/e/ssl/RRT-MIL/modules/rrt.py", line 123, in forward_trans
    z = self.attn(self.norm(x))
  File "/home/sinnani/miniconda3/envs/mambamil/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e/ssl/RRT-MIL/modules/rmsa.py", line 218, in forward
    attn_regions = self.attn(x_regions)  # nW*B, region_size*region_size, C
  File "/home/sinnani/miniconda3/envs/mambamil/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e/ssl/RRT-MIL/modules/rmsa.py", line 107, in forward
    pe = self.pe(attn)
  File "/home/sinnani/miniconda3/envs/mambamil/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sinnani/miniconda3/envs/mambamil/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/sinnani/miniconda3/envs/mambamil/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected canUse32BitIndexMath(input) && canUse32BitIndexMath(output) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Pytorch Version : 2.0.1 py3.10_cuda11.8_cudnn8.7.0_0

DearCaat commented 1 month ago

I found a similar issue in the official torch repository. if it's the input that is feature size is too large (>2**31) then it seems that torch is doing the limiting, can the input size be lowered? Lowering the dimensions or increasing the region_num will do the trick. If this is not the reason, can you provide a code snippet that can be reproduced stably and I will try to find the problem. Much appreciated.

shubhaminnani commented 1 week ago
Namespace(datasets='tcga', dataset_root='/mnt/e/ssl/RRTMIL', tcga_max_patch=-1, fix_loader_random=False, fix_train_random=False, val_ratio=0.0, fold_start=0, cv_fold=10, persistence=False, same_psize=0, tcga_sub='gbm', cls_alpha=1.0, aux_alpha=1.0, auto_resume=False, num_epoch=200, early_stopping=True, max_epoch=130, input_dim=1024, n_classes=3, batch_size=1, num_workers=0, loss='ce', opt='adam', save_best_model_stage=0.0, model='rrtmil', seed=1, lr=0.0002, lr_sche='cosine', lr_supi=False, weight_decay=1e-05, accumulation_steps=1, clip_grad=0.0, always_test=True, ds_average=False, only_rrt_enc=False, act='relu', dropout=0.25, attn='rmsa', pool='attn', ffn=False, n_trans_layers=2, mlp_ratio=4.0, qkv_bias=True, all_shortcut=True, region_attn='native', min_region_num=0, region_num=8, trans_dim=64, n_heads=8, trans_drop_out=0.1, drop_path=0.0, pos='none', pos_pos=0, peg_k=7, peg_1d=False, epeg=True, epeg_bias=True, epeg_2d=False, epeg_k=13, epeg_type='attn', cr_msa=True, crmsa_k=3, crmsa_heads=1, crmsa_mlp=True, da_act='tanh', patch_shuffle=False, group_shuffle=False, shuffle_group=0, title='rrtmil', project='mil_new_c16', log_iter=100, amp=False, wandb=False, no_log=False, model_path='results/mil_new_c16/rrtmil')

Using the above parameters to train three class subtyping problem for TCGA dataset with feature size of 1024

DearCaat commented 1 week ago

Sorry, I'd like to explain a little more about what feature size means here. This piece of error code actually performs a convolution operation on the feature attention matrix of all the patches in the region. Here the dimension of the attention matrix is n_patch*n_patch*n_head, so (2**31)**0.5 / 8 = 5792. If there is a slide in the TCGA-GBM dataset that has a number of patches > 5792*8*8 (region_num) = 370728 then this error occurs. It would be helpful to me if you could change the region_num to 16 or 32 to double-check if the error still exists.

shubhaminnani commented 1 week ago

Hi, Thank you for your prompt response. The issue doesnt exist when region num is increased. I changed it to 64.

DearCaat commented 1 week ago

Best wish for your research! :smile: