loeweX / Greedy_InfoMax

Code for the paper: Putting An End to End-to-End: Gradient-Isolated Learning of Representations
https://arxiv.org/abs/1905.11786
MIT License
284 stars 36 forks source link

patch main_vision.py/main_audio.py for PyTorch v1.5+ compatibility #15

Closed bairesearch closed 3 years ago

bairesearch commented 3 years ago

This patch updates main_vision.py/main_audio.py for PyTorch v1.5+ compatibility. The particular implementation chosen requires review.

Patch summary:

This patch prevents main_vision.py/main_audio.py from throwing the following error;

"RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation"

The patch involves splitting the model.zero_grad(), cur_losses.backward(), and optimizer[idx].step() operations, such that each operation is executed for each model/layer before the next operation is performed.

It was developed based on a workaround provided here: https://discuss.pytorch.org/t/solved-pytorch1-5-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/90256/3

It has been tested on PyTorch v1.4->v1.7 with ~1 epoch of training on both the Vision and Audio models. It has also been tested on PyTorch v1.7 with the vision/audio training parameters set to train_module/train_layer=0/0 and 1/1 (not just 3/6), and likewise with model_splits=1/1 (not just 3/6). There may be significant limitations introduced by the patch not identified by these tests.

bairesearch commented 3 years ago

Commit #2 (d74df94): This patch updates ClassificationModel.py for PyTorch v1.7 compatibility. The particular implementation chosen requires review.

Patch summary:

This patch prevents ClassificationModel.py from throwing the following error;

"RuntimeError: stride should not be zero"

The patch involves removing the "stride=0" argument from the nn.AvgPool2d function.

It was developed based on a workaround provided here: https://github.com/pytorch/pytorch/issues/41767

It has been tested on PyTorch v1.2 and v1.7 with ~1 epoch of training on both the Vision and Audio models.

bairesearch commented 3 years ago

Commit #3 (2e78226): This patch updates utils.py for PyTorch v1.7 compatibility. The particular implementation chosen requires review.

Patch summary:

This patch prevents utils.py from throwing the following error;

"RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape"

The patch involves setting correct = correct.contiguous() in def accuracy() to enable the view() function to work.

It was developed based on a workaround provided here: https://discuss.pytorch.org/t/contigious-vs-non-contigious-tensor/30107

It has been tested on PyTorch v1.2 and v1.7 with ~1 epoch of training on both the Vision and Audio models.

loeweX commented 3 years ago

I've finally got around to train a model from scratch using this code and it achieves the same performance as before.

Thanks so much for setting this up!