ASK-Berkeley / Neural-Spectral-Methods

[ICLR 2024] Neural Spectral Methods: Self-supervised learning in the spectral domain.
https://arxiv.org/abs/2312.05225
MIT License
26 stars 0 forks source link

flax.core.copy #2

Open matteoguarrera opened 6 days ago

matteoguarrera commented 6 days ago

I am not quite familiar with flax and this command has been deprecated in later versions so it lacks of documentation. How are you using this copy methods? Is it for stopping gradients?

https://github.com/ASK-Berkeley/Neural-Spectral-Methods/blob/090e7a173f27734dbae5ec479a410ca1748981af/src/train.py#L18

Do you know how to make it compatible with the latest version of flax? Thank you

mrlazy1708 commented 6 days ago

Dear Matteo,

This line is for replacing the old params with those passed as arguments (and therefore creating gradients for jax.grad). The copy method here simply overrides the dict.

In new versions of flax, this line should be equivalent to

loss = self.mod.apply(dict(variable, params=params), ϕ, method="loss")