Closed shkarupa-alex closed 3 years ago
Well, RMI uses the Cov(Y, P)
and Var(Y)/Var(P)
to calculate the loss value, i.e., the statistics of the data.
I think the best way is to ignore the hole areas caused by rotation. For example, if the hole areas are labeled by 255, we can ignore the corresponding points when calculating the normal cross entropy loss or ignore these points when constructing the "high dimension points". Then these meaningless points will not joint the procdure of loss calculating and gradient backprop.
It is clear how to ignore "holes" in BCE, but i can't express how to
ignore these points when constructing the "high dimension points"
That high dimension points include almost all (~ -r^2) pixels from downsampled labes/logits. Also we cat slice sample weights in the same manner as labels. As the result we may have sw_vectors here https://github.com/ZJULearning/RMI/blob/master/losses/rmi/rmi.py#L183 corresponding 1-to-1 to pr_vectors
But i don't understand how to use them next: what and when should be multiplied by this weighs? @mzhaoshuai , can you suggest any idea?
It is clear how to ignore "holes" in BCE, but i can't express how to
ignore these points when constructing the "high dimension points"
That high dimension points include almost all (~ -r^2) pixels from downsampled labes/logits. Also we cat slice sample weights in the same manner as labels. As the result we may have sw_vectors here https://github.com/ZJULearning/RMI/blob/master/losses/rmi/rmi.py#L183 corresponding 1-to-1 to pr_vectors
But i don't understand how to use them next: what and when should be multiplied by this weighs? @mzhaoshuai , can you suggest any idea?
I think you can reserve a mask, where the "holes" are 0 and other areas are 1, then you can use this mask to select the meaningful points from the output of https://github.com/ZJULearning/RMI/blob/d71897d3175b617fa4d8c438a67f0a1dc5e35502/losses/rmi/rmi_utils.py#L17.
min_pooling
of size radius*radius
on the mask and use this mask to select the mingingful points. This means every high dimensional point which contains 0/"holes" should be ignored. Take care of the shapes and make the mask's shape (height and width) be the same as the output's shape.
https://github.com/ZJULearning/RMI/blob/d71897d3175b617fa4d8c438a67f0a1dc5e35502/losses/rmi/rmi_utils.py#L24Feel free to re-open this issue if you still have questions.
Some augmentations (e.g. random angle rotation) make image and mask not fully significant. To deal with such cases i usually use per pixel weights (0. for holes, 1. for correct parts) and multiply per pixel loss on that weights.
But RMI loss uses "high dimension points" and final loss has shape incompatible with original labels.
Could you please suggest what is the best way to decouple such "holes" loss (multiply by pixel weight)?