jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

Memory issues for AlphaFlow #89

Closed jyaacoub closed 4 months ago

jyaacoub commented 6 months ago

related: #84 Potential Solutions

Inputs 12+ failed for Davis (~178 proteins impacted)

203/229 for kiba

jyaacoub commented 6 months ago

Histogram plot with stagered labels shows the extent of this issue, especially with davis which has a lot of sequences above 1000 image

CODE FOR PLOTS

```python #%% import os import pandas as pd import matplotlib.pyplot as plt # Function to load sequences and their lengths from csv files def load_sequences(directory): lengths = [] labels_positions = {} # Dictionary to hold the last length of each file for labeling files = sorted([f for f in os.listdir(directory) if f.endswith('.csv') and f.startswith('input_')]) for file in files: file_path = os.path.join(directory, file) data = pd.read_csv(file_path) # Extract lengths current_lengths = data['seqres'].apply(len) lengths.extend(current_lengths) # Store the position for the label using the last length in the current file labels_positions[int(file.split('_')[1].split('.')[0])] = current_lengths.iloc[0] return lengths, labels_positions p = lambda d: f"/cluster/home/t122995uhn/projects/data/{d}/alphaflow_io" DATASETS = {d: p(d) for d in ['davis', 'kiba', 'pdbbind']} DATASETS['platinum'] = "/cluster/home/t122995uhn/projects/data/PlatinumDataset/raw/alphaflow_io" fig, axs = plt.subplots(len(DATASETS), 1, figsize=(10, 5*len(DATASETS) + len(DATASETS))) n_bins = 50 # Adjust the number of bins according to your preference for i, (dataset, d_dir) in enumerate(DATASETS.items()): # Load sequences and positions for labels lengths, labels_positions = load_sequences(d_dir) # Plot histogram ax = axs[i] n, bins, patches = ax.hist(lengths, bins=n_bins, color='blue', alpha=0.7) ax.set_title(dataset) # Add counts to each bin for count, x, patch in zip(n, bins, patches): ax.text(x + 0.5, count, str(int(count)), ha='center', va='bottom') # Adding red number labels for label, pos in labels_positions.items(): ax.text(pos, label, str(label), color='red', ha='center') # Optional: Additional formatting for readability ax.set_xlabel('Sequence Length') ax.set_ylabel('Frequency') ax.set_xlim([0, max(lengths) + 10]) # Adjust xlim to make sure labels fit plt.tight_layout() plt.show() # %% ```

jyaacoub commented 5 months ago

Alphaflow low mem code:

Issue mentions long_sequence_inference=True for ModelConfig https://github.com/bjing2016/alphaflow/issues/17

Outcome

This just barely output 2 additional proteins before running into OOM again.

jyaacoub commented 4 months ago

Solution

See commit https://github.com/jyaacoub/alphaflow/commit/b93e28915fb37900ab574c5e2230c705d166d17e for predict_deepspeed.py

To summarize how to get around memory issues there are 4 things that can be done. Listed in order of minimal impact to time complexity they are:

  1. --low_pres: Using low precision for parameters (torch.bfloat vs torch.float32). This will also improve time complexity since there are also fewer calculations to be made at the risk of reduced accuracy.
  2. --chunk_size: Chunking calculations on GPU by modules. Setting this to 4-2 is usually sufficient. on 2x a100 this and --low_pres would get us to sequence lengths of 1070-1167, respectively.
  3. --cpu_offload: offload parameters immediately when they are not in use to the CPU.
  4. --lma: low memory attention using Staats & Rabe's low-memory attention algorithm. This increases time complexity quite a bit and should only be used when absolutely necessary. For this to modify the default chunk_sizes we must change the source code for OpenFold (see https://github.com/aqlaboratory/openfold/pull/435)