Closed francois-rozet closed 2 years ago
Hi @francois-rozet thanks for raising this.
I agree, we should have a softplus non-linearity. I usually use f = softplus(x + bias)
where bias=0.54...
such that f(torch.zeros(1)) = 1.0
(otherwise the layer will 'shrink' the input).
@stefanwebb what about letting the user choose which non-linearity must be used for the positive mapping of parameters? Something like
layer = AffineLayer(positive_map='softplus')
I think that for actnorm
, batchnorm
etc and deep architectures (e.g. glow, iResnet) this will be useful.
Side note: the method on this:
clamp_preserve_gradients
could be simplified with
def clamp_preserve_gradients(x: torch.Tensor, min: float, max: float) -> torch.Tensor:
"""
This helper function clamps gradients but still passes through the
gradient in clamped regions
"""
x.data.clamp_(min, max)
return x
where all modifications are done in place.
@vmoens won't f = softplus(x + bias)
already have a bias term added to x
since x
is the output of a feedforward network? I've removed this feature to simplify the logic and have broken out into a separate PR: #109
Issue Description
In the
Affine
bijector, the scale parameter is obtained by clamping the parameters (network) output. According to some of my experiments this results in very unstable behavior and exploding gradients, especially in low entropy settings. I believe this is due to the non-continuities introduced in the gradients by the clamp operation.Instead of clamping, the
nflows
package appliessoftplus
to the network's output which also has the effect to bound (by below) the scale, while keeping smooth gradients. According to my experiments with Pyro,softplus
works better thanclamping
and, importantly, is not subject to exploding gradients. I would suggest replacing clamping bysoftplus
.Expected Behavior
Avoid exploding gradients. I have implemented the replacement of clamping by
softplus
for FlowTorch (https://github.com/francois-rozet/flowtorch/commit/9bf41e5b67a8993aa6173d6341f9d99ae5e7178b) but haven't had the time to test it properly.Additional Context
This issue is a replica of https://github.com/pyro-ppl/pyro/issues/2998
Merry Christmas 🎄