NRCan / geo-deep-learning

Deep learning applied to georeferenced datasets
https://geo-deep-learning.readthedocs.io/en/latest/
MIT License
150 stars 49 forks source link

Memory error - Optimization to increase batch size #9

Closed ymoisan closed 5 years ago

ymoisan commented 6 years ago

When training with samples of size 256x256 pixels, a batch size over 32 causes a memory error from cuda. We have to find a way to optimize the process, in order to increase the batch size.

RunTimeError: cuda runtime error (2) : out of memory at /opt/conda/ .../THCStorage.cu:58

NOTE : May be specific to our (GC HPC) computing environment

ymoisan commented 6 years ago

Can some kind of co-routines/generators approach be useful here ? Are there unnecessary copies of data structures in memory ? Check how to generate your data in parallel with PyTorch.

ymoisan commented 6 years ago

See https://github.com/pytorch/pytorch/issues/5210 and https://pytorch.org/docs/stable/bottleneck.html

epeterson12 commented 6 years ago

Tests were performed using the unet network rather than unetsmall in order to force a memory error. The error occurred on line 102 of unet_pytorch.py under these conditions. Here is a summary of the observations thus far:

ymoisan commented 6 years ago

@epeterson12 - interesting concept that we might want to look at : tensor comprehensions

Also, this ticket might not be all that important after all. PyTorch 1.0 will have a compiler to make it faster, which probably also means better on memory.

ymoisan commented 6 years ago

@epeterson12 : the figures for checkpointing don't seem to be that bad. In fact for one of the models checkpointing actually decreases gpu time. I suggest we create some benchmarks to test. import torch.utils.checkpoint works just fine in our environment so we could try the following:

Memory consumption and processing times could be monitored for all tests. What do you think ?

epeterson12 commented 6 years ago

@ymoisan I think that it would be a good idea to test our net using checkpointing. I will confirm that a batch size > 32 with 256 X 256 samples causes the error. I have started modifying our code in order to test checkpointing using this methodology I mentioned in my previous comment.

epeterson12 commented 6 years ago

It turns out that the maximum batch size that we can currently handle with a sample size of 256x256 is 33 when using the unetsmall net and 15 when using the unet net with our current hardware and version of the code. Beyond these quantities, we get the out of memory error.

epeterson12 commented 6 years ago

Using checkpointing in the unetsmall net increases the speed of training. Tests were performed using the following parameters:

# Training Samples # Validation Samples Sample Size # Classes # Epochs
781 495 256 11 200
Learning Rate Weight Decay Step Size Gamma Class Weights Dropout
0.0001 0 4 0.9 False False

memory_usage_by_batch_size processing_time_by_batch_size

Best results

. Original Checkpointed
Max batch size 32 50
Time to complete training over 200 epochs 350 min 316 min

Using checkpoints in the net design does seem to affect the results of the training. Tests were done on the original and on the checkpointed nets while setting the random seed to 7 and the models outputted gave similar but slightly different results. In the first test, the original algorithm gave results closer to the ground truth. In the second tests, the checkpointed version of the net yielded better results.

epeterson12 commented 5 years ago

Validation of Checkpointed results

Results when using checkpointing are slightly different from those of the original unetsmall model because CudNN has non-deterministic kernels. I ran tests using suggestions from pytorch discussions https://discuss.pytorch.org/t/non-reproducible-result-with-gpu/1831 and https://discuss.pytorch.org/t/deterministic-non-deterministic-results-with-pytorch/9087.

Using the same sample files, I ran train_model.py twice using the unetsmall model (batch_size = 32) and I ran it once using the checkpointed_unet model (batch_size = 50). Then, I classified some images with the resulting models.

The settings to try to get reproducible results were set as follows at the beginning of the code:

torch.backends.cudnn.deterministic = True
torch.manual_seed(999)
torch.cuda.manual_seed(999)
torch.cuda.manual_seed_all(999)
random.seed(0)

Also, in the DataLoaders, the parameters used for the instantiation had num_workers = 0 and shuffle = False Running the original unetsmall configuration without checkpoints 2 times yielded two slightly different results. Here are some examples of the results obtained when running image_classification.py on one of the training images with each trained model. Sections if the images that weren’t classified were left white.

Ground Truth Current Code Current Code 2 Checkpointed
1_rgb_8000_8000_ground_truth 1_rgb_8000_8000_original 1_rgb_8000_8000_original2 1_rgb_8000_8000_checkpoint
1_rgb_0_0_ground_truth 1_rgb_0_0_original 1_rgb_0_0_original2 1_rgb_0_0_checkpoint
on_5297_1_ground_truth on_5297_1_original on_5297_1_original2 on_5297_1_checkpointed

Please note that the configurations and the number of samples weren't set to yield optimal results. Verifying reproducibility was the goal of these tests. The number of training samples was set to the number of samples produced during the samples creation.

global:
  samples_size: 256
  num_classes: 5
  data_path: /my/data/path
  number_of_bands: 3
  model_name: unetsmall     # One of unet, unetsmall, checkpointed_unet or ternausnet

sample:
  prep_csv_file: /my/prep/csv/file
  samples_dist: 200
  remove_background: True
  mask_input_image: False

training:
  output_path: /my/output/path
  num_trn_samples: 3356
  num_val_samples: 1370
  batch_size: 32
  num_epochs: 100
  learning_rate: 0.0001
  weight_decay: 0
  step_size: 4
  gamma: 0.9
  class_weights: False

models:
  unet:   &unet001
    dropout: False
    probability: 0.2    # Set with dropout
    pretrained: False   # optional
  unetsmall:
    <<: *unet001
  ternausnet:
    pretrained: ./models/TernausNet.pt    # Mandatory
  checkpointed_unet: 
    <<: *unet001

I think that the results of the checkpointed_unet are similar enough to the unetsmall’s results for us to consider that it is a good memory and time optimised version of our unetsmall net architecture. I have added it as a model choice for our program.

Throughout my tests, I observed that the models produced by training are more accurate when the randoms aren’t seeded. The checkpointed_unet, observationally, seems to be more affected by this then the unetsmall.