I am thinking of adding another sharded tensor type that represents a sharded tensor, that need to be reduced (its shards summed) to get the actual tensor.
I see that the point of this is to reorder the summation to be done after a subsequent operation as shown here.
This information should be carried in the type so that downstream ops can dispatch on the type and decide whether to sum before or after.
I am thinking of adding another sharded tensor type that represents a sharded tensor, that need to be reduced (its shards summed) to get the actual tensor. I see that the point of this is to reorder the summation to be done after a subsequent operation as shown here. This information should be carried in the type so that downstream ops can dispatch on the type and decide whether to sum before or after.