metaopt / optree

OpTree: Optimized PyTree Utilities
https://optree.readthedocs.io
Apache License 2.0
136 stars 6 forks source link

[Feature Request] Consider `PyStructSequence` types as internal node types #29

Closed XuehaiPan closed 1 year ago

XuehaiPan commented 1 year ago

Required prerequisites

Motivation

Currently, only tuple and subclasses of namedtuple are considered node types. Other subclasses of tuple should be manually registered as node type. Otherwise, they are opaque leaf type.

In Python documentation Struct Sequence Objects:

Struct sequence objects are the C equivalent of namedtuple() objects, i.e. a sequence whose items can also be accessed through attributes. To create a struct sequence, you first have to create a specific struct sequence type.

PyStructSequence types are also subclasses of tuple. They are widely used in Python standard libraries (e.g, time.struct_time) and other packages (e.g., torch.return_types.*).

It would be nice if we consider PyStructSequence types as internal node types by default.

Ref:

Solution

Add special handling of PyStructSequence subclasses in C extensions as we did for namedtuple.

A snippet I posted in What’s the best way to check if a type is PyStructSequence:

import inspect

def is_namedtuple(obj: object | type) -> bool:
    cls = obj if inspect.isclass(cls) else type(obj)
    return (
        issubclass(cls, tuple)
        and isinstance(getattr(cls, '_fields', None), tuple)
        and all(isinstance(field, str) for field in cls._fields)
    )

def is_structseq(obj: object | type) -> bool:
    cls = obj if inspect.isclass(cls) else type(obj)
    if (
        cls.__base__ is tuple
        and isinstance(getattr(cls, 'n_sequence_fields', None), int)
        and isinstance(getattr(cls, 'n_fields', None), int)
        and isinstance(getattr(cls, 'n_unnamed_fields', None), int)
    ):
        try:

            class subcls(cls):
                pass

        except (
            TypeError,       # CPython
            AssertionError,  # PyPy
        ):
            return True

    return False
>>> is_structseq(torch.return_types.max)
True
>>> is_structseq(time.struct_time)
True

Alternatives

Keep current API unchanged. Users need to register the PyStructSequence types manually.

Additional context

No response