pytorch / tensordict

TensorDict is a pytorch dedicated tensor container.
MIT License
813 stars 66 forks source link

Why is Tensorclass implemented as a decorator? #663

Open kurt-stolle opened 7 months ago

kurt-stolle commented 7 months ago

Current Tensorclass interface

To define a Tensorclass, we could write:

@tensorclass
class MyTensorclass:
    foo: Tensor
    bar: Tensor

The @tensorclass decorator then generates a dataclass-like container class, adding various methods to implement (parts of) the TensorDictBase abstraction. Notice that unlike dataclasses, where the __init__ signature may entirely be inferred from the class attribute annotations, the current decorator also augments __init__ with extra keyword arguments batch_size, device and names.

Considering that the decorator has significant (side-)effects to the resulting class, I am confused on why the API has been defined in this way. In this issue, I ask for clarification on this design choice.

Simplified interface (conceptual)

Specifically, would a simplified representation using subclassing not suffice? For example, consider the following notation:

class MyTensorclass(Tensorclass):
    foo: Tensor
    bar: Tensor

The Tensorclass baseclass then has the same effects as @tensorclass, and also:

  1. Enables a more straightforward implementation in tensordict
  2. Works with builtins issubclass and isinstance
  3. Is compatible with static type checking (e.g. extra keyword arguments batch_size etc.)
  4. Is easily understood as canonical Python

Polyfill

To further clarify the effects of the subclass above, here's a quick polyfill that allows the definition of tensorclasses using the subclass-based API above in the current decorator-based paradigm:

from types import resolve_bases
from typing import dataclass_transform, TYPE_CHECKING, Sequence
from dataclasses import dataclass, KW_ONLY
import torch
from torch.types import Device
from tensordict import tensorclass

@dataclass_transform
class _TensorclassMeta(type):
    def __new__(metacls, name, bases, ns, **kwds):
        bases = resolve_bases(bases)
        cls = super().__new__(metacls, name, tuple(bases), ns, **kwds)
        return tensorclass(cls)

@dataclass_transform
class Tensorclass(metaclass=_TensorclassMeta):  # Or: TensorDictBase subclass
    # Demonstrates point (3) above
    if TYPE_CHECKING:  
        _: KW_ONLY
        batch_size: torch.Size | Sequence[int]
        device: Device | str | None = None
        ...  # Other  methods
vmoens commented 7 months ago

Hey thanks for proposing this! I agree with all your points, inheritance is easy and understandable for most of the python community, and way less hacky (which is the one thing I do not like about dataclass and consequently tensorclass too).

We thought about that initially, but there are a couple of reasons that made us go for the @tensorclass decorator. It has mainly to do with the fact that a large community likes @dataclass for the very reason that it does not inherit from anything, you can build a class that is the parent of all subclasses with little trouble. The idea of tensorclass is to have a dataclass on steroids. If we'd made a TensorClass base, it would have been a mixture of @dataclass and inheritance, a strange class to play with for people used to dataclasses.

Note that we could implement a class & metaclass that implement the isinstance if that makes things easier?

  from tensordict import tensorclass, is_tensorclass
  import torch

  class SomeClassMeta(type):
      def __instancecheck__(self, instance):
          if is_tensorclass(instance):
              return True
          return False
      def __subclasscheck__(self, subclass):
          if is_tensorclass(subclass):
              return True
          return False

  class TensorClass(metaclass=SomeClassMeta):
      pass

  @tensorclass
  class MyDC:
      x: torch.Tensor

  c = MyDC(1, batch_size=[])
  assert isinstance(c, TensorClass)
  assert issubclass(type(c), TensorClass)

That only partially fixes your issues like type checking though.

I'm personally not a huge fan of dataclass but I understand why people like them. I think that most of the time it has to do with typing (i.e., the content is more explicit than with a dictionary). Your solution still offers that so I wouldn't put it aside straight away, but if we consider this we must make sure that it won't harm adoption of the feature for the target audience (ie, users who are accustomed to @dataclass decorator).

