Open matteoguarrera opened 6 days ago
Dear Matteo,
This line is for replacing the old params
with those passed as arguments (and therefore creating gradients for jax.grad
). The copy
method here simply overrides the dict.
In new versions of flax, this line should be equivalent to
loss = self.mod.apply(dict(variable, params=params), ϕ, method="loss")
I am not quite familiar with flax and this command has been deprecated in later versions so it lacks of documentation. How are you using this copy methods? Is it for stopping gradients?
https://github.com/ASK-Berkeley/Neural-Spectral-Methods/blob/090e7a173f27734dbae5ec479a410ca1748981af/src/train.py#L18
Do you know how to make it compatible with the latest version of flax? Thank you