hpcaitech / FastFold

Optimizing AlphaFold Training and Inference on GPU Clusters
Apache License 2.0
556 stars 84 forks source link

DAP seems not work with 2 gpus #164

Closed semal closed 1 year ago

semal commented 1 year ago
image

Environment: docker, cudatoolkit11.3, A100, triton, inference data do not has a batch dimension. No matter how long the sequence, 908M memory used on each gpu, and 100% gpu util. But seems never finished.

Gy-Lu commented 1 year ago

Hi, have you tried to terminate this process(e.g. using Ctrl C) to see the operation blocked? And would you please show us your run.sh? If you can paste your sequence as well, we would try to reproduce this bug. Meanwhile, triton needs cuda11.4, it seems unreasonable to run it with cudatoolkit11.3. DAP doesn't need a batch dimension.

semal commented 1 year ago

Thx for replay. Here is the ctrl+c message:

image

Here is my run.sh: complex_id=CM012-1 python inference_with_features_pkl.py \ --output_dir ./output/ \ --output_prefix $complex_id \ --model_name model_1_multimer_v3 \ --model_preset multimer \ --gpus 2 \ --inplace \ --chunk_size 10 \ --param_path /data/ab01/gongchaohui/B30/20230301_alphafold-tools/alphafold/data/params/params_model_1_multimer_v3.npz \ /data/ab03/gongchaohui/B30/20221205_EpitopePipeline/output/180-rerun/$complex_id/save/$complex_id/features.pkl

I'm using a features.pkl which contain all input features to alphafold, I test it to single gpu and work well. Also old kernel(not using triton) test, still has the same problem. I'll try cuda 11.4 and report the test result later.

Gy-Lu commented 1 year ago

Hi, I found it strange that there is colossalai init information in your log. I am not sure if there is something wrong in your inference_with_features_pkl.py. I guess you use colossalai launch before in your script.

semal commented 1 year ago
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import random
import sys
import time
import tempfile
import contextlib

import numpy as np
import torch
import torch.multiprocessing as mp
import pickle
from fastfold.model.hub import AlphaFold

import fastfold
import fastfold.relax.relax as relax
from fastfold.common import protein, residue_constants
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.model.nn.triangular_multiplicative_update import set_fused_triangle_multiplication
from fastfold.data import feature_pipeline

from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map

if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11:
    torch.backends.cuda.matmul.allow_tf32 = True

def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)

@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
    with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
        fasta_file.write(fasta_str)
        fasta_file.seek(0)
        yield fasta_file.name

def add_data_args(parser: argparse.ArgumentParser):
    parser.add_argument('--chunk_size', type=int, default=None)
    parser.add_argument('--inplace', default=False, action='store_true')

def inference_model(rank, world_size, result_q, batch, args):
    os.environ['RANK'] = str(rank)
    os.environ['LOCAL_RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    # init distributed for Dynamic Axial Parallelism
    fastfold.distributed.init_dap()
    torch.cuda.set_device(rank)
    config = model_config(args.model_name)
    if args.chunk_size:
        config.globals.chunk_size = args.chunk_size

    if "v3" in args.param_path:
        set_fused_triangle_multiplication()

    config.globals.inplace = args.inplace
    config.globals.is_multimer = args.model_preset == 'multimer'
    model = AlphaFold(config)
    import_jax_weights_(model, args.param_path, version=args.model_name)

    model = inject_fastnn(model)
    model = model.eval()
    model = model.cuda()

    set_chunk_size(model.globals.chunk_size)

    with torch.no_grad():
        batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}

        t = time.perf_counter()
        out = model(batch)
        print(f"Inference time: {time.perf_counter() - t}")

    out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

    result_q.put(out)

    torch.distributed.barrier()
    torch.cuda.synchronize()

def main(args):
    if args.model_preset == "multimer":
        inference_multimer_model(args)
    else:
        inference_monomer_model(args)

def inference_multimer_model(args):
    print("running in multimer mode...")
    config = model_config(args.model_name)
    output_dir_base = args.output_dir
    output_prefix = args.output_prefix
    random_seed = args.data_random_seed
    if random_seed is None:
        random_seed = random.randrange(sys.maxsize)
    # seed_torch(seed=1029)

    feature_processor = feature_pipeline.FeaturePipeline(
        config.data
    )

    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)

    # 读取输入特征文件 features.pkl
    with open(args.input_feature_path, 'rb') as f:
        feature_dict = pickle.load(f)
    # 处理序列长度特征
    num_res = feature_dict['aatype'].shape[0]
    feature_dict['seq_length'] = np.asarray(
        [num_res] * num_res,
        dtype=np.int32
    )
    processed_feature_dict = feature_processor.process_features(
        feature_dict, mode='predict', is_multimer=True,
    )

    batch = processed_feature_dict

    manager = mp.Manager()
    result_q = manager.Queue()
    torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))

    out = result_q.get()
    out['asym_id'] = feature_dict['asym_id']
    if args.save_prediction_result:
        # Save the prediction result .pkl
        print(f'Saving prediction result to {args.output_dir}..')
        prediction_result_path = os.path.join(args.output_dir,
            f'{output_prefix}_{args.model_name}.pkl')
        with open(prediction_result_path, 'wb') as f:
            pickle.dump(out, f)

    # Toss out the recycling dimensions --- we don't need them anymore
    batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)

    plddt = out["plddt"]
    mean_plddt = np.mean(plddt)

    plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)

    unrelaxed_protein = protein.from_prediction(features=batch,
                                                result=out,
                                                b_factors=plddt_b_factors)

    # Save the unrelaxed PDB.
    unrelaxed_output_path = os.path.join(args.output_dir,
        f'{output_prefix}_{args.model_name}_unrelaxed.pdb')
    with open(unrelaxed_output_path, 'w') as f:
        f.write(protein.to_pdb(unrelaxed_protein))

    if(args.relaxation):
        print('Relaxing...')
        amber_relaxer = relax.AmberRelaxation(
            use_gpu=True,
            **config.relax,
        )

        # Relax the prediction.
        t = time.perf_counter()
        relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
        print(f"Relaxation time: {time.perf_counter() - t}")

        # Save the relaxed PDB.
        relaxed_output_path = os.path.join(args.output_dir,
            f'{output_prefix}_{args.model_name}_relaxed.pdb')
        with open(relaxed_output_path, 'w') as f:
            f.write(relaxed_pdb_str)