RE this point

Notice that unlike dataclasses, where the init signature may entirely be inferred from the class attribute annotations, the current decorator also augments init with extra keyword arguments batch_size, device and names.

I don't see how that relates to the four points you raised about the benefits of subclassing. To me @dataclass suffers from the 4 pitfalls you identified there too, doesn't it?

cc @shagunsodhani @apbard @tcbegley who may have some other thoughts to share on this

kurt-stolle commented 7 months ago

Thanks for your swift reply @vmoens! This clears up my main questions on the design choices made.

While I did not intend to advocate immediate alterations to the status quo, I am curious to learn what the intended role of tensorclasses is exactly. From the user perspective, the main reasons listed for using a tensorclass come from the keys being explicitly defined. In that case, weighting the library design in a way that prefers users' code aesthetics (i.e. being similar to @dataclass) over their static type checking capabilities seems a bit off-balance.

Second, was a solution that specifically addresses typing a la typing.TypedDict ever explored?

RE:

It has mainly to do with the fact that a large community likes @dataclass for the very reason that it does not inherit from anything, you can build a class that is the parent of all subclasses with little trouble.

It is my perception that 'it not inheriting from anything' in this context is more relevant to dataclasses than it is to tensorclasses. I would argue:

  1. Dataclasses are a method for defining a class without many implications on functionality. As such, I could understand why one would not want two functionally different dataclasses to share a common base, as the only thing they share is how they were defined and not what they represent. Thus, @dataclass can be interpreted as a code generator that outputs the body of a class.
  2. Tensorclasses are also a method for defining a class, but has significant implications for the functionality. This added functionality is mostly in the form of binding class methods, like would normally be done via a superclass.

Consequently, two tensorclasses will have a larger degree of shared functionality than two dataclasses. This is difficult to define in an exact manner, though.

RE:

I don't see how that relates to the four points you raised about the benefits of subclassing. To me @dataclass suffers from the 4 pitfalls you identified there too, doesn't it?

It does indeed. Let me further elaborate on why I identify this property of init signatures. It is mostly related to problematic typing:

  1. Classes made with @dataclass already have good static type check support (e.g. as shown in Pyright and Mypy)
  2. Until recently (PEP 681), such behavior was non-extensible to other classes with dataclass-like behaviour.
    • To add typing support for @tensorclass, we could follow PEP 681 and decorate with @dataclass_transform from typing.
    • However, this would result in batch_size and device missing from the signature at type-check time. To my knowledge, there is no way around this using the decorator-based approach.
  3. If defined as a subclass, we can use @dataclass_transform and TYPE_CHECKING to define the interface of the __init__ method statically in code, providing the user with proper typing using simple and canonical Python.
    • The added batch_size and device are now recogized by type checkers as being part of the keyword arguments of __init__.
    • See code block in the original issue above.
vmoens commented 7 months ago

Thank you for your comments, a lot of good points there.

To provide some context on typing: tensordict, like pytorch, is not strongly typed. We do not currently use mypy as part of our CI and it's unlikely that we will in the future. Regardless of my personal views on type checking (which are not pertinent to this discussion), the potential integration of tensordict's core features into pytorch repo makes it highly improbable that it will ever be type-checked. However, we understand that some users may wish to use mypy with tensordict, and we must consider how to support this.

Proposed Solution

Your suggestion of dataclass_transform looks very cool to me. If I understand correctly, the type checker will only be satisfied with inheritance, correct? This could be a feasible path forward.

Given this, we would find ourselves in the following scenarios based on user requirements:

Would this status quo be acceptable to you? Is the dual usage (as a decorator or subclass) even desirable? It's worth noting that offering two methods to achieve the same result can lead to confusion. Many users are currently familiar with using @tensorclass, so any decision to deprecate it must be carefully considered.

We now need to evaluate the feasibility and implications of these options...