InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.19k stars 232 forks source link

RuntimeError in update the bottleneck parameters #5

Closed navid-mahmoudian closed 3 years ago

navid-mahmoudian commented 3 years ago

Hello, First of all, thank you for providing this very nice library. I faced an error that I wanted to share with you to fix it. Let's say, I have a model which is trained for several epochs and now it is saved using save_checkpoint function as you are doing in example CompressAI/examples/train.py. Since you haven't mentioned the load part here, in order to load this checkpoint to continue training, I do as follows:

device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
net = AutoEncoder()
net = net.to(device)
optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)
aux_optimizer = optim.Adam(net.aux_parameters(), lr=args.aux_learning_rate)
criterion = RateDistortionLoss(lmbda=args.lmbda)
checkpoint = torch.load("checkpoint.pth.tar")
net.load_state_dict((checkpoint["net_state_dict"]))
net.update(force=True)  # update the model CDFs parameters.

First, I wanted to be sure if I am doing it in a correct way (for example, I use net.update withforce=True to update the entropy model parameters, etc.)

Second, if I do so, I get an error in update function of class EntropyBottleneck(EntropyModel)

samples = samples[None, :] + pmf_start[:, None, None] RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

This is because

samples.device= cpu pmf_start.device= cuda:0

To fix this, you should change your samples = torch.arange(max_length) to samples = torch.arange(max_length, device=pmf_start.device) in your def update(self, force=False).

Since I am not using exactly the code explained above (I am changing your original code a little bit, so to simlpify the explanation I used your simplified train.py code), first I want you to verify if you have this problem on your side. Then, if it is the case and you fixed it, I think similar problem must be somewhere else in the code as well.

Again I wanted to say thank you for this wonderful library you have provided. Best, Navid

jbegaint commented 3 years ago

Hi Navid, thanks for reporting this! We've not encountered such errors.

You don't need to run net.update to continue training the network, this will only update the internal parameters needed for the actual entropy coding, not the training.

But I'll take a look because this should not crash. Looks like you found a good explanation and your fix looks right. I'll run some tests and push a fix soon if needs be.

Thanks!

navid-mahmoudian commented 3 years ago

Hello Jean, Thank you for your reply. I didn't fully understand this part of your answer

You don't need to run net.update to continue training the network, this will only update the internal parameters needed for the actual entropy coding, not the training.

Let's just say I have run my code for one day and it is finished until epoch T. A few days later, I want to continue the training from the previous epoch T (I mean I am not loading the states in the same execution of the code that I have saved the states). So, to me, to continue training we have to update the internal parameters for the entropy coding. Am I missing something? Maybe the purpose of net.update is not clear for me.

In fact I took this idea of net.update after dig into your CompressAI/compressai/utils/update_model/__main__.py script. There, I noticed you get as an argument whether to update the model CDFs parameters or not

if not args.no_update:
        net.update(force=True)

Thank you again for your time.

jbegaint commented 3 years ago

Hi Navid,

update() only changes the internal buffers needed for the actual entropy coding (with the range ANS or a range coder), which are not used during training. The main parameters (so the ones with gradients) are loaded from the usual load_state_dict. you can find more information/details in the original TensorFlow documentation: https://tensorflow.github.io/compression/docs/entropy_bottleneck.html.

We also provide some details here: https://interdigitalinc.github.io/CompressAI/tutorial_train.html#updating-the-model but we might need to update it with more details. Feel free to contribute to the documentation if you have improvements :-)

jbegaint commented 3 years ago

Navid, i've pushed a simple fix here: https://github.com/InterDigitalInc/CompressAI/commit/6c08ced208482c6a1646c0ea1ab40350a6056572 this should fix the issue you encountered.

navid-mahmoudian commented 3 years ago

Thank you very much Jean.