Open dwchoo opened 1 month ago
You can not only use use_focal_loss=False
in postprocess
. Beacase use_focal_loss
and num_classes
are global vars and shared (decoder
, matcher
, loss
and postprocess
) in this version codebase
If you want it works, you should add background
class. But we do not recommend to do this directly.
If you want to conduct ablation experiments, I suggest that you configure these parameters separately in each module.
@lyuwenyu
Thank you for your clarification. I strongly agree that the RT-DETR
model doesn't require a background class, which aligns with its original design philosophy.
From my understanding, the postprocessor
isn't directly involved in the training process but rather in interpreting the model's output.
I believe there's merit in using a softmax
function to determine the final class probabilities. Additionally, filtering objects based on a certain threshold seems like a sound approach.
Given that the results don't significantly differ with my proposed method, would it be agreeable to modify the code to allow for the use of both sigmoid
and softmax
functions? This approach would maintain the model's original functionality while providing additional flexibility for users who might prefer using softmax
in certain scenarios.
The modification could look something like this:
if use_sigmoid:
scores = torch.sigmoid(out_logits)
...
else:
scores = torch.nn.functional.softmax(out_logits, dim=-1)
...
This way, we preserve the original behavior while adding the option to use softmax
without compromising the model's performance or design principles. What are your thoughts on this approach?
The modification could look something like this:
It looks good to rename use_focal_loss
as use_sigmoid
in postprocessor
.
@lyuwenyu Like this?
if use_sigmoid:
scores = torch.sigmoid(out_logits)
...
else:
scores = torch.nn.functional.softmax(out_logits, dim=-1)
...
Yes. But It may have background
classe in the top-300 candidates, which can cause problems with coco eval
For the sake of code consistency and readability, I believe it would be beneficial to maintain the use_focal_loss
variable.
I've noticed several instances where the use_focal_loss
name is used in conjunction with sigmoid
function.
matcher.py
,
Similarly, there are places where softmax
functions are used, but without the [:, :, :-1]
slicing operation, aligning with the method I proposed.
To maintain consistency across areas where sigmoid is used, I suggest we keep the use_focal_loss
variable name. This approach would preserve the uniformity of the code while implementing the softmax
function as I suggested earlier.
Yes. But It may have
background
classe in the top-300 candidates, which can cause problems withcoco eval
But scores = F.softmax(logits)[:, :, :-1]
takes only 79 class, not 80. There is no 'background' class.
In the original postprocessor
, we processed No.0 and No.2 by default ( and we believe that there should not be No.1 case).
I think you want to add No.1 case, but you need to cover all three situations in pr
Thank you for your kind explanation and clarification.
I appreciate you informing me about the original intention of covering cases No.0 and No.2. This insight is very helpful.
However, I noticed that in your current code, there doesn't seem to be any implementation for adding a background class
. Even if the intention was to include case No.2, there appears to be no preprocessor
code to perform this task. Given this, should we consider adding preprocessor
code to handle case No.2?
From my analysis, the current code structure seems to align more closely with cases No.0 and No.1.
Could you please provide some clarification on how you envision handling case No.2 within the current framework? This would help ensure that our implementation accurately reflects the intended functionality of the model.
Thank you again for your patience and guidance throughout this process. I look forward to your thoughts on this matter.
@lyuwenyu
After carefully reviewing and considering the code, I have a suggestion I'd like to propose:
From my analysis, it appears that the current postprocessor
is only considering the No.0
case, where use_focal_loss=True
.
Given this observation, what are your thoughts on removing the use_focal_loss
parameter from the postprocessor
altogether, and having it calculate only for the case where use_focal_loss=True
?
This approach would simplify the postprocessor
and align it more closely with its current functionality.
I'm interested to hear your perspective on this potential modification.
If you modify the name use_focal_loss
, the existing model config will not be compatible. I suggest not making any changes in this codebase for now. If you have any requirements, you can fork and customize your own needs.
This approach would simplify the
postprocessor
and align it more closely with its current functionality.
As you said, the exact meaning was not expressed here, but you can explain it through annotations
@lyuwenyu, thank you for your reply! I agree that making significant changes to the config isn't ideal, as you pointed out.
How about we consider removing or disabling the use_focal_loss
parameter only in the postprocessor
, which wouldn't affect the model training or the overall code structure? Alternatively, we could disable the use of softmax
in this specific part.
RT-DETR
has garnered significant attention among object detection models and is being utilized in the Huggingface/transformers
package. However, the same issue with the postprocessor
is reflected there. I submitted a PR to address this, but since it mirrors the original code(transformers' post_processing
), it's currently under discussion with member and contributor.(link)
As you mentioned earlier, there's been extensive debate about whether a background class (or void class) is necessary. We haven't reached a conclusion yet, and we're eagerly awaiting your input on this matter.
Your insights would be invaluable in resolving this issue and improving the model's implementation across different platforms. We appreciate your time and expertise in guiding us through this process.
Fix for Error Occurring When "use_focal_loss=False" in Postprocessor
scores = F.softmax(logits)[:, :, :-1]
Modification
scores = F.softmax(logits)[:, :, :-1]
->scores = F.softmax(logits, dim=-1)
Example code
When
use_focal_loss=True
resultsWhen
use_focal_loss=False
results, BEFORE THE FIX.scores = F.softmax(logits)[:, :, :-1]
When
use_focal_loss=False
results, AFTER THE FIX.scores = F.softmax(logits, dim=-1)