castorini / daam

Diffusion attentive attribution maps for interpreting Stable Diffusion.
MIT License
669 stars 61 forks source link

Fix compute_token_merge_indices (again) #24

Closed sALTaccount closed 1 year ago

sALTaccount commented 1 year ago

Was doing some more testing with the code I submitted in my previous PR and I was able to break it. I replaced it with a much simpler and more elegant solution. This solution also allows searching for anything that can be split up into tokens, not just words. For example, you can search for amane kanata. Previously you couldn't because it had a space in it. It still makes sure it doesn't match things that don't belong, for example, searching hair won't match hairclip. I also (re) added the exception if you search for something that isn't in the prompt.

I wrote a piece of code (not included in this commit) to test what was being selected, thought it might be helpful for if you wanted to test the commit. It surrounds anything being selected with [brackets]

    selected = ''
    opened = False
    for idx, token in enumerate(tokens):
        if idx in merge_idxs and not opened:
            opened = True
            selected += '['
        elif idx not in merge_idxs and opened:
            opened = False
            selected += ']'
        selected += token
    print(f'Search term: {word}\nSelected: {selected}')