jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

Setup AlphaFlow #84

Closed jyaacoub closed 3 months ago

jyaacoub commented 3 months ago

There are issues with environment setup due to OpenFold dependency that is very finicky to install on HPC:

Minimal code to replicate error with openfold

from openfold.utils.kernel.attention_core import attention_core
import torch
a = torch.tensor([[0]], device='cuda', dtype=torch.float32)
attention_core(a,a,a)

Returns:

RuntimeError: attn_softmax_inplace_forward_ not implemented on CPU

Solution:

  1. Clone the repo on login node at right commit
    git clone https://github.com/aqlaboratory/openfold
    cd openfold
    git reset --hard 103d037
  2. salloc into node with gpu
    • Might also need to load cuda environment before running pip with module load StdEnv/2020 cudacore/.11.7.0 (make sure to reload your env since this will replace the python you have loaded)
    • IMPORTANT: to avoid no kernel image available error, you must do this on the gpu you want to run on...
      salloc -t 1:00:00 -p gpu -A kumargroup_gpu --gres=gpu:v100:1 --mem 4G -c 4
  3. run pip install . on the cloned repo (this won't require internet access). This should install openfoldv1.0.1.
    • Might need to downgrade gcc version since CUDA11.3 only supports gcc < 10.0.0 (see compatibility chart)

CUDA error: no kernel Image

RuntimeError: CUDA error: no kernel image is available for execution on the device

Minimal code to replicate:

import torch
torch.matmul(torch.tensor([[0]],device='cuda', dtype=torch.float32), torch.tensor([0], device='cuda', dtype=torch.float32))

Fix: Make sure you are salloc'd into the GPU you want to run inference on (see above step 2)


Nonetheless, after setting this up it is a good idea also to run ring3 as the files are being produced so that we can immediately start training once AlphaFlow is done (see code for this here: https://github.com/jyaacoub/MutDTA/issues/77#issuecomment-2004646889)

jyaacoub commented 3 months ago

During testing we created some pdbs that had only 5 confirmations, here we delete those pdbs and thus must resubmit the jobs for all 20 since this was done before sorting the input_csv files by length

invalid_files = {}
removed = {}
for file in tqdm(os.listdir(alphaflow_dir)):
    if not file.endswith('.pdb'):
        continue
    file_path = os.path.join(alphaflow_dir, file)
    try:
        with open(file_path, 'r') as f:
            content = f.read()
        occurrences = content.count("MODEL")
        if occurrences < 50:
            invalid_files[file] = (occurrences, file_path)

        if occurrences == 5: # delete and flag
            os.remove(file_path)
            removed[file] = file_path

    except Exception as e:
        print(f"Error processing file {file_path}: {e}")
jyaacoub commented 3 months ago

CODE TO PREPARE INPUT FILES GIVEN A DIRECTORY OF MSAs

# This script takes a MSA directory input path (flat path) and prepares a directory for alphaflow input containing:                                                                                                                                                                                                                                            
# .                                                                                                                                                                                                                                                                                                                                                
# ├── input.csv                                                                                                                                                                                                                                                                                                                                    
# ├── msa_dir                                                                                                                                                                                                                                                                                                                                      
# │   └── <prot_id>                                                                                                                                                                                                                                                                                                                                
# │       └── a3m                                                                                                                                                                                                                                                                                                                                  
# │           └── <prot_id>.a3m -> ~[...]/msa_dir/<prot_id>.msa.a3m *                                                                                                                                                                                                                                                                              
#[...]                                                                                                                                                                                                                                                                                                                                             
# └── out_pdb                                                                                                                                                                                                                                                                                                                                      
#                                                                                                                                                                                                                                                                                                                                                  
# * is a symbolic link to the original msa file from the msa_dir                                                                                                                                                                                                                                                                                   
#                                                                                                                                                                                                                                                                                                                                                  
# 4 directories, 2 files                                                                                                                                                                                                                                                                                                                           
# input_csv has columns name,seqres                                                                                                                                                                                                                                                                                                                

# %%                                                                                                                                                                                                                                                                                                                                               
from multiprocessing import Pool                                                                                                                                                                                                                                                                                                                   
from tqdm import tqdm                                                                                                                                                                                                                                                                                                                              
import pandas as pd 
import os                                                                                                                                                                                                                                                                                                                                          
DATASET = 'kiba'                                                                                                                                                                                                                                                                                                                                   
MSA_DIR_IN = f'/cluster/home/t122995uhn/projects/data/{DATASET}/aln/'                                                                                                                                                                                                                                                                              
OUT_DIR =    f'/cluster/home/t122995uhn/projects/data/{DATASET}/alphaflow_io'                                                                                                                                                                                                                                                                      
MSA_DIR_OUT = os.path.join(OUT_DIR, 'msa_dir')
n_splits = 20  # Number of files to split `input.csv` into for parallel conformation generation

os.makedirs(MSA_DIR_OUT, exist_ok=True)                                                                                                                                                                                                                                                                                                            

def process_file(filename):                                                                                                                                                                                                                                                                                                                        
    # get prot_id and target seq                                                                                                                                                                                                                                                                                                                   
    prot_id, ext = os.path.splitext(filename)                                                                                                                                                                                                                                                                                                      
    fp = os.path.join(MSA_DIR_IN, filename)                                                                                                                                                                                                                                                                                                        
    dst = os.path.join(MSA_DIR_OUT, prot_id, 'a3m')                                                                                                                                                                                                                                                                                                
    t_seq = None                                                                                                                                                                                                                                                                                                                                   

    if ext == '.a3m':                                                                                                                                                                                                                                                                                                                              
        with open(fp, 'r') as f:                                                                                                                                                                                                                                                                                                                   
            for i in range(2):  # skip header                                                                                                                                                                                                                                                                                                      
                t_seq = f.readline()                                                                                                                                                                                                                                                                                                               
            t_seq = t_seq.strip()                                                                                                                                                                                                                                                                                                                  

        os.makedirs(dst, exist_ok=True)                                                                                                                                                                                                                                                                                                            
        dst = os.path.join(dst, f'{prot_id}.a3m')                                                                                                                                                                                                                                                                                                  
        # create symlink in correct dir structure                                                                                                                                                                                                                                                                                                  
        if not os.path.exists(dst):                                                                                                                                                                                                                                                                                                                
            os.symlink(fp, dst)                                                                                                                                                                                                                                                                                                                    

    elif ext == '.aln':                                                                                                                                                                                                                                                                                                                            
        with open(fp, 'r') as f_aln:                                                                                                                                                                                                                                                                                                               
            aln_seqs = [l for l in f_aln.readlines()]                                                                                                                                                                                                                                                                                              
        t_seq = aln_seqs[0].strip()                                                                                                                                                                                                                                                                                                                

        os.makedirs(dst, exist_ok=True)                                                                                                                                                                                                                                                                                                            
        dst = os.path.join(dst, f'{prot_id}.a3m')                                                                                                                                                                                                                                                                                                  
        if not os.path.exists(dst):                                                                                                                                                                                                                                                                                                                
            with open(dst, 'w') as f:                                                                                                                                                                                                                                                                                                              
                for i, l in enumerate(aln_seqs):                                                                                                                                                                                                                                                                                                   
                    f.write(f"> {i}\n")                                                                                                                                                                                                                                                                                                            
                    f.write(l)                                                                                                                                                                                                                                                                                                                     
    else:                                                                                                                                                                                                                                                                                                                                          
        return None                                                                                                                                                                                                                                                                                                                                

    return prot_id, t_seq                                                                                                                                                                                                                                                                                                                          

#%%                                                                                                                                                                                                                                                                                                                                                
files = os.listdir(MSA_DIR_IN)                                                                                                                                                                                                                                                                                                                     
with Pool(processes=None) as pool:                                                                                                                                                                                                                                                                                                                 
    results = list(tqdm(pool.imap(process_file, files), total=len(files)))                                                                                                                                                                                                                                                                         

#%% Filter out None results and create CSV                                                                                                                                                                                                                                                                                                         
csv_p = os.path.join(OUT_DIR, 'input.csv')                                                                                                                                                                                                                                                                                                         

results = [r for r in results if r is not None]                                                                                                                                                                                                                                                                                                    
df = pd.DataFrame(results, columns=['name', 'seqres'])                                                                                                                                                                                                                                                                                             

# sort by sequence length                                                                                                                                                                                                                                                                                                                          
df = df.sort_values(by='seqres', key=lambda x: x.str.len())                                                                                                                                                                                                                                                                                        
df.to_csv(csv_p, index=False)                                                                                                                                                                                                                                                                                                                      

# %% Save split CSV files                                                                                                                                                                                                                                                                                                                          
chunk_size = len(df) // n_splits                                                                                                                                                                                                                                                                                                                   

for i in range(n_splits):                                                                                                                                                                                                                                                                                                                          
    df_chunk = df.iloc[i*chunk_size:(i+1)*chunk_size]                                                                                                                                                                                                                                                                                              
    df_chunk.to_csv(os.path.join(OUT_DIR, f'input_{i}.csv'), index=False)                                                                                                                                                                                                                                                                          

# Handle any remainder                                                                                                                                                                                                                                                                                                                             
if len(df) % n_splits != 0:                                                                                                                                                                                                                                                                                                                        
    df_chunk = df.iloc[n_splits*chunk_size:]                                                                                                                                                                                                                                                                                                       
    df_chunk.to_csv(os.path.join(OUT_DIR, f'input_{n_splits}.csv'), index=False)  
jyaacoub commented 2 months ago

Setting up aflow on narval

1. load up the correct environment and create venv:

module load StdEnv/2020
module load python/3.9
module load cuda/11.4
python -m venv .venv
source .venv/bin/activate

2. Pip install all but openfold dependencies from the repo:

pip install numpy==1.21.2 pandas==1.5.3
pip install torch==1.12.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install biopython==1.79 dm-tree==0.1.6 modelcif==0.7 ml-collections==0.1.0 scipy==1.7.1 absl-py einops
pip install pytorch_lightning==2.0.4 fair-esm mdtraj wandb

3. Get v1.0.1 from the openfold repo @103d037

wget https://github.com/aqlaboratory/openfold/archive/refs/tags/v1.0.1.zip
unzip v1.0.1.zip
cd openfold-1.0.1

4. Salloc into gpu node for installing openfold

salloc --gres=gpu:1 -t 10 -c 6 --mem=8G

Make sure you are inside the root dir for the installed and unzipped openfold package before running the next cmd:

module load StdEnv/2020 cuda/11.4
pip install .

loaded packages for above pip install to work: image

jyaacoub commented 1 week ago

UPDATE AFLOW DOCS (https://github.com/jyaacoub/alphaflow/blob/deepspeed/docs/README-HPC.md) BASED ON NOUR's FEEDBACK:

Also, couple of other things that might be helpful for the docs:

  • module load cuda/11.4 at the very beginning of the AlphaFlow installation before installing torch (otherwise it defaults to cuda/11.0 I think, and this causes issues because cuda/11.4 is referenced later)
  • I ran into memory issues when installing torch (although I don’t think this happens all the time), fixed by modifying pip install to pip --no-cache-dir install torch==1.12.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
  • The openfold deepspeed initialization bug fix is probably better done before the openfold installation.

Also include rust module load info in docs

Another couple of things for the docs:

  • Remove the cd openfold step from AlphaFlow setup - the pip install should be done one directory up from here
  • For step 1 of DeepSpeed installation, load the following modules instead: module load StdEnv/2020 intel/2020.1.217 cuda/11.4 openmpi/4.0.3 mpi4py/3.1.3
  • Make sure to activate/reactivate venv AFTER module load mpi4py/3.1.3 because this adds another python to $PATH, which affects downstream installations.