metaopt / optree

OpTree: Optimized PyTree Utilities
https://optree.readthedocs.io
Apache License 2.0
146 stars 7 forks source link

feat: add `PyStructSequence` types as internal node types #30

Closed XuehaiPan closed 1 year ago

XuehaiPan commented 1 year ago

Description

Describe your changes in detail.

Add PyStructSequence types as internal node types.

Motivation and Context

Why is this change required? What problem does it solve? If it fixes an open issue, please link to the issue here. You can use the syntax close #15213 if this solves the issue #15213

Resolves #29.

Without this PR: StructSequences are leaves

>>> import sys
>>> import torch
>>> from optree import *

>>> tree_flatten(sys.float_info)
(
    [sys.float_info(max=1.7976931348623157e+308, max_exp=1024, max_10_exp=308, min=2.2250738585072014e-308, min_exp=-1021, min_10_exp=-307, dig=15, mant_dig=53, epsilon=2.220446049250313e-16, radix=2, rounds=1)],
    PyTreeSpec(*)
)

>>> tree_flatten(sys.int_info)
(
    [sys.int_info(bits_per_digit=30, sizeof_digit=4, default_max_str_digits=4300, str_digits_check_threshold=640)],
    PyTreeSpec(*)
)

>>> tree_flatten(torch.max(torch.arange(12).reshape(3, 4), dim=-1))
(
    [torch.return_types.max(values=tensor([ 3,  7, 11]),
                            indices=tensor([3, 3, 3]))],
    PyTreeSpec(*)
)

With this PR: StructSequences behave like tuples

>>> tree_flatten(sys.float_info)
(
    [1.7976931348623157e+308, 1024, 308, 2.2250738585072014e-308, -1021, -307, 15, 53, 2.220446049250313e-16, 2, 1],
    PyTreeSpec(sys.float_info(max=*, max_exp=*, max_10_exp=*, min=*, min_exp=*, min_10_exp=*, dig=*, mant_dig=*, epsilon=*, radix=*, rounds=*))
)

>>> tree_flatten(sys.int_info)
(
    [30, 4, 4300, 640],
    PyTreeSpec(sys.int_info(bits_per_digit=*, sizeof_digit=*, default_max_str_digits=*, str_digits_check_threshold=*))
)

>>> tree_flatten(torch.max(torch.arange(12).reshape(3, 4), dim=-1))
(
    [tensor([ 3,  7, 11]), tensor([3, 3, 3])],
    PyTreeSpec(torch.return_types.max(values=*, indices=*))
)

New helper functions:

>>> is_structseq(sys.float_info)
True
>>> is_structseq(type(sys.float_info))
True
>>> is_structseq_class(sys.float_info)
False
>>> is_structseq_class(type(sys.float_info))
True
>>> structseq_fields(sys.float_info)
('max', 'max_exp', 'max_10_exp', 'min', 'min_exp', 'min_10_exp', 'dig', 'mant_dig', 'epsilon', 'radix', 'rounds')
>>> structseq_fields(type(sys.float_info))
('max', 'max_exp', 'max_10_exp', 'min', 'min_exp', 'min_10_exp', 'dig', 'mant_dig', 'epsilon', 'radix', 'rounds')

Types of changes

What types of changes does your code introduce? Put an x in all the boxes that apply:

Checklist

Go over all the following points, and put an x in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help!

codecov[bot] commented 1 year ago

Codecov Report

Base: 100.00% // Head: 100.00% // No change to project coverage :thumbsup:

Coverage data is based on head (0b6aa78) compared to base (b41db31). Patch coverage: 100.00% of modified lines in pull request are covered.

:exclamation: Current head 0b6aa78 differs from pull request most recent head 25d3872. Consider uploading reports for the commit 25d3872 to get more accurate results

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #30 +/- ## ========================================= Coverage 100.00% 100.00% ========================================= Files 4 4 Lines 347 351 +4 ========================================= + Hits 347 351 +4 ``` | Flag | Coverage Δ | | |---|---|---| | unittests | `100.00% <100.00%> (ø)` | | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt#carryforward-flags-in-the-pull-request-comment) to find out more. | [Impacted Files](https://codecov.io/gh/metaopt/optree/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt) | Coverage Δ | | |---|---|---| | [optree/\_\_init\_\_.py](https://codecov.io/gh/metaopt/optree/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt#diff-b3B0cmVlL19faW5pdF9fLnB5) | `100.00% <ø> (ø)` | | | [optree/ops.py](https://codecov.io/gh/metaopt/optree/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt#diff-b3B0cmVlL29wcy5weQ==) | `100.00% <100.00%> (ø)` | | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.