flatironinstitute / deepblast

Neural Networks for Protein Sequence Alignment
BSD 3-Clause "New" or "Revised" License
109 stars 19 forks source link

DeepBLAST align method not working #112

Closed mortonjt closed 1 year ago

mortonjt commented 1 year ago

I'm running into some weird Numba errors when I try to align sequences at inference time. Below is a self-contained reproducible example

from deepblast.trainer import DeepBLAST
from transformers import T5EncoderModel, T5Tokenizer
x = 'ETTSSIGEKVWDDKDNQDGKRPEKVSVNLLANGEKVKTLDVTSETNWKYEFKDLPKYDEGKKIEYTVTEDHVKDYTTDINGTTITNKYTP'
y = 'GETSATVTKNWDDNNNQDGKRPTEIKVELYQDGKATGKTAILNESNNWTHTWTGLDEKAKGQQVKYTVEELTKVKGYTTHVDNNDMGNLITTNKYTP'
pretrain_path="/mnt/home/jmorton/ceph/prot_t5_xl_uniref50"
path='/mnt/ceph/users/jmorton/research/gert/deep_blast_training/results/protrans_l8_dp0.5_bn_med-3x/epoch=17-step=288-validation_loss=0.1562.ckpt'

tokenizer = T5Tokenizer.from_pretrained(
    pretrain_path, do_lower_case=False)
lm = T5EncoderModel.from_pretrained(pretrain_path)

model = DeepBLAST.load_from_checkpoint(path, lm=lm, tokenizer=tokenizer)
model.align(x, y)

I get the following error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[10], line 5
      3 x = res.iloc[i][0] 
      4 y = res.iloc[i][1]
----> 5 s = model.align(x, y)
      6 dps.append(s)

File /mnt/ceph/users/jmorton/deepblast/deepblast/trainer.py:77, in DeepBLAST.align(self, x, y)
     75 seq, order = pack_sequences([x_code], [y_code])
     76 gen = self.aligner.traceback(seq, order)
---> 77 decoded, _ = next(gen)
     78 pred_x, pred_y, pred_states = zip(*decoded)
     79 s = ''.join(list(map(revstate_f, pred_states)))

File /mnt/ceph/users/jmorton/deepblast/deepblast/alignment.py:150, in NeedlemanWunschAligner.traceback(self, x, order)
    148 B, _, _ = match.shape
    149 for b in range(B):
--> 150     aln = self.nw.decode(
    151         match[b, :xlen[b], :ylen[b]].unsqueeze(0),
    152         gap[b, :xlen[b], :ylen[b]].unsqueeze(0)
    153     )
    154     decoded = self.nw.traceback(aln.squeeze())
    155     yield decoded, aln

File /mnt/ceph/users/jmorton/deepblast/deepblast/nw_cuda.py:315, in NeedlemanWunschDecoder.decode(self, theta, A)
    313 """ Shortcut for doing inference. """
    314 with torch.enable_grad():
--> 315     nll = self.forward(theta, A)
    316     v = torch.sum(nll)
    317     v_grad, _ = torch.autograd.grad(v, (theta, A), create_graph=True)

File /mnt/ceph/users/jmorton/deepblast/deepblast/nw_cuda.py:264, in NeedlemanWunschDecoder.forward(self, theta, A)
    263 def forward(self, theta, A):
--> 264     return NeedlemanWunschFunction.apply(theta, A, self.operator)

File /mnt/ceph/users/jmorton/deepblast/deepblast/nw_cuda.py:184, in NeedlemanWunschFunction.forward(ctx, theta, A, operator)
    181 Vt = torch.zeros((B), dtype=theta.dtype, device=theta.device)
    182 bpg = (B + (tpb - 1)) // tpb  # blocks per grid
--> 184 _forward_pass_kernel[tpb, bpg](theta.detach(), A.detach(), Q, Vt)
    186 ctx.save_for_backward(theta, A, Q)
    187 ctx.others = operator

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/cuda/compiler.py:727, in _KernelConfiguration.__call__(self, *args)
    726 def __call__(self, *args):
