MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
175 stars 24 forks source link

Order of Arguments in LLM can mismatch #46

Closed awe2 closed 12 months ago

awe2 commented 1 year ago

Not a bug if text model is wrapped as given in qnli example, but something to possibly make users aware of/emphasize:

In lines 390-394 of modelout_functions.TextClassificationModelOutput:

        logits = ch.func.functional_call(model,
                                         (weights, buffers),
                                         args=(input_id.unsqueeze(0),
                                               token_type_id.unsqueeze(0),
                                               attention_mask.unsqueeze(0)))

the args tuple is passing an ordered set of arguments to model. The order of the arguments depends upon the model.forward function signature. For example, Bert-base via transformer library forward signature is:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels = 2,   
    output_attentions = False,
    output_hidden_states = False,
)

model.forward?

    >>>input_ids: Optional[torch.Tensor] = None,
    >>>attention_mask: Optional[torch.Tensor] = None,
    >>>token_type_ids: Optional[torch.Tensor] = None,
    >>>position_ids: Optional[torch.Tensor] = None,
    >>>head_mask: Optional[torch.Tensor] = None,
    >>>inputs_embeds: Optional[torch.Tensor] = None,
    >>>labels: Optional[torch.Tensor] = None,
    >>>output_attentions: Optional[bool] = None,
    >>>output_hidden_states: Optional[bool] = None,
    >>>return_dict: Optional[bool] = None,

The order of args doesn't match the order given in the model's forward signature, which can lead to a mismatch. In my use-case, this mismatch fails quietly, producing features that aren't correct.

In the example provided for text models you create a wrapper that correctly sets the order of the arguments. I'm wondering if there is a more robust way to provide batch of data to LLM via a dictionary of data and using keyword arguments, which would fix the ordering issue and hopefully match up with transformers.pipeline framework?

kristian-georgiev commented 1 year ago

Good catch, will address this in the next version of TRAK. I like the idea of using keywords arguments that we directly pass.

kristian-georgiev commented 12 months ago

Resolved by #49.

https://github.com/MadryLab/trak/blob/5cbe5286f55c52a868f2baabd9eee91be2b98750/trak/modelout_functions.py#L382-L397