spfrommer / torchexplorer

Interactively inspect module inputs, outputs, parameters, and gradients.
https://spfrommer.github.io/torchexplorer/
Apache License 2.0
310 stars 22 forks source link

Add histograms for intermediate tensors in a `forward` call #32

Closed spfrommer closed 10 months ago

spfrommer commented 10 months ago

Imagined API is something like:

def forward(self, x):
    y = x + 2
    y = torchexplorer.attach(y, self, name='intermediate_y', log=['val', 'grad', 'grad_norm'])
    return y * 4

Then in the associated panel, below the i/o histograms and above the param histograms, you'd get the histograms for this particular intermediate value.

I think this shouldn't require any major rearchitecting. Just create a dummy nn.Module with the correct hooks that adds to the histograms in self.torchexplorer_metadata.

spfrommer commented 10 months ago

Best way to do this is to have attach just add a submodule with the name torchexplorer_attach_NAME. It should also have an attribute torchexplorer_attach_name=NAME. This should happen in the following order:

1) On the first hook invocation, any hook applied to a module should be stored in a .torchexplorer_metadata.hooks. This is a list 2) Then when torchexplorer.attach is executed, it should create a module and add it to self. Check if name already exists and throw error if it does. It should apply all hooks to the new module that it finds in self.torchexplorer_metadata.hooks 3) In layout.py: 272, check if node.module has the attribute torchexplorer_attach_name, and if so use that as the label. Add a prefix [A] for attach, which we can later use in the interface to make it show up as "Tensor" instead of "Input/Output".