deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.13k stars 655 forks source link

scale gradient for backward pass #521

Closed enpasos closed 3 years ago

enpasos commented 3 years ago

Question or maybe Enhancement

I'm missing a feature to scale the gradient for backward pass (as e.g. used in MuZero) ... something like tensor scale + stop_gradient(tensor) (1 - scale) I'm not sure if the feature is missing or I'm simply not seeing the proper way how to do it.

Workaround

I worked around it by adding an additional forward pass, keeping the tensor as outputs and putting them in on training forward as "stop_gradient(tensor)"-inputs. This works functionally, but comes at the cost of

  1. memory consumption on the training device (rare on my gpu)
  2. lower performance
  3. higher complexity
roywei commented 3 years ago

You can use block.getParameters() to get the parameters and getArray to get the param value, then getGradient to access the gradient. You can do inplace update on the gradient value. (e.g. grad.muli(scale))

Not sure if this is what you want, if not, please provide some python code in TF, PyTorch or MXNet so we can take a look. Thanks!

enpasos commented 3 years ago

Thank you very much for your reply. I think the methods you mentioned are useful for some use cases. For use cases where the concerned node on the graph is passed many times I do not see a clever way where the methods you mentioned lead to a simple solution.

I would like to give some more information about the use case I am looking at:

MuZero use case: Java implementation of MuZero based on DJL (MXNet as Framework).

Need: The MuZero paper comes with Python-Pseudocode (see inside the suplimentary data). The pseudocode uses this function

def scale_gradient(tensor: Any, scale):
    """Scales the gradient for the backward pass."""
    return tensor * scale + tf.stop_gradient(tensor) * (1 - scale)

to scale down the error backpropagation from the recurrently called dynamic function.

Support in the frameworks In tensorflow I see the function stop_gradient on the python api. As I am using MXNet I searched for the support in MXNet and found this.

enpasos commented 3 years ago

I think I found the function in the MXNet-Python API: BlockGrad

enpasos commented 3 years ago

It would be great to have it in Java, too.

enpasos commented 3 years ago

I'll test this

    public static NDArray stopGradient(NDArray in) {
        MxNDManager manager = (MxNDManager)in.getManager();
        MxOpParams params = new MxOpParams();
        return manager.invoke("stop_gradient", in, params);
    }
enpasos commented 3 years ago

The stopGradient works well for me: I could remove my workaround and therefore gained gpu memory ... enough to double the batchsize.

As it is a general functionality (e.g. used in MuZero) it would be very useful to add the functionality on the Java API, too. e.g. in the NDArray interface and its implementations.