johnmarktaylor91 / torchlens

Package for extracting and mapping the results of every single tensor operation in a PyTorch model in one line of code.
GNU General Public License v3.0
477 stars 17 forks source link

Exposing ModelHistory and TensorLogEntry for Better Usability in Torchlens #22

Closed drwiner closed 3 months ago

drwiner commented 3 months ago

I am currently using your library and I've found it to be very useful. However, I've encountered a situation where I believe the library could be improved for better usability.

Torchlens methods like log_forward_pass return objects of type torchlens.module_history.ModelHistory, which is essentially a list of type TensorLogEntry. These classes are not currently exposed to the interfacing program. This makes it difficult to use ModelHistory as a declared type in methods like:

import torchlens as tl
def method_call(model_history: tl.ModelHistory) -> tl.TensorLogEntry:
    pass

To improve this, I suggest that the following line is added to the bottom of the __init__.py file:

from torchlens.model_history import ModelHistory, TensorLogEntry

This would follow the import statements for the user-facing functions:

""" Top level package: make the user-facing functions top-level, rest accessed as submodules.
"""
from torchlens.user_funcs import (
    log_forward_pass,
    show_model_graph,
    validate_saved_activations,
    validate_batch_of_models_and_inputs,
)

Note that the method will still work without this, but the compiler for IDEs like visual studio and pycharm give an error "Cannot find reference 'model_history' in 'init.py' ". By exposing ModelHistory and TensorLogEntry at the top level, it would allow developers to use these types directly, promoting easier integration with programs building on this library.

Thank you for your time and for the great work you've done on this library.

Best regards, David