MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
169 stars 22 forks source link

Regarding the model input during `traker.featurize` #52

Closed Jiaxin-Wen closed 10 months ago

Jiaxin-Wen commented 10 months ago

Take trak in qnli for example, https://github.com/MadryLab/trak/blob/76b13ca55f1a16a243a23aba312cf6aad57b84d0/examples/qnli.py#L138 We first specify the batch size as 16. However, as I inspect the input_ids, I find its data type is GradTrackingTensor, with a shape of torch.Size([1, 128]), instead of torch.Size([16, 128]) https://github.com/MadryLab/trak/blob/76b13ca55f1a16a243a23aba312cf6aad57b84d0/examples/qnli.py#L74

How to understand it?

Jiaxin-Wen commented 10 months ago

Oh I notice that input_ids is wrapped into BatchedTensor. problems are solved.

kristian-georgiev commented 10 months ago

Yep, this is vmap parallelizing over the batch dimension.

Jiaxin-Wen commented 10 months ago

I wonder if there is a proper way to get batch_size and sequence_length from input_ids after it is parallelized since some models in transformers still use such code batch_size, seq_len = input_ids.shape[:2]

Jiaxin-Wen commented 10 months ago

For example, when I adopt llama as the backbone model with transformers==4.30.2, here are some lines of its source code:

https://github.com/huggingface/transformers/blob/66fd3a8d626a32989f4569260db32785c6cbf42a/src/transformers/models/llama/modeling_llama.py#L508

https://github.com/huggingface/transformers/blob/66fd3a8d626a32989f4569260db32785c6cbf42a/src/transformers/models/llama/modeling_llama.py#L64

https://github.com/huggingface/transformers/blob/66fd3a8d626a32989f4569260db32785c6cbf42a/src/transformers/models/llama/modeling_llama.py#L67

And I'm confused about how to adapt the input_ids, which are already wrapped into BatchedTensor to these lines of code.

kristian-georgiev commented 10 months ago

You can think of the BatchedTensor as a simple torch tensor with batch dimension equal to one. If there's no batch dimension, you can .unsqueeze(0) to create a "dummy" batch dimension. Closing for now, but feel free to re-open if this doesn't resolve the issue.