IDEA-Research / GroundingDINO

[ECCV 2024] Official implementation of the paper "Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection"
https://arxiv.org/abs/2303.05499
Apache License 2.0
6.75k stars 685 forks source link

Text Prompts with Multiple Object Categories: Concepts are combined #85

Closed egeozguroglu closed 5 months ago

egeozguroglu commented 1 year ago

Hi, thanks for the amazing work. I've been using GroundingDINO + SAM for my research, and would like to query for multiple object categories for my usecase. e.g. "jug . onion . chair . toaster . wire . counter . glass . oil . potato . package ." (as suggested here + on GroundedSAM repo).

Unfortunately, when multiple object classes are added to the prompt as suggested, the GroundingDINO predictions get made with some categories combined. I initially encountered this error on my own scripts and was able to replicate the same with your Huggingface Spaces Demo. See below.

Detection Prompt: "jug . onion . chair . toaster . wire . counter . glass . oil . potato . package ."

Input image: image

Prediction Output: image

In this case, glass and oil were combined into "glass oil," which is not desired behavior. Would you have any insights on a quick solution? I will ultimately want to detect 300 object classes with one prompt, so resolving this is essential.

Dwrety commented 1 year ago

The tokenizer keeps track of a map to map the token back to text. You can manually separate those tokens.

egeozguroglu commented 1 year ago

Thanks @Dwrety. Do you have further insight on how that seperation can be done?

Dwrety commented 1 year ago

Sure, first of all you need to make sure if the two words are embedding into a single token or sepearate token. I assume it has to be separated tokens.

Then you check this code from the utils.py

def get_phrases_from_posmap(
    posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer):
    assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
    if posmap.dim() == 1:
        non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
        token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
        return tokenizer.decode(token_ids)
    else:
        raise NotImplementedError("posmap must be 1-dim")

And find the token_ids of "glass oil" and maually change it to two. The issue happens because the posmap has multiple reponses to the input text and pytorch nonzero collects all of them.

This is simply a hack, you might want to change the description to [glass. cooking oil] to avoid confusion.

Baboom-l commented 1 year ago

I think changing get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') to get_phrases_from_posmap(logit == torch.max(logit), tokenized, tokenizer).replace('.', '') can solve this problem.To be honest, I don't think the author has any benefits in setting text_threshold in the inference script

Glisten-5481 commented 1 year ago

@Baboom-l I adopted this approach, but found that it would preserve the results with the highest but very low threshold, resulting in meaningless labels. I hope to receive more help.

Baboom-l commented 1 year ago

@Glisten-5481 Because max is used, the original text threshold is meaningless and you should increase the box threshold to get high confidence results.

Baboom-l commented 1 year ago

In fact, I found that the direct max approach is also flawed, because during the text token process, a word may be cut into two tokens resulting in half a word being output.

Glisten-5481 commented 1 year ago
@staticmethod
def find_index(string, lst):
    # if meet string like "lake river" will only keep "lake"
    # this is an hack implementation for visualization which will be updated in the future
    string = string.lower().split()[0]
    for i, s in enumerate(lst):
        if string in s.lower():
            return i
    return -1

They have provided the above solutions in Grounded-SAM, but I don't think they are very practical.

fwarmuth commented 9 months ago

Any further approaches to get multi class prompt working without multiple forwards passes?

kyrareinert commented 5 months ago

Setting the parameter remove_combined = True in Grounding DINOs predict() function solved the issue for us.

The predict() function can be imported from groundingdino.util.inference