atomicarchitects / equiformer

[ICLR 2023 Spotlight] Equiformer: Equivariant Graph Attention Transformer for 3D Atomistic Graphs
https://arxiv.org/abs/2206.11990
MIT License
212 stars 38 forks source link

Reduce the model size #12

Closed xnuohz closed 1 year ago

xnuohz commented 1 year ago

Hi, thanks for sharing the code, I'd like to try it on my own dataset.

But, unlike MD17 in which the molecules have only 12 atoms, my dataset has more atoms and it'll allocate more GPU memories. May you give me some advice to reduce the model or input size?

train and evaluate batch size are set as 8, only when the process of calculating the force is removed during the test, it will not show OOM

Thank you.

Number of params: 3500609
Epoch: [0][0/2500]  loss_e: 0.76386, loss_f: 0.26663, e_MAE: 231721.79688, f_MAE: 40633.41797, time/step=1221ms, lr=1.00e-06
Epoch: [0][100/2500]    loss_e: 0.78266, loss_f: 0.17736, e_MAE: 237424.20057, f_MAE: 26658.75540, time/step=269ms, lr=1.00e-06
Epoch: [0][200/2500]    loss_e: 0.67438, loss_f: 0.14274, e_MAE: 204576.80185, f_MAE: 21441.58661, time/step=256ms, lr=1.00e-06
Epoch: [0][300/2500]    loss_e: 0.60979, loss_f: 0.12493, e_MAE: 184983.89528, f_MAE: 18766.71555, time/step=248ms, lr=1.00e-06
Epoch: [0][400/2500]    loss_e: 0.56140, loss_f: 0.11186, e_MAE: 170304.47533, f_MAE: 16805.37575, time/step=244ms, lr=1.00e-06
Epoch: [0][500/2500]    loss_e: 0.52590, loss_f: 0.10222, e_MAE: 159535.96707, f_MAE: 15353.32029, time/step=241ms, lr=1.00e-06
Epoch: [0][600/2500]    loss_e: 0.49673, loss_f: 0.09442, e_MAE: 150684.69430, f_MAE: 14181.63761, time/step=239ms, lr=1.00e-06
Epoch: [0][700/2500]    loss_e: 0.47570, loss_f: 0.08797, e_MAE: 144305.35255, f_MAE: 13212.08826, time/step=238ms, lr=1.00e-06
Epoch: [0][800/2500]    loss_e: 0.45563, loss_f: 0.08253, e_MAE: 138218.50663, f_MAE: 12395.78744, time/step=237ms, lr=1.00e-06
Epoch: [0][900/2500]    loss_e: 0.44249, loss_f: 0.07803, e_MAE: 134231.66208, f_MAE: 11719.96249, time/step=236ms, lr=1.00e-06
Epoch: [0][1000/2500]   loss_e: 0.42950, loss_f: 0.07414, e_MAE: 130291.51578, f_MAE: 11135.17785, time/step=236ms, lr=1.00e-06
Epoch: [0][1100/2500]   loss_e: 0.41839, loss_f: 0.07068, e_MAE: 126920.91153, f_MAE: 10616.74550, time/step=235ms, lr=1.00e-06
Epoch: [0][1200/2500]   loss_e: 0.40806, loss_f: 0.06769, e_MAE: 123787.15601, f_MAE: 10166.87837, time/step=235ms, lr=1.00e-06
Epoch: [0][1300/2500]   loss_e: 0.39778, loss_f: 0.06493, e_MAE: 120668.85974, f_MAE: 9753.33389, time/step=235ms, lr=1.00e-06
Epoch: [0][1400/2500]   loss_e: 0.39040, loss_f: 0.06247, e_MAE: 118430.25954, f_MAE: 9383.97292, time/step=235ms, lr=1.00e-06
Epoch: [0][1500/2500]   loss_e: 0.38277, loss_f: 0.06031, e_MAE: 116114.40569, f_MAE: 9058.41404, time/step=234ms, lr=1.00e-06
Epoch: [0][1600/2500]   loss_e: 0.37518, loss_f: 0.05833, e_MAE: 113813.39178, f_MAE: 8760.11055, time/step=234ms, lr=1.00e-06
Epoch: [0][1700/2500]   loss_e: 0.36870, loss_f: 0.05644, e_MAE: 111847.87797, f_MAE: 8476.67351, time/step=234ms, lr=1.00e-06
Epoch: [0][1800/2500]   loss_e: 0.36322, loss_f: 0.05477, e_MAE: 110185.60253, f_MAE: 8224.51457, time/step=234ms, lr=1.00e-06
Epoch: [0][1900/2500]   loss_e: 0.35698, loss_f: 0.05320, e_MAE: 108291.43613, f_MAE: 7988.39987, time/step=234ms, lr=1.00e-06
Epoch: [0][2000/2500]   loss_e: 0.35160, loss_f: 0.05175, e_MAE: 106659.09827, f_MAE: 7770.07496, time/step=233ms, lr=1.00e-06
Epoch: [0][2100/2500]   loss_e: 0.34805, loss_f: 0.05041, e_MAE: 105582.97466, f_MAE: 7569.16291, time/step=233ms, lr=1.00e-06
Epoch: [0][2200/2500]   loss_e: 0.34287, loss_f: 0.04914, e_MAE: 104013.07274, f_MAE: 7378.71877, time/step=233ms, lr=1.00e-06
Epoch: [0][2300/2500]   loss_e: 0.33794, loss_f: 0.04798, e_MAE: 102517.56985, f_MAE: 7204.14332, time/step=233ms, lr=1.00e-06
Epoch: [0][2400/2500]   loss_e: 0.33326, loss_f: 0.04687, e_MAE: 101095.59328, f_MAE: 7036.94210, time/step=233ms, lr=1.00e-06
Epoch: [0][2499/2500]   loss_e: 0.32863, loss_f: 0.04580, e_MAE: 99691.93445, f_MAE: 6876.80393, time/step=233ms, lr=1.00e-06
Traceback (most recent call last):
  File "main_aliqm.py", line 489, in <module>
    main(args)
  File "main_aliqm.py", line 236, in main
    val_err, val_loss = evaluate(args=args, model=model, criterion=criterion, 
  File "main_aliqm.py", line 449, in evaluate
    pred_y, pred_dy = model(node_atom=data.z, pos=data.pos, batch=data.batch)
  File "/home/ubuntu/Softwares/anaconda3/envs/rapids/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/Softwares/anaconda3/envs/rapids/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/Projects/equiformer/nets/graph_attention_transformer_aliqm.py", line 319, in forward
    torch.autograd.grad(
  File "/home/ubuntu/Softwares/anaconda3/envs/rapids/lib/python3.8/site-packages/torch/autograd/__init__.py", line 300, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 44.00 MiB (GPU 0; 23.69 GiB total capacity; 22.38 GiB already allocated; 16.44 MiB free; 22.52 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
yilunliao commented 1 year ago

Hi @xnuohz

You can use smaller L_{max} (e.g., "128x0e+128x1e+128x2e" -> "128x0e+128x1e"), smaller numbers of channels, or smaller numbers of blocks.

Besides, using a smaller number of maximum neighbors might be helpful.

xnuohz commented 1 year ago

Thanks @yilunliao As you mentioned, which one do you think will influence the model's metrics most ? I typically observe runtime/model size vs. accuracy trade-off.

yilunliao commented 1 year ago

I am not sure. This can depend on applications or datasets. I guess reducing L_{max} from 2 to 1 would hurt the most in most of cases.