BiomedSciAI / fuse-med-ml

A python framework accelerating ML based discovery in the medical field by encouraging code reuse. Batteries included :)
Apache License 2.0
137 stars 34 forks source link

bug in resnet3d following latest change #307

Closed itaijj closed 1 year ago

itaijj commented 1 year ago

Describe the bug\ after the latest change https://github.com/BiomedSciAI/fuse-med-ml/commits/911eac3b91112409e494ade8d7953d27113cb79e/fuse/dl/models/backbones/backbone_resnet_3d.py got this error return F.batch_norm( File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/functional.py", line 2438, in batch_norm return torch.batch_norm( RuntimeError: running_mean should contain 64 elements not 32.

going to previous version did not produce crash

FuseMedML version\ 0.30

Python version\ Exact Python version used. E.g. 3.8.13

To reproduce\ I got this error when using these resnet config resnet_kwargs: first_channel_dim: 32 first_stride: 2 stem_kernel_size : [3, 3, 3] stem_stride : [2, 2, 2] layers: [2, 2, 2, 2]

Expected behavior\ no crash

Screenshots\ If applicable, add screenshots to help explain your problem. Make sure not to include any sensitive information.

Additional context\ complete log Traceback (most recent call last): File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/clearml/binding/hydra_bind.py", line 173, in _patched_task_function return task_function(a_config, *a_args, a_kwargs) File "/dccstor/mm_hcls/guez/ukbb/ukbb/runner.py", line 450, in main runner_main(cfg) File "/dccstor/mm_hcls/guez/ukbb/ukbb/runner.py", line 80, in runner_main to_return = run_train(cfg["task"], cfg["model"]) File "/dccstor/mm_hcls/guez/ukbb/ukbb/runner.py", line 203, in run_train pl_trainer.fit(pl_module, data_module, ckpt_path=ckpt_path) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py", line 555, in safe_patch_function patch_function(call_original, *args, *kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py", line 254, in patch_with_managed_run result = patch_function(original, args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/mlflow/pytorch/_pytorch_autolog.py", line 370, in patched_fit result = original(self, *args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py", line 536, in call_original return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py", line 471, in call_original_fn_with_event_logging original_fn_result = original_fn(*og_args, *og_kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py", line 533, in _original_fn original_result = original(_og_args, _og_kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit call._call_and_handle_interrupt( File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 88, in launch return function(*args, *kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl self._run(model, ckpt_path=self.ckpt_path) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1103, in _run results = self._run_stage() File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run_stage self._run_train() File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1205, in _run_train self.fit_loop.run() File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance self._outputs = self.epoch_loop.run(self._data_fetcher) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 213, in advance batch_output = self.batch_loop.run(kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, *kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance outputs = self.optimizer_loop.run(optimizers, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 202, in advance result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position]) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 249, in _run_optimization self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 370, in _optimizer_step self.trainer._call_lightning_module_hook( File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1347, in _call_lightning_module_hook output = fn(args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1744, in optimizer_step optimizer.step(closure=optimizer_closure) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 169, in step step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/strategies/ddp.py", line 280, in optimizer_step optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 234, in optimizer_step return self.precision_plugin.optimizer_step( File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 119, in optimizer_step return optimizer.step(closure=closure, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper return wrapped(args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/optim/optimizer.py", line 113, in wrapper return func(*args, *kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/optim/adamw.py", line 119, in step loss = closure() File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 105, in _wrap_closure closure_result = closure() File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 149, in call self._result = self.closure(*args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 135, in closure step_output = self._step_fn() File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 419, in _training_step training_step_output = self.trainer._call_strategy_hook("training_step", kwargs.values()) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1485, in _call_strategy_hook output = fn(args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/strategies/ddp.py", line 351, in training_step return self.model(*args, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward output = self._run_ddp_forward(inputs, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward return module_to_run(*inputs[0], kwargs[0]) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/pytorch_lightning/overrides/base.py", line 98, in forward output = self._forward_module.training_step(inputs, kwargs) File "/dccstor/mm_hcls/guez/fuse-med-ml/fuse/dl/lightning/pl_module.py", line 146, in training_step batch_dict = self.forward(batch_dict) File "/dccstor/mm_hcls/guez/fuse-med-ml/fuse/dl/lightning/pl_module.py", line 139, in forward return self._model(batch_dict) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/dccstor/mm_hcls/guez/ukbb/ukbb/arch/multi_modal_model.py", line 164, in forward feat = encoder(processed_data) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/dccstor/mm_hcls/guez/fuse-med-ml/fuse/dl/models/backbones/backbone_resnet_3d.py", line 219, in forward return self.features(x) File "/dccstor/mm_hcls/guez/fuse-med-ml/fuse/dl/models/backbones/backbone_resnet_3d.py", line 201, in features x = self.stem(x) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward return F.batch_norm( File "/dccstor/mm_hcls/guez/anaconda3/envs/mmt/lib/python3.9/site-packages/torch/nn/functional.py", line 2438, in batch_norm return torch.batch_norm( RuntimeError: running_mean should contain 64 elements not 32

SagiPolaczek commented 1 year ago

Thanks!

SagiPolaczek commented 1 year ago

Solved at #309

TY! @itaijj @liamhazan