Open chmda opened 1 year ago
Here is a first implementation that could solve this problem.
def eval_diag(tensor_tree: tensap.TreeBasedTensor, dims: list[int] = None) -> tensap.FullTensor:
if dims is None:
dims = set(range(tensor_tree.order))
else:
dims = set(dims)
tree = tensor_tree.tree
tensors = [full_tensor.numpy() for full_tensor in tensor_tree.tensors]
if len(tree.children(tree.root)) == tensors[tree.root - 1].ndim:
tensors[tree.root - 1] = tensors[tree.root - 1][..., None]
assert tensors[tree.root - 1].ndim == len(tree.children(tree.root)) + 1
diag_size = min(tensor_tree.shape[dim] for dim in dims)
for level in reversed(range(max(tree.level) + 1)):
for node in tree.nodes_with_level(level):
children = tree.children(node)
if tree.is_leaf[node - 1]:
assert tensors[node - 1].ndim == 2
assert len(tree.dims[node - 1]) == 1
diag_indices = [0] if tree.dims[node - 1][0] in dims else []
else:
for child in children:
if not tensor_tree.is_active_node[child - 1]:
assert len(tree.dims[child - 1][0]) == 1
diag_indices = [
index
for index, child in enumerate(children)
if not tensor_tree.is_active_node[child - 1] and tree.dims[child - 1][0] in dims
]
if len(diag_indices) == 0:
tensors[node - 1] = tensors[node - 1][None]
else:
tensors[node - 1] = np.moveaxis(tensors[node - 1], diag_indices, np.arange(len(diag_indices)))
tensors[node - 1] = tensors[node - 1][(np.arange(diag_size),) * len(diag_indices)]
for child in children:
if tensor_tree.is_active_node[child - 1]:
# tensors[node - 1] = np.einsum("ij...r, i...j -> i...r", tensors[node - 1], tensors[child - 1])
node_order = tensors[node - 1].ndim
child_order = tensors[child - 1].ndim
node_indices = list(range(node_order))
child_indices = [0] + list(range(node_order, node_order + child_order - 2)) + [1]
out_indices = [0] + node_indices[2:-1] + child_indices[1:-1] + [node_indices[-1]]
tensors[node - 1] = np.einsum(
tensors[node - 1], node_indices, tensors[child - 1], child_indices, out_indices
)
else:
# tensors[node - 1] = np.einsum("ij...r -> i...jr", tensors[node - 1])
tensors[node - 1] = np.moveaxis(tensors[node - 1], 1, -2)
if len(tree.children(tree.root)) == tensor_tree.tensors[tree.root - 1].ndim:
tensors[tree.root - 1] = tensors[tree.root - 1][..., 0]
return tensap.FullTensor(tensors[tree.root - 1])
A drawback of this implementation is that it will always return a full tensor. It would hence be sensible to apply this algorithm automatically only to the sub-tree of the node tree.node_with_dims(dims)
.
Note: I am not sure whether my interpretation of tree.is_leaf
and tensor_tree.is_active_node
is correct.
This is why I added the assert statements.
Thank you Philipp. A version has been pushed for the case where dims corresponds to a node of the tree. Your version could be merged later if we need this specific case. The fact that it returns a FullTensor limits its use to low-dimenionsal problems, or when we evaluate most of the dimensions.
An exception is raised in
tensap.TreeBasedTensor.eval_diag
when trying to calltensap.TreeBasedTensor.eval_at_indices
with argumentsindices
of shape(N, M)
anddims
of shape(M,)
, withM > 1
.Steps to reproduce
In the implementation of
tensap.TreeBasedTensor.eval_at_indices
, it will calltensap.TreeBasedTensor.eval_diag
with the argumentdims
which is not None. However,eval_diag
with non-emptydims
is not implemented yet, see tree_based_tensor.py#L1199.