microsoft / satclip

PyTorch implementation of SatCLIP
MIT License
190 stars 19 forks source link

Number of input channels #12

Closed PlekhanovaElena closed 2 months ago

PlekhanovaElena commented 2 months ago

Hi there,

I'm trying to reproduce the pre-training of the SatClip based on S100 datset. I downloaded S100 and changed the paths in the config file default.yaml and in s2geo_dataset.py. Now, this is the output error that I'm trying to solve:

using vision transformer
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
skipped 8142/100000 images because they were smaller than 10000 bytes... they probably contained nodata pixels
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type        | Params
-----------------------------------------
0 | model    | SatCLIP     | 88.9 M
1 | loss_fun | SatCLIPLoss | 0     
-----------------------------------------
88.9 M    Trainable params
0         Non-trainable params
88.9 M    Total params
355.445   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type        | Params
-----------------------------------------
0 | model    | SatCLIP     | 88.9 M
1 | loss_fun | SatCLIPLoss | 0     
-----------------------------------------
88.9 M    Trainable params
0         Non-trainable params
88.9 M    Total params
355.445   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.
Sanity Checking DataLoader 0:   0%|                                                                                                                   | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/data/eplekh/code/satclip/satclip/main.py", line 159, in <module>
    cli_main(config_fn)
  File "/data/eplekh/code/satclip/satclip/main.py", line 144, in cli_main
    cli.trainer.fit(
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1031, in _run_stage
    self._run_sanity_check()
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1060, in _run_sanity_check
    val_loop.run()
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "/data/eplekh/code/satclip/satclip/main.py", line 75, in validation_step
    loss = self.common_step(batch, batch_idx)
  File "/data/eplekh/code/satclip/satclip/main.py", line 66, in common_step
    logits_per_image, logits_per_coord = self.model(images, t_points)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/eplekh/code/satclip/satclip/model.py", line 365, in forward
    image_features = self.encode_image(image)     
  File "/data/eplekh/code/satclip/satclip/model.py", line 358, in encode_image
    return self.visual(image.type(self.dtype))
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/eplekh/code/satclip/satclip/model.py", line 230, in forward
    x = self.conv1(x)  # shape = [*, width, grid, grid]
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [768, 4, 32, 32], expected input[64, 13, 256, 256] to have 4 channels, but got 13 channels instead

It seems the images are found, but somehow the CNN expects 4 channels and get 13, and I'm not sure why. I tried to change in the ./satclip/configs/default.yaml the row in_channels: 4 to in_channels: 13, but this did not help. Also the file "/data/eplekh/code/satclip/lightning_logs/version_14077673/./configs/default-latest.yaml" that is created while running the script contains in_channels: 4 despite I changed the ./satclip/configs/default.yaml. This might be the reason, but I don't know how to fix it.

Small additional question: is "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]" okay output or does it mean that it does not see the GPU?

Would very much appreciate any help, Kind rehards, Elena

konstantinklemmer commented 2 months ago

Are you using the default.yaml from here: https://github.com/microsoft/satclip/blob/main/satclip/configs/default.yaml?

You'd need to change in_channels to 13, then it should run. If you want to use a pretrained vision encoder you need to change vision_layer to e.g. moco_resnet50.

For more details on the vision encoders and how they are used check: https://github.com/microsoft/satclip/blob/main/satclip/model.py

PlekhanovaElena commented 2 months ago

Thanks for your reply! Changing parameters in default.yaml didn't help, but changing them in the class SatCLIPLightningModule of main.py script actually did. So it's running now, thank you!

konstantinklemmer commented 2 months ago

Great!