johnmarktaylor91 / torchlens

Package for extracting and mapping the results of every single tensor operation in a PyTorch model in one line of code.
GNU General Public License v3.0
475 stars 17 forks source link

Use torchlens with ESCNN model #18

Open kalekundert opened 11 months ago

kalekundert commented 11 months ago

I just tried using torchlens on a model built using a library called ESCNN, and I ran into some errors that I'm hoping you can help me with. ESCNN is a relatively niche library for geometric deep learning; think CNNs where the filters are matched in all possible orientations, in addition to all possible locations.

Here's a script that creates a single convolutional layer and tries to visualize it with torchlens:

import torch
import torchlens as tl

from escnn.nn import FieldType, R2Conv, GeometricTensor
from escnn.gspaces import rot2dOnR2

gspace = rot2dOnR2(4)
so2 = gspace.fibergroup

in_type = FieldType(gspace, [so2.trivial_representation])
out_type = FieldType(gspace, [so2.regular_representation])

conv = R2Conv(in_type, out_type, kernel_size=3)
conv.eval()

x = GeometricTensor(
        tensor=torch.randn(1, 1, 5, 5),
        type=in_type,
)

log = tl.log_forward_pass(conv, [x])
print(log)

Just to explain the example a little bit, this is a 2D convolution where any 90° rotation of the filters should be matched. ESCNN requires a more sophisticated concept of "channels" than a normal CNN, and that's what the in_type and out_type variables establish. The input to the convolutional layer is not a normal tensor, but a tensor wrapped in a GeometricTensor object that also keeps track of the associated "channels".

I think it might be helpful to include (and briefly describe) the source code for the R2Conv.forward() method:

def forward(self, input: GeometricTensor):
        assert input.type == self.in_type

        if not self.training:
            _filter = self.filter
            _bias = self.expanded_bias
        else:
            # retrieve the filter and the bias
            _filter, _bias = self.expand_parameters()

        if self.padding_mode == 'zeros':
            output = conv2d(input.tensor, _filter,
                            stride=self.stride,
                            padding=self.padding,
                            dilation=self.dilation,
                            groups=self.groups,
                            bias=_bias)
        else:
            output = conv2d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
                            _filter,
                            stride=self.stride,
                            dilation=self.dilation,
                            groups=self.groups,
                            bias=_bias)

        return GeometricTensor(output, self.out_type, coords=None)

If the model is in training mode, some complicated calculations are performed to get the necessary filter and bias tensors. But if the model is in evaluation mode, as it is in the above example, then these calculations have already happened. Either way, the actual convolution is just done using torch.nn.functional.conv2d. There's nothing really fancy going on under the hood.


When I run the above script, it gets stuck on the following line:

