DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.4k stars 130 forks source link

Dynamical system on collections of tensors #162

Open SimonKitSangChu opened 2 years ago

SimonKitSangChu commented 2 years ago

A dynamical system can often be described not by a single tensor but by multiple ones. For example, a system of particles can have node features, edge features, and global features, each with different feature dimensions.

While we might address the issue individually in each project, have there already been efforts to allow a collection of tensors as an input/output? For example,

def f(x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    ...
   return {'node': out_node, 'edge': out_edge, 'global': x_global}

x = {'node': x_node, 'edge': x_edge, 'global': x_global}
model = NeuralODE(f)
out = model(x)

Alternatively, we can concatenate these tensors into a single one and de-concatenate it within f. Things can be non-trivial if the output shapes are different or might be a graph object itself.

Zymrael commented 2 years ago

There have been a few discussions on this particular point: several of our downstream users are interested in GNNs, and we've thus decided to support the heterogenous state case. See for example issue 137.

We should choose a type that is supported by torchscript (WIP 163), likely a NamedTuple of tensors (supported types).

Solver steps should then be modified to work on named tuples.

SimonKitSangChu commented 2 years ago

Thanks. I will close the issue after heterogeneous state is also implemented for adaptive solver. Let me know when it is done.