aqlaboratory / openfold

Trainable, memory-efficient, and GPU-friendly PyTorch reproduction of AlphaFold 2
Apache License 2.0
2.84k stars 549 forks source link

Can l run without cuda or GPU? #229

Open bzjf opened 2 years ago

bzjf commented 2 years ago

l am running the instance on VMare. It could not install cuda, or the whole machine breaks. So can l run the instance without cuda? Regardless of speed.

gahdritz commented 2 years ago

Not in the default configuration, but if you disable the custom CUDA kernels, it's technically possible to run the model on CPU. Of course, that would cause a slowdown so large that I'd expect model inference to be impractical for almost any application. Also, CPU inference is completely untested. LMK if you still want to move forward and I can explain how to remove said kernels.

bzjf commented 2 years ago

Not in the default configuration, but if you disable the custom CUDA kernels, it's technically possible to run the model on CPU. Of course, that would cause a slowdown so large that I'd expect model inference to be impractical for almost any application. Also, CPU inference is completely untested. LMK if you still want to move forward and I can explain how to remove said kernels.

Thanks for your answer, and l'm kind of sure at least for long time l have to use CPU instead of cuda. And it would be kind of you to explain about disabling the cuda kernel. Thank you all the way.

epenning commented 2 years ago

EDIT: Note, this example is just for doing structure prediction, not training.

I'm not on the OpenFold team but I happened to have adjusted my installation so that it runs on an AMD GPU instead of Nvidia. I just did it by brute force, removing CUDA-related code that produced errors in my environment, so there are probably more elegant solutions and not all of my changes are necessarily correct. No guarantees if this will work for your case, but I'm sure you would need to do something similar. This is the relevant part of my diff, although it also includes some of the low-memory / off-loading configuration changes you may want to change:

