InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.19k stars 232 forks source link

Bug in JointAutoregressiveHierarchicalPriors._compress_ar #25

Closed BYchao100 closed 3 years ago

BYchao100 commented 3 years ago

https://github.com/InterDigitalInc/CompressAI/blob/master/compressai/models/priors.py In JointAutoregressiveHierarchicalPriors._compress_ar: ctx_p = F.conv2d( y_crop, self.context_prediction.weight, bias=self.context_prediction.bias ) self.context_prediction.weight ignores the mask, leading to the context mismatch between _compress_ar and _decompress_ar. y_hat in _compress_ar is the full feature and needs to be masked.

Solution: masked_weight = self.context_prediction.weight * self.context_prediction.mask ctx_p = F.conv2d( y_crop, masked_weight, bias=self.context_prediction.bias )

jbegaint commented 3 years ago

Good catch :-). However, the context_prediction weights are actually zeroed in place during training: https://github.com/InterDigitalInc/CompressAI/blob/master/compressai/layers/layers.py#L46, so with the current implementation this is not an issue.

jbegaint commented 3 years ago

But thanks, it might be a good idea to have this in place if we change the MaskedConv2d implementation.

BYchao100 commented 3 years ago

Thank you for your reply. Yeah, I see. In fact, I met the problem, then find the solution. During inference, I load the ckp and find that the ckp['model_state_dict']['context_prediction.weight'] is unmasked. Now, I am confused. According to the implementation of MaskedConv2d, the weight should have been masked.

BYchao100 commented 3 years ago

The model is saved as follows:
torch.save({ 'epoch': epoch, 'model_state_dict': jahp_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'aux_optimizer_state_dict': aux_optimizer.state_dict(), }, os.path.join(ckp_dir, "ckp.tar"))

jahp_model = JointAutoregressiveHierarchicalPriors().cuda()

jbegaint commented 3 years ago

ok that's weird. Can you share a bit more code or information about your training/setup?

BYchao100 commented 3 years ago

Well, yes.. The train function is as follows: def train(train_dataloader, eval_dataloader, epochs, lmbda, ckp_dir, log_dir, resume=False):

jahp_model = JointAutoregressiveHierarchicalPriors().cuda()
rd_criterion = RateDistortion(lmbda)
optimizer = optim.Adam(jahp_model.parameters(), lr = 1e-4)
aux_optimizer = optim.Adam(jahp_model.aux_parameters(), lr = 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=20)

tb_writer = SummaryWriter(log_dir)

if resume:
    ckp = torch.load(os.path.join(ckp_dir, "ckp.tar"))
    start_epoch = ckp['epoch'] + 1
    jahp_model.load_state_dict(ckp['model_state_dict'])
    optimizer.load_state_dict(ckp['optimizer_state_dict'])
    aux_optimizer.load_state_dict(ckp['aux_optimizer_state_dict'])
else:
    start_epoch = 0
train_step = 0

for epoch in range(start_epoch, epochs):
    # train
    train_step = train_one_epoch(jahp_model, rd_criterion, train_dataloader, optimizer, aux_optimizer, train_step, tb_writer, clip_max_norm=1.0)

    # save model
    torch.save({
        'epoch': epoch,
        'model_state_dict': jahp_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),   
        'aux_optimizer_state_dict': aux_optimizer.state_dict(), 
    }, os.path.join(ckp_dir, "ckp.tar"))

    # eval
    loss, bpp_loss, mse_loss, psnr, aux_loss = eval_epoch(jahp_model, rd_criterion, eval_dataloader, epoch, tb_writer)
    print("Epoch:{}, Eval bpp:{}, Eval mse:{}, Eval psnr:{}".format(epoch, bpp_loss, mse_loss, psnr))
    if epoch > 400:
        scheduler.step(loss)

tb_writer.close()
BYchao100 commented 3 years ago

train_one_epoch functino is as follows: def train_one_epoch(model, criterion, train_dataloader, optimizer, aux_optimizer, train_step, tb_writer=None, clip_max_norm=None):

model.train()
device = next(model.parameters()).device

train_size = 0
for x in train_dataloader:
    x = x.to(device).contiguous()

    optimizer.zero_grad()
    aux_optimizer.zero_grad()

    x = x/255.
    out = model(x)

    out_criterion = criterion(out, x)
    out_criterion["loss"].backward()
    if clip_max_norm:
        nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
    optimizer.step()

    aux_loss = model.aux_loss()
    aux_loss.backward()
    aux_optimizer.step()

    train_step += 1
    if tb_writer:
        tb_writer.add_scalar('train loss', out_criterion["loss"].item(), train_step)
        tb_writer.add_scalar('train mse', out_criterion["mse_loss"].item(), train_step)
        tb_writer.add_scalar('train bpp', out_criterion["bpp_loss"].item(), train_step)
        tb_writer.add_scalar('train aux', model.aux_loss().item(), train_step)

    train_size += x.shape[0]

print("train sz:{}".format(train_size))
return train_step
BYchao100 commented 3 years ago

I hope the information helps. It is not a big problem for the program anyway.

jbegaint commented 3 years ago

Thanks for the code, I managed to re-create the bug! What's happening is you are saving your model before evaluating it, so the weights have been overwritten by the optimizer. I'll apply a fix with your earlier proposal. Thanks for reporting this!

BYchao100 commented 3 years ago

Thank you for your explanation! I see. In MaskedConv2d, self.weight is masked, then you call super().forward(x). During backpropagation, there are two different gradients for self.weight, one is unmasked from forward(x), another is masked from self.weight.data *= self.mask. It finally leads to unmasked gradients for self.weight.

jbegaint commented 3 years ago

Great, I'm closing this. Feel free to open other issues if you encounter any bugs!