pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
80.64k stars 21.65k forks source link

forward AD implimentation : _scaled_dot_product_efficient_attention #98164

Open enkeejunior1 opened 1 year ago

enkeejunior1 commented 1 year ago

🚀 The feature, motivation and pitch

Hi there,

I encountered an error message that requests me to file an issue regarding a feature implementation. The error message is as follows:

NotImplementedError: Trying to use forward AD with _scaled_dot_product_efficient_attention that does not support it because it has not been implemented yet. Please file an issue to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml so that we can prioritize its implementation. Note that forward AD support for some operators require PyTorch to be built with TorchScript and for JIT to be enabled. If the environment var PYTORCH_JIT=0 is set or if the library is not built with TorchScript, some operators may no longer be used with forward AD.

I would appreciate it if you could prioritize the implementation of this feature. Thank you for your help.

Alternatives

No response

Additional context

I ran forward AD of Stable-Diffusion with diffusers library, dtype = torch.float32, device=cuda.

ShenQianli commented 12 months ago

The same issue. I'm working on something similar to data distillation for stable diffusion models, which involves bi-level optimization and requires the computation of 'the gradient of gradient'. So far I don't have a plan b in my mind. An official implementation will be excellent help! Thanks!

kulikovvictor commented 11 months ago

+1

mickelliu commented 10 months ago

requires the computation of 'the gradient of gradient'.

This works in torch 1.13 but not in torch 2.0. What changed?

moktea commented 8 months ago

same issue

moktea commented 8 months ago

same issue

I've resolved this issue! Try using version 0.11.0 of the diffusers library. I solved it by switching from 0.21.0 to 0.11.0.