finch-tensor / finch-tensor-python

Sparse and Structured Tensor Programming in Python
MIT License
8 stars 3 forks source link

Refactor punting of array storage to `Tensor.device` #13

Open hameerabbasi opened 8 months ago

hameerabbasi commented 8 months ago

Motivation

Users may need a way to:

  1. Query the format that a Tensor is currently in
  2. Create arrays with a specific format
  3. Convert to another format
  4. Materialize an array in a given format
  5. Do the above also with a given memory, i.e. CPU or GPU memory

For the purposes of being able to make decisions at the code level and also to gain back some control. In a meeting with @willow-ahrens and @mtsokol we decided the following interface for the long-term.

Proposed interface

I propose a number of new classes (stubs and descriptions below):

class LeafLevel(abc.ABC):
    @abc.abstractmethod
    def _construct(self, *, dtype, fill_value) -> jl.LeafLevel:
        ...

class Level(abc.ABC):
    @abc.abstractmethod
    def _construct(self, *, inner_level: "Level" | LeafLevel) -> jl.Level:
        ...

# Example impl of `Level`
class SparseList(Level):
    def __init__(self, index_type=dtypes.intp, pointer_type=dtypes.intp):
        self.index_type = index_type
        self.pointer_type = pointer_type
    def _construct(self, *, level) -> SparseList:
        return jl.SparseList[self.index_type, self.pos_type](level)

class Format:
    levels: tuple[Level, ...]
    order: tuple[int, ...]
    leaf: LeafLevel

    def __init__(self, *, levels: tuple[Level, ...], order: tuple[int, ...] | None, leaf: LeafLevel) -> None:
        if order is None:
            order = tuple(range(len(levels)))

        if len(order) != len(levels):
            raise ValueError(f"len(order) != len(levels), {order=}, {levels=}")

        if sorted(order) != range(len(order)):
            raise ValueError(f"sorted(order) != range(len(order)), {order=}")

        self.order = order
        self.levels = levels
        self.leaf = leaf

    def _construct(self, *, fill_value, dtype) -> jl.Swizzle:
        out_level = self.leaf._construct(dtype=dtype, fill_value=fill_value)
        for level in reversed(self.levels):
            out_level = level._construct(out_level)

        return jl.Swizzle(out_level, *reversed(self.order))

class Device:
    """The memory the `Tensor` will live on; as well as the execution context. Mixing devices will err."""

class Tensor:
    device: Device
    format: Format
    fill_value: Any
    dtype: dtypes.DType

    ...

    def to_device(self, device, /) -> "Tensor":
        ...

    def to_format(self, format, /) -> "Tensor":
        ...

def asarray(x: Any, /, *, dtype=None, device: Device | None = None, format: Format | None = None, fill_value: Any):
    # Massage dtype/device/format into acceptable form iff `None`
    return Tensor(jl_data=format._construct(fill_value=fill_value, dtype=dtype))