def inference_monomer_model(args):
    print("running in monomer mode...")
    config = model_config(args.model_name)

    output_dir_base = args.output_dir
    output_prefix = args.output_prefix
    random_seed = args.data_random_seed
    if random_seed is None:
        random_seed = random.randrange(sys.maxsize)
    # seed_torch(seed=1029)

    # 读取输入特征文件 features.pkl
    with open(args.input_feature_path, 'rb') as f:
        feature_dict = pickle.load(f)
    feature_dict['seq_length'] = np.array([feature_dict['seq_length']])
    feature_processor = feature_pipeline.FeaturePipeline(config.data)
    processed_feature_dict = feature_processor.process_features(
        feature_dict,
        mode='predict',
    )

    batch = processed_feature_dict

    manager = mp.Manager()
    result_q = manager.Queue()
    torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))

    out = result_q.get()

    # Toss out the recycling dimensions --- we don't need them anymore
    batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)

    plddt = out["plddt"]
    mean_plddt = np.mean(plddt)

    plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)

    unrelaxed_protein = protein.from_prediction(features=batch,
                                                result=out,
                                                b_factors=plddt_b_factors)

    # Save the unrelaxed PDB.
    unrelaxed_output_path = os.path.join(args.output_dir,
                                            f'{output_prefix}_{args.model_name}_unrelaxed.pdb')
    with open(unrelaxed_output_path, 'w') as f:
        f.write(protein.to_pdb(unrelaxed_protein))

    if(args.relaxation):
        amber_relaxer = relax.AmberRelaxation(
            use_gpu=True,
            **config.relax,
        )

        # Relax the prediction.
        t = time.perf_counter()
        relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
        print(f"Relaxation time: {time.perf_counter() - t}")

        # Save the relaxed PDB.
        relaxed_output_path = os.path.join(args.output_dir,
                                            f'{output_prefix}_{args.model_name}_relaxed.pdb')
        with open(relaxed_output_path, 'w') as f:
            f.write(relaxed_pdb_str)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "input_feature_path",
        type=str,
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
    )
    parser.add_argument(
        "--output_prefix",
        type=str,
        help="""Prefix for output files."""
    )
    parser.add_argument("--model_name",
                        type=str,
                        default="model_1",
                        help="""Name of a model config. Choose one of model_{1-5} or 
             model_{1-5}_ptm or model_{1-5}_multimer, as defined on the AlphaFold GitHub.""")
    parser.add_argument("--param_path",
                        type=str,
                        default=None,
                        help="""Path to model parameters. If None, parameters are selected
             automatically according to the model name from 
             ./data/params""")
    parser.add_argument(
        "--relaxation", action="store_true", default=False,
    )
    parser.add_argument("--gpus",
                        type=int,
                        default=1,
                        help="""Number of GPUs with which to run inference""")
    parser.add_argument('--save_prediction_result',
                        type=bool,
                        default=True)
    parser.add_argument('--data_random_seed', type=str, default=None)
    parser.add_argument(
        "--model_preset",
        type=str,
        default="monomer",
        choices=["monomer", "multimer"],
        help="Choose preset model configuration - the monomer model, the monomer model with "
        "extra ensembling, monomer model with pTM head, or multimer model",
    )
    add_data_args(parser)
    args = parser.parse_args()

    main(args)

This is my script, modified from inference.py. I thought the log from init_dap.

semal commented 1 year ago
image

update cuda to 11.6 and pytorch to 1.13, still has the problem. Seems using more memory at the initial state.

semal commented 1 year ago
image

After some debugging, I found the process stop at this line code.

Gy-Lu commented 1 year ago

Hi, I can't reproduce your problem. Can you provide more information of your environment, such as colossalai version and your communication backend?

semal commented 1 year ago

After lengthy troubleshooting and investigation, I finally discovered that the issue may have been caused by enabling IOMMU. However, I wasn't sure if disabling IOMMU would affect the use of Docker. Later, I found that setting the environment variable NCCl_P2P_DISABLE=1 could resolve this issue. But I am unsure if this environmental variable setting would cause other issues for distributed training.