talmolab / dreem

DREEM Relates Every Entities' Motion (DREEM). Global Tracking Transformers for biological multi-object tracking.
https://dreem.sleap.ai
BSD 3-Clause "New" or "Revised" License
7 stars 1 forks source link

Torch Conv error: 'Given groups=1, weight of size [64, 3, 7, 7], expected input[839, 6, 115, 115] to have 3 channels, but got 6 channels instead' #91

Open YiLin-Zhou opened 3 days ago

YiLin-Zhou commented 3 days ago

Bug description

Running dreem-train on SLEAP data results in this error log:

Sanity Checking: |                                                                                      | 0/? [00:00<?, ?it/s]C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
Epoch 0:   0%|                                                                                         | 0/80 [00:00<?, ?it/s][2024-10-18 08:10:18,320][dreem.models][ERROR] - Failed on frame tensor([0], device='cuda:0') of video tensor([14], device='cuda:0')
Traceback (most recent call last):
  File "C:\Users\ylzhou\dreem\dreem\models\gtr_runner.py", line 193, in _shared_eval_step
    logits = self(instances)
             ^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\dreem\dreem\models\gtr_runner.py", line 97, in forward
    asso_preds = self.model(ref_instances, query_instances)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\dreem\dreem\models\global_tracking_transformer.py", line 93, in forward
    self.extract_features(ref_instances)
  File "C:\Users\ylzhou\dreem\dreem\models\global_tracking_transformer.py", line 127, in extract_features
    features = self.visual_encoder(crops)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\dreem\dreem\models\visual_encoder.py", line 141, in forward
    feats = self.feature_extractor(
            ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\timm\models\resnet.py", line 635, in forward
    x = self.forward_features(x)
        ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\timm\models\resnet.py", line 614, in forward_features
    x = self.conv1(x)
        ^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\conv.py", line 458, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\miniconda3\envs\dreem\Lib\site-packages\torch\nn\modules\conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[839, 6, 115, 115] to have 3 channels, but got 6 channels instead
[2024-10-18 08:10:18,339][dreem.models][ERROR] - Given groups=1, weight of size [64, 3, 7, 7], expected input[839, 6, 115, 115] to have 3 channels, but got 6 channels instead
...
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
[2024-10-18 08:10:18,714][imageio_ffmpeg][WARNING] - We had to kill ffmpeg to stop it.
...

My config files:

model:
  ckpt_path: null
  encoder_cfg: 
    model_name: "resnet18"
    backend: "timm"
    in_chans: 3
  d_model: 1024
  nhead: 8
  num_encoder_layers: 1
  num_decoder_layers: 1
  dropout: 0.1
  activation: "relu"
  return_intermediate_dec: False
  norm: False
  num_layers_attn_head: 2
  dropout_attn_head: 0.1
  embedding_meta: 
    pos:
        mode: "learned"
        normalize: true
    temp:
        mode: "learned"
  return_embedding: False
  decoder_self_attn: False

loss:
  neg_unmatched: false
  epsilon: 1e-4
  asso_weight: 1.0

#currently assumes adam. TODO adapt logic for other optimizers like sgd
optimizer:
  name: "Adam"
  lr: 0.001
  betas: [0.9, 0.999]
  eps: 1e-8
  weight_decay: 0.01

#currently assumes reduce lr on plateau
scheduler:
  name: "ReduceLROnPlateau"
  mode: "min"
  factor: 0.5
  patience: 10
  threshold: 1e-4
  threshold_mode: "rel"

tracker:
  window_size: 10
  use_vis_feats: true
  overlap_thresh: 0.01
  mult_thresh: true
  decay_time: null
  iou: null
  max_center_dist: 0.1

runner:
  metrics:
      train: ['num_switches']
      val: ['num_switches']
      test: ['num_switches']
  persistent_tracking:
      train: false
      val: false
      test: true

dataset:
  train_dataset:
      slp_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/train/CC24.1_fr1-30.slp", ...]
      video_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/train/cc18.4_fr1-30.avi", ...]
      padding: 5
      crop_size: 105
      chunk: false
      anchors: ["gut1", "gut2"]
      handle_missing: "centroid"

  val_dataset:
      slp_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/val/CC21.2_fr1-30.slp", ...]
      video_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/val/CC21.2_fr1-30.avi", ...]
      padding: 5
      crop_size: 105
      chunk: false
      handle_missing: "centroid"

  test_dataset:
      dir:
            path: '/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/test'
            labels_suffix: '.slp'
            vid_suffix: '.avi'
      slp_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/test/CC2.1.slp"]
      video_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/test/CC2.1.avi"]
      padding: 5
      crop_size: 105
      chunk: False
      handle_missing: "centroid"

dataloader:
  train_dataloader:
    shuffle: true
    num_workers: 0
  val_dataloader:
    shuffle: false
    num_workers: 0
  test_dataloader:
    shuffle: false
    num_workers: 0

logging:
  logger_type: null
  name: "100example_augmentation_00"
  entity: null
  job_type: "train"
  notes: "100 examplea w/ ugmentation job 00"
  dir: "./logs"
  group: "logging"
  save_dir: './logs'
  project: "GTR"
  log_model: "all"

early_stopping:
  monitor: "val_loss"
  min_delta: 0.1
  patience: 10
  mode: "min"
  check_finite: true
  stopping_threshold: 1e-8
  divergence_threshold: 30

checkpointing:
  monitor: ["val_loss","val_num_switches"]
  verbose: true
  save_last: true
  dirpath: '.'
  auto_insert_metric_name: true
  every_n_epochs: 2

trainer:
  check_val_every_n_epoch: 1
  enable_checkpointing: true
  gradient_clip_val: null
  limit_train_batches: 1.0
  limit_test_batches: 1.0
  limit_val_batches: 1.0
  log_every_n_steps: 1
  max_epochs: 26
  min_epochs: 8

view_batch:
  enable: False
  num_frames: 0
  no_train: False
aaprasad commented 3 days ago

@YiLin-Zhou Thanks for moving the issue. So the reason this is happening is because you're using multiple anchors for the crops in your train dataset.

dataset:
  train_dataset:
      slp_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/train/CC24.1_fr1-30.slp", ...]
      video_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/train/cc18.4_fr1-30.avi", ...]
      padding: 5
      crop_size: 105
      chunk: false
      anchors: ["gut1", "gut2"]
      handle_missing: "centroid"

The way that the anchors argument works with sleap files is that for each anchor/node in the list we create a node-centered crop for that anchor and then concatenate them along the channels dimension. Thus, the image input into the model is really $(n{frames}, 3n{anchors}, h, w)$. In order to resolve this you can either

  1. only pass a single anchor. So far in our experiments we've found that the single node crop generally works better than multianchor crop.
  2. in your model.encoder_cfg set in_chans to $3 \cdot n_{nodes}$ so in your case it would be
    model:
    ckpt_path: null
    encoder_cfg: 
    model_name: "resnet18"
    backend: "timm"
    in_chans: 6 # was 3 before

    By the way I noticed in your val_dataset and test_dataset you didn't specify any anchors. This will also cause an issue but for the other way around. ie the img crops will only have shape $(n{frames}, 3, h, w)$ but your model would expect inputs with shape $(n{frames}, 6, h, w)$. Also in this case your model will be validated on the images centered around the pose centroid rather than the anchor you wanted as the default when anchors is not specified is the centroid of the pose