bowang-lab / scGPT

https://scgpt.readthedocs.io/en/latest/
MIT License
1.05k stars 207 forks source link

TypeError: bwd(): incompatible function arguments. #253

Open IchigoJiken opened 2 months ago

IchigoJiken commented 2 months ago

Hi,

Similar error to the following closed issue, but I have not been able to resolve it. https://github.com/bowang-lab/scGPT/issues/43

I ran https://github.com/bowang-lab/scGPT/blob/main/tutorials/Tutorial_Perturbation.ipynb and got the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[28], line 6
      3 train_loader = pert_data.dataloader["train_loader"]
      4 valid_loader = pert_data.dataloader["val_loader"]
----> 6 train(
      7     model,
      8     train_loader,
      9 )
     11 val_res = eval_perturb(valid_loader, model, device)
     12 val_metrics = compute_perturbation_metrics(
     13     val_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
     14 )

File /y/Jiro/2024/BR240105-01_main/240722-foundation_model/240722-scGPT_test/240917-perturbation.py:61
     58     loss = loss_mse = criterion(output_values, target_values, masked_positions)
     60 model.zero_grad()
---> 61 scaler.scale(loss).backward()
     62 scaler.unscale_(optimizer)
     63 with warnings.catch_warnings(record=True) as w:

File ~/miniconda3/envs/scgpt_3/lib/python3.10/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File ~/miniconda3/envs/scgpt_3/lib/python3.10/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

File ~/miniconda3/envs/scgpt_3/lib/python3.10/site-packages/torch/autograd/function.py:288, in BackwardCFunction.apply(self, *args)
    282     raise RuntimeError(
    283         "Implementing both 'backward' and 'vjp' for a custom "
    284         "Function is not allowed. You should only implement one "
    285         "of them."
    286     )
    287 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 288 return user_fn(self, *args)

File ~/miniconda3/envs/scgpt_3/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:78, in FlashAttnQKVPackedFunc.backward(ctx, dout, *args)
     76 qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
     77 dqkv = torch.empty_like(qkv)
---> 78 _flash_attn_backward(
     79     dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
     80     dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
     81     ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
     82     rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
     83 )
     84 return dqkv, None, None, None, None, None, None, None

File ~/miniconda3/envs/scgpt_3/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:44, in _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, rng_state, num_splits, generator)
     36 """
     37 num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
     38 not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
   (...)
     41 This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
     42 """
     43 dout = dout.contiguous()  # CUDA code assumes that dout is contiguous
---> 44 _, _, _, softmax_d = flash_attn_cuda.bwd(
     45     dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
     46     max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal,
     47     # num_splits, generator, rng_state)
     48     num_splits, generator)
     49 # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
     50 #     breakpoint()
     51 return dq, dk, dv, softmax_d

TypeError: bwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: torch.Tensor, arg4: torch.Tensor, arg5: torch.Tensor, arg6: torch.Tensor, arg7: torch.Tensor, arg8: torch.Tensor, arg9: torch.Tensor, arg10: torch.Tensor, arg11: int, arg12: int, arg13: float, arg14: float, arg15: bool, arg16: bool, arg17: int, arg18: Optional[torch.Generator], arg19: Optional[torch.Tensor]) -> List[torch.Tensor]

Invoked with: tensor([[[ 7.6294e-06,  2.0862e-06, -1.4901e-06,  ..., -4.2915e-06,
          -2.4438e-06,  1.3232e-05],
         [ 2.9802e-07, -1.0252e-05,  7.8678e-06,  ...,  3.5167e-06,
          -3.5763e-06, -1.7881e-07],
         [-8.8811e-06, -9.5367e-07,  3.0398e-06,  ..., -1.1384e-05,
          -8.2850e-06,  8.2850e-06],
         ...,

I downloaded scGPThuman from https://drive.google.com/drive/folders/1oWh-ZRdhtoGQ2Fw24HP41FgLoomVo-y

I am using flash_attn 1.0.3 torch 2.1.0 scgpt 0.2.1 CUDA 12.1.0 Python 3.10.11