ClementPinard / SfmLearner-Pytorch

Pytorch version of SfmLearner from Tinghui Zhou et al.
MIT License
1.01k stars 226 forks source link

ref_imgs for pose_exp_net #68

Closed feitongt closed 5 years ago

feitongt commented 5 years ago

Hi,

In train.py, you prepared the ref_imgs with the following codes: ref_imgs = [img.to(device) for img in ref_imgs] It generates a list of ref_imgs from different batches.

Then you concatenate the target_image and ref_imgs with the following code. input = [target_image] input.extend(ref_imgs) input = torch.cat(input, 1) It seems that the target_image is concatented with ref_imgs from other batches.

Is the size of ref_imgs [batch, nb_refs, 3, height, width] ?

Thanks

ClementPinard commented 5 years ago

ref_imgs is not a tensor, it's a list of tensors, and every tensor has a shape of [batch, 3, height, width], just like target_image.

The resulting tensor is of size [batch, 3*seq_length, height, width]

feitongt commented 5 years ago

Hi,

But in train.py, it generates the list by the following code: ref_imgs = [img.to(device) for img in ref_imgs].

But I think the size of ref_imgs from train_loader should be [batch, seq_length-1, 3, height, width]

Here is the code from the train.py:

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1/disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)
ClementPinard commented 5 years ago

ref_imgs is a list, even when it comes out of train_loader, it's a list of tensors.

I'm not sure why you think it should be a single tensor of size [batch, seq_length-1, 3, height, width] , the collate_fn used in default train_loader keeps the sample pythonic structure, even with nested lists. See here : https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py#L31

feitongt commented 5 years ago

Oh! Thank you so much!

Because I thought the imgs[1:] in class SequenceFolder(data.Dataset): was an array.

Thanks for your kind patience.