Closed enpasos closed 3 years ago
You can use block.getParameters()
to get the parameters and getArray
to get the param value, then getGradient
to access the gradient. You can do inplace update on the gradient value. (e.g. grad.muli(scale)
)
Not sure if this is what you want, if not, please provide some python code in TF, PyTorch or MXNet so we can take a look. Thanks!
Thank you very much for your reply. I think the methods you mentioned are useful for some use cases. For use cases where the concerned node on the graph is passed many times I do not see a clever way where the methods you mentioned lead to a simple solution.
I would like to give some more information about the use case I am looking at:
MuZero use case: Java implementation of MuZero based on DJL (MXNet as Framework).
Need: The MuZero paper comes with Python-Pseudocode (see inside the suplimentary data). The pseudocode uses this function
def scale_gradient(tensor: Any, scale):
"""Scales the gradient for the backward pass."""
return tensor * scale + tf.stop_gradient(tensor) * (1 - scale)
to scale down the error backpropagation from the recurrently called dynamic function.
Support in the frameworks In tensorflow I see the function stop_gradient on the python api. As I am using MXNet I searched for the support in MXNet and found this.
It would be great to have it in Java, too.
I'll test this
public static NDArray stopGradient(NDArray in) {
MxNDManager manager = (MxNDManager)in.getManager();
MxOpParams params = new MxOpParams();
return manager.invoke("stop_gradient", in, params);
}
The stopGradient works well for me: I could remove my workaround and therefore gained gpu memory ... enough to double the batchsize.
As it is a general functionality (e.g. used in MuZero) it would be very useful to add the functionality on the Java API, too. e.g. in the NDArray interface and its implementations.
Question or maybe Enhancement
I'm missing a feature to scale the gradient for backward pass (as e.g. used in MuZero) ... something like tensor scale + stop_gradient(tensor) (1 - scale) I'm not sure if the feature is missing or I'm simply not seeing the proper way how to do it.
Workaround
I worked around it by adding an additional forward pass, keeping the tensor as outputs and putting them in on training forward as "stop_gradient(tensor)"-inputs. This works functionally, but comes at the cost of