Open albertz opened 2 years ago
I have no clear opinion on this yet. The most important thing for me is that there is a crash if something is not as expected. I had the case with one Hiwi where the decoding ran without any warning or crash although he loaded a network with shapes that were not matching the ones provided in the config. But as all loaded parameters were consistent the model still worked fine. This can be very dangerous, and with returnn_common
I think this is more likely to happen as you do code changes in the same files.
PyTorch
nn.Module
has a_version
class attrib, which is supposed to be increased when there are changes, to better handle old checkpoints, and checkpoints would save the version. See here.Do we want sth similar?
If so, how would we do it? Currently our checkpoint does not store any such information. But furthermore, the computation graph is created on RETURNN side, which does not know about the
nn.Module
anymore.RETURNN so far made sure that all checkpoints kept being compatible for a given net dict. Of course there never were guarantees when the user changed the net dict, so this was completely up to the user then. But there were also no real error checking. But for a given layer, RETURNN handled conversions automatically. It did not use a version number for this but it was always able to infer the format from the parameter name. The same logic was also used when the user switched e.g. from NativeLSTM to StandardLSTM or so. Or otherwise it encoded the version inside the parameter name itself, see
LayerBase.batch_norm
.Now as we basically shifted all param handling over to RETURNN-common, when we change the parameters in some
Module
, it would not automatically work, RETURNN would not be able to handle this automatically. So we need similar conversion handling in RETURNN-common.So far this was never needed but we will need this at some point in the future.
How would we do it? In RETURNN, when params are loaded but some param name is not found in the checkpoint, this causes the default param loading logic to fail, and then it uses
CustomCheckpointLoader
, which will try some of the automatic conversions. SeeTFNetwork.load_params_from_file
. So either we somehow extend that, by making some API, where a user of RETURNN-common could add conversion rules, in an easy way (but it should be straightforward). Or we somehow need to replicate the whole thing on RETURNN-common side. It's also not clear whether we really need the version or not.