tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.11k forks source link

Regularization in FFJORD #1095

Open gregreen opened 4 years ago

gregreen commented 4 years ago

It would be great to implement regularization techniques that have been recently been developed for FFJORD. For example,

The simplest regularization adds a penalty for the path length of each sample's trajectory. I don't know how to begin to implement this in Tensorflow Probability, as I don't understand the internals well enough. During training time, when calculating log_prob(batch), it's necessary to calculate the integral of |f(t,x)| over the path taken by each sample in the batch, and to somehow expose this information so that it can be used in an additional penalty term.

gitlabspy commented 4 years ago

I just finished reading this paper. Isnt't that taking norm of two output of the ode_solver then added to the log_det_jacobian? Please correct me if I'm making mistakes😂 RNODE

https://github.com/tensorflow/probability/blob/17af734f859c73e8e1dc39d8df7b058f1fb1e742/tensorflow_probability/python/bijectors/ffjord.py#L356

gregreen commented 4 years ago

Almost. During training, I think you can add those two terms to the log_det_jacobian. However, after training, when evaluating the log probability, you don't want to add those additional terms in.

I've been trying to understand the internal mechanics of the TFP implementation of FFJORD. I think that augmented_ode_fn needs to be updated to calculate the additional regularization terms (\dot{E}_j and \dot{n}_j). They would then be cached, just like is done with the log_det_jacobian. In particular, this line would be changed to something like the following:

y, fldj, Edot, ndot = self._solve_ode(augmented_ode_fn, augmented_x)

I'm willing to take a shot at implementing this regularization, but it might be easier for people who already understand the internal workings of the FFJORD implementation.

gitlabspy commented 4 years ago

How about passing is_training to_forward/ _inverse as parameter?

def _forward(self, x,  is_training=True):
    y, _ = self._augmented_forward(x, is_training=is_training)
    return y
 def _augmented_forward(self, y, is_training=True):
   ...
   if is_training:
      return y, fldj + Edot + ndot
   else:
      return y, fldj
# and cached:
...
cached = self._cache.forward_attributes(x, is_training=is_training)
...

It caches with input x and is_training. If all we need is square y and fldj/ildj in y, fldj = self._solve_ode(augmented_ode_fn, augmented_x), we don't need to modify augmented_ode_fn in this way.

brianwa84 commented 4 years ago

If you want to pass a kwarg to bijector.forward, you can pass it via transformed_dist.log_prob(x, bijector_kwargs=dict(is_training=..)). Is that helpful? We don't have improvements to ffjord currently on our roadmap, so we'd be happy to look at a PR adding an option to turn this on, assuming this is a generally-useful improvement.

On Tue, Oct 6, 2020 at 10:47 AM gitlabspy notifications@github.com wrote:

How about passing is_training to _forward/ _inverse as parameter?

def _forward(self, x, istraining=True): y, = self._augmented_forward(x, is_training=is_training) return y def _augmented_forward(self, y, is_training=True): ... if is_training: return y, fldj + Edot + ndot else: return y, fldj

and cached:

... cached = self._cache.forward_attributes(x, is_training=is_training) ...

It caches when both input x and is_training are same. If all we need is square y and fldj/ildj in y, fldj = self._solve_ode(augmented_ode_fn, augmented_x) , we don't need to modify augmented_ode_fn in this way.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1095#issuecomment-704320528, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSIZI34KWRGGL3NKBO5LSJMUZDANCNFSM4R7RFB6Q .

gregreen commented 4 years ago

We'll need a few hyperparameters as well (how much to weight Edot and ndot), which can either be passed to _forward and cached or saved as attributes of FFJORD. Thoughts on which is better? I can try to implement this addition.

brianwa84 commented 4 years ago

If we take the presence of one or both of those hparams to indicate whether or not is_training=True, then it seems like log_prob (equivalently, _forward) could just take them as kwargs. Let's give them readable names, though. Edot and ndot don't say anything (to me) about what they are regularizing.

