Closed awe2 closed 1 year ago
Hi @awe2, good catch! I think your solution should work :) Do you want to submit a PR?
Addressed in https://github.com/MadryLab/trak/commit/fb5e0a495a859359b382b6b2a066800f5f7f66c7 (on branch https://github.com/MadryLab/trak/tree/0.2.2 for now, will be merged in main in the next release)
in lines 105-107 of trak/gradient_computers.py
grads tensor inherits from the dtype of tensor in batch[0], but fast_jl.cu asserts that the dtype of grads is torch.float32 or torch.float16. If batch[0] is not dtype torch.float32 or torch.float16 then CudaProjector fails. (e.g., Input to language model might be tokenized words with dtype torch.uint32)
Possible solution: should grads.dtype inherit the dtype from the user specified dtype in TRAKer.use_half_precision instead?