Closed richardtml closed 1 year ago
Interesting idea, I'll definitely consider this feature; it sounds like a good addition. Open to accepting PRs adding this functionality!
Thank you for the quick reply @TylerYep . I could try to add this feature if you don't mind.
Go for it! As you work on it we can discuss the API design, naming, etc
An approach could be adding the following parameters to summary()
:
params_unit (str):
"A" : Auto, infers a suitable unit
"" : No conversion
"M" : To millions
"G" : To billions
"T" : To trillions
Default: "A".
macs_unit (str):
"A" : Auto, infers a suitable unit
"" : No conversion
"M" : To millions
"G" : To billions
"T" : To trillions
Default: "A".
What do you think @TylerYep ? I'm considering adding these params to the FormattingOptions
class.
I don't think we should add two parameters. If in the future we start caring about some other measurement (e.g. flops), we would need to add a new parameter again and again.
This conversation makes me wonder whether it would be better to set these on the result of summary() instead, i.e.:
stats = torchinfo.summary(...)
stats.formatting.units = "M"
stats.formatting.col_width = 30
This way the actual call to summary()
stays very simple and only takes in parameters required to execute the forward pass correctly.
I realize I wasn't clear initially. I suggested this feature to just modify the global report.
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 3.41
Do you think that is a good idea that unit specifiers apply to counts at layer-level (columns)?
On the other hand, often MACS are 1 order of magnitude larger than parameters, for example. Having a single argument for both could result in reporting very small or large numbers for one of them.
Thought through this some more and I've decided these shouldn't be added to summary(), as it will crowd up the API for what to me seems like a very uncommon use case (we could print stats.total_params
to compare the exact number of parameters, for example).
I think the best place to support this would be a field in FormattingOptions:
stats = torchinfo.summary(...)
stats.formatting.params_units = "M"
stats.formatting.macs_units = "auto"
print(stats)
In the future I would likely move fields like col_width
there as well.
I had added results.formatting.params_units = ""
and results.formatting.macs_units = ""
to output values without units (raw number of params or macs). Is an empty string ""
a good argument value?
Empty unit specifier macs_units = ""
modifies the output summary from:
Total mult-adds (M): 0.00
to:
Total mult-adds: 0
Some tests explicitly expect first output summary, for instance torchinfo_test.py::test_uninitialized_tensor
. Do you agree in changing this? The best way to update test test would be updating output files in test_output/
?
The defaults you specify should not change the existing behavior unless specified, so after your changes I expect all existing output tests to pass without modification.
As an FYI though, you can update all output files with pytest --overwrite
, which might be easier to see what you've changed across all files
I'm having trouble to pass the mypy checking. I've got the following error even when I haven't touched torchinfo.py
.
torchinfo/torchinfo.py:619: error: No overload variant of "reversed" matches argument type "dict_items[str, Optional[Module]]" [call-overload]
torchinfo/torchinfo.py:619: note: Possible overload variants:
torchinfo/torchinfo.py:619: note: def [_T] reversed(self, Reversible[_T]) -> reversed[_T]
torchinfo/torchinfo.py:619: note: def [_T] reversed(self, SupportsLenAndGetItem[_T]) -> reversed[_T]
Found 1 error in 1 file (checked 2 source files)
@TylerYep , do you have any suggestion? Thanks!
Can you give me your versions of:
I am using Python 3.10.6, Pytorch 1.13.0, mypy 0.981
I have python 3.7.13, torch 1.13.0 mypy 0.982, mypy-extensions 0.4.3, pylint 2.15.5 .
My guess is your python version is too old for torchinfo development. torchinfo as a package works on python 3.7, but all of our typechecking runs on 3.9. See https://github.com/TylerYep/torchinfo/blob/main/.github/workflows/test.yml for the pairings we support for developer work
I still need to add Python 3.10 and Pytorch 1.13 actually
Thank you, that worked. I opened the PR #188
Resolved by #188, will be released in v1.7.2.
Thank you for the contribution @richardtml !
It would be very useful to have a way to specify the units (MB, GB, etc ) in which the number of parameters and MACS are reported. This could help quickly compare different architectures.
I think of something like adding arguments
params_units
andmacs_units
to thesummary()
function with a default value 'auto' to respect the current behavior.