Closed navid-mahmoudian closed 3 years ago
Hi Navid, thanks for reporting this! We've not encountered such errors.
You don't need to run net.update
to continue training the network, this will only update the internal parameters needed for the actual entropy coding, not the training.
But I'll take a look because this should not crash. Looks like you found a good explanation and your fix looks right. I'll run some tests and push a fix soon if needs be.
Thanks!
Hello Jean, Thank you for your reply. I didn't fully understand this part of your answer
You don't need to run
net.update
to continue training the network, this will only update the internal parameters needed for the actual entropy coding, not the training.
Let's just say I have run my code for one day and it is finished until epoch T. A few days later, I want to continue the training from the previous epoch T (I mean I am not loading the states in the same execution of the code that I have saved the states). So, to me, to continue training we have to update the internal parameters for the entropy coding. Am I missing something? Maybe the purpose of net.update is not clear for me.
In fact I took this idea of net.update
after dig into your CompressAI/compressai/utils/update_model/__main__.py
script. There, I noticed you get as an argument whether to update the model CDFs parameters or not
if not args.no_update:
net.update(force=True)
Thank you again for your time.
Hi Navid,
update()
only changes the internal buffers needed for the actual entropy coding (with the range ANS or a range coder), which are not used during training. The main parameters (so the ones with gradients) are loaded from the usual load_state_dict
. you can find more information/details in the original TensorFlow documentation: https://tensorflow.github.io/compression/docs/entropy_bottleneck.html.
We also provide some details here: https://interdigitalinc.github.io/CompressAI/tutorial_train.html#updating-the-model but we might need to update it with more details. Feel free to contribute to the documentation if you have improvements :-)
Navid, i've pushed a simple fix here: https://github.com/InterDigitalInc/CompressAI/commit/6c08ced208482c6a1646c0ea1ab40350a6056572 this should fix the issue you encountered.
Thank you very much Jean.
Hello, First of all, thank you for providing this very nice library. I faced an error that I wanted to share with you to fix it. Let's say, I have a model which is trained for several epochs and now it is saved using
save_checkpoint
function as you are doing in exampleCompressAI/examples/train.py
. Since you haven't mentioned the load part here, in order to load this checkpoint to continue training, I do as follows:First, I wanted to be sure if I am doing it in a correct way (for example, I use
net.update
withforce=True
to update the entropy model parameters, etc.)Second, if I do so, I get an error in
update
function ofclass EntropyBottleneck(EntropyModel)
This is because
To fix this, you should change your
samples = torch.arange(max_length)
tosamples = torch.arange(max_length, device=pmf_start.device)
in yourdef update(self, force=False)
.Since I am not using exactly the code explained above (I am changing your original code a little bit, so to simlpify the explanation I used your simplified train.py code), first I want you to verify if you have this problem on your side. Then, if it is the case and you fixed it, I think similar problem must be somewhere else in the code as well.
Again I wanted to say thank you for this wonderful library you have provided. Best, Navid