Open CZXIANGOvO opened 2 months ago
If it's truly the same model you're running, this is surprising. Given the magnitude of the difference, though, I suspect the models or optimizers differ in important ways: for example, maybe the precise definition of "learning rate" differs between the two implementations.
I don't know either optax
or pytorch
well enough to guess where that difference might lie, but if it's important to you to debug these differences in the implementations, that's probably where I'd start.
If it's truly the same model you're running, this is surprising. Given the magnitude of the difference, though, I suspect the models or optimizers differ in important ways: for example, maybe the precise definition of "learning rate" differs between the two implementations.
I don't know either
optax
orpytorch
well enough to guess where that difference might lie, but if it's important to you to debug these differences in the implementations, that's probably where I'd start.
We're using the same model.
We're using the same model.
Sure, but what I'm suggesting is that you may not be using the same optimizer.
Description
Please specify cuda:0 at the very beginning.
System info (python version, jaxlib version, accelerator, etc.)
download the code:https://drive.google.com/file/d/1H8uPgPdslVpizmSsif6oK4ey2e-oum9x/view?usp=sharing