jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

Implement Distributed Training #19

Closed jyaacoub closed 1 year ago

jyaacoub commented 1 year ago

Use distributed training options as described in: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

Likely need to use torch.nn.parallel.DistributedDataParallel if we need to train the module.

According to the following, using torch.nn.DataParallel is not recomemded if you want to train. Can be okay for ESM module that will not be trained...

image

jyaacoub commented 1 year ago

We need to aim for model parallel training for the sake of scalability in the future when we try out different architecures

image

Image comes from a tutorial on DDP with pytorch: https://medium.com/mlearning-ai/distributed-data-parallel-with-slurm-submitit-pytorch-168c1004b2ca

jyaacoub commented 1 year ago

From the FSDP tutorial

At high level FSDP works as follows:

In constructor

In forward path

In backward path

jyaacoub commented 1 year ago

Communication overhead with torch.nn.DataParallel makes it hard to justify:

1 gpu:

submitit INFO (2023-08-02 15:48:10,845) - Starting with JobEnvironment(job_id=9557258, hostname=node99.uhnh4h.cluster, local_rank=0(1), node=0(1), global_rank=0(1))
submitit INFO (2023-08-02 15:48:10,845) - Loading pickle: /cluster/home/t122995uhn/projects/MutDTA/log_test/Test_gpu_1/9557258_submitted.pkl
[device(type='cuda', index=0)]
Time to load: 82.89621496200562
Time to tok: 0.031525611877441406
Time to inf: **0.7391221523284912**

2 gpu:

submitit INFO (2023-08-02 15:43:48,784) - Starting with JobEnvironment(job_id=9557040, hostname=node35.uhnh4h.cluster, local_rank=0(1), node=0(1), global_rank=0(1))
submitit INFO (2023-08-02 15:43:48,784) - Loading pickle: /cluster/home/t122995uhn/projects/MutDTA/log_test/Test_gpu_2/9557040_submitted.pkl
[device(type='cuda', index=0), device(type='cuda', index=1)]
Time to load: 83.33079123497009
Time to tok: 0.03038311004638672
Time to inf: **4.740084886550903**

3 gpu:

submitit INFO (2023-08-02 15:51:26,154) - Starting with JobEnvironment(job_id=9557307, hostname=node35.uhnh4h.cluster, local_rank=0(1), node=0(1), global_rank=0(1))
submitit INFO (2023-08-02 15:51:26,155) - Loading pickle: /cluster/home/t122995uhn/projects/MutDTA/log_test/Test_gpu_3/9557307_submitted.pkl
[device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2)]
Time to load: 83.82851767539978
Time to tok: 0.03020191192626953
Time to inf: **9.758665084838867**

Code to replicate:

#%%
import os, random, itertools, math, pickle, json, time
import config

from tqdm import tqdm
import pandas as pd
import numpy as np

import submitit

import torch
from transformers import AutoTokenizer, EsmModel

from src.models.utils import print_device_info

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print_device_info(device)

#%% Testing esm model 
# based on https://github.com/facebookresearch/esm#main-models-you-should-use-
# I should be using ESM-MSA-1b (esm1b_t33_650M_UR50S)
# from https://github.com/facebookresearch/esm/tree/main#pre-trained-models-
# Luke used esm2_t6_8M_UR50D for his experiments

# https://huggingface.co/facebook/esm2_t33_650M_UR50D is <10GB
# https://huggingface.co/facebook/esm2_t36_3B_UR50D is 11GB
df = pd.read_csv('../data/DavisKibaDataset/davis_msa/processed/XY.csv', index_col=0)
prot_seqs = list(df['prot_seq'].unique())

# %%
def run_esm(prot_seqs):
    start_t = time.time()
    # this will raise a warning since lm head is missing but that is okay since we are not using it:
    device_ids = [torch.device(f'cuda:{x}') for x in range(torch.cuda.device_count())]
    print(device_ids)

    esm_tok = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
    esm_mdl = EsmModel.from_pretrained('facebook/esm2_t6_8M_UR50D').to(device_ids[0])
    print('Time to load:', time.time()-start_t)
    start_t = time.time()

    tok = esm_tok(prot_seqs, return_tensors='pt', padding=True)
    print('Time to tok:', time.time()-start_t)
    start_t = time.time()

    out = torch.nn.DataParallel(esm_mdl, device_ids=device_ids)(**tok)
    print('Time to inf:', time.time()-start_t)
    print(out)

    return out.last_hidden_state.to(torch.device('cpu'))