--> 727     return self.dispatcher.call(args, self.griddim, self.blockdim,
    728                                 self.stream, self.sharedmem)

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/cuda/compiler.py:915, in Dispatcher.call(self, args, griddim, blockdim, stream, sharedmem)
    913     kernel = next(iter(self.overloads.values()))
    914 else:
--> 915     kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
    917 kernel.launch(args, griddim, blockdim, stream, sharedmem)

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/cuda/compiler.py:940, in Dispatcher.typeof_pyval(self, val)
    936 def typeof_pyval(self, val):
    937     # Based on _DispatcherBase.typeof_pyval, but differs from it to support
    938     # the CUDA Array Interface.
    939     try:
--> 940         return typeof(val, Purpose.argument)
    941     except ValueError:
    942         if numba.cuda.is_cuda_array(val):
    943             # When typing, we don't need to synchronize on the array's
    944             # stream - this is done when the kernel is launched.

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/typing/typeof.py:35, in typeof(val, purpose)
     32 if ty is None:
     33     msg = _termcolor.errmsg(
     34         f"Cannot determine Numba type of {type(val)}")
---> 35     raise ValueError(msg)
     36 return ty

ValueError: Cannot determine Numba type of <class 'torch.Tensor'>

This is based off of pull request https://github.com/flatironinstitute/deepblast/pull/111 since the align method was broken earlier. @blackwer do you have any intuition what is going on with numba?

mortonjt commented 1 year ago

Another clarification, the above issue only happens on the CPU - I get GPU errors that look like they are to the install (which is also weird since the code is able to train)

TypingError                               Traceback (most recent call last)
Cell In[11], line 5
      3 x = res.iloc[i][0] 
      4 y = res.iloc[i][1]
----> 5 s = model.align(x, y)
      6 dps.append(s)

File /mnt/ceph/users/jmorton/deepblast/deepblast/trainer.py:77, in DeepBLAST.align(self, x, y)
     75 seq, order = pack_sequences([x_code], [y_code])
     76 gen = self.aligner.traceback(seq, order)
---> 77 decoded, _ = next(gen)
     78 pred_x, pred_y, pred_states = zip(*decoded)
     79 s = ''.join(list(map(revstate_f, pred_states)))

File /mnt/ceph/users/jmorton/deepblast/deepblast/alignment.py:150, in NeedlemanWunschAligner.traceback(self, x, order)
    148 B, _, _ = match.shape
    149 for b in range(B):
--> 150     aln = self.nw.decode(
    151         match[b, :xlen[b], :ylen[b]].unsqueeze(0),
    152         gap[b, :xlen[b], :ylen[b]].unsqueeze(0)
    153     )
    154     decoded = self.nw.traceback(aln.squeeze())
    155     yield decoded, aln

File /mnt/ceph/users/jmorton/deepblast/deepblast/nw_cuda.py:315, in NeedlemanWunschDecoder.decode(self, theta, A)
    313 """ Shortcut for doing inference. """
    314 with torch.enable_grad():
--> 315     nll = self.forward(theta, A)
    316     v = torch.sum(nll)
    317     v_grad, _ = torch.autograd.grad(v, (theta, A), create_graph=True)

File /mnt/ceph/users/jmorton/deepblast/deepblast/nw_cuda.py:264, in NeedlemanWunschDecoder.forward(self, theta, A)
    263 def forward(self, theta, A):
--> 264     return NeedlemanWunschFunction.apply(theta, A, self.operator)

File /mnt/ceph/users/jmorton/deepblast/deepblast/nw_cuda.py:184, in NeedlemanWunschFunction.forward(ctx, theta, A, operator)
    181 Vt = torch.zeros((B), dtype=theta.dtype, device=theta.device)
    182 bpg = (B + (tpb - 1)) // tpb  # blocks per grid
