TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.56k stars 119 forks source link

Params and MACs Unit Specifier #183

Closed richardtml closed 1 year ago

richardtml commented 2 years ago

It would be very useful to have a way to specify the units (MB, GB, etc ) in which the number of parameters and MACS are reported. This could help quickly compare different architectures.

I think of something like adding arguments params_units and macs_units to the summary() function with a default value 'auto' to respect the current behavior.

TylerYep commented 2 years ago

Interesting idea, I'll definitely consider this feature; it sounds like a good addition. Open to accepting PRs adding this functionality!

richardtml commented 2 years ago

Thank you for the quick reply @TylerYep . I could try to add this feature if you don't mind.

TylerYep commented 2 years ago

Go for it! As you work on it we can discuss the API design, naming, etc

richardtml commented 1 year ago

An approach could be adding the following parameters to summary():

        params_unit (str):
                "A" : Auto, infers a suitable unit
                ""  : No conversion
                "M" : To millions
                "G" : To billions
                "T" : To trillions
                Default: "A".

        macs_unit (str):
                "A" : Auto, infers a suitable unit
                ""  : No conversion
                "M" : To millions
                "G" : To billions
                "T" : To trillions
                Default: "A".

What do you think @TylerYep ? I'm considering adding these params to the FormattingOptions class.

TylerYep commented 1 year ago

I don't think we should add two parameters. If in the future we start caring about some other measurement (e.g. flops), we would need to add a new parameter again and again.

This conversation makes me wonder whether it would be better to set these on the result of summary() instead, i.e.:

stats = torchinfo.summary(...)

stats.formatting.units = "M"
stats.formatting.col_width = 30

This way the actual call to summary() stays very simple and only takes in parameters required to execute the forward pass correctly.

richardtml commented 1 year ago

I realize I wasn't clear initially. I suggested this feature to just modify the global report.

Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 3.41

Do you think that is a good idea that unit specifiers apply to counts at layer-level (columns)?

On the other hand, often MACS are 1 order of magnitude larger than parameters, for example. Having a single argument for both could result in reporting very small or large numbers for one of them.

TylerYep commented 1 year ago

Thought through this some more and I've decided these shouldn't be added to summary(), as it will crowd up the API for what to me seems like a very uncommon use case (we could print stats.total_params to compare the exact number of parameters, for example).

I think the best place to support this would be a field in FormattingOptions:

stats = torchinfo.summary(...)

stats.formatting.params_units = "M"
stats.formatting.macs_units = "auto"
print(stats)

In the future I would likely move fields like col_width there as well.

richardtml commented 1 year ago

I had added results.formatting.params_units = "" and results.formatting.macs_units = "" to output values without units (raw number of params or macs). Is an empty string "" a good argument value?

Empty unit specifier macs_units = "" modifies the output summary from:

Total mult-adds (M): 0.00

to:

Total mult-adds: 0

Some tests explicitly expect first output summary, for instance torchinfo_test.py::test_uninitialized_tensor. Do you agree in changing this? The best way to update test test would be updating output files in test_output/?

TylerYep commented 1 year ago

The defaults you specify should not change the existing behavior unless specified, so after your changes I expect all existing output tests to pass without modification.

As an FYI though, you can update all output files with pytest --overwrite, which might be easier to see what you've changed across all files

richardtml commented 1 year ago

I'm having trouble to pass the mypy checking. I've got the following error even when I haven't touched torchinfo.py.

torchinfo/torchinfo.py:619: error: No overload variant of "reversed" matches argument type "dict_items[str, Optional[Module]]"  [call-overload]
torchinfo/torchinfo.py:619: note: Possible overload variants:
torchinfo/torchinfo.py:619: note:     def [_T] reversed(self, Reversible[_T]) -> reversed[_T]
torchinfo/torchinfo.py:619: note:     def [_T] reversed(self, SupportsLenAndGetItem[_T]) -> reversed[_T]
Found 1 error in 1 file (checked 2 source files)

@TylerYep , do you have any suggestion? Thanks!

TylerYep commented 1 year ago

Can you give me your versions of:

I am using Python 3.10.6, Pytorch 1.13.0, mypy 0.981

richardtml commented 1 year ago

I have python 3.7.13, torch 1.13.0 mypy 0.982, mypy-extensions 0.4.3, pylint 2.15.5 .

TylerYep commented 1 year ago

My guess is your python version is too old for torchinfo development. torchinfo as a package works on python 3.7, but all of our typechecking runs on 3.9. See https://github.com/TylerYep/torchinfo/blob/main/.github/workflows/test.yml for the pairings we support for developer work

I still need to add Python 3.10 and Pytorch 1.13 actually

richardtml commented 1 year ago

Thank you, that worked. I opened the PR #188

TylerYep commented 1 year ago

Resolved by #188, will be released in v1.7.2.

Thank you for the contribution @richardtml !