Open kalekundert opened 1 year ago
Also, if you want to install ESCNN to try running this example yourself, there's a gotcha to be aware of. A dependency called lie_learn
has to be installed from GitHub instead of PyPI, for modern versions of python. I think the following commands should work, but let me know if they don't:
$ pip install git+https://github.com/AMLab-Amsterdam/lie_learn
$ pip install escnn
I figured out the infinite loop. It was my fault; there was a typo in the script I was running. I was calling tl.log_forward_pass(conv, x)
instead of tl.log_forward_pass(conv, [x])
. You can even see this by looking at the stack traces I posted. It turns out that GeometricTensor
objects are considered iterable because they implement __getitem__()
, but they do so in such a way that the iteration never ends. Sorry for the confusion. I was playing with multiple versions of my test script, and got things mixed up a bit.
So there's nothing wrong with the deep copy, but I'm still running into the 'Tensor' object has no attribute 'tl_tensor_label_raw'
stack trace.
Thanks so much for describing this so carefully! Looks like an interesting case I didn’t think of from the armchair, I’ll check it out pronto.
Okay, I think I found the issue (bear with me): in r2convolution.py, there is the following seemingly innocuous line:
from torch.nn.functional import conv2d, pad
Later on in that file, it calls conv2d
and pad
under those names, instead of calling them as torch.nn.functional.conv2d
. This ends up mattering because TorchLens works by replacing all the functions in the PyTorch namespace with modified versions of themselves, such that they log their results whenever they get called. But, if you just import a function as (e.g.) conv2d
, TorchLens isn't able to modify that instance of the function (since it's "dangling", no longer attached to the torch namespace), so it doesn't log its results as it should.
This is a subtle issue that will require a bit of refactoring to fix. In the meantime, I think the easiest fix would be to remove from torch.nn.functional import conv2d, pad
and replace it with import torch.nn.functional as f
, then later on call f.conv2d
and f.pad
.
Apologies for the headache over something so silly--I'll bump this up my priority list to fix more thoroughly, and let me know if the stopgap solution doesn't work.
Thanks for looking into this so quickly! I'm traveling today, but I'll check if the quick fix works for me as soon as I get the chance.
Sorry it took me a few days to get back to you on this, but I can confirm that the quick-fix you suggested works for me. Thanks again for such a quick response! I'll leave the issue open in case you're planning to do some refactoring to accommodate from torch.nn.functional import ...
imports (which seems like it would be really difficult to me), but feel free to close the issue if you're not.
Delighted to hear that it worked :) I haven’t figured out a totally general fix yet, and indeed it looks like it’s not going to be easy, but I’ll brainstorm some more…
I think this issue is impossible to solve in the totally general case, so closing this issue unless someone has an idea.
I can imagine a few ways to wrap the pytorch functions regardless of how they're imported. I don't know how torchlens works at all, so I don't know if any of these approaches would really solve the problem, but I figure they're worth mentioning:
Provide a function that the user must call before importing any libraries that would import torch
. This function would import torch
itself and replace the necessary pytorch functions with wrappers that can be further manipulated later. The nice thing about this approach is that it is explicit, but it requires the user to understand the need to call such a function.
Setup an import hook to perform the necessary monkeypatching as soon as torch
is first imported. See here for some ways to do this. (I think the MetaPathFinder
approach is better than the __import__
approach, for what it's worth.) This is more "magical" than having the user call a function, but it just requires torchlens
to be imported before torch
, and it would even be possible to issue a warning if the user imports the libraries in the wrong order.
The nuclear option: Use gc.get_referrers()
to get the list of all objects that hold a reference to each pytorch function, and then replace each of those references. This would be very hard to do reliably, because different kinds of references (e.g. dicts, lists, sets, etc.) would have to be replaced differently. But it might not be too hard to cover the most common use cases. There's an out-of-date library called pyjack
that tries to do this; it doesn't support python3, but you could look at its source to see how this kind of thing is done more specifically. At the very least, it might be worth using gc.get_referrers()
to warn the user if it looks like there are copies of the pytorch functions that can't be wrapped.
Wait, you are a genius :] I hadn't thought of any of these options, thanks so much for laying them out and describing them so carefully. For reference, the way Torchlens works is by attaching a rather elaborate decorator to all the functions in the PyTorch namespace such that their results get logged every time they're called. Currently, this decoration is done when TorchLens function (e.g., log_forward_pass
) is called, not when TorchLens is imported, so that the functions are in their "clean" original states outside of the TorchLens function calls.
The question is, which of these options makes for the minimum confusion from the user's perspective and wouldn't slow down performance too much. What about something like this: have the decoration occur when TorchLens is imported, and do this in an "exhaustive" way (potentially involving gc.get_referrers; e.g., get the references, and check for any references outside of the torch namespace), but make the decorator "silent" unless some attribute attached to the function is toggled on or off? This way, the decoration step only has to occur once (on import), it should work no matter the order of the imports, and all the torch functions should behave normally outside of the TorchLens function calls. If torch hasn't been imported yet then all the functions will get decorated so all subsequent imports will get the decorated version, and if torch has already been imported it'll catch all the references.
If this sounds reasonable I'll make this a priority in the next TorchLens update. Thanks again for the great suggestions, there's no way I would have thought of these on my own.
Sorry for the slow reply, I kinda let this issue fall off my radar. The big question (as I see it) is whether or not to do the decoration automatically, or to require the user to call a function that does it. Either way, I imagine that the decorator would work pretty much like you described. Some pros/cons:
Do it automatically:
torch
is imported before torchlens
.Do it manually:
I'm not sure which approach I prefer.
I just tried using torchlens on a model built using a library called ESCNN, and I ran into some errors that I'm hoping you can help me with. ESCNN is a relatively niche library for geometric deep learning; think CNNs where the filters are matched in all possible orientations, in addition to all possible locations.
Here's a script that creates a single convolutional layer and tries to visualize it with torchlens:
Just to explain the example a little bit, this is a 2D convolution where any 90° rotation of the filters should be matched. ESCNN requires a more sophisticated concept of "channels" than a normal CNN, and that's what the
in_type
andout_type
variables establish. The input to the convolutional layer is not a normal tensor, but a tensor wrapped in aGeometricTensor
object that also keeps track of the associated "channels".I think it might be helpful to include (and briefly describe) the source code for the
R2Conv.forward()
method:If the model is in training mode, some complicated calculations are performed to get the necessary filter and bias tensors. But if the model is in evaluation mode, as it is in the above example, then these calculations have already happened. Either way, the actual convolution is just done using
torch.nn.functional.conv2d
. There's nothing really fancy going on under the hood.When I run the above script, it gets stuck on the following line:
This line never completes, and keeps allocating memory until my machine runs out (>16 GB). It seems that
deepcopy()
is getting caught in an infinite loop, probably while trying to copyin_type
(which is a relatively complicated object). My first thought was that there might be a reference cycle, but it seems thatdeepcopy()
automatically handles reference cycles, so that's probably not it. The obvious solution would be to somehow modify the ESCNN objects so that they can be deep-copied, but that might not be a trivial change, and it seems to me that torchlens shouldn't require accommodations by libraries such as this if at all possible.I tried side-stepping this problem by just removing the
deepcopy()
call. That results in the following stack trace, which I wasn't able to make any sense of. I don't know if this is just the consequence of removing the deep copy, or indicative of some other problem:I'd really like to be able to use torchlens, so I'd appreciate any help you can offer.