hpcaitech / FastFold

Optimizing AlphaFold Training and Inference on GPU Clusters
Apache License 2.0
557 stars 86 forks source link

Illegal memory access error after get_chi_atom_indices #101

Open georgkempf opened 1 year ago

georgkempf commented 1 year ago

I tried to predict a 5 subunit complex (in total ~5000 aa) and get the following error with various settings (1-4x A100 80GB, w/ and w/o --inplace, w/ and w/o --chunk_size 1-32). The error seems to be associated with exceeding the GPU memory and I am not sure if this is normal at the given sequence length and available GPU memory. I installed fastfold from the recent commit 930a58a into a clean conda environment and built triton from source. For a smaller complex (~2000 aa) it ran without errors.

terminate called after throwing an instance of 'c10::Error'
  what():  NCCL error in: /opt/conda/conda-bld/pytorch_1659484810403/work/torch/csrc/distributed/c10d/NCCLUtils.hpp:173, unhandled cuda error, NCCL version 2.10.3
Process Group destroyed on rank 1
Exception raised from ncclCommAbort at /opt/conda/conda-bld/pytorch_1659484810403/work/torch/csrc/distributed/c10d/NCCLUtils.hpp:173 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f43cf264497 in .../fastfold/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f43cf23bc94 in .../fastfold/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x19ea61 (0x7f44092e2a61 in .../fastfold/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #3: c10d::ProcessGroupNCCL::~ProcessGroupNCCL() + 0x118 (0x7f44092c6098 in .../fastfold/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #4: c10d::ProcessGroupNCCL::~ProcessGroupNCCL() + 0x9 (0x7f44092c6369 in .../fastfold/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #5: <unknown function> + 0x9d7799 (0x7f440f4fd799 in .../fastfold/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x354732 (0x7f440ee7a732 in .../fastfold/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #7: <unknown function> + 0x3555ff (0x7f440ee7b5ff in .../fastfold/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x116878 (0x55a0881ca878 in .../fastfold/bin/python3)
frame #9: <unknown function> + 0x11699d (0x55a0881ca99d in .../fastfold/bin/python3)
frame #10: <unknown function> + 0x1fd471 (0x55a0882b1471 in .../fastfold/bin/python3)
frame #11: <unknown function> + 0x10e937 (0x55a0881c2937 in .../fastfold/bin/python3)
frame #12: _PyGC_CollectNoFail + 0x2b (0x55a0882b134b in .../fastfold/bin/python3)
frame #13: PyImport_Cleanup + 0x371 (0x55a0882b11b1 in .../fastfold/bin/python3)
frame #14: Py_FinalizeEx + 0x7a (0x55a0882aff9a in .../fastfold/bin/python3)
frame #15: Py_Exit + 0x8 (0x55a0881454bc in .../fastfold/bin/python3)
frame #16: <unknown function> + 0x9141b (0x55a08814541b in .../fastfold/bin/python3)
frame #17: <unknown function> + 0x910ee (0x55a0881450ee in .../fastfold/bin/python3)
frame #18: PyRun_SimpleStringFlags + 0x4a (0x55a088141f12 in .../fastfold/bin/python3)
frame #19: Py_RunMain + 0x27b (0x55a0882abc1b in .../fastfold/bin/python3)
frame #20: Py_BytesMain + 0x39 (0x55a088283619 in .../fastfold/bin/python3)
frame #21: __libc_start_main + 0xf5 (0x7f444b239555 in /lib64/libc.so.6)
frame #22: <unknown function> + 0x1cf525 (0x55a088283525 in .../fastfold/bin/python3)

Traceback (most recent call last):
  File ".../.../FastFold/inference.py", line 519, in <module>
    main(args)
  File ".../.../FastFold/inference.py", line 149, in main
    inference_multimer_model(args)
  File ".../.../FastFold/inference.py", line 282, in inference_multimer_model
    torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
  File ".../fastfold/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File ".../fastfold/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File ".../fastfold/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedExceptio

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File ".../fastfold/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File ".../.../FastFold/inference.py", line 136, in inference_model
    out = model(batch)
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/hub/alphafold.py", line 507, in forward
    outputs, m_1_prev, z_prev, x_prev = self.iteration(
  File ".../.../FastFold/fastfold/model/hub/alphafold.py", line 264, in iteration
    template_embeds = self.template_embedder(
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/fastnn/embedders_multimer.py", line 351, in forward
    self.template_single_embedder(
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/fastnn/embedders_multimer.py", line 238, in forward
    all_atom_multimer.compute_chi_angles(
  File ".../.../FastFold/fastfold/utils/all_atom_multimer.py", line 403, in compute_chi_angles
    chi_atom_indices = get_chi_atom_indices(aatype.device)
  File ".../.../FastFold/fastfold/utils/all_atom_multimer.py", line 365, in get_chi_atom_indices
    return torch.tensor(chi_atom_indices, device=device)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Shenggan commented 1 year ago

The most likely reason is indeed out of memory, reasoning about single-precision sequences on a 40GB card, 5000 is the limit of length. It is recommended to use --inplace --chunk_size 1.

You may need to set PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:15000 to inference such an extreme long sequence. Or you may need to use bfloat16 for inference.

georgkempf commented 1 year ago

I tried again with --inplace --chunk_size 1 and PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:15000 but still same error. The card has 80 GB memory. Is my understanding correct that using 2 GPUs would speed-up the job but not increase the memory limit to 160 GB for a long sequence? What would be the best way to switch to bf16?

oahzxl commented 1 year ago

The cuda execution is asyncronize, so you need to set CUDA_LAUNCH_BLOCKING=1 to locate the bug. Or you can give us the fasta file to reproduce the bug.

georgkempf commented 1 year ago

This would be the traceback with CUDA_LAUNCH_BLOCKING=1. I saw in the installation instructions that cuda >= 11.4 is suggested for building triton but the environment.yml installs cudatoolkit 11.3 and there doesn't seem to be a colossalai relelease for cuda > 11.3. Could this cause any problems?

Traceback (most recent call last):
  File ".../fastfold/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File ".../.../FastFold/inference.py", line 136, in inference_model
    out = model(batch)
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/hub/alphafold.py", line 507, in forward
    outputs, m_1_prev, z_prev, x_prev = self.iteration(
  File ".../.../FastFold/fastfold/model/hub/alphafold.py", line 264, in iteration
    template_embeds = self.template_embedder(
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/fastnn/embedders_multimer.py", line 339, in forward
    pair_act = self.template_pair_embedder(
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/fastnn/embedders_multimer.py", line 215, in forward
    query_embedding = self.query_embedding_layer_norm(query_embedding)
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/fastnn/kernel/layer_norm.py", line 38, in forward
    return LayerNormTritonFunc.apply(input, self.normalized_shape, self.weight, self.bias,
  File ".../.../FastFold/fastfold/model/fastnn/kernel/triton/layer_norm.py", line 164, in forward
    _layer_norm_fwd_fused[(M,)](
  File ".../fastfold/lib/python3.8/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "<string>", line 23, in _layer_norm_fwd_fused
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
oahzxl commented 1 year ago

The bug has been fixed https://github.com/hpcaitech/FastFold/pull/103 and will merge into the main branch soon.

georgkempf commented 1 year ago

Great, thank's a lot! Now it was running for some time but another error occurred.

Command line args were: --gpus 4 --inplace --chunk_size 1

Traceback (most recent call last):
  File ".../fastfold/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File ".../.../FastFold/inference.py", line 136, in inference_model
    out = model(batch)
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/hub/alphafold.py", line 507, in forward
    outputs, m_1_prev, z_prev, x_prev = self.iteration(
  File ".../.../FastFold/fastfold/model/hub/alphafold.py", line 373, in iteration
    m, z, s = self.evoformer.inplace(
  File ".../.../FastFold/fastfold/model/fastnn/evoformer.py", line 319, in inplace
    m, z = checkpoint_blocks(
  File ".../.../FastFold/fastfold/utils/checkpointing.py", line 73, in checkpoint_blocks
    return exec(blocks, args)
  File ".../.../FastFold/fastfold/utils/checkpointing.py", line 60, in exec
    a = wrap(block(*a))
  File ".../.../FastFold/fastfold/model/fastnn/evoformer.py", line 131, in inplace
    z = self.communication.inplace(m[0], msa_mask, z)
  File ".../.../FastFold/fastfold/model/fastnn/ops.py", line 206, in inplace
    left_act = M_mask_col * left_act
RuntimeError: The size of tensor a (1282) must match the size of tensor b (5128) at non-singleton dimension 2
oahzxl commented 1 year ago

Fixed in https://github.com/hpcaitech/FastFold/pull/105.

georgkempf commented 1 year ago

Thanks again for the super fast fix. It is now running for some hours but at some point it still crashes with OOM. I set the max_split_size already down to 1000MB. Is it possible to globally change the precision to fp16?

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File ".../fastfold/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File ".../.../FastFold/inference.py", line 136, in inference_model
    out = model(batch)
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/hub/alphafold.py", line 507, in forward
    outputs, m_1_prev, z_prev, x_prev = self.iteration(
  File ".../.../FastFold/fastfold/model/hub/alphafold.py", line 389, in iteration
    outputs["sm"] = self.structure_module(
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/nn/structure_module.py", line 886, in forward
    outputs = self._forward_multimer(s, z, aatype, mask)
  File ".../.../FastFold/fastfold/model/nn/structure_module.py", line 825, in _forward_multimer
    s = s + self.ipa(s, z, rigids, mask)
  File ".../fastfold/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../.../FastFold/fastfold/model/nn/structure_module.py", line 397, in forward
    pt_att = sum([c**2 for c in pt_att])
RuntimeError: CUDA out of memory. Tried to allocate 4.70 GiB (GPU 0; 79.21 GiB total capacity; 77.87 GiB already allocated; 103.12 MiB free; 77.96 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
oahzxl commented 1 year ago

It goes out of memory in structure generation because multimer structure generation consumes much more memory than monomer sequence. But sadly we haven't optimized that part yet.

oahzxl commented 1 year ago

Monomer supports bf16, but multimer doesn't yet. Maybe we will support it in the future.

georgkempf commented 1 year ago

Looking forward to these optimizations. Overall great project!

bj600800 commented 4 months ago

Monomer supports bf16, but multimer doesn't yet. Maybe we will support it in the future.

How to use bf16 or fp32 to reduce memory use in GPU calculation? Thanks for your kind help!