tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.34k stars 407 forks source link

Feature: Burn equivalent to torch.retain_grad #1802

Open ArthurBrussee opened 3 months ago

ArthurBrussee commented 3 months ago

Feature description

I'm writing some tests to check gradients again a reference implementation. This works great for leaf nodes, but I can't atm seem to get gradients of intermediate nodes. PyTorch solves this with my_tensor.retain_grad(), which instructs the autodiff engine to keep the gradients during the backward pass. An equivalent in Burn could help with this.

Feature motivation

Testing of gradient activations.

Suggest a Solution

An exact equivalent a la my_tensor.retain_grad(), or, alternatively, make my_tensor.require_grad() valid on non-leaf nodes (currently panics). The semantics of retain/require are sligthly different, but, the use-cases for retained-but-only-if-calculated gradients don't seem that massive to me... not sure!

Apogeum12 commented 1 month ago

Any updates with these features? Maybe it's available now in burn, so then how's it possible to use it? I'm working on Gan, and it would be nice if I could use this feature, like in pytorch. Because currently I have to run the generator twice in one iteration to generate fake data in one iteration, in other cases, when I clone fake data, the generator loss grows into NaN after a couple iterations.