kevinzakka / recurrent-visual-attention

A PyTorch Implementation of "Recurrent Models of Visual Attention"
MIT License
468 stars 123 forks source link

pycharm step debug freeze, bug fixed #43

Open fatalfeel opened 3 years ago

fatalfeel commented 3 years ago

python3 main.py --use_gpu False --is_train True


#kwargs = {}
    if config.use_gpu:
        torch.cuda.manual_seed(config.random_seed)
        kwargs = {"num_workers": 1, "pin_memory": True}
    else:
        kwargs = {}

    # instantiate data loaders
   '''if config.is_train:
        dloader = data_loader.get_train_valid_loader(config.data_dir,
                                                    config.batch_size,
                                                    config.random_seed,
                                                    config.valid_size,
                                                    config.shuffle,
                                                    config.show_sample,
                                                    **kwargs)
    else:
        dloader = data_loader.get_test_loader(config.data_dir,
                                              config.batch_size,
                                              **kwargs)'''

    if config.is_train:
        dloader = data_loader.get_train_valid_loader(config.data_dir,
                                                    config.batch_size,
                                                    config.random_seed,
                                                    config.valid_size,
                                                    config.shuffle,
                                                    config.show_sample,
                                                    kwargs)
    else:
        dloader = data_loader.get_test_loader(config.data_dir,
                                              config.batch_size,
                                              kwargs)

~~~~~~~~~~~~~~~~~~~~~data_loader.py~~~~~~~~~~~~~~~
'''def get_train_valid_loader(
    data_dir,
    batch_size,
    random_seed,
    valid_size=0.1,
    shuffle=True,
    show_sample=False,
    num_workers=4,
    pin_memory=False,
):'''
def get_train_valid_loader(data_dir,
                            batch_size,
                            random_seed,
                            valid_size,
                            shuffle,
                            show_sample,
                            kwargs):

'''train_loader = torch.utils.data.DataLoader(dataset,
                                                batch_size=batch_size,
                                                sampler=train_sampler,
                                                num_workers=num_workers,
                                                pin_memory=pin_memory)

    valid_loader = torch.utils.data.DataLoader(dataset,
                                                batch_size=batch_size,
                                                sampler=valid_sampler,
                                                num_workers=num_workers,
                                                pin_memory=pin_memory)'''

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, **kwargs)
    valid_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler, **kwargs)

    # visualize some images
    if show_sample:
        '''sample_loader = torch.utils.data.DataLoader(dataset,
                                                    batch_size=9,
                                                    shuffle=shuffle,
                                                    num_workers=num_workers,
                                                    pin_memory=pin_memory)'''
        sample_loader = torch.utils.data.DataLoader(dataset, batch_size=9, shuffle=shuffle, **kwargs)
        data_iter = iter(sample_loader)
        images, labels = data_iter.next()
        X = images.numpy()
        X = np.transpose(X, [0, 2, 3, 1])
        plot_images(X, labels)

    return (train_loader, valid_loader)

'''def get_test_loader(data_dir, batch_size, num_workers=4, pin_memory=False):'''
def get_test_loader(data_dir, batch_size, kwargs):
    """Test datalaoder.

    If using CUDA, num_workers should be set to 1 and pin_memory to True.

    Args:
        data_dir: path directory to the dataset.
        batch_size: how many samples per batch to load.
        num_workers: number of subprocesses to use when loading the dataset.
        pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
            True if using GPU.
    """
    # define transforms
    normalize = transforms.Normalize((0.1307,), (0.3081,))
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    # load dataset
    dataset = datasets.MNIST(data_dir, train=False, download=True, transform=trans)

    '''data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )'''
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, **kwargs)

    return data_loader