facebookincubator / flowtorch

This library would form a permanent home for reusable components for deep probabilistic programming. The library would form and harness a community of users and contributors by focusing initially on complete infra and documentation for how to use and create components.
https://flowtorch.ai
MIT License
301 stars 21 forks source link

Bijective tensors for caching intermediate values #89

Closed vmoens closed 2 years ago

vmoens commented 2 years ago

Motivation

As described in #88, we'd wish to have a way of caching intermediate values computed in the flow.

Changes proposed

This PR assigns this responsibility to a new class, BijectiveTensor. A BijectiveTensor keeps track of the layer that has created it, the original tensor and whether it comes from a call to 'forward' or 'inverse'. It inherits from torch.Tensor. By default, an operation on a BijectiveTensor returns a torch.Tensor (except if this operation is a Bijector). One can control if BijectiveTensors should be used (which is the case by default) with the context manager set_record_flow_graph.

Test Plan

A test file can be found in test/test_bijectivetensor.py.

Types of changes

Checklist

vmoens commented 2 years ago

There's still a performance issue which is that _log_abs_det_jacobian can call params(x), even though param(x) may have been computed already (e.g. [here]).(https://github.com/facebookincubator/flowtorch/blob/b60d2274fd7f6bc0a8779d7326fef96ca4881756/flowtorch/bijectors/ops/affine.py#L77)

In other words, we should cache the value returned by param(x) too.

The way I suggest to do that is to pass the result of param (if any) after calling _forward or _inverse:

def _forward(self, x: torch.Tensor, ...) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
    _params = self.params(x, ...)
    ...
    return y, _params

and similarly for _inverse. Then, the forward and inverse public methods would cache the params in the BijectiveTensor. For instance, in base.py:

def forward(
            self,
            x: torch.Tensor,
            context: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
    y, params = self._forward(x, context)
    if is_record_flow_graph_enabled():
        y = to_bijective_tensor(x, y, self, mode="forward", params=params)

The final missing piece of this puzzle would be for the Parameters class to check if the input is a BijectiveTensor:

def forward(self, x):
    if isinstance(x, BijectiveTensor):
        return x.params
    ...

Since this change is probably BC breaking or at least more disruptive, I wanted to check with you @stefanwebb first what were your thoughts on this. I saw you mentioning on several occasions in the code that params should cache values, that may be the way to do it.

stefanwebb commented 2 years ago

@vmoens I've had a small go at making your suggested changes. It occurred to me, though, why do we need to store params(x, context)? Why can't we just store log_det_J in the forward or inverse pass?

What do you think about this solution?

  1. We calculate params(x, context) inside .forward, .inverse (and .log_abs_det_jacobian)
  2. ._forward and ._inverse input a precalculated params, and also return the log(abs(det(J)))
  3. BijectiveTensor stores (x, y, context, bijector, log_detJ)
  4. .forward uses the stored value for y if x is a BijectiveTensor
  5. .inverse uses the stored value for x if y is a BijectiveTensor
  6. .log_abs_det_jacobian uses the stored value for log_detJ if either x or y is a BijectiveTensor

This requires some quite extensive changes to the code in flowtorch.bijectors so I've only got a partial implementation right now... Would you like to pick up the (flow)torch from here?

vmoens commented 2 years ago

This requires some quite extensive changes to the code in flowtorch.bijectors so I've only got a partial implementation right now... Would you like to pick up the (flow)torch from here?

I'd be glad to! I had thought about that but what made me consider storing params rather than the LADJ was that for some flows, the computation of the LADJ is straightforward but for others it comes at a high cost (e.g. invertible resnet). I would like to think of an option where we can disable this computation of the LADJ if we don't need to (e.g. using invertible resnet for image classification does not require the LADJ).

Now that being said, we could flag the forward (or inverse) pass with some decorator / context manager / flag to say that we don't want to compute the LADJ:

with set_requires_ladj(False): # default is True. This cm disables the computation of the LADJ to save time y = flow(x) ladj = flow.log_abs_dej_jacobian(x, y) # works but LADJ will be computed using cached x and y, not a cached ladj.

We could also have a one-time display warning when calling `log_abs_dej_jacobian` inside the `set_requires_ladj(False)` like this:
```python
Warning: computing the LADJ when set_requires_ladj is set to False may slow down your forward and backward passes. Consider removing the `set_requires_ladj(False)` context manager.

Side note: there are cases where we want to have an output that requires grad but don't want a LADJ, and the opposite is also true. So we can't just use torch.is_grad_enabled in place of this context manager.

y = flow(x, requires_ladj=False) ladj = flow.log_abs_det_jacobian(x, y) # works but slower as only uses cached x and y, and perhaps recompute the params(x) functions

codecov-commenter commented 2 years ago

Codecov Report

Merging #89 (6977b08) into main (e884bc8) will increase coverage by 0.19%. The diff coverage is 98.71%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #89      +/-   ##
==========================================
+ Coverage   98.03%   98.23%   +0.19%     
==========================================
  Files           5        6       +1     
  Lines         153      227      +74     
==========================================
+ Hits          150      223      +73     
- Misses          3        4       +1     
Flag Coverage Δ
unittests 98.23% <98.71%> (+0.19%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
tests/test_bijectivetensor.py 98.64% <98.64%> (ø)
tests/test_bijector.py 100.00% <100.00%> (ø)
tests/test_distribution.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update e884bc8...6977b08. Read the comment docs.

facebook-github-bot commented 2 years ago

@stefanwebb has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.