pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 988 forks source link

BatchNormalization Transform [feature request] #1517

Closed boathit closed 5 years ago

boathit commented 6 years ago

It will be useful to add BatchNorm Transform to stabilize the training in the normalizing flows, but I did not find it in the current dev branch, any plan to add the support?

Kelvinson commented 6 years ago

Hi, @eb8680 could you please any hints to implement this feature? Add a module bijector like TFP or add it to the utility module? Thanks!

fritzo commented 6 years ago

This more naturally lives in torch.distributions.transforms. Take a look a the other transforms to see how they're implemented.

boathit commented 6 years ago

All transforms implemented in torch.distributions.tranforms right now do not have the learnable parameters, this is not the case for BatchNorm Transform.

TFP has 22 composable transforms right now; won't that be more convenient to include those transforms as out-of-box toolkits instead of asking the users to write from scratch?

fritzo commented 6 years ago

@boathit oh I didn't realize there are trainable parameters. I think it's fine to add this to Pyro in a new file pyro/distributions/batch_norm_transform.py, and possibly move it upstream to torch.distributions once we work out the details.

I'm not sure the cleanest way to implement this, but it would be nice to leverage AffineTransform for ._inverse() and .log_abs_det_jacobian().

boathit commented 6 years ago

@fritzo I have managed to implement a simple version which assumes the input is just a D-dimension vector. I found it is able to significantly stabilize the training process in the density estimation task.

Since the Transform is not a subclass of nn.Module, I dealt with this by first implementing a class BatchNorm(nn.Module) and then wrapping it with BatchNormTransform(Transform). In the training stage, I need to explicitly collect the parameters using nn.ModuleList in this line. I am not sure whether there is other more elegant way to handle this.

Anyway, I am looking forward a more general one in the next release of pyro.

eb8680 commented 6 years ago

@boathit neat, thanks for sharing your code. I don't think we'll get to this ourselves, but maybe you or @Kelvinson could turn your implementation into a PR?

boathit commented 6 years ago

@eb8680 Thanks. My only concern is that my current implementation is not general enough (assuming the input is a 1-D vector). I will try to figure out how to generalize it to handle arbitrary dimension tensor.

fritzo commented 6 years ago

@boathit if you'd be willing to submit a PR for your 1-D version (with tests :smile:) we could extend it to higher rank in a follow-up PR.

Also some comments:

boathit commented 6 years ago

@fritzo I will be happy to make a PR.

Can you help me understand why it needs to be an nn.Module? Do the beta and epsilon get trained simultaneously with other parameters?

Yes. gamma and beta are required to be trained simultaneously with all other parameters. Since the BatchNorm Transform requires to distinguish the training and testing stage in the inverse computation, thus there is no other way but creating a new nn.Module.

It would be desirable if Transform could inherit methods from nn.Module, then we can directly put all trainable parameters in the Transform without explicitly creating a new Module.

fritzo commented 6 years ago

It would be desirable if Transform could inherit methods from nn.Module

Again I think this would be more cleanly accomplish via multiple inheritance, but we may need to modify Transform.__init__() to call super(Transform, self).__init__().

class BatchNormTransform(torch.distributions.Transform, torch.nn.Module):
    def __init__(self):
        # TODO Fix torch.distributions.Transform.__init__() so that torch.nn.Module is called:
        # super(BatchNormTransform, self).__init__()
        # FIXME Instead we manually call both __init__ methods:
        torch.distributions.Transforms.__init__(self)
        torch.nn.Module.__init__(self)
        ...

Can any Python expert comment?

stefanwebb commented 5 years ago

@fritzo this was taken care of in a recent PR with the TransformModule class as you'll recall :)

Hi @boathit, I'm a contributor to Pyro who's been working on adding normalizing flow features. I don't mean to step on your turf, but I was thinking there might be a lower opportunity cost for me to add this feature to Pyro since I'm familiar with the code. Would you mind if I took your code and adapted it so it can be added to Pyro?

p.s. in Line 201 should it read x.var instead of x.std?

boathit commented 5 years ago

Hi Stefan, I will be very happy if you could add it to Pyro.

Stefan Webb notifications@github.com 于 2018年12月9日周日 12:37写道:

@fritzo https://github.com/fritzo this was taken care of in a recent PR with the TransformModule class as you'll recall :)

Hi @boathit https://github.com/boathit, I'm a contributor to Pyro who's been working on adding normalizing flow features. I don't mean to step on your turf, but I was thinking there might be a lower opportunity cost for me to add this feature to Pyro since I'm familiar with the code. Would you mind if I took your code and adapted it so it can be added to Pyro?

p.s. in Line 201 https://github.com/boathit/Deep-Probabilistic-Models/blob/master/maf/normflow.py#L201 should it read x.var instead of x.std?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/uber/pyro/issues/1517#issuecomment-445510625, or mute the thread https://github.com/notifications/unsubscribe-auth/AEaDzeHglCFgWIoc-65ZHlwyCPvv1fp9ks5u3JOjgaJpZM4YPsk6 .

stefanwebb commented 5 years ago

Great, thanks :) If you're interested in checking my implementation (based on yours)

Still have to add some tests before I do a PR

boathit commented 5 years ago

@stefanwebb You are correct, it should be x.var instead of x.std. The implementation also looks good to me.

Besides, I find your IAF implementation may blow up the memory; as the cached x (this line) may never be popped up since we rarely do the inverse computation for the same y in the usage of IAF.

stefanwebb commented 5 years ago

Thanks, there's a fix on the way for this: #1638

stefanwebb commented 5 years ago

This issue can be closed now after #1803