JackieHanLab / TOSICA

Transformer for One-Stop Interpretable Cell-type Annotation
MIT License
121 stars 23 forks source link

model depth=2, cause more 8 times GPU usage than depth=1 in hPancreas dataset train #10

Closed rushrush2022 closed 6 months ago

rushrush2022 commented 1 year ago

@JackieHanLab problem: model depth=2, cause much more 8 times GPU usage than depth=1 in hPancreas dataset train, 16GPU not enough. environment: python=3.9, pytorch=1.12.1, Tesla T4 16G GPU.

hPancreas dataset train, that's the demo_train.h5ad, strictly follow the tutorial config, max_g=300, max_gs=300, mask_ratio=0.015, n_unannotated=1, batch_size=8, embed_dim=48, depth=1, num_heads=4, lr=0.001, epochs= 10, lrf=0.01 The GPU usage suddenly exceeds 16GPU during 8% of the first epoch. Even I adjust embed_dim=2 and num_heads=2, there is also OOM of 16 GPU. error: File "/home/user1/codes/TOSICA/TOSICA/customized_linear.py", line 52, in backward grad_weight = grad_weight * mask RuntimeError: CUDA out of memory. Tried to allocate 166.00 MiB (GPU 0; 14.76 GiB total capacity; 13.31 GiB already allocated; 157.75 MiB free; 13.94 GiB reserved in total by PyTorch)

Then I changed depth=1, and keep embed_dim=48 and num_heads=4, the GPU usage is no more than 2G GPU, no problem to train with epochs=30. The depth the is transformer block, mostly is the self Attention parts. Could you check why the depth=2 vs depth=1 cause more than 8times GPU usage? Did you use 16GPU to train with depth=2? thanks a lot in advance!!

wang-qf commented 1 year ago

I use model default setting, v100 (32G), during 15% out of memory too. python=3.10.0, torch=2.0.0

JackieHanLab commented 1 year ago

Thank you for your interest in TOSICA. We have previously encountered the same error and determined that it is caused by using a higher version of Torch. We have found that using torch=1.7.1 eliminates this memory error. However, we are uncertain about the specific location of the problematic code and would greatly appreciate any assistance in identifying it.

Yutong18 commented 10 months ago

I encountered a memory issue while using attn_weights during the training process on PyTorch 2.0.0. As a solution, I modified the code to only return attn_weights during the evaluation phase, which effectively mitigated the memory issue.

Thank you!