QuantumLab-ZY / HamGNN

An E(3) equivariant Graph Neural Network for predicting electronic Hamiltonian matrix
GNU General Public License v3.0
55 stars 15 forks source link

RuntimeError: CUDA out of memory. #20

Open Youhaojen opened 5 months ago

Youhaojen commented 5 months ago

Dear Zhong Yang,

I tried to train the model with 3 RTX 3090. The training data are generated via abacus, which includes 60 structures of 96 atoms Mg2Ge. I get an Error when running the training process, there's the full error message below:

##################################################################
#                                                                #
#    ██╗  ██╗ █████╗ ███╗   ███╗ ██████╗ ███╗   ██╗███╗   ██╗    #
#    ██║  ██║██╔══██╗████╗ ████║██╔════╝ ████╗  ██║████╗  ██║    #
#    ███████║███████║██╔████╔██║██║  ███╗██╔██╗ ██║██╔██╗ ██║    #
#    ██╔══██║██╔══██║██║╚██╔╝██║██║   ██║██║╚██╗██║██║╚██╗██║    #
#    ██║  ██║██║  ██║██║ ╚═╝ ██║╚██████╔╝██║ ╚████║██║ ╚████║    #
#    ╚═╝  ╚═╝╚═╝  ╚═╝╚═╝     ╚═╝ ╚═════╝ ╚═╝  ╚═══╝╚═╝  ╚═══╝    #
#       Author: Yang Zhong       Email: yzhong@fudan.edu.cn      #
##################################################################

{'dataset_params': {'batch_size': 1,
                    'csv_params': {'crystal_path': 'crystals',
                                   'file_type': 'poscar',
                                   'id_prop_path': './',
                                   'l_pred_atomwise_tensor': True,
                                   'l_pred_crystal_tensor': False,
                                   'rank_tensor': 0},
                    'database_type': 'db',
                    'db_params': {'db_path': './',
                                  'property_list': ['energy', 'hamiltonian']},
                    'graph_data_path': './graph/',
                    'max_num_nbr': 32,
                    'radius': 6.0,
                    'split_file': None,
                    'test_ratio': 0.1,
                    'train_ratio': 0.7,
                    'val_ratio': 0.2},
 'losses_metrics': {'losses': [{'loss_weight': 1.0,
                                'metric': L1Loss(),
                                'prediction': 'hamiltonian',
                                'target': 'hamiltonian'}],
                    'metrics': [{'metric': L1Loss(),
                                 'prediction': 'hamiltonian',
                                 'target': 'hamiltonian'}]},
 'molecular_dynamics': {'device': None,
                        'dt': 0.1,
                        'energy_units_to_eV': 1.0,
                        'initial_xyz': '',
                        'length_units_to_A': 1.0,
                        'log_frequency': 2,
                        'logdir': './',
                        'model': None,
                        'n_steps': 100,
                        'nvt_q': 43.06225052549201,
                        'r_max': 6.0,
                        'save_frequency': 2,
                        'seed': 66,
                        'temperature': 300.0},
 'optim_params': {'gradient_clip_val': 0.0,
                  'lr': 0.001,
                  'lr_decay': 0.5,
                  'lr_patience': 5,
                  'max_epochs': 3000,
                  'min_epochs': 100,
                  'stop_patience': 30},
 'output_nets': {'HamGNN_out': {'add_H0': True,
                                'band_num_control': None,
                                'calculate_band_energy': False,
                                'ham_only': True,
                                'ham_type': 'abacus',
                                'k_path': None,
                                'nao_max': 27,
                                'nonlinearity_type': 'norm',
                                'num_k': 5,
                                'soc_switch': False,
                                'symmetrize': True},
                 'output_module': 'HamGNN_out'},
 'post_processing': {'EPC': {}, 'post_utility': 'EPC'},
 'profiler_params': {'progress_bar_refresh_rat': 1, 'train_dir': './train'},
 'representation_nets': {'HamGNN_pre': {'add_edge_tp': False,
                                        'cutoff': 20.0,
                                        'cutoff_func': 'cos',
                                        'edge_sh_normalization': 'component',
                                        'edge_sh_normalize': True,
                                        'export_triplet': False,
                                        'feature_irreps_hidden': '32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e',
                                        'invariant_layers': 2,
                                        'invariant_neurons': 64,
                                        'irreps_edge_output': '32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e',
                                        'irreps_edge_sh': '0e + 1o + 2e + 3o + '
                                                          '4e + 5o + 6e',
                                        'irreps_node_features': '32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e',
                                        'irreps_node_output': '32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e',
                                        'irreps_node_prev': '16x0o+16x0e+8x1o+8x1e+8x2e+8x2o+8x3o+8x3e+8x4o+8x4e',
                                        'irreps_triplet_output': '32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e',
                                        'num_interaction_layers': 5,
                                        'num_node_attr_feas': 64,
                                        'num_radial': 8,
                                        'num_spherical': 8,
                                        'num_types': 120,
                                        'rbf_func': 'bessel',
                                        'resnet': True,
                                        'set_features': True}},
 'setup': {'GNN_Net': 'HamGNN_pre',
           'accelerator': 'ddp',
           'checkpoint_path': './network_weights_silicon.ckpt',
           'ignore_warnings': True,
           'l_minus_mean': False,
           'load_from_checkpoint': False,
           'num_gpus': [0, 1, 2],
           'precision': 32,
           'property': 'hamiltonian',
           'resume': False,
           'stage': 'fit'}}
