lucidrains / perfusion-pytorch

Implementation of Key-Locked Rank One Editing, from Nvidia AI
MIT License
231 stars 9 forks source link
artificial-intelligence deep-learning memory-editing text-to-image

Perfusion - Pytorch

Implementation of Key-Locked Rank One Editing. Project page

The selling point of this paper is extremely low extra parameters per added concept, down to 100kb.

It seems they successfully applied the Rank-1 editing technique from a memory editing paper for LLM, with a few improvements. They also identified that the keys determine the "where" of the new concept, while the values determine the "what", and propose local / global-key locking to a superclass concept (while learning the values).

For researchers out there, if this paper checks out, the tools in this repository should work for any other text-to-<insert modality> network using cross attention conditioning. Just a thought

Appreciation

Install

$ pip install perfusion-pytorch

Usage

import torch
from torch import nn

from perfusion_pytorch import Rank1EditModule

to_keys = nn.Linear(768, 320, bias = False)
to_values = nn.Linear(768, 320, bias = False)

wrapped_to_keys = Rank1EditModule(
    to_keys,
    is_key_proj = True
)

wrapped_to_values = Rank1EditModule(
    to_values
)

text_enc = torch.randn(4, 77, 768)                  # regular input
text_enc_with_superclass = torch.randn(4, 77, 768)  # init_input in algorithm 1, for key-locking
concept_indices = torch.randint(0, 77, (4,))        # index where the concept or superclass concept token is in the sequence
key_pad_mask = torch.ones(4, 77).bool()

keys = wrapped_to_keys(
    text_enc,
    concept_indices = concept_indices,
    text_enc_with_superclass = text_enc_with_superclass,
)

values = wrapped_to_values(
    text_enc,
    concept_indices = concept_indices,
    text_enc_with_superclass = text_enc_with_superclass,
)

# after much training ...

wrapped_to_keys.eval()
wrapped_to_values.eval()

keys = wrapped_to_keys(text_enc)

values = wrapped_to_values(text_enc)

The repository also contains an EmbeddingWrapper that makes it easy to train on a new concept (and for eventual inference with multiple concepts)

import torch
from torch import nn

from perfusion_pytorch import EmbeddingWrapper

embed = nn.Embedding(49408, 512) # open clip embedding, somewhere in the module tree of stable diffusion

# wrap it, and will automatically create a new concept for learning, based on the superclass embed string

wrapped_embed = EmbeddingWrapper(
    embed,
    superclass_string = 'dog'
)

# now just pass in your prompts with the superclass id

embeds_with_new_concept, embeds_with_superclass, embed_mask, concept_indices = wrapped_embed([
    'a portrait of dog',
    'dog running through a green field',
    'a man walking his dog'
]) # (3, 77, 512), (3, 77, 512), (3, 77), (3,)

# now pass both embeds through clip text transformer
# the embed_mask needs to be passed to the cross attention as key padding mask

If you can identify the CLIP instance within the stable diffusion instance, you can also pass it directly to the OpenClipEmbedWrapper to gain everything you need on forward for the cross attention layers

ex.

from perfusion_pytorch import OpenClipEmbedWrapper

texts = [
    'a portrait of dog',
    'dog running through a green field',
    'a man walking his dog'
]

wrapped_clip_with_new_concept = OpenClipEmbedWrapper(
    stable_diffusion.path.to.clip,
    superclass_string = 'dog'
)

text_enc, superclass_enc, mask, indices = wrapped_clip_with_new_concept(texts)

# (3, 77, 512), (3, 77, 512), (3, 77), (3,)

Todo

Citations

@article{Tewel2023KeyLockedRO,
    title   = {Key-Locked Rank One Editing for Text-to-Image Personalization},
    author  = {Yoad Tewel and Rinon Gal and Gal Chechik and Yuval Atzmon},
    journal = {ACM SIGGRAPH 2023 Conference Proceedings},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:258436985}
}
@inproceedings{Meng2022LocatingAE,
    title   = {Locating and Editing Factual Associations in GPT},
    author  = {Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov},
    booktitle = {Neural Information Processing Systems},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:255825985}
}