MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
180 stars 25 forks source link

Gradient_Computers grad.dtype shouldn't inherit from batch[0].dtype? #44

Closed awe2 closed 1 year ago

awe2 commented 1 year ago

in lines 105-107 of trak/gradient_computers.py

 grads = torch.empty(size=(batch[0].shape[0], self.num_params),
                      dtype=batch[0].dtype,
                      device=batch[0].device)

grads tensor inherits from the dtype of tensor in batch[0], but fast_jl.cu asserts that the dtype of grads is torch.float32 or torch.float16. If batch[0] is not dtype torch.float32 or torch.float16 then CudaProjector fails. (e.g., Input to language model might be tokenized words with dtype torch.uint32)

Possible solution: should grads.dtype inherit the dtype from the user specified dtype in TRAKer.use_half_precision instead?

kristian-georgiev commented 1 year ago

Hi @awe2, good catch! I think your solution should work :) Do you want to submit a PR?

kristian-georgiev commented 1 year ago

Addressed in https://github.com/MadryLab/trak/commit/fb5e0a495a859359b382b6b2a066800f5f7f66c7 (on branch https://github.com/MadryLab/trak/tree/0.2.2 for now, will be merged in main in the next release)