talmolab / dreem

DREEM Relates Every Entities' Motion (DREEM). Global Tracking Transformers for biological multi-object tracking.
https://dreem.sleap.ai
BSD 3-Clause "New" or "Revised" License
7 stars 1 forks source link

DictConfig in struct mode does not support pop #85

Open YiLin-Zhou opened 2 days ago

YiLin-Zhou commented 2 days ago

Bug description

Running 'dreem-train' on sleap data produces two DictConfig errors: omegaconf.errors.ConfigTypeError: DictConfig in struct mode does not support pop

The rest of the training does not run once one of these error is thrown.

Hacky fix

Changing dreem\models\gtrrunner.py line 68 ` = self.model_cfg.pop("ckpt_path", None)` to the following:

try:
       _ = self.model_cfg.__delattr__("ckpt_path")
except Exception:
       pass

fixed this issue by allowing delete if "ckpt_path" exists and otherwise passing. Obviously a hacky, code-smelly-solution but it worked. An analogous solution in config.py also worked.

Terminal error logs

Log1:

Error executing job with overrides: []
Traceback (most recent call last):
  File "C:\Users\ylzhou\dreem\dreem\training\train.py", line 76, in run
    model = train_cfg.get_gtr_runner()  # TODO see if we can use torch.compile()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\dreem\dreem\io\config.py", line 148, in get_gtr_runner
    model = GTRRunner(
            ^^^^^^^^^^
  File "C:\Users\ylzhou\dreem\dreem\models\gtr_runner.py", line 68, in __init__
    _ = self.model_cfg.pop("ckpt_path", None)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
omegaconf.errors.ConfigTypeError: DictConfig in struct mode does not support pop
    full_key: model.ckpt_path
    object_type=dict

log2:

Error executing job with overrides: []
Traceback (most recent call last):
  File "C:\Users\ylzhou\dreem\dreem\training\train.py", line 81, in run
    _ = callbacks.extend(train_cfg.get_checkpointing())
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ylzhou\dreem\dreem\io\config.py", line 372, in get_checkpointing
    dirpath = checkpoint_params.pop("dirpath", None)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
omegaconf.errors.ConfigTypeError: DictConfig in struct mode does not support pop
    full_key: checkpointing.dirpath
    object_type=dict
aaprasad commented 1 day ago

Hi @YiLin-Zhou, thanks for the report! Yes, I believe this may be an edge case we missed although it would be nice to know what exactly triggered it. Could you share the config file you used to run training? Also, while your solution would work, the correct approach is to use omegaconf.open_dict as a context-manager. For example see here:

def get_model(self) -> "GlobalTrackingTransformer":
        """Getter for gtr model.

        Returns:
            A global tracking transformer with parameters indicated by cfg
        """
        from dreem.models import GlobalTrackingTransformer

        model_params = self.cfg.model
        with open_dict(model_params):
            ckpt_path = model_params.pop("ckpt_path", None)

        if ckpt_path is not None and len(ckpt_path) > 0:
            return GTRRunner.load_from_checkpoint(ckpt_path).model

        return GlobalTrackingTransformer(**model_params)
YiLin-Zhou commented 14 hours ago

Sure, here's the config file I used for training:

model:
  ckpt_path: null
  encoder_cfg: 
    model_name: "resnet18"
    backend: "timm"
    in_chans: 3
  d_model: 1024
  nhead: 8
  num_encoder_layers: 1
  num_decoder_layers: 1
  dropout: 0.1
  activation: "relu"
  return_intermediate_dec: False
  norm: False
  num_layers_attn_head: 2
  dropout_attn_head: 0.1
  embedding_meta: 
    pos:
        mode: "learned"
        normalize: true
    temp:
        mode: "learned"
  return_embedding: False
  decoder_self_attn: False

loss:
  neg_unmatched: false
  epsilon: 1e-4
  asso_weight: 1.0

#currently assumes adam. TODO adapt logic for other optimizers like sgd
optimizer:
  name: "Adam"
  lr: 0.001
  betas: [0.9, 0.999]
  eps: 1e-8
  weight_decay: 0.01

#currently assumes reduce lr on plateau
scheduler:
  name: "ReduceLROnPlateau"
  mode: "min"
  factor: 0.5
  patience: 10
  threshold: 1e-4
  threshold_mode: "rel"

tracker:
  window_size: 10
  use_vis_feats: true
  overlap_thresh: 0.01
  mult_thresh: true
  decay_time: null
  iou: null
  max_center_dist: 0.1

runner:
  metrics:
      train: ['num_switches']
      val: ['num_switches']
      test: ['num_switches']
  persistent_tracking:
      train: false
      val: false
      test: true

dataset:
  train_dataset:
      slp_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/train/CC24.1_fr1-30.slp", ... , "/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/train/CC23.2_fr961-990.slp"]
      video_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/train/CC24.1_fr1-30.avi", ... , "/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/train/CC23.2_fr961-990.avi"]
      padding: 5
      crop_size: 105
      chunk: false
      handle_missing: "centroid"

  val_dataset:
      slp_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/val/CC21.2_fr1-30.slp", ... , "/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/val/CC39.2_fr961-990.slp"]
      video_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/val/CC21.2_fr1-30.avi", ... , "/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/val/CC39.2_fr961-990.avi"]
      padding: 5
      crop_size: 105
      chunk: false
      handle_missing: "centroid"

  test_dataset:
      dir:
            path: '/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/test'
            labels_suffix: '.slp'
            vid_suffix: '.avi'
      slp_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/test/CC2.1.slp"]
      video_files: ["/Users/ylzhou/Documents/DREEM/larva_dreem_100examples_dataset/test/CC2.1.avi"]
      padding: 5
      crop_size: 105
      chunk: False
      handle_missing: "centroid"

dataloader:
  train_dataloader:
    shuffle: true
    num_workers: 0
  val_dataloader:
    shuffle: false
    num_workers: 0
  test_dataloader:
    shuffle: false
    num_workers: 0

logging:
  logger_type: null
  name: "100exampletrain_00"
  entity: null
  job_type: "train"
  notes: "100 example train job 00"
  dir: "./logs"
  group: "logging"
  save_dir: './logs'
  project: "GTR"
  log_model: "all"

early_stopping:
  monitor: "val_loss"
  min_delta: 0.1
  patience: 10
  mode: "min"
  check_finite: true
  stopping_threshold: 1e-8
  divergence_threshold: 30

checkpointing:
  monitor: ["val_loss","val_num_switches"]
  verbose: true
  save_last: true
  dirpath: null
  auto_insert_metric_name: true
  every_n_epochs: 2

trainer:
  check_val_every_n_epoch: 1
  enable_checkpointing: true
  gradient_clip_val: null
  limit_train_batches: 1.0
  limit_test_batches: 1.0
  limit_val_batches: 1.0
  log_every_n_steps: 1
  max_epochs: 26
  min_epochs: 8

view_batch:
  enable: False
  num_frames: 0
  no_train: False