MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
169 stars 22 forks source link

RuntimeError: input.dtype() == torch::kFloat16 || input.dtype() == torch::kFloat32 error on QNLI #48

Closed rishabluthra closed 10 months ago

rishabluthra commented 11 months ago

Hey, I am trying out the BERT QNLI example given in the repo on a A100 on Colab and keep running into this Runtime issue when trying to featurize

[/usr/local/lib/python3.10/dist-packages/trak/projectors.py](https://localhost:8080/#) in project(self, grads, model_id)
    295 
    296         try:
--> 297             result = fn(grads, self.proj_dim, self.seed + int(1e4) * model_id, self.num_sms)
    298         except RuntimeError as e:
    299             if str(e) == 'CUDA error: too many resources requested for launch\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n':  # noqa: E501

RuntimeError: input.dtype() == torch::kFloat16 || input.dtype() == torch::kFloat32 INTERNAL ASSERT FAILED at "fast_jl.cu":37, please report a bug to PyTorch. input must be fp16 or fp32`

during this code sample:

for batch in tqdm(loader_train, desc='Featurizing..'):
    # process batch into compatible form for TRAKer TextClassificationModelOutput
    batch = process_batch(batch)
    batch = [x.cuda() for x in batch]
    traker.featurize(batch=batch, num_samples=batch[0].shape[0])

It seems the batch, when, processed is returned as ints so I end up modifying the process function as

def process_batch(batch):
    return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'].to(dtype=ch.float32), batch['labels']

but that still has the same issue

Was just wondering if you ever ran into the issue when trying it out on QNLI or if I am missing anything, thanks!

kristian-georgiev commented 10 months ago

The reason is, as pointed out in https://github.com/MadryLab/trak/issues/44, that grads gets initialized with the dtype of batch[0]:

 grads = torch.empty(size=(batch[0].shape[0], self.num_params),
                      dtype=batch[0].dtype,
                      device=batch[0].device)

Instead, it should be initialized as float16/float32.

Fixed in https://github.com/MadryLab/trak/commit/fb5e0a495a859359b382b6b2a066800f5f7f66c7 (on branch https://github.com/MadryLab/trak/tree/0.2.2 for now, will be merged in main in the next release)