I have a class that contains an nnx.Module and trains it. I try to save and restore by accessing this attribute but as the title says I find that when I restore the model, it's loss is as bad as a randomly initialized model.
I have no way to describe the problem as anything or than the title says because I will train a model, halve the loss from it's initialization, save the model using the instructions in the tutorial on saving and loading models (or the instructions given here https://github.com/google/flax/issues/4383, or the instructions on the orbax website) and then restore them in another file and re-run the training loop. However at the final step my loss is the same as the loss I got at initialization. Note, that the parameters are not the ones I had at initialization but completely different ones that are equally poor when evaluated on my objective function.
I have attatched the code for my model, my training file, and my loading function.
Hello,
I have a class that contains an
nnx.Module
and trains it. I try to save and restore by accessing this attribute but as the title says I find that when I restore the model, it's loss is as bad as a randomly initialized model.I have no way to describe the problem as anything or than the title says because I will train a model, halve the loss from it's initialization, save the model using the instructions in the tutorial on saving and loading models (or the instructions given here https://github.com/google/flax/issues/4383, or the instructions on the orbax website) and then restore them in another file and re-run the training loop. However at the final step my loss is the same as the loss I got at initialization. Note, that the parameters are not the ones I had at initialization but completely different ones that are equally poor when evaluated on my objective function.
I have attatched the code for my model, my training file, and my loading function.
Model file:
Training file
Load function