gtegner / mine-pytorch

Mutual Information Neural Estimation in Pytorch
MIT License
294 stars 56 forks source link

Understanding EMA loss implementation #1

Open anhhuyalex opened 4 years ago

anhhuyalex commented 4 years ago

Hi,

I'm curious about your implementation of the EMA loss

def ema_loss(x, running_mean, alpha):
    t_exp = torch.exp(torch.logsumexp(x, 0) - math.log(x.shape[0])).detach()
    if running_mean == 0:
        running_mean = t_exp
    else:
        running_mean = ema(t_exp, alpha, running_mean.item())
    t_log = EMALoss.apply(x, running_mean)

    # Recalculate ema

    return t_log, running_mean

When I ran mine.optimize, the code is complaining that in the 2nd batch, running_mean is a tensor. I'm confused about what the type of running_mean is supposed to be.

Did you mean

if running_mean == 0:
        running_mean = t_exp.mean()

or something like that?

ZachFriedenberger commented 2 years ago

Your question is quite old, but I hope others find this useful. The line you are referring to only runs during the first batch of each epoch. It comes from the definition of an exponential moving average.