alanqrwang / keymorph

Robust multimodal image registration via keypoints
MIT License
67 stars 17 forks source link

Evaluation with run.py fails #13

Closed HastingsGreer closed 6 months ago

HastingsGreer commented 9 months ago

Thank you for publishing this code! run.py fails on my system with

venv) tgreer@biag-w05:/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph$ python run.py --kp_align_method affine --num_keypoints 128 --loss_fn mse --eval                 --load_path ./weights/numkey128_aff_dice.1560.h5
{'affine_slope': -1,
 'batch_size': 1,
 'data_dir': './data/centered_IXI/',
 'dataset': 'ixi',
 'debug_mode': False,
 'dim': 3,
 'epochs': 2000,
 'eval': True,
 'gpus': '0',
 'job_name': 'keymorph',
 'kp_align_method': 'affine',
 'kp_extractor': 'conv_com',
 'kpconsistency_coeff': 0,
 'load_path': './weights/numkey128_aff_dice.1560.h5',
 'log_interval': 25,
 'loss_fn': 'mse',
 'lr': 3e-06,
 'mix_modalities': False,
 'norm_type': 'instance',
 'num_keypoints': 128,
 'num_test_subjects': 100,
 'num_workers': 1,
 'resume': False,
 'save_dir': './output/',
 'save_preds': False,
 'seed': 23,
 'steps_per_epoch': 32,
 'tps_lmbda': None,
 'transform': 'none',
 'use_amp': False,
 'use_wandb': False,
 'visualize': False,
 'wandb_api_key_path': None,
 'wandb_kwargs': {},
 'weighted_kp_align': False}
Number of GPUs: 2
Fixed train dataset has 3 modalities.
-> Modality T1 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality T2 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality PD has 427 subjects (427 images, 427 masks and 0 segmentations)
Moving train dataset has 3 modalities.
-> Modality T1 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality T2 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality PD has 427 subjects (427 images, 427 masks and 0 segmentations)
Test dataset has 3 modalities.
-> Modality T1 has 100 subjects (100 images, 100 masks and 0 segmentations)
-> Modality T2 has 100 subjects (100 images, 100 masks and 0 segmentations)
-> Modality PD has 100 subjects (100 images, 100 masks and 0 segmentations)
/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:32: UserWarning:
    There is an imbalance between your GPUs. You may want to exclude GPU 1 which
    has less than 75% of the memory or cores of GPU 0. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.
  warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))

Model Summary
---------------------------------------------------------------
module.keypoint_extractor.module.block1.conv.weight
module.keypoint_extractor.module.block1.conv.bias
module.keypoint_extractor.module.block2.conv.weight
module.keypoint_extractor.module.block2.conv.bias
module.keypoint_extractor.module.block3.conv.weight
module.keypoint_extractor.module.block3.conv.bias
module.keypoint_extractor.module.block4.conv.weight
module.keypoint_extractor.module.block4.conv.bias
module.keypoint_extractor.module.block5.conv.weight
module.keypoint_extractor.module.block5.conv.bias
module.keypoint_extractor.module.block6.conv.weight
module.keypoint_extractor.module.block6.conv.bias
module.keypoint_extractor.module.block7.conv.weight
module.keypoint_extractor.module.block7.conv.bias
module.keypoint_extractor.module.block8.conv.weight
module.keypoint_extractor.module.block8.conv.bias
module.keypoint_extractor.module.block9.conv.weight
module.keypoint_extractor.module.block9.conv.bias
Total parameters: 8794496
---------------------------------------------------------------

Running test: subject id 0->0, mod T1->T1, aug rot0
Traceback (most recent call last):
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/run.py", line 771, in <module>
    main()
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/run.py", line 616, in main
    grid, points_f, points_m = registration_model(
                               ^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/model.py", line 45, in forward
    points_f, points_m = self.extract_keypoints_step(img_f, img_m)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/model.py", line 69, in extract_keypoints_step
    return self.keypoint_extractor(img1), self.keypoint_extractor(img2)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/net.py", line 85, in forward
    out = self.block1(x)
          ^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/layers.py", line 128, in forward
    out = self.conv(x)
          ^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 613, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 608, in _conv_forward
    return F.conv3d(
           ^^^^^^^^^
TypeError: conv3d() received an invalid combination of arguments - got (tuple, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (tuple of (Tensor,), Parameter, Parameter, tuple of (int, int, int), tuple of (int, int, int), tuple of (int, int, int), int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (tuple of (Tensor,), Parameter, Parameter, tuple of (int, int, int), tuple of (int, int, int), tuple of (int, int, int), int)

Do you have any advice for proceeding?

Janmaking commented 7 months ago

270-280 with torch.set_grad_enabled(True): grid, points_f, points_m, points_a = registration_model( img_f, img_m, lmbda, True )

img_a = align_img(grid, img_m)

        img_a = align_img(grid, img_m[0])
        if seg_available:
            seg_a = align_img(grid, seg_m)