pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.19k stars 240 forks source link

Softplus transform as a more numerically stable way to enforce positive constraint #855

Closed vitkl closed 3 years ago

vitkl commented 3 years ago

Happy New Year Numpyro team!

I observe a very substantial improvement in inference stability when I replace exp transformation with softplus for constraining site_scale (https://github.com/vitkl/cell2location_numpyro/blob/main/cell2location_numpyro/distributions/AutoNormal.py#L142). Without this change the ELBO tends to become very variable and shoot up (see plot). image

This also agrees with Softplus being used in pymc3 and this discussion: https://github.com/tensorflow/probability/issues/751.

So I am wondering if you could add softplus transformation to the library as well as use it as a default way to enforce the positive constraint.

I tried implementing it, starting from TFP implementation (https://github.com/tensorflow/probability/blob/v0.11.1/tensorflow_probability/python/bijectors/softplus.py#L61-L173) but I am confused by the difference in the meaning of log_det_jacobian between TFP and Numpyro.

fehiepsi commented 3 years ago

Happy new year @vitkl! Your suggestion looks reasonable to me. I think we can start with an implementation of SoftplusTransform (probably without hinge_softness for simplicity). Would you like to submit a PR? Regarding log_det_jacobian, I think both frameworks use the same terminology: that is, log_abs_det_jacobian(x, y) = -softplus(-x).

After that, there might be several approaches to do. The simplest one to me is to add a context manager to switch from ExpTransform to SoftplusTransform for biject_to(constraints.positive) globally (i.e. modify the behavior of this register).

vitkl commented 3 years ago

Hi @fehiepsi ,

I tested whether replacing Exp with Softplus improves the behaviour of our model (https://github.com/vitkl/cell2location_numpyro). As I mentioned before, using Softplus for transforming site_scale to positive substantially improves stability. However, I do not see additional improvement in the stability (or accuracy) after adding Softplus as a default positive transformation (https://github.com/vitkl/numpyro/commit/2369cae8c1b4ffb05f7473f4c24eadc13bf86ecd). So I am not sure whether Softplus should be used as a default transformation - using it for site_scale is more important. What do you think?

fehiepsi commented 3 years ago

I see. You don't need to change the default behaviour. After you have Softplus, you can create a handler to do the job

class _Positive:
    pass

@dist.biject_to.register(_Positive)
def _transform_to_positive(constraint):
    return dist.transforms.Softplus()

class exp_to_softplus(numpyro.primitives.Messenger):
    def process_message(self, msg):
        if msg["type"] == "param" and msg["name"].endswith("_scale"):
            msg["kwargs"]["constraint"] = _Positive()

guide = exp_to_softplus(guide)
vitkl commented 3 years ago

Thanks for the suggestion! I think it will be useful to set this as the default behaviour for scales rather than leaving users to add guide = exp_to_softplus(guide) after they create the guide. Could the following step be added into the definition of AutoNormal?

if msg["type"] == "param" and msg["name"].endswith("_scale"):
            msg["kwargs"]["constraint"] = _Positive()
fritzo commented 3 years ago

Could the following step be added into the definition of AutoNormal

That approach seems error prone, but here's a simpler way to change the default behavior everywhere: it should be enough to override the default biject_to registration using the existing constraint

@dist.biject_to.register(dist.constraints.positive)
def _transform_to_positive(constraint):
    return dist.transforms.Softplus()

or simply

dist.biject_to.register(dist.constraints.positive, lambda c: dist.transforms.Softplus())

I agree we've seen that softplus is more numerically stable. The main reason we don't use softplus as the default constraint is that it is not scale invariant, and it has trouble with parameters with very large units like global_population ~ 1e10. In deep learning settings, it's common to pre-scale data to have units around 1.0, but I believe pre-scaling data is less common in statistical applications.

vitkl commented 3 years ago

@fritzo I am not sure I understand your motivation to not make softplus default transformation (compared to current default exp). I see how this might be a problem for a model with large parameter values but so will exp, isn't it? Small changes to unconstrained scales lead to large swings in scales 1e10 = exp(23.02585 + 1) => 2.7e11 = exp(23.02585 + 1).

In the previous post, I suggested using softplus as a default transformation for scales (rather than model variables/sites). Scales in AutoNormal approximation are always positive regardless of the model. For examples, pymc3 just hardcoded softplus into their version of AutoNormal (https://github.com/pymc-devs/pymc3/blob/master/pymc3/variational/approximations.py#L61, https://github.com/pymc-devs/pymc3/blob/1769258e459e8f40aa8a56e0ac911aa99e7f67de/pymc3/distributions/dist_math.py#L194).

So, there are 2 decisions:

  1. Default positive transformation, registered via dist.biject_to.register. I am not suggesting any changes.
  2. Transformation for scales in AutoNormal. I suggest changing this to softplus. This can be done simply by (rather than registering transformation):
from jax.nn import softplus
site_scale_unconstrained = numpyro.param("{}_{}_scale".format(name, self.prefix),
                                                         jnp.full(jnp.shape(init_loc), self._init_scale),
                                                         constraint=constraints.real,
                                                         event_dim=event_dim)
site_scale = softplus(site_scale_unconstrained)

(see https://github.com/vitkl/cell2location_numpyro/blob/main/cell2location_numpyro/distributions/AutoNormal.py#L137-L142 for full AutoNormal code with this edit)

Does this make sense? If yes I will make changes and submit a PR.

fehiepsi commented 3 years ago

I don't think we should change the default transformation unless there are enough evidence that softplus is better. Following this line might need further investigation. :)

Small changes to unconstrained scales lead to large swings in scales

I guess this could be desirable? Derivative of softplus is bounded by 1, so it would take a long time to move a parameter to a desire domain. Derivative of exp is unbounded, hence together with some optimizers like ClippedAdam (to make the inference less fluctuated), it can work well.

@vitkl Just curious, when you mentioned that setting default transform to Softplus didn't help, does it mean that softplus is not good for some of your model's parameters? If so, it triggers another issue: softplus is good for some parameters and exp is good for some other parameters, hence a default choice is more difficult to make.

I just hope that the issue can be solved by tuning some parameteres of optimizers. :D

vitkl commented 3 years ago

Just to clarify, for our model:

  1. Setting default transformation to softplus did not help compared to point 2 but it did help compared to exp.
  2. Setting softplus for transforming scales improved stability and accuracy.

This could mean that because the posteriors of most parameters in our model are < 1 and all parameters < 100 - either softplus or exp work well. At the same time, softplus for transforming scales seems to be necessary for improving accuracy and matching it to the accuracy of pymc3 translation of the same model.

I am using ClippedAdam. Interestingly, reducing clip_norm and learning rate decreased stability (ELBO started fluctuating up earlier) - but I did not do a systematic parameter tuning.

fehiepsi commented 3 years ago

Thanks, @vitkl! I see that we can narrow down the reasons to choose a default transformation. Seems like the posterior of positive variables in your model look more like LogNormal than something likes Softplus(Normal), hence setting a default transformation to softplus does not help much.

Now come to the optimization problem for auto_*_scale parameters. Your suggestion is the optimizers that you used works better with softplus than exp transform. Could you make some replicable code or let me know some ways to use your repository to see this issue?

FYI, using exp_to_softplus handler is pretty flexible. You can decide which parameters you want to use softplus. It is unlikely that we will change the default behavior of AutoNormal unless we do intensive tests. One of the main reason is it will affect current Pyro/NumPyro users, who have tune their optimizers well with exp transform (there are plenty of parameters to tune like learning rate, clip gradient, learning rate schedule,...) If softplus is better, we should think about adding some arguments to the construction of AutoNormal, like use_softplus_transform=True/False to switch the behavior. We still need to discuss more about a good API, but it will be less controvesial than changing the default behavior of AutoNormal.

fritzo commented 3 years ago

@vitkl I agree with @fehiepsi that this change would affect many existing users. That means (1) your observation could have widespread positive impact in both NumPyro and Pyro, but (2) we would need to thoroughly test a wide range of models to ensure existing users are not negatively impacted. Last time we made a similar change to autoguides was with a change of parameters to AutoLowRankMultivariateNormal in https://github.com/pyro-ppl/pyro/pull/2127, which was accompanied by thorough experiments. If you'd like to perform similar experiments, would you mind submitting a PR with a new directory in https://github.com/pyro-ppl/sandbox containing notebooks, similar to sandbox/2019-11-lowrank, maybe calling it 2021-01-softplus? That way the experiments are public, part of pyro-ppl, and accessible by affected users.

vitkl commented 3 years ago

I am new to engaging/contributing to project - so thanks a lot for the feedback!

I like the approach of:

adding some arguments to the construction of AutoNormal, like use_softplus_transform=True/False to switch the behavior.

Looking at the lowrank sandbox, I see that the test notebooks there test the specific issue with AutoLowRankMultivariateNormal. The changes I propose will affect all models using Normal approximation. Do you have any test models that could be used for a more thorough experiment with different types of models?

Our model can be tested using this Colab notebook, https://colab.research.google.com/github/vitkl/cell2location_numpyro/blob/main/docs/notebooks/cell2location_short_demo_colab.ipynb, however, it is relatively computationally intensive (requires GPU). @fritzo do you suggest submitting this notebook into sandbox/2021-01-softplus?

I am talking to @martinjankowiak next Monday - so maybe I could ask these questions to him too?

fritzo commented 3 years ago

WDYT of keeping exp as the default transform but adding an option to AutoNormal (and maybe other autoguides) to use softplus? We might name this option assume_prewhitened=True or something. We could make the underlying machinery easier to use by adding a constraints.prewhitened_positive whose registered transform is Softplus. That would also make it easy for users to create custom derived distributions that are know to be prewhitened, e.g.

class PrewhitenedNormal(Normal):
    arg_constraints = {"loc": constraints.real,
                       "scale": constraints.prewhitened_positive}

and use these distributions in a model.

vitkl commented 3 years ago

@fritzo This sounds like this can be added without any changes for existing users - so that's great. Not sure what prewhitened means though.

We have implemented a simple SoftplusTransform for pyro models but not certain about whether it belongs in pyro or pytorch: https://github.com/BayraktarLab/cell2location/blob/master/cell2location/distributions/transforms.py (in addition to numpyro https://github.com/vitkl/numpyro/blob/master/numpyro/distributions/transforms.py#L315-L342)

I still think that if the benefit of softplus for positive parameters can be demonstrated for several models it is worth considering it as default (for Normal scales but also other positive parameters like Gamma shape and rate). The choice of transform for scales is quite far from the first idea that comes to mind when troubleshooting models (I would not know I should look at it if I was not a pymc3 user). At least would be good to mention softplus in the SVI tutorial (https://pyro.ai/examples/svi_part_i.html).

fritzo commented 3 years ago

@vitkl I think SoftplusTransform best belongs in PyTorch. Could you submit upstream and cc me for review?

Re: terminology, can we find some sort of term for "standardized" or "prewhitened" or something to mean "data has been scaled and shifted such that its values are on the order of 1"? IMHO this should be a single word we can use in interfaces and documentation, and the word should be a bit vague because this notion is a bit vague.

vitkl commented 3 years ago
class exp_to_softplus(numpyro.primitives.Messenger):
    ...

@fehiepsi do you by any chance know how to do this in pyro?

fehiepsi commented 3 years ago

In Pyro, Messenger belongs to poutine module, but the implementation should be very similar. You can see how mask messenger implemented here. To convert a messenger to a handler, you can use _make_handler utility.

fehiepsi commented 3 years ago

WDYT of keeping exp as the default transform but adding an option to AutoNormal (and maybe other autoguides) to use softplus? We might name this option assume_prewhitened=True or something.

I like this solution. I remember that when converting bijectors to constraints in TFP wrappers, I also saw some popular TFP distributions using this softplus bijector under the hood. I don't know why but we can assume that users will want to have the customization as you suggested.

About the names, I think prewhitened or standardized is a bit misleading (it causes confusion with the whitening approach in LocScaleReparam). I like other names that are less vague, like softplus_positive or even better IMO soft_positive, which is still vague but should be easier to deliver the meaning to users.

vitkl commented 3 years ago

In Pyro, Messenger belongs to poutine module, but the implementation should be very similar. You can see how mask messenger implemented here. To convert a messenger to a handler, you can use _make_handler utility.

Thanks for the tip about handlers - but this approach does not work for me for some reason (even more reasons to have support for softplus in AutoNormal). Code: https://github.com/vitkl/scvi-tools/blob/pyro-cell2location/scvi/external/cell2location/_module.py#L38-L69 Error when using with AutoNormal guide (self.guide = ExpToSoftplusHandler(self.guide)):

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-e0e5d2386535> in <module>
     18         module=None,
     19         use_gpu=True,
---> 20         batch_size=2019)

/BayraktarLab/scvi-tools/scvi/external/cell2location/_model.py in __init__(self, adata, cell_state_df, var_names_read, sample_id, module, use_gpu, batch_size, **model_kwargs)
    570         self.module = module(n_obs=self.n_obs, n_var=self.n_var,
    571                              n_fact=self.n_fact, n_exper=self.n_exper, batch_size=self.batch_size,
--> 572                              cell_state_mat=self.cell_state_df.values, **model_kwargs)

/BayraktarLab/scvi-tools/scvi/external/cell2location/_module.py in __init__(self, n_obs, n_var, n_fact, n_exper, batch_size, cell_state_mat, n_comb, m_g_gene_level_prior, m_g_gene_level_var_prior, cell_number_prior, cell_number_var_prior, alpha_g_phi_hyp_prior, gene_add_alpha_hyp_prior, gene_add_mean_hyp_prior, w_sf_mean_var_ratio)
    151                                 create_plates=self.create_plates)
    152         # replace exp transform with softplus https://github.com/pyro-ppl/numpyro/issues/855
--> 153         self.guide = ExpToSoftplusHandler(self.guide)
    154 
    155     @staticmethod

~/.local/lib/python3.7/site-packages/torch/nn/modules/module.py in __setattr__(self, name, value)
    813                     raise TypeError("cannot assign '{}' as child module '{}' "
    814                                     "(torch.nn.Module or None expected)"
--> 815                                     .format(torch.typename(value), name))
    816                 modules[name] = value
    817             else:

TypeError: cannot assign 'pyro.infer.autoguide.guides._bound_partial' as child module 'guide' (torch.nn.Module or None expected)
vitkl commented 3 years ago

I agree about the names - softplus_positive is easier to interpret correctly than prewhitened which I also encountered in completely different contexts.

fehiepsi commented 3 years ago

I think this issue has been resolved. We can easily switch the behaviors now. Let's change the default behavior to softplus when we have more evidence in sandbox that it is good. Thanks @vitkl for raising this interesting issue! I'm looking forward to exploring the notebook that you intended to put in sandbox. :)

vitkl commented 3 years ago

Just a curious observation we made with @yozhikoff: The issue with exp transform seems to be less severe with cuda 11.1 compared to cuda 10.2 (from corresponding pytorch distributions).

image image

fehiepsi commented 3 years ago

Interesting, probably the precision of some op is different w.r.t. different system.