anthony-nouy / tensap

Tensor Approximation Package: a Python package for the approximation of functions and tensors.
https://anthony-nouy.github.io/sphinx/tensap/master/
Other
20 stars 11 forks source link

'eval_diag' with non-empty 'dims' is not implemented for TreeBasedTensor #38

Open chmda opened 1 year ago

chmda commented 1 year ago

An exception is raised in tensap.TreeBasedTensor.eval_diag when trying to call tensap.TreeBasedTensor.eval_at_indices with arguments indices of shape (N, M) and dims of shape (M,), with M > 1.

Steps to reproduce

import tensap

dim_tree = tensap.DimensionTree.balanced(5)
tensor = tensap.TreeBasedTensor.randn(dim_tree, shape=(100,100,100,100,100))
indices = np.random.randint(100, size=(100, 2))
dims = [0,1]
tensor.eval_at_indices(indices,dims)

In the implementation of tensap.TreeBasedTensor.eval_at_indices, it will call tensap.TreeBasedTensor.eval_diag with the argument dims which is not None. However, eval_diag with non-empty dims is not implemented yet, see tree_based_tensor.py#L1199.

ptrunschke commented 1 year ago

Here is a first implementation that could solve this problem.

def eval_diag(tensor_tree: tensap.TreeBasedTensor, dims: list[int] = None) -> tensap.FullTensor:
    if dims is None:
        dims = set(range(tensor_tree.order))
    else:
        dims = set(dims)

    tree = tensor_tree.tree
    tensors = [full_tensor.numpy() for full_tensor in tensor_tree.tensors]

    if len(tree.children(tree.root)) == tensors[tree.root - 1].ndim:
        tensors[tree.root - 1] = tensors[tree.root - 1][..., None]
    assert tensors[tree.root - 1].ndim == len(tree.children(tree.root)) + 1

    diag_size = min(tensor_tree.shape[dim] for dim in dims)

    for level in reversed(range(max(tree.level) + 1)):
        for node in tree.nodes_with_level(level):
            children = tree.children(node)

            if tree.is_leaf[node - 1]:
                assert tensors[node - 1].ndim == 2
                assert len(tree.dims[node - 1]) == 1
                diag_indices = [0] if tree.dims[node - 1][0] in dims else []
            else:
                for child in children:
                    if not tensor_tree.is_active_node[child - 1]:
                        assert len(tree.dims[child - 1][0]) == 1
                diag_indices = [
                    index
                    for index, child in enumerate(children)
                    if not tensor_tree.is_active_node[child - 1] and tree.dims[child - 1][0] in dims
                ]
            if len(diag_indices) == 0:
                tensors[node - 1] = tensors[node - 1][None]
            else:
                tensors[node - 1] = np.moveaxis(tensors[node - 1], diag_indices, np.arange(len(diag_indices)))
                tensors[node - 1] = tensors[node - 1][(np.arange(diag_size),) * len(diag_indices)]

            for child in children:
                if tensor_tree.is_active_node[child - 1]:
                    # tensors[node - 1] = np.einsum("ij...r, i...j -> i...r", tensors[node - 1], tensors[child - 1])
                    node_order = tensors[node - 1].ndim
                    child_order = tensors[child - 1].ndim
                    node_indices = list(range(node_order))
                    child_indices = [0] + list(range(node_order, node_order + child_order - 2)) + [1]
                    out_indices = [0] + node_indices[2:-1] + child_indices[1:-1] + [node_indices[-1]]
                    tensors[node - 1] = np.einsum(
                        tensors[node - 1], node_indices, tensors[child - 1], child_indices, out_indices
                    )
                else:
                    # tensors[node - 1] = np.einsum("ij...r -> i...jr", tensors[node - 1])
                    tensors[node - 1] = np.moveaxis(tensors[node - 1], 1, -2)

    if len(tree.children(tree.root)) == tensor_tree.tensors[tree.root - 1].ndim:
        tensors[tree.root - 1] = tensors[tree.root - 1][..., 0]

    return tensap.FullTensor(tensors[tree.root - 1])

A drawback of this implementation is that it will always return a full tensor. It would hence be sensible to apply this algorithm automatically only to the sub-tree of the node tree.node_with_dims(dims).

Note: I am not sure whether my interpretation of tree.is_leaf and tensor_tree.is_active_node is correct. This is why I added the assert statements.

anthony-nouy commented 1 year ago

Thank you Philipp. A version has been pushed for the case where dims corresponds to a node of the tree. Your version could be merged later if we need this specific case. The fact that it returns a FullTensor limits its use to low-dimenionsal problems, or when we evaluate most of the dimensions.