baker-laboratory / RoseTTAFold-All-Atom

Other
572 stars 97 forks source link

CUDA Out of Memory Error #129

Open Dong-HoSeo opened 2 weeks ago

Dong-HoSeo commented 2 weeks ago

Environment:

GPU: NVIDIA RTX 3060 12GB PyTorch Version: 2.3.1+cu118 CUDA Version: 11.8 OS: Linux Ubuntu Python Version: 3.10.13 Amino acid : 651

I am encountering a CUDA Out of Memory error when running the run_inference.py script from the RoseTTAFold-All-Atom repository. The error occurs during the model inference step. Below is the detailed error traceback: Running PSIPRED Running hhsearch Error executing job with overrides: [] Traceback (most recent call last): File "/home/dhseo/Data_HDD2/RoseTTAFold-All-Atom/rf2aa/run_inference.py", line 206, in main runner.infer() File "/home/dhseo/Data_HDD2/RoseTTAFold-All-Atom/rf2aa/run_inference.py", line 155, in infer outputs = self.run_model_forward(input_feats) File "/home/dhseo/Data_HDD2/RoseTTAFold-All-Atom/rf2aa/run_inference.py", line 121, in run_model_forward outputs = recycle_step_legacy(self.model, File "/home/dhseo/Data_HDD2/RoseTTAFold-All-Atom/rf2aa/training/recycling.py", line 30, in recycle_step_legacy output_i = ddp_model(input_i) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) File "/home/dhseo/Data_HDD2/RoseTTAFold-All-Atom/rf2aa/model/RoseTTAFoldModel.py", line 364, in forward pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) File "/home/dhseo/Data_HDD2/RoseTTAFold-All-Atom/rf2aa/model/layers/Embeddings.py", line 335, in forward templ = self.templ_stack(templ, rbf_feat, t1d, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop) # (B, T, L,L, d_templ) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) File "/home/dhseo/Data_HDD2/RoseTTAFold-All-Atom/rf2aa/model/layers/Embeddings.py", line 185, in forward templ = self.block[i_block](templ, rbf_feat, state) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) File "/home/dhseo/Data_HDD2/RoseTTAFold-All-Atom/rf2aa/model/Track_module.py", line 374, in forward gate = einsum('bli,bmj->blmij', left, right).reshape(B,L,L,-1) File "/home/dhseo/.local/lib/python3.10/site-packages/opt_einsum/contract.py", line 507, in contract return _core_contract(operands, contraction_list, backend=backend, einsum_kwargs) File "/home/dhseo/.local/lib/python3.10/site-packages/opt_einsum/contract.py", line 591, in _core_contract new_view = _einsum(einsum_str, *tmp_operands, backend=backend, *einsum_kwargs) File "/home/dhseo/.local/lib/python3.10/site-packages/opt_einsum/sharing.py", line 151, in cached_einsum return einsum(args, kwargs) File "/home/dhseo/.local/lib/python3.10/site-packages/opt_einsum/contract.py", line 353, in _einsum return fn(einsum_str, *operands, *kwargs) File "/home/dhseo/.local/lib/python3.10/site-packages/opt_einsum/backends/torch.py", line 45, in einsum return torch.einsum(equation, operands) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/functional.py", line 380, in einsum return einsum(equation, _operands) File "/home/dhseo/.local/lib/python3.10/site-packages/torch/functional.py", line 385, in einsum return _VF.einsum(equation, operands) # type: ignore[attr-defined] torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.72 GiB. GPU has a total capacity of 11.76 GiB of which 871.88 MiB is free. Including non-PyTorch memory, this process has 10.90 GiB memory in use. Of the allocated memory 9.07 GiB is allocated by PyTorch, and 873.77 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Any insights or suggestions on how to address this CUDA out of memory error would be greatly appreciated. Is there any way to further optimize the memory usage or any specific configurations that can help mitigate this issue?

Thank you in advance for your assistance!

inuyasha10121 commented 2 weeks ago

I was encountering a similar problem, and may have found a potential hot-fix (though this issue can stem from the sequence being too long, in which case this will not help). I have been trying to model a protein-RNA complex and found that even the protein alone (516 residues) would not launch on an identically spec'd GPU to yours. This error can stem from the sequence being too large to fit, but after snooping around with some promiscuous print debugging, I discovered that the model would go through one round of prediction fine, but would crash during the second round. After digging in, I saw that two elements are being popped off the GPU. I suspected that these are not getting caught by the garbage collector fast enough to be properly cleared before the next cycle, so I added a line to tell the GPU to force clear the cache to free up the memory again, and now I'm able to predict the protein structure.

rf2aa/training/recycling.py

def add_recycle_inputs(network_input, output_i, i_cycle, gpu, return_raw=False, use_checkpoint=False):
    input_i = {}
    for key in network_input:
        if key in ['msa_latent', 'msa_full', 'seq']:
            input_i[key] = network_input[key][:,i_cycle].to(gpu, non_blocking=True)
        else:
            input_i[key] = network_input[key]

    L = input_i["msa_latent"].shape[2]
    msa_prev, pair_prev, _, alpha, mask_recycle = output_i
    xyz_prev = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)

    input_i['msa_prev'] = msa_prev
    input_i['pair_prev'] = pair_prev
    input_i['xyz'] = xyz_prev
    input_i['mask_recycle'] = mask_recycle
    input_i['sctors'] = alpha
    input_i['return_raw'] = return_raw
    input_i['use_checkpoint'] = use_checkpoint

    input_i.pop('xyz_prev')
    input_i.pop('alpha_prev')
    torch.cuda.empty_cache() #JME: Force GPU to clear popped tensors
    return input_i