--> 184 _forward_pass_kernel[tpb, bpg](theta.detach(), A.detach(), Q, Vt)
    186 ctx.save_for_backward(theta, A, Q)
    187 ctx.others = operator

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/cuda/compiler.py:727, in _KernelConfiguration.__call__(self, *args)
    726 def __call__(self, *args):
--> 727     return self.dispatcher.call(args, self.griddim, self.blockdim,
    728                                 self.stream, self.sharedmem)

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/cuda/compiler.py:915, in Dispatcher.call(self, args, griddim, blockdim, stream, sharedmem)
    913     kernel = next(iter(self.overloads.values()))
    914 else:
--> 915     kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
    917 kernel.launch(args, griddim, blockdim, stream, sharedmem)

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/cuda/compiler.py:923, in Dispatcher._compile_for_args(self, *args, **kws)
    921 assert not kws
    922 argtypes = [self.typeof_pyval(a) for a in args]
--> 923 return self.compile(tuple(argtypes))

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/cuda/compiler.py:1084, in Dispatcher.compile(self, sig)
   1082 if not self._can_compile:
   1083     raise RuntimeError("Compilation disabled")
-> 1084 kernel = _Kernel(self.py_func, argtypes, link=self.link,
   1085                  **self.targetoptions)
   1086 # Inspired by _DispatcherBase.add_overload, but differs slightly
   1087 # because we're inserting a _Kernel object instead of a compiled
   1088 # function.
   1089 c_sig = [a._code for a in argtypes]

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/cuda/compiler.py:390, in _Kernel.__init__(self, py_func, argtypes, link, debug, lineinfo, inline, fastmath, extensions, max_registers, opt, device)
    381 self.extensions = extensions or []
    383 nvvm_options = {
    384     'debug': self.debug,
    385     'lineinfo': self.lineinfo,
    386     'fastmath': fastmath,
    387     'opt': 3 if opt else 0
    388 }
--> 390 cres = compile_cuda(self.py_func, types.void, self.argtypes,
    391                     debug=self.debug,
    392                     lineinfo=self.lineinfo,
    393                     inline=inline,
    394                     fastmath=fastmath,
    395                     nvvm_options=nvvm_options)
    396 tgt_ctx = cres.target_context
    397 code = self.py_func.__code__

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/cuda/compiler.py:200, in compile_cuda(pyfunc, return_type, args, debug, lineinfo, inline, fastmath, nvvm_options)
    197     flags.nvvm_options = nvvm_options
    199 # Run compilation pipeline
--> 200 cres = compiler.compile_extra(typingctx=typingctx,
    201                               targetctx=targetctx,
    202                               func=pyfunc,
    203                               args=args,
    204                               return_type=return_type,
    205                               flags=flags,
    206                               locals={},
    207                               pipeline_class=CUDACompiler)
    209 library = cres.library
    210 library.finalize()

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler.py:693, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    669 """Compiler entry point
    670 
    671 Parameter
   (...)
    689     compiler pipeline
    690 """
    691 pipeline = pipeline_class(typingctx, targetctx, library,
    692                           args, return_type, flags, locals)
--> 693 return pipeline.compile_extra(func)

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler.py:429, in CompilerBase.compile_extra(self, func)
    427 self.state.lifted = ()
    428 self.state.lifted_from = None
--> 429 return self._compile_bytecode()

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler.py:497, in CompilerBase._compile_bytecode(self)
    493 """
    494 Populate and run pipeline for bytecode input
    495 """
    496 assert self.state.func_ir is None
--> 497 return self._compile_core()

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler.py:476, in CompilerBase._compile_core(self)
    474         self.state.status.fail_reason = e
    475         if is_final_pipeline:
--> 476             raise e
    477 else:
    478     raise CompilerError("All available pipelines exhausted")

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler.py:463, in CompilerBase._compile_core(self)
    461 res = None
    462 try:
--> 463     pm.run(self.state)
    464     if self.state.cr is not None:
    465         break

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler_machinery.py:353, in PassManager.run(self, state)
    350 msg = "Failed in %s mode pipeline (step: %s)" % \
    351     (self.pipeline_name, pass_desc)
    352 patched_exception = self._patch_error(msg, e)