On Tue, Oct 6, 2020 at 12:42 PM Gregory Green notifications@github.com wrote:

We'll need a few hyperparameters as well (how much to weight Edot and ndot), which can either be passed to _forward and cached or saved as attributes of FFJORD. Thoughts on which is better? I can try to implement this addition.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1095#issuecomment-704405003, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI2VLL6JUJFYY7OMCCLSJNCFVANCNFSM4R7RFB6Q .

brianwa84 commented 4 years ago

BTW: _forward could take those as kwargs, in lieu of a is_training kwarg.

gitlabspy commented 4 years ago

Edot and ndot stand for kinetic penalty and jacobian penalty respectively.


 def _augmented_forward(self, y, is_training=True):
   ...
   if is_training:
      return y, fldj + Edot * self.kinetic_penalty + ndot * self. jacobian_penalty

I think weight of kinetic_penalty and jacobian_penalty could be set as hparams of ffjord When no regularizing needed, kinetic_penalty and jacobian_penaltyare set as default 0 , and the is_training should be treated as kwargs for forward.

ffjord=FFJORD(odefunc,  kinetic_penalty=0.01, jacobian_penalty=0.01)
z=ffjord.forward(x, is_training=True)
gitlabspy commented 4 years ago

BTW: _forward could take those as kwargs, in lieu of a is_training kwarg.

I realize your suggestion is a more elegant way of implementation than mine above. Just put kinetic_penalty and jacobian_penalty in forward’s kwargs, default 0.

gregreen commented 4 years ago

I've realized that it's possible to impose the regularization by providing a modified trace_augmentation_fn, without actually altering the internals of the FFJORD class.

For example, the kinetic regularization term can be added with a wrapper around either of the trace augmentation functions provided by ffjord.py:

def add_kinetic_regularization(trace_augmentation_fn, kinetic_penalty=0.):
  def get_aug_ode_fn(ode_fn, state_shape, dtype):
    augmented_ode_fn = trace_augmentation_fn(ode_fn, state_shape, dtype)
    def reg_augmented_ode_fn(time, state_log_det_jac):
      state_time_derivative, trace_value = augmented_ode_fn(time, state_log_det_jac)
      kinetic_reg = kinetic_penalty * state_time_derivative**2
      return state_time_derivative, trace_value - kinetic_reg
    return reg_augmented_ode_fn
  return get_aug_ode_fn

trace_augmentation_fn = add_kinetic_regularization(
   tfb.ffjord.trace_jacobian_exact, kinetic_penalty=0.05)

I don't think the Jacobian penalty can be added with a simple wrapper like this, because the Frobenius norm of ∇f (see @gitlabspy's above comment) is not computable from trace_value (which only contains the diagonal components of ∇f). To implement the Jacobian penalty, one actually has to alter the internals of the individual trace augmentation functions. It shouldn't be too difficult to do so, though.

gitlabspy commented 4 years ago

Isn't that Ldot(second term of ffjord circled by blue box in the image I uploaded above) caculated by the Hutchinson esitimator? I notice that you use trace_jacobian_exact which is not caculating the Ldot term. From original paper:

Thus Jacobian Frobeniusnorm regularization is available with essentially no extra computational cost.

If I am not making mistake, we can use this term to calculate Frobenius norm. https://github.com/tensorflow/probability/blob/cfeae22d71766041d2b4108f5b7675e9e7175e34/tensorflow_probability/python/bijectors/ffjord.py#L97

gregreen commented 4 years ago

Ldot can be calculated using an intermediate variable used in Hutchinson estimator, but it needs to be calculated within the function trace_jacobian_hutchinson. I think that Ldot is the norm of the vector jvp:

https://github.com/tensorflow/probability/blob/cfeae22d71766041d2b4108f5b7675e9e7175e34/tensorflow_probability/python/bijectors/ffjord.py#L96

One can also modify trace_jacobian_exact to calculate Ldot.