InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.15k stars 229 forks source link

About Auxiliary loss optimization #167

Closed cucNaifuXue closed 1 year ago

cucNaifuXue commented 1 year ago

Hi,

I am confused about the 'Auxiliary loss'. It should be minimized but I don't know it's influence on the whole model's R_D performance.

Should I adjust lr for aux optimizer according to some metrics when training my own model?

In tutorial "Custom model", the example looks like the Factorized Prior model, bmshj2018-factorized. The Rate and Distortion are included in main loss, and I don't know how aux loss works.

Thanks for answers.

YodaEmbedding commented 1 year ago

aux_loss adjusts only the EntropyBottleneck.quantiles to ensure that the trained distribution fits a few target boundary conditions at runtime. Nothing further. It also does not take any image data as input. It has no effect on the training of the model since EntropyBottleneck.quantiles is not used for the model's forward pass during training.

https://github.com/InterDigitalInc/CompressAI/blob/216af74fa09a141eaab68dd0ff2edd89629f62dc/compressai/optimizers/net_aux.py#L46-L57

https://github.com/InterDigitalInc/CompressAI/blob/216af74fa09a141eaab68dd0ff2edd89629f62dc/compressai/entropy_models/entropy_models.py#L431-L434

https://github.com/InterDigitalInc/CompressAI/blob/216af74fa09a141eaab68dd0ff2edd89629f62dc/compressai/entropy_models/entropy_models.py#L378-L383

Thus, I would say that minimizing aux_loss as much as possible isn't critical since it has nearly no effect on the trained RD performance. All it does is make sure the support of the encoding distribution is finite, ensuring that we don't use too many symbols at runtime, and that symbols have some probability over a precision threshold. Potentially, it may also act as a small regularizer that keeps the distributions in check. The .quantiles are not used at all during training, so aux_loss is not relevant until after training finishes.


Related discussions:

Freed-Wu commented 1 year ago

it has nearly no effect on the trained RD performance.

It means user shouldn't adjust lr for aux optimizer according to some metrics when training model, really?

YodaEmbedding commented 1 year ago

The only time the .quantiles parameters are referenced during training is in _get_medians, which references the .quantiles[:, :, 1] (midpoints):

https://github.com/InterDigitalInc/CompressAI/blob/23e9c70bb0930b76538b81b2b307c6a1f622334a/compressai/entropy_models/entropy_models.py#L385-L387

...which is only used to offset prior to adding noise:

https://github.com/InterDigitalInc/CompressAI/blob/23e9c70bb0930b76538b81b2b307c6a1f622334a/compressai/entropy_models/entropy_models.py#L496-L498

The model should not be too heavily influenced by a near-constant "midpoint" bias.


Funnily enough: