sokrypton / ColabDesign

Making Protein Design accessible to all via Google Colab!
Other
620 stars 146 forks source link

OOM on AF DB proteins? #116

Open Abhishaike opened 1 year ago

Abhishaike commented 1 year ago

Creating a binder for this protein: https://alphafold.ebi.ac.uk/entry/Q8W3K0 and I'm getting this error both on T4's and A100's:

This error makes sense, but I'm confused as to how a protein that requires 100GB could've been folded by Alphafold in the first place? Shouldn't any protein that Alphafold can intake also be used by Colabdesign? Or does Colabdesign take more memory?

Stage 1: running (logits → soft)
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-3-09dc0c470929>](https://localhost:8080/#) in <module>
     36 if optimizer == "pssm_semigreedy":
---> 37   model.design_pssm_semigreedy(120, 32, **flags)
     38   pssm = softmax(model._tmp["seq_logits"],1)

25 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 164515000848 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation: 1006.63MiB
              constant allocation:    97.4KiB
        maybe_live_out allocation:  660.41MiB
     preallocated temp allocation:  153.22GiB
                 total allocation:  154.84GiB
Peak buffers:
    Buffer 1:
        Size: 22.78GiB
        XLA Label: copy
        Shape: f32[288,4,4,1152,1152]
        ==========================

    Buffer 2:
        Size: 22.78GiB
        XLA Label: copy
        Shape: f32[288,4,4,1152,1152]
        ==========================

    Buffer 3:
        Size: 22.78GiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_attention_starting_node/broadcast_in_dim[shape=(288, 4, 4, 1152, 1152) broadcast_dimensions=()]" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/stateful.py" source_line=640
        XLA Label: broadcast
        Shape: f32[288,4,4,1152,1152]
        ==========================

    Buffer 4:
        Size: 648.00MiB
        XLA Label: fusion
        Shape: f32[128,1152,1152]
        ==========================

    Buffer 5:
        Size: 648.00MiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/gating_linear/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
        XLA Label: custom-call
        Shape: f32[1327104,128]
        ==========================

    Buffer 6:
        Size: 648.00MiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/output_projection/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
        XLA Label: custom-call
        Shape: f32[1327104,128]
        ==========================

    Buffer 7:
        Size: 648.00MiB
        XLA Label: fusion
        Shape: f32[128,1152,1152]
        ==========================

    Buffer 8:
        Size: 648.00MiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/mul" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/layer_norm.py" source_line=205
        XLA Label: fusion
        Shape: f32[128,1152,1152]
        ==========================

    Buffer 9:
        Size: 648.00MiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/mul" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/layer_norm.py" source_line=205
        XLA Label: fusion
        Shape: f32[128,1152,1152]
        ==========================

    Buffer 10:
        Size: 648.00MiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/right_projection/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
        XLA Label: custom-call
        Shape: f32[1327104,128]
        ==========================

    Buffer 11:
        Size: 648.00MiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/right_gate/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
        XLA Label: custom-call
        Shape: f32[1327104,128]
        ==========================

    Buffer 12:
        Size: 648.00MiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/left_projection/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
        XLA Label: custom-call
        Shape: f32[1327104,128]
        ==========================

    Buffer 13:
        Size: 648.00MiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/left_gate/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
        XLA Label: custom-call
        Shape: f32[1327104,128]
        ==========================

    Buffer 14:
        Size: 648.00MiB
        XLA Label: fusion
        Shape: f32[128,1152,1152]
        ==========================

    Buffer 15:
        Size: 648.00MiB
        Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/structure_module/broadcast_in_dim[shape=(1152, 1152, 128) broadcast_dimensions=()]" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/stateful.py" source_line=640
        XLA Label: broadcast
        Shape: f32[1152,1152,128]
        ==========================

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:
Abhishaike commented 1 year ago
sokrypton commented 1 year ago

Gradient compute takes about 2X more memory. Semigreedy is just trying random mutations and accepts those that improve loss, which does not require gradient compute.

On Wed, Mar 8, 2023, 2:17 PM Abhishaike Mahajan @.***> wrote:

  • This is with PSSM, the memory issue doesnt occur with semigreedy. Why is PSSM so memory intensive?

— Reply to this email directly, view it on GitHub https://github.com/sokrypton/ColabDesign/issues/116#issuecomment-1460725802, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA76LAS2YE27MII42ITHYVDW3DLOJANCNFSM6AAAAAAVUD6OPI . You are receiving this because you are subscribed to this thread.Message ID: @.***>