kreshuklab / plant-seg

A tool for cell instance aware segmentation in densely packed 3D volumetric images
https://kreshuklab.github.io/plant-seg/
MIT License
89 stars 31 forks source link

Non-OOM error raised by PyTorch during `find_batch_size()` #286

Open OriKovacsiKatz opened 1 month ago

OriKovacsiKatz commented 1 month ago

running plantseg example col-0_20161116 getting crash:

(plant-seg-dev) user-name@kitchen_computer:/home/user-name/plant-seg# python -m run_plantseg --config=/home/user-name/data/plantseg_config_2016.yaml
You are using the latest version of PlantSeg: 1.8.1.
2024-07-29 12:00:02,666 [MainThread] INFO PlantSeg - Running the pipeline on: ['/home/user-name/data/col-0_20161116/20161116.tif']
2024-07-29 12:00:02,666 [MainThread] INFO PlantSeg - Executing pipeline, see terminal for verbose logs.
2024-07-29 12:00:05,357 [MainThread] INFO PlantSeg - Executing pipeline step: 'preprocessing'. Parameters: '{'state': False, 'save_directory': 'PreProcessing', 'factor': [1.0, 1.0, 1.0], 'order': 2, 'crop_volume': '[:, :, :]', 'filter': {'state': False, 'type': 'gaussian', 'filter_param': 1.0}}'. Files ['/home/user-name/data/col-0_20161116/20161116.tif'].
2024-07-29 12:00:05,357 [MainThread] INFO PlantSeg - Skipping 'DataPreProcessing3D'. Disabled by the user.
2024-07-29 12:00:05,357 [MainThread] INFO PlantSeg - Executing pipeline step: 'cnn_prediction'. Parameters: '{'state': False, 'model_name': 'generic_confocal_3D_unet', 'device': 'cuda', 'patch': [80, 160, 160], 'stride_ratio': 0.75, 'patch_halo': [4, 8, 8], 'model_update': True, 'num_workers': 8}'. Files ['/home/user-name/data/col-0_20161116/20161116.tif'].
2024-07-29 12:00:05,389 [MainThread] INFO PlantSeg - File config_train.yml already exists. Skipping download.
2024-07-29 12:00:05,389 [MainThread] INFO PlantSeg - File best_checkpoint.pytorch already exists. Skipping download.
2024-07-29 12:00:05,389 [MainThread] INFO PlantSeg Zoo - Loaded model from PlantSeg zoo: generic_confocal_3D_unet
2024-07-29 12:00:05,496 [MainThread] INFO PlantSeg Zoo - Loaded model from user specified weights: /user-name/.plantseg_models/generic_confocal_3D_unet/best_checkpoint.pytorch
/home/user-name/plant-seg/plantseg/predictions/predict.py:80: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state = torch.load(model_path, map_location='cpu')
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/user-name/plant-seg/run_plantseg.py", line 4, in <module>
    main()
  File "/home/user-name/plant-seg/plantseg/run_plantseg.py", line 96, in main
    process_config(args.config)
  File "/home/user-name/plant-seg/plantseg/run_plantseg.py", line 77, in process_config
    raw2seg(config)
  File "/home/user-name/plant-seg/plantseg/pipeline/raw2seg.py", line 148, in raw2seg
    pipeline_step = pipeline_step_setup(input_paths, config[pipeline_step_name])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/pipeline/raw2seg.py", line 55, in configure_cnn_step
    return UnetPredictions(
           ^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/predictions/predict.py", line 93, in __init__
    self.predictor = ArrayPredictor(
                     ^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/predictions/functional/array_predictor.py", line 181, in __init__
    self.batch_size = find_batch_size(model, in_channels, patch, patch_halo, device)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/predictions/functional/array_predictor.py", line 51, in find_batch_size
    _ = model(x)
        ^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/training/model.py", line 500, in forward
    x = decoder(encoder_features, x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/training/model.py", line 294, in forward
    x = self.upsampling(encoder_features=encoder_features, x=x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/training/model.py", line 377, in forward
    return self.upsample(x, output_size)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/training/model.py", line 394, in _interpolate
    return F.interpolate(x, size=size, mode=mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/plant-seg-dev/lib/python3.11/site-packages/torch/nn/functional.py", line 4052, in interpolate
    return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected output.numel() <= std::numeric_limits<int32_t>::max() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
(plant-seg-dev) user-name@kitchen_computer:/home/user-name/plant-seg# 

modified code to print details:

  File "/home/user-name/plant-seg/plantseg/predictions/functional/array_predictor.py", line 181, in __init__
    self.batch_size = find_batch_size(model, in_channels, patch, patch_halo, device)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user-name/plant-seg/plantseg/predictions/functional/array_predictor.py", line 51, in find_batch_size

added print debugging details 
    with torch.no_grad():
        for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
            try:
                # ====================================================================
                print("# [Ori Kovacsi-Katz]: attempt setting batch_size ",sep='=')
                print(batch_size, sep=',')
                print("in_channels", sep='=')
                print(in_channels, sep=',')
                print("actual_patch_shape", sep='=')
                print(actual_patch_shape, sep=',')
                print("device", sep='=')
                print(device)
                # ====================================================================            
                x = torch.randn((batch_size, in_channels) + actual_patch_shape).to(device)
                _ = model(x)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print("# [Ori Kovacsi-Katz]: out of memory exception while attempt setting batch_size ")
                    batch_size //= 2
                    break
                else:
                    print("# [Ori Kovacsi-Katz]: Other then out of memory exception while attempt setting batch_size ")                
                    print(e)
                    raise
            finally:
                del x
                torch.cuda.empty_cache()

it crashed at batch_size=16

(plant-seg-dev) root@n311:/home/lahavt/plant-seg# python -m run_plantseg --config=/home/lahavt/data/plantseg_config_2016.yaml
You are using the latest version of PlantSeg: 1.8.1.
2024-07-30 10:26:52,559 [MainThread] INFO PlantSeg - Running the pipeline on: ['/home/lahavt/data/col-0_20161116/20161116.tif']
2024-07-30 10:26:52,559 [MainThread] INFO PlantSeg - Executing pipeline, see terminal for verbose logs.
2024-07-30 10:26:55,362 [MainThread] INFO PlantSeg - Executing pipeline step: 'preprocessing'. Parameters: '{'state': True, 'save_directory': 'PreProcessing', 'factor': [1.0, 1.0, 1.0], 'order': 2, 'crop_volume': '[:, :, :]', 'filter': {'state': False, 'type': 'gaussian', 'filter_param': 1.0}}'. Files ['/home/lahavt/data/col-0_20161116/20161116.tif'].
# [Ori Kovacsi-Katz DEBUGGING]
Executing pipeline step: 'preprocessing'. Parameters: '{'state': True, 'save_directory': 'PreProcessing', 'factor': [1.0, 1.0, 1.0], 'order': 2, 'crop_volume': '[:, :, :]', 'filter': {'state': False, 'type': 'gaussian', 'filter_param': 1.0}}'. Files ['/home/lahavt/data/col-0_20161116/20161116.tif'].
2024-07-30 10:26:55,363 [MainThread] INFO PlantSeg - Loading stack from /home/lahavt/data/col-0_20161116/20161116.tif
2024-07-30 10:27:04,145 [MainThread] INFO PlantSeg - Preprocessing files...
2024-07-30 10:27:04,146 [MainThread] INFO PlantSeg - Cropping input image to: [:, :, :]
2024-07-30 10:27:04,147 [MainThread] INFO PlantSeg - Saving results in /home/lahavt/data/col-0_20161116/PreProcessing/20161116.h5
2024-07-30 10:27:33,122 [MainThread] INFO PlantSeg - Executing pipeline step: 'cnn_prediction'. Parameters: '{'state': False, 'model_name': 'generic_confocal_3D_unet', 'device': 'cuda', 'patch': [80, 160, 160], 'stride_ratio': 0.75, 'patch_halo': [4, 8, 8], 'model_update': True, 'num_workers': 8}'. Files ['/home/lahavt/data/col-0_20161116/PreProcessing/20161116.h5'].
# [Ori Kovacsi-Katz DEBUGGING]
Executing pipeline step: 'cnn_prediction'. Parameters: '{'state': False, 'model_name': 'generic_confocal_3D_unet', 'device': 'cuda', 'patch': [80, 160, 160], 'stride_ratio': 0.75, 'patch_halo': [4, 8, 8], 'model_update': True, 'num_workers': 8}'. Files ['/home/lahavt/data/col-0_20161116/PreProcessing/20161116.h5'].
2024-07-30 10:27:33,153 [MainThread] INFO PlantSeg - File config_train.yml already exists. Skipping download.
2024-07-30 10:27:33,154 [MainThread] INFO PlantSeg - File best_checkpoint.pytorch already exists. Skipping download.
2024-07-30 10:27:33,154 [MainThread] INFO PlantSeg Zoo - Loaded model from PlantSeg zoo: generic_confocal_3D_unet
2024-07-30 10:27:33,239 [MainThread] INFO PlantSeg Zoo - Loaded model from user specified weights: /home/lahavt/tmp/.plantseg_models/generic_confocal_3D_unet/best_checkpoint.pytorch
# [Ori Kovacsi-Katz]: modified to be weights_only=True
# [Ori Kovacsi-Katz]: before setting self.batch_size... using find_batch_size(...)  
# [Ori Kovacsi-Katz]: attempt setting batch_size 
1
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: attempt setting batch_size 
2
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: attempt setting batch_size 
4
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: attempt setting batch_size 
8
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: attempt setting batch_size 
16
in_channels
1
actual_patch_shape
(88, 176, 176)
device
cuda
# [Ori Kovacsi-Katz]: Other then out of memory exception while attempt setting batch_size 
Expected output.numel() <= std::numeric_limits<int32_t>::max() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/lahavt/plant-seg/run_plantseg.py", line 4, in <module>
    main()
  File "/home/lahavt/plant-seg/plantseg/run_plantseg.py", line 96, in main
    process_config(args.config)
  File "/home/lahavt/plant-seg/plantseg/run_plantseg.py", line 77, in process_config
    raw2seg(config)
  File "/home/lahavt/plant-seg/plantseg/pipeline/raw2seg.py", line 150, in raw2seg
    pipeline_step = pipeline_step_setup(input_paths, config[pipeline_step_name])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

changed the sizes to maximal number 8 and it didn't crash

    with torch.no_grad():
        for batch_size in [1, 2, 4, 8]:

how can I fix the plant-seg/plantseg/predictions/functional/array_predictor.py line 51 so it will not crash the plantseg execution with all batch_sizes :

       for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: ?

thanks Ori

qin-yu commented 1 month ago

I see. Thanks for reporting this @OriKovacsiKatz

The reason why it failed is because PyTorch didn't raise OOM but raised RuntimeError: Expected output.numel() <= std::numeric_limits<int32_t>::max() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) while I decided to only allow OOM errors:

https://github.com/kreshuklab/plant-seg/blob/6c90c750469b22efee1e49716b4e041ef91683de/plantseg/predictions/functional/array_predictor.py#L48-L57

I do not understand why it raised something else. For now, yes, just change the list to have max 8 would work. But removing line 53, 54, 56, 57 would be the real solution if no OOM happened and the error message was correct. If what really happened was OOM, but PyTorch reports this, then we report an enhancement request to PyTorch) as suggested.

I'll keep this issue open until we figure this out.

qin-yu commented 1 month ago

"Legacy" tag because Napari GUI has workaround (single patch mode).

qin-yu commented 1 month ago

Just formatted the issue for readability.

qin-yu commented 1 month ago

Just in case I didn't sound encouraging, @OriKovacsiKatz you are very welcomed to check if OOM really happens in your device and then make a PR for PlantSeg and/or an issue for PyTorch. The easy way is just to stare at the terminal of PlantSeg and watch nvidia-smi together