recursionpharma / gflownet

GFlowNet library specialized for graph & molecular data
MIT License
211 stars 41 forks source link

add faster subTB #102

Closed SobhanMP closed 1 year ago

SobhanMP commented 1 year ago
self.subtb_loss_fast(log_p_F, log_p_B, log_F_preds, clip_log_R, batch.traj_lens)
  Median: 59.95 ms
  IQR:    0.29 ms (59.84 to 60.13)
  17 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f89941df130>
self.subtb_cum(log_p_F, log_p_B, log_F_preds, clip_log_R, batch.traj_lens)
  Median: 16.43 ms
  IQR:    0.12 ms (16.41 to 16.53)
  7 measurements, 10 runs per measurement, 1 thread

image

subtb_diff is the mse of the loss between the two implementations.

Diff is not zero because I precompute P_F - P_B, doing cross(P_F) - cross(P_B) instead of cross(P_F - P_B) yields the exact same results.

SobhanMP commented 1 year ago

(squash)