#%%
jobs = []
for num_gpus in range(1,4):
    print(num_gpus)
    log_folder = f"log_test/Test_gpu_{num_gpus}"
    executor = submitit.AutoExecutor(folder=log_folder)

    executor.update_parameters(timeout_min=30, 
                            slurm_partition="gpu",
                            slurm_account='kumargroup_gpu',
                            slurm_mem='10G',
                            slurm_gres=f'gpu:v100:{num_gpus}',
                            slurm_job_name=f'Test_gpu_{num_gpus}')

    # The submission interface is identical to concurrent.futures.Executor
    job = executor.submit(run_esm, prot_seqs[:15])  # will compute add(5, 7)
    print('state:', job.state)
    print('id:', job.job_id)  # ID of your job
    jobs.append(job)

output = job.result()  # waits for the submitted function to complete and returns its output

# %%
jyaacoub commented 1 year ago

Validating torch.nn.DataParallel results

I retested this with a simple experiment looking at total inference time for running on 1 GPU with 15 input sequences vs on 3 GPUs with 15*3=45 input sequences (Note: these are the same sequences just repeated 3 times).

1 GPU with 15 input sequences

submitit INFO (2023-08-14 14:09:44,695) - Starting with JobEnvironment(job_id=9896550, hostname=node99.uhnh4h.cluster, local_rank=0(1), node=0(1), global_rank=0(1))
submitit INFO (2023-08-14 14:09:44,695) - Loading pickle: /cluster/home/t122995uhn/projects/MutDTA/log_test/Test_gpu_1_15/9896550_submitted.pkl
Devices [device(type='cuda', index=0)]
Time to load: 83.6283392906189
Time to tok: 0.03265643119812012
Time to inf: **1.4991624355316162**
submitit INFO (2023-08-14 14:11:13,632) - Job completed successfully

3 GPU with 45 input sequences

submitit INFO (2023-08-15 11:06:42,430) - Starting with JobEnvironment(job_id=9907346, hostname=node99.uhnh4h.cluster, local_rank=0(1), node=0(1), global_rank=0(1))
submitit INFO (2023-08-15 11:06:42,430) - Loading pickle: /cluster/home/t122995uhn/projects/MutDTA/log_test/Test_gpu_3_15/9907346_submitted.pkl
Devices [device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2)]
Time to load: 84.35338306427002
Time to tok: 0.09466838836669922
Time to infer: 9.231343746185303
submitit INFO (2023-08-15 11:08:18,806) - Job completed successfully
1.5*3 = 4.5s
9.2-4.5 = 4.7s 

4.7s worth of communication overhead

which is greater than even if we just ran it sequentially. We cant increase the batch size any further due to memory limitations so this is the best we can get with this method.

Code to replicate:

#%%
import os, random, itertools, math, pickle, json, time
import config

from tqdm import tqdm
import pandas as pd
import numpy as np

import submitit

import torch
from transformers import AutoTokenizer, EsmModel

from src.models.utils import print_device_info

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print_device_info(device)

#%% Testing esm model 
# based on https://github.com/facebookresearch/esm#main-models-you-should-use-
# I should be using ESM-MSA-1b (esm1b_t33_650M_UR50S)
# from https://github.com/facebookresearch/esm/tree/main#pre-trained-models-
# Luke used esm2_t6_8M_UR50D for his experiments

# https://huggingface.co/facebook/esm2_t33_650M_UR50D is <10GB
# https://huggingface.co/facebook/esm2_t36_3B_UR50D is 11GB
df = pd.read_csv('../data/DavisKibaDataset/davis_nomsa/processed/XY.csv', index_col=0)
prot_seqs = list(df['prot_seq'].unique())

# %%
def run_esm(prot_seqs):
    start_t = time.time()
    # this will raise a warning since lm head is missing but that is okay since we are not using it:
    device_ids = [torch.device(f'cuda:{x}') for x in range(torch.cuda.device_count())]
    print('Devices', device_ids)

    esm_tok = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
    esm_mdl = EsmModel.from_pretrained('facebook/esm2_t6_8M_UR50D').to(device_ids[0])
    print('Time to load:', time.time()-start_t)
    start_t = time.time()

    tok = esm_tok(prot_seqs, return_tensors='pt', padding=True)
    print('Time to tok:', time.time()-start_t)
    start_t = time.time()

    # DATA Parallel
    out = torch.nn.DataParallel(esm_mdl, device_ids=device_ids)(**tok)
    print('Time to infer:', time.time()-start_t)
    # print(out)

    return f'COMPLETED {len(prot_seqs)} seqs with {len(device_ids)} gpus'