Traceback (most recent call last):
  File "/home/kale/hacking/bugs/torchlens_escnn/torchlens_escnn.py", line 21, in <module>
    log = tl.log_forward_pass(conv, x)
  File "/home/kale/hacking/forks/torchlens/torchlens/user_funcs.py", line 98, in log_forward_pass
    model_history = run_model_and_save_specified_activations(
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 7001, in run_model_and_save_specified_activations
    model_history.run_and_log_inputs_through_model(
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 1762, in run_and_log_inputs_through_model
    input_args = [copy.deepcopy(arg) for arg in input_args]

This line never completes, and keeps allocating memory until my machine runs out (>16 GB). It seems that deepcopy() is getting caught in an infinite loop, probably while trying to copy in_type (which is a relatively complicated object). My first thought was that there might be a reference cycle, but it seems that deepcopy() automatically handles reference cycles, so that's probably not it. The obvious solution would be to somehow modify the ESCNN objects so that they can be deep-copied, but that might not be a trivial change, and it seems to me that torchlens shouldn't require accommodations by libraries such as this if at all possible.

I tried side-stepping this problem by just removing the deepcopy() call. That results in the following stack trace, which I wasn't able to make any sense of. I don't know if this is just the consequence of removing the deep copy, or indicative of some other problem:

Traceback (most recent call last):
  File "/home/kale/hacking/bugs/torchlens_escnn/torchlens_escnn.py", line 21, in <module>
    log = tl.log_forward_pass(conv, [x])
  File "/home/kale/hacking/forks/torchlens/torchlens/user_funcs.py", line 98, in log_forward_pass
    model_history = run_model_and_save_specified_activations(
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 7001, in run_model_and_save_specified_activations
    model_history.run_and_log_inputs_through_model(
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 1818, in run_and_log_inputs_through_model
    raise e
  File "/home/kale/hacking/forks/torchlens/torchlens/model_history.py", line 1802, in run_and_log_inputs_through_model
    self.output_layers.append(t.tl_tensor_label_raw)
AttributeError: 'Tensor' object has no attribute 'tl_tensor_label_raw'

I'd really like to be able to use torchlens, so I'd appreciate any help you can offer.

kalekundert commented 11 months ago

Also, if you want to install ESCNN to try running this example yourself, there's a gotcha to be aware of. A dependency called lie_learn has to be installed from GitHub instead of PyPI, for modern versions of python. I think the following commands should work, but let me know if they don't:

$ pip install git+https://github.com/AMLab-Amsterdam/lie_learn
$ pip install escnn
kalekundert commented 11 months ago

I figured out the infinite loop. It was my fault; there was a typo in the script I was running. I was calling tl.log_forward_pass(conv, x) instead of tl.log_forward_pass(conv, [x]). You can even see this by looking at the stack traces I posted. It turns out that GeometricTensor objects are considered iterable because they implement __getitem__(), but they do so in such a way that the iteration never ends. Sorry for the confusion. I was playing with multiple versions of my test script, and got things mixed up a bit.

So there's nothing wrong with the deep copy, but I'm still running into the 'Tensor' object has no attribute 'tl_tensor_label_raw' stack trace.

johnmarktaylor91 commented 11 months ago

Thanks so much for describing this so carefully! Looks like an interesting case I didn’t think of from the armchair, I’ll check it out pronto.

johnmarktaylor91 commented 11 months ago

Okay, I think I found the issue (bear with me): in r2convolution.py, there is the following seemingly innocuous line:

from torch.nn.functional import conv2d, pad

Later on in that file, it calls conv2d and pad under those names, instead of calling them as torch.nn.functional.conv2d. This ends up mattering because TorchLens works by replacing all the functions in the PyTorch namespace with modified versions of themselves, such that they log their results whenever they get called. But, if you just import a function as (e.g.) conv2d, TorchLens isn't able to modify that instance of the function (since it's "dangling", no longer attached to the torch namespace), so it doesn't log its results as it should.

This is a subtle issue that will require a bit of refactoring to fix. In the meantime, I think the easiest fix would be to remove from torch.nn.functional import conv2d, pad and replace it with import torch.nn.functional as f, then later on call f.conv2d and f.pad.

Apologies for the headache over something so silly--I'll bump this up my priority list to fix more thoroughly, and let me know if the stopgap solution doesn't work.

kalekundert commented 11 months ago

Thanks for looking into this so quickly! I'm traveling today, but I'll check if the quick fix works for me as soon as I get the chance.

kalekundert commented 11 months ago

Sorry it took me a few days to get back to you on this, but I can confirm that the quick-fix you suggested works for me. Thanks again for such a quick response! I'll leave the issue open in case you're planning to do some refactoring to accommodate from torch.nn.functional import ... imports (which seems like it would be really difficult to me), but feel free to close the issue if you're not.

johnmarktaylor91 commented 11 months ago

Delighted to hear that it worked :) I haven’t figured out a totally general fix yet, and indeed it looks like it’s not going to be easy, but I’ll brainstorm some more…

johnmarktaylor91 commented 11 months ago

I think this issue is impossible to solve in the totally general case, so closing this issue unless someone has an idea.

kalekundert commented 11 months ago

I can imagine a few ways to wrap the pytorch functions regardless of how they're imported. I don't know how torchlens works at all, so I don't know if any of these approaches would really solve the problem, but I figure they're worth mentioning:

johnmarktaylor91 commented 11 months ago

Wait, you are a genius :] I hadn't thought of any of these options, thanks so much for laying them out and describing them so carefully. For reference, the way Torchlens works is by attaching a rather elaborate decorator to all the functions in the PyTorch namespace such that their results get logged every time they're called. Currently, this decoration is done when TorchLens function (e.g., log_forward_pass) is called, not when TorchLens is imported, so that the functions are in their "clean" original states outside of the TorchLens function calls.

The question is, which of these options makes for the minimum confusion from the user's perspective and wouldn't slow down performance too much. What about something like this: have the decoration occur when TorchLens is imported, and do this in an "exhaustive" way (potentially involving gc.get_referrers; e.g., get the references, and check for any references outside of the torch namespace), but make the decorator "silent" unless some attribute attached to the function is toggled on or off? This way, the decoration step only has to occur once (on import), it should work no matter the order of the imports, and all the torch functions should behave normally outside of the TorchLens function calls. If torch hasn't been imported yet then all the functions will get decorated so all subsequent imports will get the decorated version, and if torch has already been imported it'll catch all the references.

If this sounds reasonable I'll make this a priority in the next TorchLens update. Thanks again for the great suggestions, there's no way I would have thought of these on my own.

kalekundert commented 11 months ago

Sorry for the slow reply, I kinda let this issue fall off my radar. The big question (as I see it) is whether or not to do the decoration automatically, or to require the user to call a function that does it. Either way, I imagine that the decorator would work pretty much like you described. Some pros/cons:

Do it automatically:

Do it manually:

I'm not sure which approach I prefer.