dotnet / runtime

.NET is a cross-platform runtime for cloud, mobile, desktop, and IoT apps.
https://docs.microsoft.com/dotnet/core/
MIT License
14.56k stars 4.55k forks source link

Optimize the layout of the Tensor types #102268

Open tannergooding opened 1 month ago

tannergooding commented 1 month ago

Work is being done to introduce a Tensor<T> and supporting types. Due to often representing slices of multi-dimensional memory, there is quite a lot of additional data that needs to be tracked beyond what something like Span<T> needs. Correspondingly, the naive approach requires tracking multiple nint[] to support the potential for an arbitrary number of dimensions and therefore for an allocation to be made per slice. Doing these allocations every time a slice needs to be produced can get expensive and should ideally be optimized to allow avoiding it for common dimension counts.

A simple approach would be to track a single nint[] where it has rank pieces of data tracking the length of each dimension and then rank more pieces of data tracking the stride of each dimension. But, this still necessitates an allocation every time. The next best thing would be to track data inline for some common dimension counts, but this quickly grows the size of the TensorSpan and that can itself have negative impact due to the larger copies required when passing the data by value, it can also negatively impact the CPU cache if it grows too large.

As such, the optimal setup is likely to pick a limit that is representative of commonly encountered dimension counts and which is no larger than a single cache line (typically assumed to be 64 bytes).

dotnet-policy-service[bot] commented 1 month ago

Tagging subscribers to this area: @dotnet/area-system-numerics-tensors See info in area-owners.md if you want to be subscribed.

tannergooding commented 1 month ago

Given the constraints described above, if we were to track the data inline using nint then we can track up to 3 dimensions on a 64-bit system:

public ref struct SpanND<T>
{
    private ref T _reference;       // 8 bytes
    private nint[]? _metadata;         // 8 bytes

    private fixed nint _lengths[3]; // 3x8 bytes (24)
    private fixed nint _strides[3]; // 3x8 bytes (24)
}

While this covers the most frequent dimension counts (1, 2, and 3), it excludes some other dimension counts (4 and 5) which are encountered, although less frequently (it depends on the domain). It also doesn't allow tracking the underlying FlattenedLength, which must now be computed dynamically, and doesn't allow tracking the Rank so some schema must be determined to allow figuring out how many dimensions exist (likely by stopping at the first _lengths entry that is 0).


If we were to track the FlattenedLength and Rank explicitly, we end up going down to 2 dimensions which is undesirable. However, if we say that we only need to avoid allocations for common cases and having more than 2.14b elements in a single dimension or single tensor is uncommon, then we could define something like:

public ref struct SpanND<T>
{
    private ref T _reference;       // 8 bytes
    private nint[]? _metadata;      // 8 bytes

    private int _flattenedLength;   // 4 bytes
    private int _rank;              // 4 bytes

    private fixed int _lengths[5];  // 5x4 bytes (20)
    private fixed int _strides[5];  // 5x4 bytes (20)
}

This allows tracking up to 5 dimensions and all the relevant information for the common case without allocating. Whether we use the inline data or not can be trivially checked by checking if _metadata is null. It is what I think to be the overall best approach and what we should likely pursue first.

One alternative layout would be to reduce this to tracking 4 dimensions so that the _flattenedLength can be nint and giving us space to track some explicit flags for any optimizations or other work that was desired (although many such flags could also exist in the prior case by using one's complement negatives like is done for fromEnd in Index):

public ref struct SpanND<T>
{
    private ref T _reference;       // 8 bytes
    private nint[]? _metadata;      // 8 bytes

    private nint _flattenedLength;  // 8 bytes
    private int _rank;              // 4 bytes
    private int _flags;             // 4 bytes

    private fixed int _lengths[4];  // 4x4 bytes (16)
    private fixed int _strides[4];  // 4x4 bytes (16)
}

It's worth noting that any of the approaches that involve tracking inline data to avoid allocations will push TensorSpan<T> over the FDG recommended maximum 24 byte threshold for struct types. This is expected since we're in a somewhat specialized domain where we are intentionally avoiding allocations for common scenarios. However, it will entail us passing these types via in to avoid implicit copy overhead for such large types