#%%
jobs = []
batch = 15
for num_gpus in [3]:
    print(num_gpus)
    log_folder = f"log_test/Test_gpu_{num_gpus}_{batch*num_gpus}"
    executor = submitit.AutoExecutor(folder=log_folder)

    executor.update_parameters(timeout_min=30, 
                            slurm_partition="gpu",
                            slurm_account='kumargroup_gpu',
                            slurm_mem='15G',
                            slurm_gres=f'gpu:v100:{num_gpus}',
                            slurm_constraint='gpu32g',
                            slurm_job_name=f'Test_gpu_{num_gpus}')

    # The submission interface is identical to concurrent.futures.Executor
    job = executor.submit(run_esm, prot_seqs[:batch]*num_gpus)  # will compute add(5, 7)
    print('state:', job.state)
    print('id:', job.job_id)  # ID of your job
    jobs.append(job)

output = job.result()  # waits for the submitted function to complete and returns its output

# %%
jyaacoub commented 1 year ago

Initial DDP tests

Tests with DGraphDTA with batch_size=64 on 3 GPUs

Non-weighted results (edge_weight_opt == 'binary'):

For weighted it's even more impressive:

It seems like the thing that takes the most time is the communication overhead since we also get 11s per epoch here:

----------------- HYPERPARAMETERS -----------------
               Batch size: 64
            Learning rate: 0.0001
                  Dropout: 0.4
               Num epochs: 100
              Edge option: simple
Data loaded
DGraphDTA Loaded
Epoch 1/100: Train Loss: 0.0064, Time elapsed: 24.0365629196167
Epoch 2/100: Train Loss: 0.0055, Time elapsed: 11.699882984161377
Epoch 3/100: Train Loss: 0.0054, Time elapsed: 11.656761646270752

vs 84s for non-distributed on only 80% of the data:

(.venv) [t122995uhn@h4huhnlogin1:MutDTA]$ tail /cluster/home/t122995uhn/projects/sbatch/out/old/lt_2023-08-15/DG_d
avis_weighted9771036.out
Epoch 1999/2000: 100%|██████████| 376/376 [01:24<00:00,  4.45batch/s, Train Loss=0.14] 
Epoch 1999/2000: Train Loss: 0.1399, Val Loss: 0.4199, Best Val Loss: 0.3952 @ Epoch 1053
Epoch 2000/2000: 100%|██████████| 376/376 [01:25<00:00,  4.42batch/s, Train Loss=0.14]

This is a >85% increase in performance!

Code

to replicate these results run the https://github.com/jyaacoub/MutDTA/commit/7407ca03ba36b87cb783634c9beeb065327dd602 version of the test.py file.

jyaacoub commented 1 year ago

Test with EsmDTA with batch_size=10 on 3GPUS:

Massive improvement here as well

(.venv) [t122995uhn@h4huhnlogin1:MutDTA]$ tail -n 5 /cluster/home/t122995uhn/projects/sbatch/out/old/ED_davis_weighted9864276.out      
Epoch 239/2000: Train Loss: 0.2128, Val Loss: 0.3600, Best Val Loss: 0.3494 @ Epoch 205
Epoch 240/2000: 100%|██████████| 2401/2401 [17:41<00:00,  2.26batch/s, Train Loss=0.211]
Model saved to: results/model_checkpoints/ours//EDM_davisD_nomsaF_simpleE_10B_0.0001LR_0.4D_2000E.model_tmp
Epoch 240/2000: Train Loss: 0.2106, Val Loss: 0.3512, Best Val Loss: 0.3494 @ Epoch 205
Epoch 241/2000:  47%|████▋     | 1132/2401 [08:11<10:20,  2.05batch/s, Train Loss=0.204]slurmstepd: error: *** JOB 9864276 ON node99 CANCELLED AT 2023-08-14T15:24:43 DUE TO TIME LIMIT ***

(.venv) [t122995uhn@h4huhnlogin1:MutDTA]$ tail -n 2 slurm_tests/DDP/9939089/9939089_*.out
==> slurm_tests/DDP/9939089/9939089_0_log.out <==
Data loaded
Epoch 1/100: Train Loss: 0.0007, Time elapsed: 297.28276681900024

==> slurm_tests/DDP/9939089/9939089_1_log.out <==
Data loaded
Epoch 1/100: Train Loss: 0.0009, Time elapsed: 297.28658509254456

==> slurm_tests/DDP/9939089/9939089_2_log.out <==
Data loaded
Epoch 1/100: Train Loss: 0.0010, Time elapsed: 297.28863620758057

297.3/60 = 5.0 mins vs previous 17.5 mins is a 3.5x speed up

jyaacoub commented 1 year ago

This is pretty much resolved, environment works well, it's just a bit messy and needs to be cleaned up. Marking as done, for now, will come back if FSDP is necessary