ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
399 stars 37 forks source link

Requesting support for input_embeds for tracer.invoke() #85

Closed HuFY-dev closed 7 months ago

HuFY-dev commented 7 months ago

I'm using the LanguageModel class to wrap a vision-language model LLaVA, and during the execution of

with tracer.invoke(inputs)

nnsight/contexts/Invoker.py#L55:

self.inputs, batch_size = self.tracer._model._prepare_inputs(
  *self.inputs, **self.kwargs
)

results in errors. FYI, a typical input to LLaVA is

forward(
  input_ids: torch.LongTensor,
  images: Optional[torch.FloatTensor],
  **kwargs
)

or

forward(
  inputs_embeds,
  **kwargs
)

Can you add support to accept inputs_embeds as an alternative to inputs so that I can use the code in the following way?

with tracer.invoke(
    inputs=None, 
    inputs_embeds=inputs_embeds,
):
cadentj commented 7 months ago

Here's an example wrapper to get models with different .forward() args working with NNsight.

def transformerlens_to_nnsight_wrapper(original_method):
    def wrapper(self, *args, **kwargs):
        if "input_ids" in kwargs:
            kwargs["input"] = kwargs.pop("input_ids")
        _ = kwargs.pop("labels", None)
        _ = kwargs.pop("attention_mask", None)
        return original_method(self, *args, **kwargs)
    return wrapper

# Bind the wrapped method to only this instance
tl_model.forward = transformerlens_to_nnsight_wrapper(HookedTransformer.forward).__get__(tl_model, HookedTransformer)
tl_model.generate = transformerlens_to_nnsight_wrapper(HookedTransformer.generate).__get__(tl_model, HookedTransformer)

# Also set a few attributes, so that it works with NNsight
tl_model.device = tl_model.cfg.device

I'll look into a solution for making this easier/more documented.

HuFY-dev commented 7 months ago

From my understanding to your code, it didn't solve the problem that the model does not take image into the input. LLaVA model only works if input has 1. both input_ids and images or 2. has inputs_embeds as image tokens does not have a specific embedding in the embedding matrix but is calculated from the image encoder and a linear layer.

JadenFiotto-Kaufman commented 7 months ago

@HuFY-dev I think the right thing to do here is subclass LanguageModel and implement your own virtual methods to handle the inputs you want to give them. It shouldn't be too much effort.

HuFY-dev commented 7 months ago

I'll try that. I just think inputs_embeds is a standard argument for the transformer pretrained model classes for the .forward() method (not only VLMs but most LLMs as well) and adding support to inputs_embeds can make the code more flexible. Maybe I'll write some code on my side and submit a PR if things are working properly.