Closed xc-chengdu closed 2 years ago
hi, have you tested the Benchmark Result of YOLOX* Series?
Hi @xc-chengdu , we collect the gradients in a backward-hook way. The code could be found here: code. The pos_neg, pos_grad, and neg_grad are all buffers which are used to save the collected gradients. Please feel free to ask any questions about this, thanks.
@waveboo Thank you very much for your reply. My intention is to use your improved Equalized focal loss, but I find it impossible to determine the input passed by the gradient collection function. For example, which input parameter is represented by grad_out[0]? Is it possible to call the gradient collection function in the forward in the same way as eqlv2 implementation? When I use EFL directly, it improves the accuracy of the tail category a bit, but suppresses too much of the head accuracy, which makes it not as good as focal loss. I found by debugging that the gradient collection function is not working, so how can I solve it? I am looking forward to your reply again!
hi, have you tested the Benchmark Result of YOLOX* Series? #20 I just wanted to apply their improved loss function to my own task, so I did not reproduce their results.
@xc-chengdu ,
grad = torch.cat(self.grad_buffer[::-1], dim=1).reshape(-1, self.num_classes)
@waveboo每个fpn层输出的维度是不一致的,其shape分别为[32768,37],[8192,37],[2048,37],[512,37],[128,37];执行上面这句代码会因为维度不一致而无法合并;那么你们的工作是如何保证每一个特征层输出的维度一致呢?以下是我按照eqlv2的方式实现的手动搜集梯度:
@xc-chengdu , We collect the gradients of each subnet with the grad_buffer. And when the buffer get five subnets gradients, we concat them. When we save the gradient into buffer, we reshape the gradient to (batchsize, -1, num_classes). Thus we could concat them in the dim-1. Your tensor shape is 2-dims, which could not concat at the dim-1.
And one thing you should notice is that you should not directly use our collect function, but need to implement your own focal loss one. Because our gradient collect function is designed for the auto gradient collection hook. If you want to manually collect the gradient, you need to calculate the derivative of the focal loss and implement the collect function by yourself.
@xc-chengdu Meanwhile, we highly recommend you use the gradient collection hook because it is simple and easy, and less error-prone.
Hello, can you answer the mechanism of action of the gradient collection function? Although the gradient gathering function is defined in the forward propagation function, it does not seem to call this function. Even if self.pos_neg.detach() is used, what is the input parameter in the collect_grad() function? Does it really work?