Clay-foundation / model

The Clay Foundation Model (in development)
https://clay-foundation.github.io/model/
Apache License 2.0
262 stars 30 forks source link

Notebooks are broken in Main #134

Closed MaceGrim closed 5 months ago

MaceGrim commented 6 months ago
clay-v0-interpolation.ipynb
clay-v0-reconstruction.ipynb
clay-v0-location-embeddings.ipynb

Are currently broken in main. This was brought up in this PR https://github.com/Clay-foundation/model/pull/118 but it came through in the merge.

It seems like it's a change in the encoder that's breaking things. It's this same cell for all 3 notebooks.

Here's the cell that causes the error:

# Pass the pixels through the encoder & decoder of CLAY
with torch.no_grad():
    # Move data from to the device of model
    batch["pixels"] = batch["pixels"].to(model.device)
    batch["timestep"] = batch["timestep"].to(model.device)
    batch["latlon"] = batch["latlon"].to(model.device)

    # Pass pixels, latlon, timestep through the encoder to create encoded patches
    (
        unmasked_patches,
        unmasked_indices,
        masked_indices,
        masked_matrix,
    ) = model.model.encoder(batch)

    # Pass the unmasked_patches through the decoder to reconstruct the pixel space
    pixels = model.model.decoder(unmasked_patches, unmasked_indices, masked_indices)

Here's an excerpt of the error:

File [~/miniforge3/envs/claymodel/lib/python3.11/site-packages/torch/nn/modules/conv.py:456](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/o/Ode/Github/model_main/docs/~/miniforge3/envs/claymodel/lib/python3.11/site-packages/torch/nn/modules/conv.py:456), in Conv2d._conv_forward(self, input, weight, bias)
    [452](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/o/Ode/Github/model_main/docs/~/miniforge3/envs/claymodel/lib/python3.11/site-packages/torch/nn/modules/conv.py:452) if self.padding_mode != 'zeros':
    [453](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/o/Ode/Github/model_main/docs/~/miniforge3/envs/claymodel/lib/python3.11/site-packages/torch/nn/modules/conv.py:453)     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    [454](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/o/Ode/Github/model_main/docs/~/miniforge3/envs/claymodel/lib/python3.11/site-packages/torch/nn/modules/conv.py:454)                     weight, bias, self.stride,
    [455](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/o/Ode/Github/model_main/docs/~/miniforge3/envs/claymodel/lib/python3.11/site-packages/torch/nn/modules/conv.py:455)                     _pair(0), self.dilation, self.groups)
--> [456](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/o/Ode/Github/model_main/docs/~/miniforge3/envs/claymodel/lib/python3.11/site-packages/torch/nn/modules/conv.py:456) return F.conv2d(input, weight, bias, self.stride,
    [457](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/o/Ode/Github/model_main/docs/~/miniforge3/envs/claymodel/lib/python3.11/site-packages/torch/nn/modules/conv.py:457)                 self.padding, self.dilation, self.groups)

RuntimeError: Input type (c10::Half) and bias type (float) should be the same
brunosan commented 6 months ago

This seems to solve the issue, but inference now is extremely slow, at least 20x slower than 1 month ago. (My GPU was not recognized)

https://github.com/Clay-foundation/model/commit/0708614640feb2c6b9476820df848f6cf89957fd

brunosan commented 6 months ago

Update, the patch works. The speed was my problem (GPU was offline).

MaceGrim commented 6 months ago

Hmm, it's still not working for me even after checking out the patch branch.

image