Loading graph data from ./graph/graph_data.npz!
Building model
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Start training.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]
initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/3
initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/3
initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/3
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 3 processes
----------------------------------------------------------------------------------------------------

  | Name           | Type       | Params
----------------------------------------------
0 | representation | HamGNN_pre | 86.3 M
1 | output_module  | HamGNN_out | 276 K 
----------------------------------------------
86.5 M    Trainable params
0         Non-trainable params
86.5 M    Total params
346.162   Total estimated model params size (MB)
Validation sanity check: 0it [00:00, ?it/s]/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:120: UserWarning: strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks. Consider setting num_workers>0 and persistent_workers=True
  rank_zero_warn(
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:59: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 10. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
  warning_cache.warn(
/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:432: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s] 
-------------------------------------------------------------------------------
HamGNN 33 <module>
sys.exit(load_entry_point('HamGNN==0.1.0', 'console_scripts', 'HamGNN')())

main.py 308 HamGNN
train_and_eval(configure)

main.py 263 train_and_eval
trainer.fit(model, data)

trainer.py 740 fit
self._call_and_handle_interrupt(

trainer.py 685 _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)

trainer.py 777 _fit_impl
self._run(model, ckpt_path=ckpt_path)

trainer.py 1199 _run
self._dispatch()

trainer.py 1279 _dispatch
self.training_type_plugin.start_training(self)

ddp_spawn.py 173 start_training
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)

ddp_spawn.py 201 spawn
mp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), nprocs=self.num_processes)

spawn.py 240 spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')

spawn.py 198 start_processes
while not context.join():

spawn.py 160 join
raise ProcessRaisedException(msg, error_index, failed_process.pid)

torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 208, in _wrapped_function
    result = function(*args, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 236, in new_process
    results = trainer.run_stage()
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
    return self._run_train()
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1319, in _run_train
    self.fit_loop.run()
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance
    batch_output = self.batch_loop.run(batch, batch_idx)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 215, in advance
    result = self._run_optimization(
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 266, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 378, in _optimizer_step
    lightning_module.optimizer_step(
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py", line 1652, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 164, in step
    trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 339, in optimizer_step
    self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 163, in optimizer_step
    optimizer.step(closure=closure, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/optim/adamw.py", line 100, in step
    loss = closure()
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 148, in _wrap_closure
    closure_result = closure()
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure
    step_output = self._step_fn()
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 435, in _training_step
    training_step_output = self.trainer.accelerator.training_step(step_kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 219, in training_step
    return self.training_type_plugin.training_step(*step_kwargs.values())
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 383, in training_step
    return self.model(*args, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 963, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/pytorch_lightning/overrides/base.py", line 81, in forward
    output = self.module.training_step(*inputs, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/HamGNN-0.1.0-py3.9.egg/HamGNN/models/Model.py", line 100, in training_step
    pred = self(data)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/HamGNN-0.1.0-py3.9.egg/HamGNN/models/Model.py", line 228, in forward
    representation = self.representation(data)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/HamGNN-0.1.0-py3.9.egg/HamGNN/models/HamGNN/net.py", line 680, in forward
    self.convnet[i](data)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/HamGNN-0.1.0-py3.9.egg/HamGNN/models/HamGNN/nequip/nn/_convnetlayer.py", line 160, in forward
    data = self.conv(data)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/HamGNN-0.1.0-py3.9.egg/HamGNN/models/HamGNN/nequip/nn/_interaction_block.py", line 168, in forward
    edge_features = self.tp(
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/e3nn-0.5.1-py3.9.egg/e3nn/o3/_tensor_product/_tensor_product.py", line 542, in forward
    return self._compiled_main_left_right(x, y, real_weight)
  File "/home/danken/application/compiler/anaconda3/envs/HamGNN/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torch/fx/graph_module/___torch_mangle_16.py", line 1382, in forward
    getitem_172 = torch.slice(_168, 1, 9472, 9504)
    reshape_340 = torch.reshape(getitem_172, [-1, 32, 1])
    einsum_236 = torch.einsum("edb,eca->ecdab", [reshape_17, reshape_12])
                 ~~~~~~~~~~~~ <--- HERE
    mul_132 = torch.mul(reshape_17, 0.33333333333333331)
    einsum_237 = torch.einsum("dca,dba->dcb", [mul_132, reshape_12])

Traceback of TorchScript, original code (most recent call last):
  File "<eval_with_key>.75", line 1154, in forward
    getitem_172 = reshape_2[(slice(None, None, None), slice(9472, 9504, None))]
    reshape_340 = getitem_172.reshape((-1, 32, 1));  getitem_172 = None
    einsum_236 = torch.functional.einsum('edb,eca->ecdab', reshape_17, reshape_12)
                 ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    mul_132 = reshape_17 * 0.3333333333333333;  reshape_17 = None
    einsum_237 = torch.functional.einsum('dca,dba->dcb', mul_132, reshape_12);  mul_132 = None
RuntimeError: CUDA out of memory. Tried to allocate 222.00 MiB (GPU 1; 23.69 GiB total capacity; 21.37 GiB already allocated; 218.44 MiB free; 21.70 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Following is my config.yaml.

dataset_params:
  batch_size: 1
  split_file: null
  test_ratio: 0.1
  train_ratio: 0.7
  val_ratio: 0.2
  graph_data_path: ./graph/ # Directory where graph_data.npz is located

losses_metrics:
  losses:
  - loss_weight: 1.0
    metric: mae
    prediction: hamiltonian
    target: hamiltonian
  #- loss_weight: 1.0
  #  metric: mae
  #  prediction: band_gap
  #  target: band_gap
  #- loss_weight: 0.001
  #  metric: mae
  #  prediction: band_energy
  #  target: band_energy
  #- loss_weight: 1.0
  #  metric: mae
  #  prediction: overlap
  #  target: overlap
  #- loss_weight: 1.0
  #  metric: mae
  #  prediction: peak
  #  target: peak
  #- loss_weight: 0.0
  #  metric: mae
  #  prediction: hamiltonian_imag
  #  target: hamiltonian_imag
  #- loss_weight: 0.0001
  #  metric: abs_mae
  #  prediction: wavefunction
  #  target: wavefunction
  metrics:
  - metric: mae
    prediction: hamiltonian
    target: hamiltonian
  #- metric: mae
  #  prediction: band_gap
  #  target: band_gap
  #- metric: mae
  #  prediction: peak
  #  target: peak
  #- metric: mae
  #  prediction: overlap
  #  target: overlap
  #- metric: mae
  #  prediction: hamiltonian_imag
  #  target: hamiltonian_imag
  #- metric: mae
  #  prediction: hamiltonian_imag
  #  target: hamiltonian_imag
  #- metric: mae
  #  prediction: band_energy
  #  target: band_energy
  #- metric: abs_mae
  #  prediction: wavefunction
  #  target: wavefunction

# Generally, the optim_params module only needs to set the initial learning rate (lr)
optim_params:
  lr: 0.001
  lr_decay: 0.5
  lr_patience: 5
  gradient_clip_val: 0.0
  max_epochs: 3000
  min_epochs: 100
  stop_patience: 30

output_nets:
  output_module: HamGNN_out
  HamGNN_out:
    ham_only: true # true: Only the Hamiltonian H is computed; 'false': Fit both H and S
    ham_type: abacus # openmx: fit openmx Hamiltonian; abacus: fit abacus Hamiltonian
    nao_max: 27 # The maximum number of atomic orbitals in the data set, which can be 14, 19 or 26 for openmx, and 13, 27, 40 for abacus
    add_H0: true # Generally true, the complete Hamiltonian is predicted as the sum of H_scf plus H_nonscf (H0)
    symmetrize: true # if set to true, the Hermitian symmetry constraint is imposed on the Hamiltonian
    calculate_band_energy: false # Whether to calculate the energy bands to train the model
    num_k: 5 # When calculating the energy bands, the number of K points to use
    band_num_control: null # `dict`: controls how many orbitals are considered for each atom in energy bands; `int`: [vbm-num, vbm+num]; `null`: all bands
    k_path: null # `auto`: Automatically determine the k-point path; `null`: random k-point path; `list`: list of k-point paths provided by the user
    soc_switch: false # if true, fit the SOC Hamiltonian
    nonlinearity_type: norm # norm or gate

profiler_params:
  progress_bar_refresh_rat: 1
  train_dir: ./train #The folder for saving training information and prediction results. This directory can be read by tensorboard to monitor the training process.

representation_nets:
  # Network parameters usually do not need to be changed.
  HamGNN_pre:
    cutoff: 20.0
    resnet: True
    cutoff_func: cos
    edge_sh_normalization: component
    edge_sh_normalize: true
    ######## Irreps set 1 (crystal): ################
    feature_irreps_hidden: 32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e
    irreps_edge_output: 32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e
    irreps_edge_sh: 0e + 1o + 2e + 3o + 4e + 5o + 6e
    irreps_node_features: 32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e
    irreps_node_output: 32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e
    irreps_triplet_output: 32x0o+128x0e+128x1o+64x1e+128x2e+32x2o+64x3o+32x3e+32x4o+32x4e
    invariant_layers: 2
    invariant_neurons: 64
    num_interaction_layers: 5
    num_radial: 8
    num_spherical: 8
    export_triplet: false
    rbf_func: bessel
    set_features: true
    add_edge_tp: false
    num_types: 120
    irreps_node_prev: 16x0o+16x0e+8x1o+8x1e+8x2e+8x2o+8x3o+8x3e+8x4o+8x4e
    num_node_attr_feas: 64

setup:
  GNN_Net: HamGNN_pre
  accelerator: ddp
  ignore_warnings: true
  checkpoint_path: ./network_weights_silicon.ckpt # Path to the model weights file
  load_from_checkpoint: false
  resume: false
  num_gpus: [0,1,2] # null: use cpu; [i]: use the ith GPU device
  precision: 32
  property: hamiltonian
  stage: fit # fit: training; test: inference

I suspect the number of atoms is too large in my database, which causes the error, but I'm not entirely certain of the exact reason. Could you provide some guidance?

QuantumLab-ZY commented 5 months ago

Dear Youhaojen,

The number of atoms in your database is too large. Mg2Ge contains 96 atoms, and the memory required for training this structure exceeds that of the RTX 3090. You can try using smaller features first, such as 32x0o+128x0e+32x1o+32x1e+32x2e+32x2o+32x3o+32x3e+32x4o+32x4e. If it still exceeds the GPU's memory, then you will have to use a smaller structure (less than 40 atoms in the unit cell, I guess) for the training set. By the way, considering the limited size of your training set with only 60 structures, employing multiple GPUs may not be necessary.

Best wishes, Yang Zhong

Youhaojen commented 5 months ago

Dear Yang Zhong,

Thanks for your kind reply. The 96 atoms of Mg2Ge data are too large for RTX 3090. Therefore, I used 200 data of unitcell for the training model. It is working now.

Thank you again.

Best, Hao-Jen You