agrimgupta92 / sgan

Code for "Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks", Gupta et al, CVPR 2018
MIT License
813 stars 261 forks source link

Question about generator step loss #54

Closed Shunichi09 closed 5 years ago

Shunichi09 commented 5 years ago

Hi! Thank you for your great paper and your code.

I have a question about the diverse loss referred as eq (5) in your paper. In the part of the generator step function of your code,

for _ in range(args.best_k):
        generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)

        pred_traj_fake_rel = generator_out
        pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel,
                pred_traj_gt_rel,
                loss_mask,
                mode='raw'))

It seems that you did not implement that kind of loss function. I thought that this program was for 1V-1 or 1V-20.

It would really help if you have any ideas to implement that loss function. @amiryanj Thank you!!

amiryanj commented 5 years ago

I dont understand the question. This part of the codes is exactly about implementation of eq(5). Where it generates best_k number of samples, and computes the error for each one, and takes the min error:

      _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(
          loss_mask[start:end])

and then this value is added to the loss function: loss += g_l2_loss_sum_rel

Shunichi09 commented 5 years ago

I'm sorry that I missed the things. I really appreciate your help.