nasaharvest / presto

Lightweight, Pre-trained Transformers for Remote Sensing Timeseries
https://arxiv.org/abs/2304.14065
MIT License
151 stars 26 forks source link

Move embeddings to the device #30

Closed mkondratyev85 closed 6 months ago

mkondratyev85 commented 6 months ago

This fixes the issue https://github.com/nasaharvest/presto/issues/29.

First of all, thank you for making this wonderful model available for public!

The problem is that all embeddings that are created during the model initialization are created on the "cpu" even if the device variable loaded from the .utils module is pointing to the "cuda:0". This leads to a RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

Moving all embeddings (except for the positional) to the device at the time of initialization fixes this problem.

For some strange reason, if at the init time I move the pos_embed to the device this makes it impossible to load the pretrained model. Because of this, I move the positional_embedding variable to the device at the time of inference. Maybe it is not the cleanest solution, but it is the easiest one I've found.

mkondratyev85 commented 6 months ago

I'm closing this PR because #31 is a better fix for the problem.