uber-research / DeepPruner

DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch (ICCV 2019)
Other
354 stars 41 forks source link

Proper environment #4

Closed ahrnbom closed 4 years ago

ahrnbom commented 4 years ago

I find it difficult to construct an environment where I can run DeepPruner. PyTorch 0.4.0 and Torchvision 0.2.0 are so old that I fail to install them with conda or pip. Do you have any scripts for setting up a proper environment for DeepPruner? Preferably as some kind of container (Docker or Singularity) but even a simple script that installs the necessary libraries in some versions that are supported by DeepPruner would be a huge help.

I tried running DeepPruner using the latest PyTorch/Torchvision (1.3.1 and 0.4.2 respectively) and it doesn't work, so clearly there is some cutoff version when it stopped working. I get warnings about stuff being changed since 0.4.0 so it seems like we really need 0.4.0 to run DeepPruner (see below).

/opt/conda/lib/python2.7/site-packages/torch/nn/functional.py:2404: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead. warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.") /opt/conda/lib/python2.7/site-packages/torch/nn/functional.py:2494: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.

ahrnbom commented 4 years ago

After a lot of work, I was able to get a Python 2.7 environment with PyTorch 0.4.0 and Torchvision 0.2.0 with the following Singularity definition:

Bootstrap: docker
From: continuumio/anaconda

%environment
    export LANG=C.UTF-8
    export LC_ALL=C.UTF-8

%post
    /opt/conda/bin/conda config --set allow_conda_downgrades true
    /opt/conda/bin/conda install conda=4.6.14
    /opt/conda/bin/conda install -c pytorch -c conda-forge pytorch=0.4.0 torchvision=0.2.0 scikit-image matplotlib

%runscript
    export PROMPT_COMMAND="echo -n \[\ Singularity \]\ "
    exec bash

When trying to run model.forward fails, with the following errors:

[ Singularity ] ahrnbom@eisenstein:~/deeppruner/DeepPruner/deeppruner$ python lol.py
lol.py: Number of model parameters: 7390142
Traceback (most recent call last):
  File "lol.py", line 119, in <module>
    main()
  File "lol.py", line 110, in main
    disparity = test(imgL, imgR)
  File "lol.py", line 84, in test
    refined_disparity = model(imgL, imgR)
  File "/opt/conda/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 112, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/opt/conda/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ahrnbom/deeppruner/DeepPruner/deeppruner/models/deeppruner.py", line 263, in forward
    sample_count=self.patch_match_sample_count, sampler_type="patch_match")
  File "/home/ahrnbom/deeppruner/DeepPruner/deeppruner/models/deeppruner.py", line 160, in generate_disparity_samples
    max_disparity, sample_count, self.patch_match_iteration_count)
  File "/opt/conda/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ahrnbom/deeppruner/DeepPruner/deeppruner/models/patch_match.py", line 256, in forward
    normalized_disparity_samples = self.propagation(normalized_disparity_samples, device, propagation_type="horizontal")
  File "/opt/conda/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ahrnbom/deeppruner/DeepPruner/deeppruner/models/patch_match.py", line 187, in forward
    one_hot_filter = torch.zeros_like(label).scatter_(0, label, 1).float()
RuntimeError: Expected object of type torch.cuda.LongTensor but found type torch.cuda.FloatTensor for argument #3 'index'
[ Singularity ] ahrnbom@eisenstein:~/deeppruner/DeepPruner/deeppruner$ 

and, trying to load any of the .tar files with weights also fails, with errors about incorrect dimensions.

I therefore suspect that this code was actually designed to run in some other, unspecified environment.

In case anyone is wondering, here is what lol.py looks like:

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torch.nn.functional as F
import skimage.io
import numpy as np
import logging
from dataloader import preprocess
from PIL import Image
from models.deeppruner import DeepPruner
from models.config import config as config_args
from setup_logging import setup_logging

def ls():
    left_test = ['/home/ahrnbom/deeppruner/images/74221-0.jpg']
    right_test = ['/home/ahrnbom/deeppruner/images/74221-2.jpg']

    return left_test, right_test

parser = argparse.ArgumentParser(description='DeepPruner')
parser.add_argument('--datapath', default='/',
                    help='datapath')
parser.add_argument('--loadmodel', default=None,
                    help='load model')
parser.add_argument('--save_dir', default='./',
                    help='save directory')
parser.add_argument('--logging_filename', default='./trolololo.log',
                    help='filename for logs')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')

args = parser.parse_args()
torch.backends.cudnn.benchmark = True
args.cuda = not args.no_cuda and torch.cuda.is_available()

args.cost_aggregator_scale = config_args.cost_aggregator_scale
args.downsample_scale = args.cost_aggregator_scale * 8.0

setup_logging(args.logging_filename)

if args.cuda:
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

test_left_img, test_right_img = ls()

model = DeepPruner()

if args.cuda:
    model = nn.DataParallel(model)
    model.cuda()

logging.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

if args.loadmodel is not None:
    logging.info("loading model...")
    state_dict = torch.load(args.loadmodel)
    model.load_state_dict(state_dict['state_dict'], strict=True)

