WorldCereal / worldcereal-classification

This repository contains the classification module of the WorldCereal system.
https://esa-worldcereal.org/
MIT License
35 stars 4 forks source link

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #217

Open emmamusiari opened 2 days ago

emmamusiari commented 2 days ago

I've tried to run the worldcereal_v1_demo_custom_cropland.ipynb notebook in the "notebooks" folder.

When running this cell:

from utils import prepare_training_dataframe

training_dataframe = prepare_training_dataframe(public_df, task_type="cropland")

I get the runtime error below. I tried to fix it locally in presto.py but immediatelly after got another similar error.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], [line 3](vscode-notebook-cell:?execution_count=5&line=3)
      [1](vscode-notebook-cell:?execution_count=5&line=1) from utils import prepare_training_dataframe
----> [3](vscode-notebook-cell:?execution_count=5&line=3) training_dataframe = prepare_training_dataframe(public_df, task_type="cropland")

File ~/dev_lx/worldcereal-classification/notebooks/utils.py:530, in prepare_training_dataframe(df, batch_size, task_type, augment, mask_ratio, repeats)
    [522](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:522) ds = WorldCerealTrainingDataset(
    [523](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:523)     df,
    [524](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:524)     task_type=task_type,
   (...)
    [527](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:527)     repeats=repeats,
    [528](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:528) )
    [529](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:529) logger.info("Computing Presto embeddings ...")
--> [530](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:530) df = get_training_df(
    [531](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:531)     ds,
    [532](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:532)     presto_model,
    [533](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:533)     batch_size=batch_size,
    [534](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:534)     valid_date_as_token=use_valid_date_token,
    [535](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:535) )
    [537](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:537) logger.info("Done.")
    [539](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/dev_lx/worldcereal-classification/notebooks/utils.py:539) return df

File ~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:288, in get_training_df(dataset, presto_model, batch_size, valid_date_as_token, num_workers)
    [285](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:285) # Compute Presto embeddings; only feed valid date as token if valid_date_as_token is True
    [286](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:286) with torch.no_grad():
    [287](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:287)     encodings = (
--> [288](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:288)         presto_model.encoder(
    [289](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:289)             x_f,
    [290](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:290)             dynamic_world=dw_f.long(),
    [291](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:291)             mask=variable_mask_f,
    [292](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:292)             latlons=latlons_f,
    [293](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:293)             month=month_f,
    [294](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:294)             valid_month=valid_month_f if valid_date_as_token else None,
    [295](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:295)         )
    [296](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:296)         .cpu()
    [297](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:297)         .numpy()
    [298](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:298)     )
    [300](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:300) # Convert to dataframe
    [301](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/worldcereal/train/data.py:301) attrs = pd.DataFrame.from_dict(attrs)

File ~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1544)     result = None

File ~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/presto/presto.py:439, in Encoder.forward(self, x, dynamic_world, latlons, mask, month, valid_month, eval_task)
    [436](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/presto/presto.py:436) all_tokens, all_masks = [], []
    [438](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/presto/presto.py:438) for channel_group, channel_idxs in self.band_groups.items():
--> [439](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/presto/presto.py:439)     tokens = self.eo_patch_embed[channel_group](x[:, :, channel_idxs])
    [441](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/presto/presto.py:441)     channel_embedding = self.channel_embed(
    [442](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/presto/presto.py:442)         torch.tensor(self.band_group_to_idx[channel_group]).long().to(device)
    [443](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/presto/presto.py:443)     )
    [445](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/presto/presto.py:445)     channel_embedding = repeat(channel_embedding, "d -> b t d", b=x.shape[0], t=x.shape[1])

File ~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1531) else:
-> [1532](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1543) try:
   [1544](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/module.py:1544)     result = None

File ~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input)
    [115](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/linear.py:115) def forward(self, input: Tensor) -> Tensor:
--> [116](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/home/username/dev_lx/worldcereal-classification/notebooks/~/anaconda3/envs/WorldCereal/lib/python3.11/site-packages/torch/nn/modules/linear.py:116)     return F.linear(input, self.weight, self.bias)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
kvantricht commented 2 hours ago

Thank you for reporting this. It seems there is inconsistent use of the device in presto-worldcereal code base when the inference demo runs on a system with access to a GPU. Simplest workaround at the moment would be if you run the notebook on a system without GPU or where you disable the GPU visibility to torch (or use the Terrascope option to run on a pre-configured system) Meanwhile we'll work on a more permanent fix.