facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8k stars 565 forks source link

Torch JIT breaks when memory_efficient_attention #406

Open Dango233 opened 1 year ago

Dango233 commented 1 year ago

🐛 Bug

torch.jit.trace breaks with the following error:
RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_generic
The output of the ops contains an int that can't be traced by JIT.

Command

To Reproduce

torch.jit.trace the module mentioned in
huggingface/diffusers#532

Expected behavior

No int output so module can be JIT traced.

danthe3rd commented 1 year ago

Thanks for reporting :) Should be fixed in https://github.com/facebookresearch/xformers/pull/438

geekinglcq commented 1 year ago

Thanks for reporting :) Should be fixed in #438

Hello, has it been fixed right now?

danthe3rd commented 1 year ago

Hi, the PR was merged so it should be yes. Please let us know if you have other issues

geekinglcq commented 1 year ago

Thank you. I have tried the newest commit of xformers, and the RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_generic is solved.
However, another problem appears. When I run the following code:

inputs = torch.randn(2,4,64,64, dtype=torch.half, device='cuda:6'), torch.randn(1, dtype=torch.half, device='cuda:6'), torch.randn(2, 77, 768, dtype=torch.half, device='cuda:6')
# Here pipeline is a `diffusers.LDMTextToImagePipeline`
with torch.no_grad():
    with torch.autocast("cuda"):
        jit_unet = torch.jit.trace(pipeline.unet, inputs, strict=False)

image But, if I run the code above twice, the error disappears itself 😂 and the pipeline works fine in later parts.

gigadeplex commented 9 months ago

I'm getting this error too. return self._op(*args, **kwargs or {}) RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_cutlass

roninjiang commented 8 months ago

got this error too.

xinlin-xiao commented 7 months ago

got save erros when I use torch.jit.trace ,any update?

ShijunK commented 3 months ago

I think the original fix (https://github.com/facebookresearch/xformers/pull/438) did work, but the issue was re-introduced later in https://github.com/facebookresearch/xformers/pull/587

question to @danthe3rd , what's the purpose of two int output values rng_seed, rng_offset? is it possible to re-apply the fix from #438?

danthe3rd commented 3 months ago

Oh this is a regression - right. The purpose of rng_seed, rng_offset is to keep the RNG state for the backward pass. This is useful when there is a dropout in the FW pass, and we need to mask the exact same values in the BW pass (and we don't want to save a "dropout" mask that would be too expensive). There are also complications due to replaying CUDA Graphs (in which case we want the RNG to be different). I believe we should be able to store these values in a torch.Tensor, or maybe there is a a best-practice for these sort of issues? @drisspg or @fmassa maybe?

drisspg commented 3 months ago

https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L1018-L1055 cc @danthe3rd

danthe3rd commented 3 months ago

Does JIT support SymInt? Because the version in PT outputs SymInt, not exactly sure why. Anyway we want to rely on the PyTorch version moving forward (with the C++ code moving to PyTorch repo), so hopefully this can be fixed at the same time.

ShijunK commented 2 months ago

@danthe3rd , which version of torch are you referring to? for torch 2.2.0, I see the type is Tensor for both seed and offset

func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) https://github.com/pytorch/pytorch/blob/4cb7dd0fc99981062cebf8d5a94e62b48bf78446/aten/src/ATen/native/native_functions.yaml#L14484-L14488

and here are when they are initialized: https://github.com/pytorch/pytorch/blob/d47f715d29d05e28b94c280f15dce097ef3dc7cb/aten/src/ATen/native/transformers/cuda/attention.cu#L978-L982

ShijunK commented 2 months ago

Anyway we want to rely on the PyTorch version moving forward (with the C++ code moving to PyTorch repo)

@danthe3rd are you referring to at::_scaled_dot_product_efficient_attention ?