Closed CHBKoenders closed 3 years ago
seems like a reasonable point, let me think about it to be sure. (I think you're right though)
say n_batch = 2.
the loss from aalr method is: l1 = BCE(theta1, x1, 1) + BCE(theta2, x1, 0) l2 = BCE(theta2, x2, 1) + BCE(theta1, x2, 0)
The average across batches is therefore (l1 + l2) / n_batch = TOTAL_LOSS
--
our loss fn: l1_us = BCE(theta1, x1, 1) + BCE(theta2, x1, 0) + BCE(theta2, x2, 1) + BCE(theta1, x2, 0) l2_us = 0; this doesn't exist since we move everything over to the second dimension into groups of 4
(l1_us + l2_us) / 2 = (l1 + l2 + 0) / n_batch = l1 + l2 == TOTAL_LOSS
--
Is this wrong?
I looked into it in a little more detail and everything checks out after all.
My main mistake was in the fact that the view of lnL lnL.view(-1, 4, lnL.shape[-1])
changes the size of the batch dimension from 2*n_batch to n_batch/2. So dividing by n_batch at the end leaves you with correct average!
Since loss is calculated from interleaving jointly drawn parameter-observation (doubling the effective batch size) pairs, there should be an additional factor 1/2 for the loss to approach the expectation value $$\mathbb{E}_{z \sim p(z), x \sim p(x\given{z}), z'\sim p(z')} \left[\ln(d(x,z)) + \ln(1-d(x,z'))\right$$.
https://github.com/undark-lab/swyft/blob/8a6d53e347b2dc5219348e2ebb414856f88cb50a/swyft/inference/loss.py#L52