Closed SobhanMP closed 1 year ago
self.subtb_loss_fast(log_p_F, log_p_B, log_F_preds, clip_log_R, batch.traj_lens) Median: 59.95 ms IQR: 0.29 ms (59.84 to 60.13) 17 measurements, 1 runs per measurement, 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f89941df130> self.subtb_cum(log_p_F, log_p_B, log_F_preds, clip_log_R, batch.traj_lens) Median: 16.43 ms IQR: 0.12 ms (16.41 to 16.53) 7 measurements, 10 runs per measurement, 1 thread
subtb_diff is the mse of the loss between the two implementations.
Diff is not zero because I precompute P_F - P_B, doing cross(P_F) - cross(P_B) instead of cross(P_F - P_B) yields the exact same results.
P_F - P_B
cross(P_F) - cross(P_B)
cross(P_F - P_B)
(squash)
subtb_diff is the mse of the loss between the two implementations.
Diff is not zero because I precompute
P_F - P_B
, doingcross(P_F) - cross(P_B)
instead ofcross(P_F - P_B)
yields the exact same results.