undark-lab / swyft

A system for scientific simulation-based inference at scale.
Other
161 stars 14 forks source link

Small difference between actual loss and defined loss #78

Closed CHBKoenders closed 3 years ago

CHBKoenders commented 3 years ago

Since loss is calculated from interleaving jointly drawn parameter-observation (doubling the effective batch size) pairs, there should be an additional factor 1/2 for the loss to approach the expectation value $$\mathbb{E}_{z \sim p(z), x \sim p(x\given{z}), z'\sim p(z')} \left[\ln(d(x,z)) + \ln(1-d(x,z'))\right$$.

https://github.com/undark-lab/swyft/blob/8a6d53e347b2dc5219348e2ebb414856f88cb50a/swyft/inference/loss.py#L52

bkmi commented 3 years ago

seems like a reasonable point, let me think about it to be sure. (I think you're right though)

bkmi commented 3 years ago

say n_batch = 2.

the loss from aalr method is: l1 = BCE(theta1, x1, 1) + BCE(theta2, x1, 0) l2 = BCE(theta2, x2, 1) + BCE(theta1, x2, 0)

The average across batches is therefore (l1 + l2) / n_batch = TOTAL_LOSS

--

our loss fn: l1_us = BCE(theta1, x1, 1) + BCE(theta2, x1, 0) + BCE(theta2, x2, 1) + BCE(theta1, x2, 0) l2_us = 0; this doesn't exist since we move everything over to the second dimension into groups of 4

(l1_us + l2_us) / 2 = (l1 + l2 + 0) / n_batch = l1 + l2 == TOTAL_LOSS

--

Is this wrong?

CHBKoenders commented 3 years ago

I looked into it in a little more detail and everything checks out after all.

My main mistake was in the fact that the view of lnL lnL.view(-1, 4, lnL.shape[-1]) changes the size of the batch dimension from 2*n_batch to n_batch/2. So dividing by n_batch at the end leaves you with correct average!