yunxiaoshi / Neural-IMage-Assessment

A PyTorch Implementation of Neural IMage Assessment
Other
507 stars 92 forks source link

The EMD loss still seems to be wrong #10

Closed luqiang360 closed 5 years ago

luqiang360 commented 5 years ago

https://github.com/kentsyx/Neural-IMage-Assessment/blob/9c50b3e384a88a8afdc00333d01656be5526bfed/model.py#L36

The EMD loss still seems to be wrong, my opinion is the sum operation should be inside of torch.abs

Originally posted by @luqiang360 in https://github.com/kentsyx/Neural-IMage-Assessment/issues/4#issuecomment-456781249

yunxiaoshi commented 5 years ago

The implementation is correct, see Eq 1 in the original paper. Generate some random data and compute by hand for a comparison if you have to

luqiang360 commented 5 years ago

1 Yes, I did the comparison, and your result is still error, you should read more about the EMD original Eq.

In mathematical, abs(sum(A) - sum(B)) is not equal to sum(abs(A) - abs(B)), because of abs is not linear, this is what your implementation problem lies. modify you implementation, put the sum into torch.abs will get the correct result. emd_loss += torch.abs(sum(p[:i] - q[:i])) ** r

wangshy31 commented 5 years ago

@luqiang360 I think your implementation is right. Have you reproduced the final performance (VGG16 EMD 0.052)as the original paper?

luqiang360 commented 5 years ago

@wangshy31 I am just a little interested about the difference between EMD loss and L1/L2 loss in classification for No-reference IQA. so, I didn't try to reproduce the performance, maybe with the corrected implementation it could be, anyone interested can have a try.

wanghonglie commented 5 years ago

@wangshy31 I am just a little interested about the difference between EMD loss and L1/L2 loss in classification for No-reference IQA. so, I didn't try to reproduce the performance, maybe with the corrected implementation it could be, anyone interested can have a try.

@luqiang360 I think your implementation is right. Have you reproduced the final performance (VGG16 EMD 0.052)as the original paper?

Have you reproduced the final performance (VGG16 EMD 0.052)as the original paper?

ygean commented 5 years ago

I have reproduced the final performance both on VGG16 and MobileNet, but I didn't use this repository, I reimplement my own NIMA model actually.

luqiang360 commented 5 years ago

@zhouyuangan Which EMD loss implementation did you use or could you give any idea about this issue?

ygean commented 5 years ago

@luqiang360 I have finished my NIMA repo for more than a half year, some details I forgot, so I prefer to show my codes here.

def earth_movers_distance_torch(y, y_pred):
    cdf_y = torch.cumsum(y, dim=1)
    cdf_pred = torch.cumsum(y_pred, dim=1)
    cdf_diff = cdf_pred - cdf_y
    emd_loss = torch.sqrt(torch.mean(torch.pow(torch.abs(cdf_diff), 2)))
    return emd_loss.mean()
luqiang360 commented 5 years ago

Thanks a lot, I Think your implementation is right, I will close the issue soon and after. @WangHonglie You can check zhouyuangan's NIMA repo for more details.