ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
501 stars 190 forks source link

Fine-tuning of foundation model failure on main branch #590

Closed stenczelt closed 1 month ago

stenczelt commented 1 month ago

Describe the bug Fine-tuning of the MP0b model has failed on the main branch (edde9e9c00405c6b4a90ace371891fabeb7ba55e) but worked earlier on a386d997ae5675a9129d87cd53e2ce435871510a using the same hardware.

To Reproduce Steps to reproduce the behavior:

  1. Install MACE from source
  2. Run fine-tuning of the https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0b medium model
mace_run_train \
    --name="CASTEP_001" \
    --train_file="train.xyz" \
    --valid_fraction=0.05 \
    --test_file="test.xyz" \
    --E0s='{6:-148.1738815378, 14:-163.4954402249}' \
    --model="MACE" \
    --energy_key='CASTEP_energy' \
    --forces_key='CASTEP_forces' \
    --stress_key='CASTEP_stress' \
    --multiheads_finetuning True \
    --foundation_model /home/coder/.cache/mace/mace_agnesi_mediummodel \
    --device=cuda

where the errors are:

2024-09-13 06:50:52.886 INFO: ===========TRAINING===========
2024-09-13 06:50:52.886 INFO: Started training, reporting errors on validation set
2024-09-13 06:50:52.886 INFO: Loss metrics on validation set
2024-09-13 06:50:57.696 INFO: Initial: head: pt_head, loss=  0.0039, RMSE_E_per_atom=   119.3 meV, RMSE_F=   104.0 meV / A, RMSE_stress=    13.9 meV / A^3
2024-09-13 06:50:57.774 INFO: Initial: head: default, loss=  0.0899, RMSE_E_per_atom=   742.3 meV, RMSE_F=  1364.8 meV / A, RMSE_stress=    20.2 meV / A^3
Traceback (most recent call last):
  File "/home/coder/project/venv/bin/mace_run_train", line 8, in <module>
    sys.exit(main())
  File "/home/coder/project/venv/lib/python3.9/site-packages/mace/cli/run_train.py", line 62, in main
    run(args)
  File "/home/coder/project/venv/lib/python3.9/site-packages/mace/cli/run_train.py", line 575, in run
    tools.train(
  File "/home/coder/project/venv/lib/python3.9/site-packages/mace/tools/train.py", line 218, in train
    train_one_epoch(
  File "/home/coder/project/venv/lib/python3.9/site-packages/mace/tools/train.py", line 340, in train_one_epoch
    _, opt_metrics = take_step(
  File "/home/coder/project/venv/lib/python3.9/site-packages/mace/tools/train.py", line 378, in take_step
    loss.backward()
  File "/home/coder/.local/lib/rolos-python-blank/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/home/coder/.local/lib/rolos-python-blank/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/home/coder/.local/lib/rolos-python-blank/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: NVML_SUCCESS == r INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":838, please report a bug to PyTorch. 

see the full log attached below.

Additional context Env:

Collecting environment information...
PyTorch version: 2.4.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.9.5 (default, Nov 23 2021, 15:27:38)  [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-119-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100 80GB PCIe
  MIG 2g.20gb     Device  0:

Nvidia driver version: 550.54.15
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Byte Order:                           Little Endian
Address sizes:                        46 bits physical, 48 bits virtual
CPU(s):                               32
On-line CPU(s) list:                  0-31
Thread(s) per core:                   1
Core(s) per socket:                   32
Socket(s):                            1
NUMA node(s):                         1
Vendor ID:                            GenuineIntel
CPU family:                           6
Model:                                85
Model name:                           Intel Xeon Processor (Cascadelake)
Stepping:                             5
CPU MHz:                              2593.910
BogoMIPS:                             5187.82
Virtualization:                       VT-x
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            1 MiB
L1i cache:                            1 MiB
L2 cache:                             128 MiB
NUMA node0 CPU(s):                    0-31
Vulnerability Gather data sampling:   Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:          KVM: Mitigation: VMX disabled
Vulnerability L1tf:                   Mitigation; PTE Inversion; VMX conditional cache flushes, SMT disabled
Vulnerability Mds:                    Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown:               Mitigation; PTI
Vulnerability Mmio stale data:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Mitigation; Clear CPU buffers; SMT Host state unknown
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq md_clear arch_capabilities

Versions of relevant libraries:
[pip3] mace-torch==0.3.7
[pip3] numpy==1.26.3
[pip3] torch==2.4.1+cu124
[pip3] torch-ema==0.3
[pip3] torchaudio==2.4.1+cu124
[pip3] torchmetrics==1.4.1
[pip3] torchvision==0.19.1+cu124
[pip3] triton==3.0.0
[conda] Could not collect

CASTEP_001_run-123_debug.log

ilyes319 commented 1 month ago

@stenczelt Was this solved by downgrading the pytorch version?

stenczelt commented 1 month ago

Yes!