Closed shubhaminnani closed 1 week ago
I found a similar issue in the official torch repository. if it's the input that is feature size is too large (>2**31) then it seems that torch is doing the limiting, can the input size be lowered? Lowering the dimensions or increasing the region_num
will do the trick. If this is not the reason, can you provide a code snippet that can be reproduced stably and I will try to find the problem.
Much appreciated.
Namespace(datasets='tcga', dataset_root='/mnt/e/ssl/RRTMIL', tcga_max_patch=-1, fix_loader_random=False, fix_train_random=False, val_ratio=0.0, fold_start=0, cv_fold=10, persistence=False, same_psize=0, tcga_sub='gbm', cls_alpha=1.0, aux_alpha=1.0, auto_resume=False, num_epoch=200, early_stopping=True, max_epoch=130, input_dim=1024, n_classes=3, batch_size=1, num_workers=0, loss='ce', opt='adam', save_best_model_stage=0.0, model='rrtmil', seed=1, lr=0.0002, lr_sche='cosine', lr_supi=False, weight_decay=1e-05, accumulation_steps=1, clip_grad=0.0, always_test=True, ds_average=False, only_rrt_enc=False, act='relu', dropout=0.25, attn='rmsa', pool='attn', ffn=False, n_trans_layers=2, mlp_ratio=4.0, qkv_bias=True, all_shortcut=True, region_attn='native', min_region_num=0, region_num=8, trans_dim=64, n_heads=8, trans_drop_out=0.1, drop_path=0.0, pos='none', pos_pos=0, peg_k=7, peg_1d=False, epeg=True, epeg_bias=True, epeg_2d=False, epeg_k=13, epeg_type='attn', cr_msa=True, crmsa_k=3, crmsa_heads=1, crmsa_mlp=True, da_act='tanh', patch_shuffle=False, group_shuffle=False, shuffle_group=0, title='rrtmil', project='mil_new_c16', log_iter=100, amp=False, wandb=False, no_log=False, model_path='results/mil_new_c16/rrtmil')
Using the above parameters to train three class subtyping problem for TCGA dataset with feature size of 1024
Sorry, I'd like to explain a little more about what feature size
means here. This piece of error code actually performs a convolution operation on the feature attention matrix of all the patches in the region. Here the dimension of the attention matrix is n_patch*n_patch*n_head
, so (2**31)**0.5 / 8 = 5792
. If there is a slide in the TCGA-GBM dataset that has a number of patches > 5792*8*8 (region_num) = 370728
then this error occurs. It would be helpful to me if you could change the region_num
to 16
or 32
to double-check if the error still exists.
Hi, Thank you for your prompt response. The issue doesnt exist when region num is increased. I changed it to 64.
Best wish for your research! :smile:
While running the code facing the issue as in title, can you please help?
Pytorch Version : 2.0.1 py3.10_cuda11.8_cudnn8.7.0_0