Closed ysl2 closed 2 years ago
Hello,
which variant are you training here? Is it tiny version with 48 embed dims?
I was training with the base.yaml
. However when I turn to tiny.yaml
, the bug still occured but in different line:
Traceback (most recent call last):
File "train.py", line 306, in <module>
main(arguments)
File "train.py", line 166, in main
segs_S1 = model_1(inputs_S1)
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vision_transformer.py", line 49, in forward
return self.swin_unet(x)
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 1116, in forward
x, x_downsample, v_values_1, k_values_1, q_values_1, v_values_2, k_values_2, q_values_2 = self.forward_features(
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 957, in forward_features
x, v1, k1, q1, v2, k2, q2 = layer(x, i)
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 723, in forward
x, v1, k1, q1 = blk(x, attn_mask, None, None, None)
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 389, in forward
x, x2, v, k, q = self.forward_part1(x, mask_matrix, prev_v, prev_k, prev_q, is_decoder)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 337, in forward_part1
attn_windows, cross_attn_windows, v, k, q = self.attn(x_windows, mask=attn_mask, prev_v=prev_v, prev_k=prev_k,
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 195, in forward
attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
RuntimeError: CUDA out of memory. Tried to allocate 9.02 GiB (GPU 0; 23.69 GiB total capacity; 16.22 GiB already allocated; 4.95 GiB free; 16.62 Gi
B reserved in total by PyTorch)
To train tiny version you will need at least 12 GB GPU memory. From the stacktrace I can see you are using 24GB GPU and it is enough to train base version. What is the batch size in this case? Normally I use batch size of 1, if I exceed that I get the same OOM error.
Working with 1 GPUs
=> merge config from configs/vt_unet_tiny.yaml
SwinTransformerSys3D expand initial----depths:[2, 2, 2, 1];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.1;num_classes:2;embed_dims:48;window:(7, 7,
7)
---final upsample expand_first---
pretrained_path:./pretrained_ckpt/swin_tiny_patch4_window7_224.pth
---start load pretrained modle by splitting---
total number of trainable parameters 5374620
<bound method EDiceLoss_Val.metric of EDiceLoss_Val()>
Train dataset number of batch: 1332
Val dataset number of batch: 332
Bench Test dataset number of batch: 264
start training now!
Traceback (most recent call last):
File "train.py", line 307, in <module>
main(arguments)
File "train.py", line 167, in main
segs_S1 = model_1(inputs_S1)
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vision_transformer.py", line 49, in forward
return self.swin_unet(x)
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 1116, in forward
x, x_downsample, v_values_1, k_values_1, q_values_1, v_values_2, k_values_2, q_values_2 = self.forward_features(
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 957, in forward_features
x, v1, k1, q1, v2, k2, q2 = layer(x, i)
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 723, in forward
x, v1, k1, q1 = blk(x, attn_mask, None, None, None)
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 389, in forward
x, x2, v, k, q = self.forward_part1(x, mask_matrix, prev_v, prev_k, prev_q, is_decoder)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 337, in forward_part1
attn_windows, cross_attn_windows, v, k, q = self.attn(x_windows, mask=attn_mask, prev_v=prev_v, prev_k=prev_k,
File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 195, in forward
attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
RuntimeError: CUDA out of memory. Tried to allocate 9.02 GiB (GPU 0; 23.69 GiB total capacity; 16.22 GiB already allocated; 4.95 GiB free; 16.62 Gi
B reserved in total by PyTorch)
This is the entire console output. My batch_size
is 1. The default batch_size is 1 and I did't modify this parameter.
Seems this is an issue with Pytorch. Did you able to solve this issue?
Notice: I modified the dataloader to fit my own dataset. I've serched lots of solution but failed to solve this. Very sad :-( Could you please help me? Maybe the model is too large to train? I don't know. Thanks!