Diff for testing OpenFold on AMD GPU I removed PyTorch from the environment for reasons specific to my own system. In general it should probably be left as a dependency. This diff should be originating from commit [feed4ae22edf899b37bee49293fff902bdd64e2d](https://github.com/aqlaboratory/openfold/tree/feed4ae22edf899b37bee49293fff902bdd64e2d) ```diff diff --git a/environment.yml b/environment.yml index 0dfb9db..3a3250b --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: openfold_venv +name: base channels: - conda-forge - bioconda @@ -9,14 +9,12 @@ dependencies: - conda-forge::pip - conda-forge::openmm=7.5.1 - conda-forge::pdbfixer - - conda-forge::cudatoolkit==11.3.* - bioconda::hmmer==3.3.2 - bioconda::hhsuite==3.3.0 - bioconda::kalign2==2.04 - - pytorch::pytorch=1.12.* - pip: - biopython==1.79 - - deepspeed==0.5.10 + - deepspeed - dm-tree==0.1.6 - ml-collections==0.1.0 - numpy==1.21.2 @@ -27,4 +25,3 @@ dependencies: - typing-extensions==3.10.0.2 - pytorch_lightning==1.5.10 - wandb==0.12.21 - - git+https://github.com/NVIDIA/dllogger.git diff --git a/openfold/config.py b/openfold/config.py index 4413da1..c402624 --- a/openfold/config.py +++ b/openfold/config.py @@ -196,7 +196,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool) eps = mlc.FieldReference(1e-8, field_type=float) templates_enabled = mlc.FieldReference(True, field_type=bool) embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) -tune_chunk_size = mlc.FieldReference(True, field_type=bool) +tune_chunk_size = mlc.FieldReference(False, field_type=bool) NUM_RES = "num residues placeholder" NUM_MSA_SEQ = "msa placeholder" @@ -369,11 +369,11 @@ config = mlc.ConfigDict( "chunk_size": chunk_size, # Use Staats & Rabe's low-memory attention algorithm. Mutually # exclusive with use_flash. - "use_lma": False, + "use_lma": True, # Use FlashAttention in selected modules. Mutually exclusive with # use_lma. Doesn't work that well on long sequences (>1000 residues). "use_flash": False, - "offload_inference": False, + "offload_inference": True, "c_z": c_z, "c_m": c_m, "c_t": c_t, @@ -451,7 +451,7 @@ config = mlc.ConfigDict( # runtime. Useful for inference on very long sequences. # Mutually exclusive with average_templates. Automatically # enabled if offload_inference is set. - "offload_templates": False, + "offload_templates": True, }, "extra_msa": { "extra_msa_embedder": { diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index 354d2cb..ce65ccc --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -34,7 +34,6 @@ from scipy.stats import truncnorm from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.chunk_utils import _chunk_slice -from openfold.utils.kernel.attention_core import attention_core from openfold.utils.tensor_utils import ( permute_final_dims, flatten_final_dims, @@ -480,13 +479,7 @@ class Attention(nn.Module): # [*, Q, H, C_hidden] if(use_memory_efficient_kernel): - if(len(biases) > 2): - raise ValueError( - "If use_memory_efficient_kernel is True, you may only " - "provide up to two bias terms" - ) - o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) - o = o.transpose(-2, -3) + raise ValueError("No memory efficient kernel on AMD.") elif(use_lma): biases = [ b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) diff --git a/openfold/model/structure_module.py b/openfold/model/structure_module.py old mode 100644 new mode 100755 index 72ea4cd..5cbf2cd --- a/openfold/model/structure_module.py +++ b/openfold/model/structure_module.py @@ -40,8 +40,6 @@ from openfold.utils.tensor_utils import ( flatten_final_dims, ) -attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") - class AngleResnetBlock(nn.Module): def __init__(self, c_hidden): @@ -352,12 +350,7 @@ class InvariantPointAttention(nn.Module): a += pt_att del pt_att a += square_mask.unsqueeze(-3) - # in-place softmax - attn_core_inplace_cuda.forward_( - a, - reduce(mul, a.shape[:-1]), - a.shape[-1], - ) + a = self.softmax(a) else: a = a + pt_att a = a + square_mask.unsqueeze(-3) diff --git a/openfold/np/relax/amber_minimize.py b/openfold/np/relax/amber_minimize.py index dfc0984..f461873 --- a/openfold/np/relax/amber_minimize.py +++ b/openfold/np/relax/amber_minimize.py @@ -105,7 +105,7 @@ def _openmm_minimize( _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues) integrator = openmm.LangevinIntegrator(0, 0.01, 0.0) - platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU") + platform = openmm.Platform.getPlatformByName("OpenCL" if use_gpu else "CPU") simulation = openmm_app.Simulation( pdb.topology, system, integrator, platform ) diff --git a/run_pretrained_openfold.py b/run_pretrained_openfold.py index 82642a7..e285388 --- a/run_pretrained_openfold.py +++ b/run_pretrained_openfold.py @@ -34,12 +34,6 @@ import torch torch_versions = torch.__version__.split(".") torch_major_version = int(torch_versions[0]) torch_minor_version = int(torch_versions[1]) -if( - torch_major_version > 1 or - (torch_major_version == 1 and torch_minor_version >= 12) -): - # Gives a large speedup on Ampere-class GPUs - torch.set_float32_matmul_precision("high") torch.set_grad_enabled(False) @@ -266,7 +260,8 @@ def main(args): feature_dict, feature_processor, args.config_preset, - args.multimer_ri_gap + args.multimer_ri_gap, + args.subtract_plddt ) unrelaxed_output_path = os.path.join( diff --git a/setup.py b/setup.py index 1d5c3a0..ed40d35 --- a/setup.py +++ b/setup.py @@ -16,10 +16,7 @@ import os from setuptools import setup, Extension, find_packages import subprocess -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME - -from scripts.utils import get_nvidia_cc - +from torch.utils.cpp_extension import BuildExtension version_dependent_macros = [ '-DVERSION_GE_1_1', @@ -27,50 +24,6 @@ version_dependent_macros = [ '-DVERSION_GE_1_5', ] -extra_cuda_flags = [ - '-std=c++14', - '-maxrregcount=50', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '--expt-relaxed-constexpr', - '--expt-extended-lambda' -] - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - -compute_capabilities = set([ - (3, 7), # K80, e.g. - (5, 2), # Titan X - (6, 1), # GeForce 1000-series -]) - -compute_capabilities.add((7, 0)) -_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) -if int(bare_metal_major) >= 11: - compute_capabilities.add((8, 0)) - -compute_capability, _ = get_nvidia_cc() -if compute_capability is not None: - compute_capabilities = set([compute_capability]) - -cc_flag = [] -for major, minor in list(compute_capabilities): - cc_flag.extend([ - '-gencode', - f'arch=compute_{major}{minor},code=sm_{major}{minor}', - ]) - -extra_cuda_flags += cc_flag - - setup( name='openfold', version='1.0.0', @@ -82,31 +35,8 @@ setup( packages=find_packages(exclude=["tests", "scripts"]), include_package_data=True, package_data={ - "openfold": ['utils/kernel/csrc/*'], "": ["resources/stereo_chemical_props.txt"] }, - ext_modules=[CUDAExtension( - name="attn_core_inplace_cuda", - sources=[ - "openfold/utils/kernel/csrc/softmax_cuda.cpp", - "openfold/utils/kernel/csrc/softmax_cuda_kernel.cu", - ], - include_dirs=[ - os.path.join( - os.path.dirname(os.path.abspath(__file__)), - 'openfold/utils/kernel/csrc/' - ) - ], - extra_compile_args={ - 'cxx': ['-O3'] + version_dependent_macros, - 'nvcc': ( - ['-O3', '--use_fast_math'] + - version_dependent_macros + - extra_cuda_flags - ), - } - )], - cmdclass={'build_ext': BuildExtension}, classifiers=[ 'License :: OSI Approved :: Apache Software License', 'Operating System :: POSIX :: Linux',( ```
lkqnaruto commented 2 years ago

Not in the default configuration, but if you disable the custom CUDA kernels, it's technically possible to run the model on CPU. Of course, that would cause a slowdown so large that I'd expect model inference to be impractical for almost any application. Also, CPU inference is completely untested. LMK if you still want to move forward and I can explain how to remove said kernels.

Could you please elaborate how to remove kernels? I want to run it under the debug mode using IDE so that I can better understand how the code run.