def test(imgL, imgR):
    model.eval()
    with torch.no_grad():
        imgL = Variable(torch.FloatTensor(imgL))
        imgR = Variable(torch.FloatTensor(imgR))

        if args.cuda:
            imgL, imgR = imgL.cuda(), imgR.cuda()

        refined_disparity = model(imgL, imgR)
        return refined_disparity

def main():

    for left_image_path, right_image_path in zip(test_left_img, test_right_img):
        imgL = np.asarray(Image.open(left_image_path))
        imgR = np.asarray(Image.open(right_image_path))

        processed = preprocess.get_transform()
        imgL = processed(imgL).numpy()
        imgR = processed(imgR).numpy()

        imgL = np.reshape(imgL, [1, 3, imgL.shape[1], imgL.shape[2]])
        imgR = np.reshape(imgR, [1, 3, imgR.shape[1], imgR.shape[2]])

        w = imgL.shape[3]
        h = imgL.shape[2]
        dw = int(args.downsample_scale - (w%args.downsample_scale + (w%args.downsample_scale==0)*args.downsample_scale))
        dh = int(args.downsample_scale - (h%args.downsample_scale + (h%args.downsample_scale==0)*args.downsample_scale))

        top_pad = dh
        left_pad = dw
        imgL = np.lib.pad(imgL, ((0, 0), (0, 0), (top_pad, 0), (0, left_pad)), mode='constant', constant_values=0)
        imgR = np.lib.pad(imgR, ((0, 0), (0, 0), (top_pad, 0), (0, left_pad)), mode='constant', constant_values=0)

        disparity = test(imgL, imgR)
        disparity = disparity[0, top_pad:, :-left_pad].data.cpu().numpy()
        skimage.io.imsave(os.path.join(args.save_dir, 'output.png'), (disparity * 256).astype('uint16'))

        logging.info("Disparity for {} generated at: {}".format(left_image_path, os.path.join(args.save_dir, 
                                                                left_image_path.split('/')[-1])))

if __name__ == '__main__':
    main()

The suspicion that there's something strange with the environment is further backed by the fact that the .tar weight files actually load when used with PyTorch 1.3.1, but the code still doesn't run...

ShivamDuggal4 commented 4 years ago

I find it difficult to construct an environment where I can run DeepPruner. PyTorch 0.4.0 and Torchvision 0.2.0 are so old that I fail to install them with conda or pip. Do you have any scripts for setting up a proper environment for DeepPruner? Preferably as some kind of container (Docker or Singularity) but even a simple script that installs the necessary libraries in some versions that are supported by DeepPruner would be a huge help.

I tried running DeepPruner using the latest PyTorch/Torchvision (1.3.1 and 0.4.2 respectively) and it doesn't work, so clearly there is some cutoff version when it stopped working. I get warnings about stuff being changed since 0.4.0 so it seems like we really need 0.4.0 to run DeepPruner (see below).

/opt/conda/lib/python2.7/site-packages/torch/nn/functional.py:2404: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead. warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.") /opt/conda/lib/python2.7/site-packages/torch/nn/functional.py:2494: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.

Hi @ahrnbom , I updated my torch and torchvision to the latest one, and DeepPruner runs completely as expected even on the latest version. Before this, my torch version was 1.2 , with torchvision at 0.3.0. So, there is no need to downgrade to the lower versions of torch or torchvision.

Could you please post the entire error message when you run DeepPruner on the latest torch version ? What you posted were just minor warnings and can be ignored.

Thanks !!

ahrnbom commented 4 years ago

Sure, here is the error I get from running the lol.py posted above, in Python3.7, with PyTorch 1.3.1 and Torchvision 0.4.2:

lol.py: Number of model parameters: 7390142
lol.py: loading model...
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:2404: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:2494: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode))
Traceback (most recent call last):
  File "lol.py", line 119, in <module>
    main()
  File "lol.py", line 111, in main
    disparity = disparity[0, top_pad:, :-left_pad].data.cpu().numpy()
TypeError: tuple indices must be integers or slices, not tuple

If I instead use Python 2.7 with the same versions of PyTorch/Torchvision, I get the same error:

lol.py: Number of model parameters: 7390142
lol.py: loading model...
/opt/conda/lib/python2.7/site-packages/torch/nn/functional.py:2404: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/opt/conda/lib/python2.7/site-packages/torch/nn/functional.py:2494: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode))
Traceback (most recent call last):
  File "lol.py", line 119, in <module>
    main()
  File "lol.py", line 111, in main
    disparity = disparity[0, top_pad:, :-left_pad].data.cpu().numpy()
TypeError: tuple indices must be integers, not tuple

Note that the problematic lines

disparity = test(imgL, imgR)
disparity = disparity[0, top_pad:, :-left_pad].data.cpu().numpy()

are taken straight from submission_kitti.py. It seems we get a tuple out of test, rather than a tensor.

ShivamDuggal4 commented 4 years ago

Hi @ahrnbom , it seems like your config file hasn't been set correctly for evaluation. https://github.com/uber-research/DeepPruner/blob/b13a254773843b92ee8bbfb5dbc18cf9a3078ea4/deeppruner/models/config.py#L31

Please change the "mode" to "evaluation" in order to successfully run submission_kitti.py Most probably this should do the fix, and you could use latest version of torch and torchvision.