Closed BYchao100 closed 3 years ago
Good catch :-). However, the context_prediction weights are actually zeroed in place during training: https://github.com/InterDigitalInc/CompressAI/blob/master/compressai/layers/layers.py#L46, so with the current implementation this is not an issue.
But thanks, it might be a good idea to have this in place if we change the MaskedConv2d implementation.
Thank you for your reply. Yeah, I see. In fact, I met the problem, then find the solution. During inference, I load the ckp and find that the ckp['model_state_dict']['context_prediction.weight'] is unmasked. Now, I am confused. According to the implementation of MaskedConv2d, the weight should have been masked.
The model is saved as follows:
torch.save({
'epoch': epoch,
'model_state_dict': jahp_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'aux_optimizer_state_dict': aux_optimizer.state_dict(),
}, os.path.join(ckp_dir, "ckp.tar"))
jahp_model = JointAutoregressiveHierarchicalPriors().cuda()
ok that's weird. Can you share a bit more code or information about your training/setup?
Well, yes.. The train function is as follows: def train(train_dataloader, eval_dataloader, epochs, lmbda, ckp_dir, log_dir, resume=False):
jahp_model = JointAutoregressiveHierarchicalPriors().cuda()
rd_criterion = RateDistortion(lmbda)
optimizer = optim.Adam(jahp_model.parameters(), lr = 1e-4)
aux_optimizer = optim.Adam(jahp_model.aux_parameters(), lr = 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=20)
tb_writer = SummaryWriter(log_dir)
if resume:
ckp = torch.load(os.path.join(ckp_dir, "ckp.tar"))
start_epoch = ckp['epoch'] + 1
jahp_model.load_state_dict(ckp['model_state_dict'])
optimizer.load_state_dict(ckp['optimizer_state_dict'])
aux_optimizer.load_state_dict(ckp['aux_optimizer_state_dict'])
else:
start_epoch = 0
train_step = 0
for epoch in range(start_epoch, epochs):
# train
train_step = train_one_epoch(jahp_model, rd_criterion, train_dataloader, optimizer, aux_optimizer, train_step, tb_writer, clip_max_norm=1.0)
# save model
torch.save({
'epoch': epoch,
'model_state_dict': jahp_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'aux_optimizer_state_dict': aux_optimizer.state_dict(),
}, os.path.join(ckp_dir, "ckp.tar"))
# eval
loss, bpp_loss, mse_loss, psnr, aux_loss = eval_epoch(jahp_model, rd_criterion, eval_dataloader, epoch, tb_writer)
print("Epoch:{}, Eval bpp:{}, Eval mse:{}, Eval psnr:{}".format(epoch, bpp_loss, mse_loss, psnr))
if epoch > 400:
scheduler.step(loss)
tb_writer.close()
train_one_epoch functino is as follows: def train_one_epoch(model, criterion, train_dataloader, optimizer, aux_optimizer, train_step, tb_writer=None, clip_max_norm=None):
model.train()
device = next(model.parameters()).device
train_size = 0
for x in train_dataloader:
x = x.to(device).contiguous()
optimizer.zero_grad()
aux_optimizer.zero_grad()
x = x/255.
out = model(x)
out_criterion = criterion(out, x)
out_criterion["loss"].backward()
if clip_max_norm:
nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
optimizer.step()
aux_loss = model.aux_loss()
aux_loss.backward()
aux_optimizer.step()
train_step += 1
if tb_writer:
tb_writer.add_scalar('train loss', out_criterion["loss"].item(), train_step)
tb_writer.add_scalar('train mse', out_criterion["mse_loss"].item(), train_step)
tb_writer.add_scalar('train bpp', out_criterion["bpp_loss"].item(), train_step)
tb_writer.add_scalar('train aux', model.aux_loss().item(), train_step)
train_size += x.shape[0]
print("train sz:{}".format(train_size))
return train_step
I hope the information helps. It is not a big problem for the program anyway.
Thanks for the code, I managed to re-create the bug! What's happening is you are saving your model before evaluating it, so the weights have been overwritten by the optimizer. I'll apply a fix with your earlier proposal. Thanks for reporting this!
Thank you for your explanation! I see. In MaskedConv2d, self.weight is masked, then you call super().forward(x). During backpropagation, there are two different gradients for self.weight, one is unmasked from forward(x), another is masked from self.weight.data *= self.mask. It finally leads to unmasked gradients for self.weight.
Great, I'm closing this. Feel free to open other issues if you encounter any bugs!
https://github.com/InterDigitalInc/CompressAI/blob/master/compressai/models/priors.py In JointAutoregressiveHierarchicalPriors._compress_ar:
ctx_p = F.conv2d( y_crop, self.context_prediction.weight, bias=self.context_prediction.bias )
self.context_prediction.weight ignores the mask, leading to the context mismatch between _compress_ar and _decompress_ar. y_hat in _compress_ar is the full feature and needs to be masked.Solution:
masked_weight = self.context_prediction.weight * self.context_prediction.mask
ctx_p = F.conv2d( y_crop, masked_weight, bias=self.context_prediction.bias )