Closed cetmann closed 4 years ago
@cetmann Hi, this is actually something that I would like to have in MemCNN as well and it should be not that hard with the current setup. Although, it might be a bit tricky to get right.
If you want to have a look at it, check how checkpointing in PyTorch is implemented, which also allows for multi-input/multi-output operations: https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint
I might have a look at this issue myself this weekend.
EDIT: If you want to implement this I can recommend using tuples instead of lists for this.
@cetmann Also, looking at what you are trying to achieve, you might want to have a look at how splitting and combining is done in the AdditiveCoupling
and AffineCoupling
classes. These also split inputs in half by channels and these can be easily wrapped by the InvertibleModuleWrapper
.
@silvandeleemput Sounds great! If you need any help or suggestions, I'd be more than happy to help. I'll give it a shot myself over the next few days.
The impossibility of using detach_
instead of detach
above by the way happened because index slicing (which I used) only applies a view
-operation. One would have to clone the slices before detaching, but I wonder whether that still does what we want -- why do the detaching in-place then?
Can you give me a pointer why the input needs to be detached in-place and everything else is done with the standard detach
operation? If detaching in-place is not necessary, I'd suggest not doing that, otherwise people might run into the same problem that I ran into.
@silvandeleemput Sounds great! If you need any help or suggestions, I'd be more than happy to help. I'll give it a shot myself over the next few days.
Great, good luck! Here are some pointers:
t1, t2 = torch.chunk(input_tensor, 2, dim=1)
t1, t2 = t1.contiguous(), t2.contiguous()
operation(t1, t2)
Since you need to specify both inputs and weights, you might want to dynamically pass those, like this:
def forward(ctx, fn, fn_inverse, keep_input, num_bwd_passes, num_inputs, *inputs_and_weights):
and for the backwards pass:
def backward(ctx, *grad_outputs): # pragma: no cover
...
gradients = torch.autograd.grad(outputs=temp_outputs, inputs=detached_inputs + tuple(ctx.weights), grad_outputs=grad_outputs)
...
return (None, None, None, None, None) + gradients
Those were good suggestions. I have a working version now and am preparing some tests and will create a pull request once these are done.
I first didn't want to mess with your function arguments, so I originally went for list-inputs (keras-style) instead of combining inputs and weights like you proposed, which ended up being much easier. Internally, I now just convert everything to tuples (and in the case of single arguments, unpack these before returning).
Description
Sorry for the many questions over the last few days, I'd be happy to contribute.
I have a class of networks in which the channels are split at some point, after which several (reversible) layers are applied to each chunk of the channels. Then, these are recombined by concatenating the channels again. Channel splitting and concatenating (with fixed chunk size) are invertible functions and should work really well with this.
A possible implementation is this one:
One could also take the input as a list. The output should however always be a tensor or a tuple of tensors, because the
InvertibleCheckpointFunction
is aautograd.Function
object (whoseforward
-method wraps this) and does not accept lists as output types.In my architecture, the splitting and recombination appears in a nested fashion, such that an implementation as a single reversible layer is not possible (like it is done with the coupling blocks, where the splitting and recombination is done in the same layer).
With the current implementation, this is not possible. The error thrown is the following:
What I tried
So I already tried converting the code in
revop.py
such that it operates on tuples instead of tensors if needed, e.g.A problem is that the detaching does not work in-place here (no idea why), so one would have to use
detach()
instead ofdetach_()
. If I usedetach()
however, Pytorch will also throw an error: