Closed vitkl closed 3 years ago
Happy new year @vitkl! Your suggestion looks reasonable to me. I think we can start with an implementation of SoftplusTransform (probably without hinge_softness
for simplicity). Would you like to submit a PR? Regarding log_det_jacobian
, I think both frameworks use the same terminology: that is, log_abs_det_jacobian(x, y) = -softplus(-x)
.
After that, there might be several approaches to do. The simplest one to me is to add a context manager to switch from ExpTransform to SoftplusTransform for biject_to(constraints.positive)
globally (i.e. modify the behavior of this register).
Hi @fehiepsi ,
I tested whether replacing Exp
with Softplus
improves the behaviour of our model (https://github.com/vitkl/cell2location_numpyro). As I mentioned before, using Softplus
for transforming site_scale
to positive substantially improves stability. However, I do not see additional improvement in the stability (or accuracy) after adding Softplus
as a default positive transformation (https://github.com/vitkl/numpyro/commit/2369cae8c1b4ffb05f7473f4c24eadc13bf86ecd). So I am not sure whether Softplus
should be used as a default transformation - using it for site_scale
is more important. What do you think?
I see. You don't need to change the default behaviour. After you have Softplus, you can create a handler to do the job
class _Positive:
pass
@dist.biject_to.register(_Positive)
def _transform_to_positive(constraint):
return dist.transforms.Softplus()
class exp_to_softplus(numpyro.primitives.Messenger):
def process_message(self, msg):
if msg["type"] == "param" and msg["name"].endswith("_scale"):
msg["kwargs"]["constraint"] = _Positive()
guide = exp_to_softplus(guide)
Thanks for the suggestion! I think it will be useful to set this as the default behaviour for scales rather than leaving users to add guide = exp_to_softplus(guide)
after they create the guide. Could the following step be added into the definition of AutoNormal?
if msg["type"] == "param" and msg["name"].endswith("_scale"):
msg["kwargs"]["constraint"] = _Positive()
Could the following step be added into the definition of AutoNormal
That approach seems error prone, but here's a simpler way to change the default behavior everywhere: it should be enough to override the default biject_to
registration using the existing constraint
@dist.biject_to.register(dist.constraints.positive)
def _transform_to_positive(constraint):
return dist.transforms.Softplus()
or simply
dist.biject_to.register(dist.constraints.positive, lambda c: dist.transforms.Softplus())
I agree we've seen that softplus is more numerically stable. The main reason we don't use softplus
as the default constraint is that it is not scale invariant, and it has trouble with parameters with very large units like global_population ~ 1e10
. In deep learning settings, it's common to pre-scale data to have units around 1.0, but I believe pre-scaling data is less common in statistical applications.
@fritzo I am not sure I understand your motivation to not make softplus
default transformation (compared to current default exp
). I see how this might be a problem for a model with large parameter values but so will exp
, isn't it? Small changes to unconstrained scales lead to large swings in scales 1e10 = exp(23.02585 + 1)
=> 2.7e11 = exp(23.02585 + 1)
.
In the previous post, I suggested using softplus
as a default transformation for scales (rather than model variables/sites). Scales in AutoNormal approximation are always positive regardless of the model. For examples, pymc3 just hardcoded softplus
into their version of AutoNormal (https://github.com/pymc-devs/pymc3/blob/master/pymc3/variational/approximations.py#L61, https://github.com/pymc-devs/pymc3/blob/1769258e459e8f40aa8a56e0ac911aa99e7f67de/pymc3/distributions/dist_math.py#L194).
So, there are 2 decisions:
dist.biject_to.register
. I am not suggesting any changes.from jax.nn import softplus
site_scale_unconstrained = numpyro.param("{}_{}_scale".format(name, self.prefix),
jnp.full(jnp.shape(init_loc), self._init_scale),
constraint=constraints.real,
event_dim=event_dim)
site_scale = softplus(site_scale_unconstrained)
(see https://github.com/vitkl/cell2location_numpyro/blob/main/cell2location_numpyro/distributions/AutoNormal.py#L137-L142 for full AutoNormal code with this edit)
Does this make sense? If yes I will make changes and submit a PR.
I don't think we should change the default transformation unless there are enough evidence that softplus is better. Following this line might need further investigation. :)
Small changes to unconstrained scales lead to large swings in scales
I guess this could be desirable? Derivative of softplus is bounded by 1, so it would take a long time to move a parameter to a desire domain. Derivative of exp is unbounded, hence together with some optimizers like ClippedAdam (to make the inference less fluctuated), it can work well.
@vitkl Just curious, when you mentioned that setting default transform to Softplus didn't help, does it mean that softplus is not good for some of your model's parameters? If so, it triggers another issue: softplus is good for some parameters and exp is good for some other parameters, hence a default choice is more difficult to make.
I just hope that the issue can be solved by tuning some parameteres of optimizers. :D
Just to clarify, for our model:
softplus
did not help compared to point 2 but it did help compared to exp
.softplus
for transforming scales improved stability and accuracy.This could mean that because the posteriors of most parameters in our model are < 1 and all parameters < 100 - either softplus
or exp
work well. At the same time, softplus
for transforming scales seems to be necessary for improving accuracy and matching it to the accuracy of pymc3 translation of the same model.
I am using ClippedAdam. Interestingly, reducing clip_norm
and learning rate decreased stability (ELBO started fluctuating up earlier) - but I did not do a systematic parameter tuning.
Thanks, @vitkl! I see that we can narrow down the reasons to choose a default transformation. Seems like the posterior of positive variables in your model look more like LogNormal than something likes Softplus(Normal), hence setting a default transformation to softplus
does not help much.
Now come to the optimization problem for auto_*_scale
parameters. Your suggestion is the optimizers that you used works better with softplus
than exp
transform. Could you make some replicable code or let me know some ways to use your repository to see this issue?
FYI, using exp_to_softplus
handler is pretty flexible. You can decide which parameters you want to use softplus. It is unlikely that we will change the default behavior of AutoNormal unless we do intensive tests. One of the main reason is it will affect current Pyro/NumPyro users, who have tune their optimizers well with exp
transform (there are plenty of parameters to tune like learning rate, clip gradient, learning rate schedule,...) If softplus
is better, we should think about adding some arguments to the construction of AutoNormal, like use_softplus_transform=True/False
to switch the behavior. We still need to discuss more about a good API, but it will be less controvesial than changing the default behavior of AutoNormal.
@vitkl I agree with @fehiepsi that this change would affect many existing users. That means (1) your observation could have widespread positive impact in both NumPyro and Pyro, but (2) we would need to thoroughly test a wide range of models to ensure existing users are not negatively impacted. Last time we made a similar change to autoguides was with a change of parameters to AutoLowRankMultivariateNormal
in https://github.com/pyro-ppl/pyro/pull/2127, which was accompanied by thorough experiments. If you'd like to perform similar experiments, would you mind submitting a PR with a new directory in https://github.com/pyro-ppl/sandbox containing notebooks, similar to sandbox/2019-11-lowrank, maybe calling it 2021-01-softplus? That way the experiments are public, part of pyro-ppl, and accessible by affected users.
I am new to engaging/contributing to project - so thanks a lot for the feedback!
I like the approach of:
adding some arguments to the construction of AutoNormal, like use_softplus_transform=True/False to switch the behavior.
Looking at the lowrank sandbox, I see that the test notebooks there test the specific issue with AutoLowRankMultivariateNormal
. The changes I propose will affect all models using Normal approximation. Do you have any test models that could be used for a more thorough experiment with different types of models?
Our model can be tested using this Colab notebook, https://colab.research.google.com/github/vitkl/cell2location_numpyro/blob/main/docs/notebooks/cell2location_short_demo_colab.ipynb, however, it is relatively computationally intensive (requires GPU). @fritzo do you suggest submitting this notebook into sandbox/2021-01-softplus
?
I am talking to @martinjankowiak next Monday - so maybe I could ask these questions to him too?
WDYT of keeping exp
as the default transform but adding an option to AutoNormal
(and maybe other autoguides) to use softplus
? We might name this option assume_prewhitened=True
or something. We could make the underlying machinery easier to use by adding a constraints.prewhitened_positive
whose registered transform is Softplus
. That would also make it easy for users to create custom derived distributions that are know to be prewhitened, e.g.
class PrewhitenedNormal(Normal):
arg_constraints = {"loc": constraints.real,
"scale": constraints.prewhitened_positive}
and use these distributions in a model.
@fritzo This sounds like this can be added without any changes for existing users - so that's great. Not sure what prewhitened means though.
We have implemented a simple SoftplusTransform for pyro models but not certain about whether it belongs in pyro or pytorch: https://github.com/BayraktarLab/cell2location/blob/master/cell2location/distributions/transforms.py (in addition to numpyro https://github.com/vitkl/numpyro/blob/master/numpyro/distributions/transforms.py#L315-L342)
I still think that if the benefit of softplus for positive parameters can be demonstrated for several models it is worth considering it as default (for Normal scales but also other positive parameters like Gamma shape and rate). The choice of transform for scales is quite far from the first idea that comes to mind when troubleshooting models (I would not know I should look at it if I was not a pymc3 user). At least would be good to mention softplus in the SVI tutorial (https://pyro.ai/examples/svi_part_i.html).
@vitkl I think SoftplusTransform
best belongs in PyTorch. Could you submit upstream and cc me for review?
Re: terminology, can we find some sort of term for "standardized" or "prewhitened" or something to mean "data has been scaled and shifted such that its values are on the order of 1"? IMHO this should be a single word we can use in interfaces and documentation, and the word should be a bit vague because this notion is a bit vague.
class exp_to_softplus(numpyro.primitives.Messenger): ...
@fehiepsi do you by any chance know how to do this in pyro?
In Pyro, Messenger
belongs to poutine
module, but the implementation should be very similar. You can see how mask messenger implemented here. To convert a messenger to a handler, you can use _make_handler utility.
WDYT of keeping exp as the default transform but adding an option to AutoNormal (and maybe other autoguides) to use softplus? We might name this option assume_prewhitened=True or something.
I like this solution. I remember that when converting bijectors to constraints in TFP wrappers, I also saw some popular TFP distributions using this softplus bijector under the hood. I don't know why but we can assume that users will want to have the customization as you suggested.
About the names, I think prewhitened
or standardized
is a bit misleading (it causes confusion with the whitening approach in LocScaleReparam). I like other names that are less vague, like softplus_positive
or even better IMO soft_positive
, which is still vague but should be easier to deliver the meaning to users.
In Pyro,
Messenger
belongs topoutine
module, but the implementation should be very similar. You can see how mask messenger implemented here. To convert a messenger to a handler, you can use _make_handler utility.
Thanks for the tip about handlers - but this approach does not work for me for some reason (even more reasons to have support for softplus in AutoNormal).
Code: https://github.com/vitkl/scvi-tools/blob/pyro-cell2location/scvi/external/cell2location/_module.py#L38-L69
Error when using with AutoNormal guide (self.guide = ExpToSoftplusHandler(self.guide)
):
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-10-e0e5d2386535> in <module>
18 module=None,
19 use_gpu=True,
---> 20 batch_size=2019)
/BayraktarLab/scvi-tools/scvi/external/cell2location/_model.py in __init__(self, adata, cell_state_df, var_names_read, sample_id, module, use_gpu, batch_size, **model_kwargs)
570 self.module = module(n_obs=self.n_obs, n_var=self.n_var,
571 n_fact=self.n_fact, n_exper=self.n_exper, batch_size=self.batch_size,
--> 572 cell_state_mat=self.cell_state_df.values, **model_kwargs)
/BayraktarLab/scvi-tools/scvi/external/cell2location/_module.py in __init__(self, n_obs, n_var, n_fact, n_exper, batch_size, cell_state_mat, n_comb, m_g_gene_level_prior, m_g_gene_level_var_prior, cell_number_prior, cell_number_var_prior, alpha_g_phi_hyp_prior, gene_add_alpha_hyp_prior, gene_add_mean_hyp_prior, w_sf_mean_var_ratio)
151 create_plates=self.create_plates)
152 # replace exp transform with softplus https://github.com/pyro-ppl/numpyro/issues/855
--> 153 self.guide = ExpToSoftplusHandler(self.guide)
154
155 @staticmethod
~/.local/lib/python3.7/site-packages/torch/nn/modules/module.py in __setattr__(self, name, value)
813 raise TypeError("cannot assign '{}' as child module '{}' "
814 "(torch.nn.Module or None expected)"
--> 815 .format(torch.typename(value), name))
816 modules[name] = value
817 else:
TypeError: cannot assign 'pyro.infer.autoguide.guides._bound_partial' as child module 'guide' (torch.nn.Module or None expected)
I agree about the names - softplus_positive
is easier to interpret correctly than prewhitened
which I also encountered in completely different contexts.
I think this issue has been resolved. We can easily switch the behaviors now. Let's change the default behavior to softplus when we have more evidence in sandbox that it is good. Thanks @vitkl for raising this interesting issue! I'm looking forward to exploring the notebook that you intended to put in sandbox. :)
Just a curious observation we made with @yozhikoff: The issue with exp transform seems to be less severe with cuda 11.1 compared to cuda 10.2 (from corresponding pytorch distributions).
Interesting, probably the precision of some op is different w.r.t. different system.
Happy New Year Numpyro team!
I observe a very substantial improvement in inference stability when I replace
exp
transformation withsoftplus
for constrainingsite_scale
(https://github.com/vitkl/cell2location_numpyro/blob/main/cell2location_numpyro/distributions/AutoNormal.py#L142). Without this change the ELBO tends to become very variable and shoot up (see plot).This also agrees with Softplus being used in pymc3 and this discussion: https://github.com/tensorflow/probability/issues/751.
So I am wondering if you could add softplus transformation to the library as well as use it as a default way to enforce the positive constraint.
I tried implementing it, starting from TFP implementation (https://github.com/tensorflow/probability/blob/v0.11.1/tensorflow_probability/python/bijectors/softplus.py#L61-L173) but I am confused by the difference in the meaning of
log_det_jacobian
between TFP and Numpyro.