Closed vmoens closed 2 years ago
There's still a performance issue which is that _log_abs_det_jacobian
can call params(x)
, even though param(x)
may have been computed already (e.g. [here]).(https://github.com/facebookincubator/flowtorch/blob/b60d2274fd7f6bc0a8779d7326fef96ca4881756/flowtorch/bijectors/ops/affine.py#L77)
In other words, we should cache the value returned by param(x)
too.
The way I suggest to do that is to pass the result of param (if any) after calling _forward
or _inverse
:
def _forward(self, x: torch.Tensor, ...) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
_params = self.params(x, ...)
...
return y, _params
and similarly for _inverse
.
Then, the forward
and inverse
public methods would cache the params
in the BijectiveTensor
. For instance, in base.py
:
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
y, params = self._forward(x, context)
if is_record_flow_graph_enabled():
y = to_bijective_tensor(x, y, self, mode="forward", params=params)
The final missing piece of this puzzle would be for the Parameters
class to check if the input is a BijectiveTensor
:
def forward(self, x):
if isinstance(x, BijectiveTensor):
return x.params
...
Since this change is probably BC breaking or at least more disruptive, I wanted to check with you @stefanwebb first what were your thoughts on this. I saw you mentioning on several occasions in the code that params should cache values, that may be the way to do it.
@vmoens I've had a small go at making your suggested changes. It occurred to me, though, why do we need to store params(x, context)
? Why can't we just store log_det_J
in the forward or inverse pass?
What do you think about this solution?
params(x, context)
inside .forward
, .inverse
(and .log_abs_det_jacobian
)._forward
and ._inverse
input a precalculated params, and also return the log(abs(det(J)))
BijectiveTensor
stores (x, y, context, bijector, log_detJ)
.forward
uses the stored value for y
if x
is a BijectiveTensor
.inverse
uses the stored value for x
if y
is a BijectiveTensor
.log_abs_det_jacobian
uses the stored value for log_detJ
if either x
or y
is a BijectiveTensor
This requires some quite extensive changes to the code in flowtorch.bijectors
so I've only got a partial implementation right now... Would you like to pick up the (flow)torch from here?
This requires some quite extensive changes to the code in
flowtorch.bijectors
so I've only got a partial implementation right now... Would you like to pick up the (flow)torch from here?
I'd be glad to! I had thought about that but what made me consider storing params rather than the LADJ was that for some flows, the computation of the LADJ is straightforward but for others it comes at a high cost (e.g. invertible resnet). I would like to think of an option where we can disable this computation of the LADJ if we don't need to (e.g. using invertible resnet for image classification does not require the LADJ).
Now that being said, we could flag the forward (or inverse) pass with some decorator / context manager / flag to say that we don't want to compute the LADJ:
y = flow(x)
ladj = flow.log_abs_dej_jacobian(x, y) # barely sums all the stored ladj from the BijectiveTensors
with set_requires_ladj(False): # default is True. This cm disables the computation of the LADJ to save time
y = flow(x)
ladj = flow.log_abs_dej_jacobian(x, y) # works but LADJ will be computed using cached x
and y
, not a cached ladj
.
We could also have a one-time display warning when calling `log_abs_dej_jacobian` inside the `set_requires_ladj(False)` like this:
```python
Warning: computing the LADJ when set_requires_ladj is set to False may slow down your forward and backward passes. Consider removing the `set_requires_ladj(False)` context manager.
Side note: there are cases where we want to have an output that requires grad but don't want a LADJ, and the opposite is also true. So we can't just use torch.is_grad_enabled
in place of this context manager.
y = flow(x, requires_ladj=True) # we may want to set requires_ladj=True as default behaviour.
ladj = flow.log_abs_det_jacobian(x, y) # works fast
y = flow(x, requires_ladj=False) ladj = flow.log_abs_det_jacobian(x, y) # works but slower as only uses cached x and y, and perhaps recompute the params(x) functions
Merging #89 (6977b08) into main (e884bc8) will increase coverage by
0.19%
. The diff coverage is98.71%
.
@@ Coverage Diff @@
## main #89 +/- ##
==========================================
+ Coverage 98.03% 98.23% +0.19%
==========================================
Files 5 6 +1
Lines 153 227 +74
==========================================
+ Hits 150 223 +73
- Misses 3 4 +1
Flag | Coverage Δ | |
---|---|---|
unittests | 98.23% <98.71%> (+0.19%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
Impacted Files | Coverage Δ | |
---|---|---|
tests/test_bijectivetensor.py | 98.64% <98.64%> (ø) |
|
tests/test_bijector.py | 100.00% <100.00%> (ø) |
|
tests/test_distribution.py | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update e884bc8...6977b08. Read the comment docs.
@stefanwebb has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
Motivation
As described in #88, we'd wish to have a way of caching intermediate values computed in the flow.
Changes proposed
This PR assigns this responsibility to a new class, BijectiveTensor. A BijectiveTensor keeps track of the layer that has created it, the original tensor and whether it comes from a call to 'forward' or 'inverse'. It inherits from torch.Tensor. By default, an operation on a BijectiveTensor returns a torch.Tensor (except if this operation is a Bijector). One can control if BijectiveTensors should be used (which is the case by default) with the context manager
set_record_flow_graph
.Test Plan
A test file can be found in
test/test_bijectivetensor.py
.Types of changes
Checklist