Open jackdent opened 4 months ago
Hey @jackdent This is not supported, and we don't document any examples of nesting LightningModules like that as far as I know. Since this goes against the design principles behind LightningModules, we can't really "fix" this. Even if the trainer reference was set, you still run into issues from a design perspective where hook calling suddenly is no longer well defined.
The LightningModule is meant to be the top-level module that organizes your code. Nesting it conceptually does not make sense. I'm making a guess, but the reason your are doing this could be that you'd like to reuse some code that you wrote and want to inherit from. This by itself is not a bad idea, but it can be achieved without having your children modules to be LightningModules. I recommend that you refactor your code so that your LightningModule is the top-level module.
@awaelchli thank you for the response--it's reasonable that this behaviour is unsupported. However, it is worth pointing out that making a LightningModule
a direct children of another LightningModule
s does seem to be well supported (e.g. the code snippet I shared above handles this case explicitly, and sets the trainer on all children).
Perhaps it's better to take a step back and explain what I'm trying to accomplish. I want to be able to log intermediate values (e.g. the mean + stddev of some activations in a nn.Module
). Having the ability to Introspect operations inside nn.Module
s is extremely important, since we regularly need to link into the internal state of the model to find bugs (e.g. we want to track the flow of grad norms through our network).
Since the Lightning logger is only defined on the LightningModule
, not on nn.Module
s, we can't call log
inside an nn.Module
. Right now, our solution is to pass through the logger as an argument to every forward
method in all our child submodules, but that's fairly inelegant. Ideally, we'd just be able to caller lightning_logger.log
from anywhere inside our child modules--creating a Logger
class that inherits from LightningModule
and setting that as a child on the nn.Module
s was my attempt to achieve the desired behaviour, but I'm open to better solutions if you can think of any (e.g. should Lightning expose a singleton logger?).
I don't know why that code for assigning trainer to children is there. But I have never seen a use case where this can be exploited to a great benefit.
Having read about your use case (thanks for the context), it only makes me more confident that it is the wrong approach. Suppose what you proposed was supported in Lightning. You'd have a PyTorch model that has nested layers, but some of them are going to be LightningModules. While you might be able to implement your logging strategy this way, what happens when you're done training? Very likely you want to use the model, by loading a checkpoint. But now your model contains tons of code that is unrelated to your inference/deployment pipeline. In fact, at inference there won't be a trainer object defined! So you'll have to anyway change/update your model code after the fact and guard all your logging calls. All this will unnecessarily complicate your model code. But this is the reason why model code should not be mixed with orchestration code! It's a trap. The better way, the Lightning way, is to separate training orchestration code from model code (the definition of your forward). That's why the LightningModule as a top-level system is there for
# The pure nn Module. Contains PURE modelling code, no training, logging or testing
class MyModel(nn.Module):
def forward(self, x):
....
# The LightningModule, contains code for all interactions with your model
# For example training, evaluation, or inference
class MyTask(LightningModule):
def __init__(self, ...):
super().__init__()
self.model = MyModel()
# special hooks
# Later on when we're done training, we can just use MyModel directly and throw away the LightningModule (no need anymore)
model = MyModel()
model.load_state_dict(...)
model(input)
This is the high-level approach to Lightning's design principles.
To achieve this
Having the ability to Introspect operations inside nn.Modules is extremely important, since we regularly need to link into the internal state of the model to find bugs (e.g. we want to track the flow of grad norms through our network).
there are other ways. I see at least two:
a) Return the debugging information of your intermediate outputs as meta-data from your forward call, perhaps as a dict. Collect that output in your training_step()
, then process it there and log it. Pro: Maximum flexibility, no orchestration dependencies. Con: Still some debugging related code tied to your model
b) Use forward or backward hooks (a feature in PyTorch) to collect your intermediate module outputs and do something with them. For example, I've done that in the past for plotting histograms of intermediate outputs. Pro: your model code remains completely clean of any debugging code! Con: Less flexible, more code to write.
c) If you can't avoid using a logger in your PyTorch module directly, then I suggest passing it to __init__
or saving it as a reference rather than passing it through forward. You can access the logger in your LightningModule as self.logger
. Pro: Closes to what you've been doing so far, no new Lightning features required. Con: Logger is tied to your model
I hope one of these fits your needs and you can give it some thoughts.
Thank you for the comprehensive answer @awaelchli -- using forward/backward hooks, and passing through the trainer/logger as a closure when defining those hooks on the level of the task is a great solution (far better than the direction I was going in). Your data monitor snippet is an extremely helpful reference implementation.
Bug description
Suppose I have a
LightningModule
(parent) that contains ann.Module
(child), which in turn contains anotherLightningModule
(grandchild). Calling.log
inside theLightningModule
(the grandchild) results in the following warning:The trainer is only set on the direct
children
of the parentLightningModule
, not all the descendants, since thetrainer.setter
usesself.children()
rather thanself.modules()
: https://github.com/Lightning-AI/pytorch-lightning/blob/3730e980e388c23f7e9d1f535793e8d614633362/src/lightning/pytorch/core/module.py#L221-L226What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
Environment
Current environment
* CUDA: - GPU: - NVIDIA A100-SXM4-80GB - available: True - version: 12.1 * Lightning: - lightning: 2.2.1 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.2.1 - torch: 2.3.1 - torchmetrics: 1.3.2 - torchvision: 0.18.1 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.9 - release: 5.15.0-113-generic - version: #123-Ubuntu SMP Mon Jun 10 08:16:17 UTC 2024More info
No response
cc @carmocca @justusschock @awaelchli @borda