Closed chaserileyroberts closed 5 years ago
The API for this should look like net.split_edge(edge, shape)
Could you explain how this will help distributed computations? I would have thought "slice_edge" (direct-sum decomposition of index space) and "concat_edges()" would be more useful for splitting contractions across devices.
It's a TPU specific optimization, and I'm not sure I'm allowed to explain the reasoning here. :/
Note that this is different from https://github.com/google/TensorNetwork/issues/260 in that split_edge
would turn 1 edge into multiple edges. slice_edge
would take 1 edge of dimension X and cut it so that it only has dimension Y. The underlying tensors connected to the edge in slice_edge
would be sliced, whereas with split_edge
they would only be reshaped.
I'll have a go at this.
@Thenerdstation What's the convention for handling shapes in the codebase? I've found Tuple[int]
being used in the constructor of BaseNode
but the flatten_edges
code uses backend tensors that then get concatenated to construct the new shape for reshaping. Writing tests for my split_edge
implementation (that accepts shape: Tuple[int]
) I'm seeing that the reshape
wrappers of backends pytorch
and tensorflow
accept tuples whereas numpy
and jax
do not (due to the numpy
.astype
trying to cast the tuple immediately). Should this be made consistent? Or do the backends need helper functions to convert tuples to and from whatever backend tensor is preferred for shapes (like LongTensor
for PyTorch)?
Aw yes, the dreaded tensor as a shape issue.
This weirdness is explictly just a qwerk of tensorflow. The reason why we do this is sometimes the shape of a tensor is actually unknown at construction time, so the only way to know the actual shape at run time is to store the shape as a tensor. It seems super janky but trust me we put it in for a reason.
I'm surprised numpy and jax are the ones giving you the issue. What's your current implementation and how is it breaking for them?
Implementation is pretty similar to flatten but you reshape with the user-specified shape instead of the product of the to-be-flattened dimensions. Like I said, numpy and jax complain because you cannot feed tuples to reshape
for numpy and jax since the numpy backed expects a numpy array https://github.com/google/TensorNetwork/blob/acf77a14d08d8c1e355d02d81c1dc14da7c78d60/tensornetwork/backends/numpy/numpy_backend.py#L35 and will throw AttributeError: 'tuple' object has no attribute 'astype'
.
I think casting to a numpy array in network_components.py
is a bit ugly for what should be backend-agnostic code. Maybe we should add a numpy asarray
before the typecasting to integer? That's what the PyTorch backend does in https://github.com/google/TensorNetwork/blob/acf77a14d08d8c1e355d02d81c1dc14da7c78d60/tensornetwork/backends/pytorch/pytorch_backend.py#L44
Oh I see what you mean now. Yes that is totally an oversight.
Go ahead and change that to np.asarry(shape).astype(...)
.
Good catch btw! I'm surprised we didn't hit that earlier.
Add an option that will split an edge into multiple edges. This plus
flatten_edges
are needed for certain distributed hardware optimization.