Closed sogartar closed 2 weeks ago
Add one case of resnet block sharding ResnetBlock2DSplitOutputChannelsSharding.
ResnetBlock2DSplitOutputChannelsSharding
Move AnyTensor to avoid cyclic dependencies.
Add broadcast_dim(s) functions and handle broadcasting for element-wise binary with sharded split tensors.
Make sharded linear override for a cartesian product of type sets. Let the constituent ops fail if they can't handle the input.
Add matmul for replicated LHS and split RHS.
Fix resharding of thetas.
Add unshard op.
Fix slicing of split tensor.
Refactor Theta to reflect that it is actually a tree of tensors and expose the tree. Fix some asserts during construction.
Add one case of resnet block sharding
ResnetBlock2DSplitOutputChannelsSharding
.Move AnyTensor to avoid cyclic dependencies.
Add broadcast_dim(s) functions and handle broadcasting for element-wise binary with sharded split tensors.
Make sharded linear override for a cartesian product of type sets. Let the constituent ops fail if they can't handle the input.
Add matmul for replicated LHS and split RHS.
Fix resharding of thetas.
Add unshard op.
Fix slicing of split tensor.
Refactor Theta to reflect that it is actually a tree of tensors and expose the tree. Fix some asserts during construction.