facebookincubator / flowtorch

This library would form a permanent home for reusable components for deep probabilistic programming. The library would form and harness a community of users and contributors by focusing initially on complete infra and documentation for how to use and create components.
https://flowtorch.ai
MIT License
301 stars 21 forks source link

[Bug] AffineAutoregressive transform leads to exploding gradients #85

Closed francois-rozet closed 2 years ago

francois-rozet commented 2 years ago

Issue Description

In the Affine bijector, the scale parameter is obtained by clamping the parameters (network) output. According to some of my experiments this results in very unstable behavior and exploding gradients, especially in low entropy settings. I believe this is due to the non-continuities introduced in the gradients by the clamp operation.

Instead of clamping, the nflows package applies softplus to the network's output which also has the effect to bound (by below) the scale, while keeping smooth gradients. According to my experiments with Pyro, softplus works better than clamping and, importantly, is not subject to exploding gradients. I would suggest replacing clamping by softplus.

Expected Behavior

Avoid exploding gradients. I have implemented the replacement of clamping by softplus for FlowTorch (https://github.com/francois-rozet/flowtorch/commit/9bf41e5b67a8993aa6173d6341f9d99ae5e7178b) but haven't had the time to test it properly.

Additional Context

This issue is a replica of https://github.com/pyro-ppl/pyro/issues/2998

Merry Christmas 🎄

vmoens commented 2 years ago

Hi @francois-rozet thanks for raising this.

I agree, we should have a softplus non-linearity. I usually use f = softplus(x + bias) where bias=0.54... such that f(torch.zeros(1)) = 1.0 (otherwise the layer will 'shrink' the input).

@stefanwebb what about letting the user choose which non-linearity must be used for the positive mapping of parameters? Something like

layer = AffineLayer(positive_map='softplus')

I think that for actnorm, batchnorm etc and deep architectures (e.g. glow, iResnet) this will be useful.

vmoens commented 2 years ago

Side note: the method on this: clamp_preserve_gradients could be simplified with

def clamp_preserve_gradients(x: torch.Tensor, min: float, max: float) -> torch.Tensor:
    """
    This helper function clamps gradients but still passes through the
    gradient in clamped regions
    """
    x.data.clamp_(min, max)
    return x

where all modifications are done in place.

stefanwebb commented 2 years ago

@vmoens won't f = softplus(x + bias) already have a bias term added to x since x is the output of a feedforward network? I've removed this feature to simplify the logic and have broken out into a separate PR: #109