google / TensorNetwork

A library for easy and efficient manipulation of tensor networks.
Apache License 2.0
1.82k stars 359 forks source link

Add split_edge method. #150

Closed chaserileyroberts closed 5 years ago

chaserileyroberts commented 5 years ago

Add an option that will split an edge into multiple edges. This plus flatten_edges are needed for certain distributed hardware optimization.

chaserileyroberts commented 5 years ago

The API for this should look like net.split_edge(edge, shape)

amilsted commented 5 years ago

Could you explain how this will help distributed computations? I would have thought "slice_edge" (direct-sum decomposition of index space) and "concat_edges()" would be more useful for splitting contractions across devices.

chaserileyroberts commented 5 years ago

It's a TPU specific optimization, and I'm not sure I'm allowed to explain the reasoning here. :/

chaserileyroberts commented 5 years ago

Note that this is different from https://github.com/google/TensorNetwork/issues/260 in that split_edge would turn 1 edge into multiple edges. slice_edge would take 1 edge of dimension X and cut it so that it only has dimension Y. The underlying tensors connected to the edge in slice_edge would be sliced, whereas with split_edge they would only be reshaped.

mcbal commented 5 years ago

I'll have a go at this.

mcbal commented 5 years ago

@Thenerdstation What's the convention for handling shapes in the codebase? I've found Tuple[int] being used in the constructor of BaseNode but the flatten_edges code uses backend tensors that then get concatenated to construct the new shape for reshaping. Writing tests for my split_edge implementation (that accepts shape: Tuple[int]) I'm seeing that the reshape wrappers of backends pytorch and tensorflow accept tuples whereas numpy and jax do not (due to the numpy .astype trying to cast the tuple immediately). Should this be made consistent? Or do the backends need helper functions to convert tuples to and from whatever backend tensor is preferred for shapes (like LongTensor for PyTorch)?

chaserileyroberts commented 5 years ago

Aw yes, the dreaded tensor as a shape issue.

This weirdness is explictly just a qwerk of tensorflow. The reason why we do this is sometimes the shape of a tensor is actually unknown at construction time, so the only way to know the actual shape at run time is to store the shape as a tensor. It seems super janky but trust me we put it in for a reason.

I'm surprised numpy and jax are the ones giving you the issue. What's your current implementation and how is it breaking for them?

mcbal commented 5 years ago

Implementation is pretty similar to flatten but you reshape with the user-specified shape instead of the product of the to-be-flattened dimensions. Like I said, numpy and jax complain because you cannot feed tuples to reshape for numpy and jax since the numpy backed expects a numpy array https://github.com/google/TensorNetwork/blob/acf77a14d08d8c1e355d02d81c1dc14da7c78d60/tensornetwork/backends/numpy/numpy_backend.py#L35 and will throw AttributeError: 'tuple' object has no attribute 'astype'.

I think casting to a numpy array in network_components.py is a bit ugly for what should be backend-agnostic code. Maybe we should add a numpy asarray before the typecasting to integer? That's what the PyTorch backend does in https://github.com/google/TensorNetwork/blob/acf77a14d08d8c1e355d02d81c1dc14da7c78d60/tensornetwork/backends/pytorch/pytorch_backend.py#L44

chaserileyroberts commented 5 years ago

Oh I see what you mean now. Yes that is totally an oversight.

Go ahead and change that to np.asarry(shape).astype(...).

Good catch btw! I'm surprised we didn't hit that earlier.