NVlabs / DIODE

Official PyTorch implementation of Data-free Knowledge Distillation for Object Detection, WACV 2021.
https://openaccess.thecvf.com/content/WACV2021/html/Chawla_Data-Free_Knowledge_Distillation_for_Object_Detection_WACV_2021_paper.html
Other
61 stars 6 forks source link

About L2 & TV loss, mean value #16

Open mountains-high opened 2 years ago

mountains-high commented 2 years ago

Good day, Thank you for the nice job and for making it open-source ~ I have some doubts about these two losses. Could you please explain why you have taken the "mean of L2 loss and TV loss"? https://github.com/NVlabs/DIODE/blob/80a396d5772528d4c393a301b0a1390eb7e7e039/deepinversion_yolo.py#L178

and

https://github.com/NVlabs/DIODE/blob/80a396d5772528d4c393a301b0a1390eb7e7e039/deepinversion_yolo.py#L245

Thank you ~

mountains-high commented 2 years ago

These two below gotten from the original source DI/cifar10: https://github.com/NVlabs/DeepInversion/blob/6d64b65c573a8229844c746d77993b2c0431a3e5/cifar10/deepinversion_cifar10.py#L184 loss_var = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)

and

https://github.com/NVlabs/DeepInversion/blob/6d64b65c573a8229844c746d77993b2c0431a3e5/cifar10/deepinversion_cifar10.py#L177 loss = loss + l2_coeff * torch.norm(inputs_jit, 2) Which will be correct one? Thanks