--> 353 raise patched_exception

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler_machinery.py:341, in PassManager.run(self, state)
    339 pass_inst = _pass_registry.get(pss).pass_inst
    340 if isinstance(pass_inst, CompilerPass):
--> 341     self._runPass(idx, pass_inst, state)
    342 else:
    343     raise BaseException("Legacy pass in use")

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler_machinery.py:296, in PassManager._runPass(self, index, pss, internal_state)
    294     mutated |= check(pss.run_initialization, internal_state)
    295 with SimpleTimer() as pass_time:
--> 296     mutated |= check(pss.run_pass, internal_state)
    297 with SimpleTimer() as finalize_time:
    298     mutated |= check(pss.run_finalizer, internal_state)

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/compiler_machinery.py:269, in PassManager._runPass.<locals>.check(func, compiler_state)
    268 def check(func, compiler_state):
--> 269     mangled = func(compiler_state)
    270     if mangled not in (True, False):
    271         msg = ("CompilerPass implementations should return True/False. "
    272                "CompilerPass with name '%s' did not.")

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/typed_passes.py:105, in BaseTypeInference.run_pass(self, state)
     99 """
    100 Type inference and legalization
    101 """
    102 with fallback_context(state, 'Function "%s" failed type inference'
    103                       % (state.func_id.func_name,)):
    104     # Type inference
--> 105     typemap, return_type, calltypes, errs = type_inference_stage(
    106         state.typingctx,
    107         state.targetctx,
    108         state.func_ir,
    109         state.args,
    110         state.return_type,
    111         state.locals,
    112         raise_errors=self._raise_errors)
    113     state.typemap = typemap
    114     # save errors in case of partial typing

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/typed_passes.py:83, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     81     infer.build_constraint()
     82     # return errors in case of partial typing
---> 83     errs = infer.propagate(raise_errors=raise_errors)
     84     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     86 # Output all Numba warnings

File /mnt/ceph/users/jmorton/venv/deepblast/lib/python3.9/site-packages/numba/core/typeinfer.py:1086, in TypeInferer.propagate(self, raise_errors)
   1083 force_lit_args = [e for e in errors
   1084                   if isinstance(e, ForceLiteralArg)]
   1085 if not force_lit_args:
-> 1086     raise errors[0]
   1087 else:
   1088     raise reduce(operator.or_, force_lit_args)

TypingError: Failed in cuda mode pipeline (step: nopython frontend)
Failed in cuda mode pipeline (step: nopython frontend)
Internal error at <numba.core.typeinfer.CallConstraint object at 0x155464d120a0>.
libNVVM cannot be found. Do `conda install cudatoolkit`:
libnvvm.so: cannot open shared object file: No such file or directory
During: resolving callee type: type(<numba.cuda.compiler.Dispatcher object at 0x155476b137c0>)
During: typing of call at /mnt/ceph/users/jmorton/deepblast/deepblast/nw_cuda.py (63)

Enable logging at debug level for details.

File "../../../../../deepblast/deepblast/nw_cuda.py", line 63:
def _forward_pass_device(theta, A, Q):
    <source elided>

            v = _soft_max_device(maxargs, Q[i, j])
            ^

During: resolving callee type: type(<numba.cuda.compiler.Dispatcher object at 0x155476b13900>)
During: typing of call at /mnt/ceph/users/jmorton/deepblast/deepblast/nw_cuda.py (76)

File "../../../../../deepblast/deepblast/nw_cuda.py", line 76:
def _forward_pass_kernel(theta, A, Q, Vt):
    <source elided>
    if batchid < theta.shape[0]:
        Vt[batchid] = _forward_pass_device(theta[batchid], A[batchid],
        ^

@blackwer let me know if you have any immediate thoughts -- I'll be continuing to look at this.

mortonjt commented 1 year ago

It turned out there were some modules not loaded. GPU seems to be working. Closing