Open alirezanobakht13 opened 1 year ago
I'm sure we can make this work let me have a look
Do you have a more precise use case? Looking at torch geometric data, there is not real concept of shape or batch of a Data
object.
We could support hosting it in tensordict that have no shape and allow shape operations on lazy stacks only (ie you could not stack Data but they could appear as stacked).
If you have some minimal example of what you want to do that would be really helpful!
One example can be using reinforcement learning to solve TSP. suppose that in each timestep, the remaining graph (a subgraph including nodes that are not added to the TSP sequence yet) would be given as the observation, and the action is selecting a node. this continues until all nodes are added to the sequence.
to achieve this, one approach can be a DQN with GNN as its neural network. the graph is given as input and it passes through the GNN and outputs the Q-value of each node.
You can carry torch geometric data in a tensordict without necessarily carrying the torch_geometric classes
For example, adjacency matrices in troch_geometric are [2, n_edges] tensors. Also the data.batch, data.x, data... are all tensors that carry differen features
you can imagine putting in a tensordict like this
TensorDict{
"graph": {
"pos": Tensor,
"x": Tensor,
"edge_index": Tensor,
"edge_attr": Tensor,
}
}
Yes, each attribute of the Data
object can be a distinct tensor and can be passed to the tensordict. However, there arises a requirement to convert it between these forms throughout the flow of an algorithm. In certain steps, it becomes necessary to utilize PyG's methods on the graph. It may increase computation time (I'm not really aware of the inner mechanism of these two, whether they obtain references or make full copies of the object), and it also requires writing additional code.
We could write functions that transform pytorch_geom Data
classes into tensordicts and vice versa. Without any memeory copying. WDYT?
I think it's ok but I don't know how we can handle a batch of graphs. the non-graph data batch adds an additional dimension to the tensor, but in PyG, it's a big (disconnected) graph (link). so there is no additional dimension, it's just a bigger 1d tensor. there should be a mechanism to deal with this too.
Yeah that is a constant pain with torch_geoemtric. We can write tools that get you a torch_geometric batch from a batched tensor (with actual dimensions)
Here is an example of the one I wrote for rllib https://github.com/proroklab/HetGPPO/blob/main/models/gppo.py#L115 but the principle is the same for any time you pass from dense to sparse tensors
If I got it right, your code supposes that all the graphs have the same number of nodes. but in PyG a batch can contain graphs with different numbers of nodes (I think that's the reason why they can't use an additional dimension for the batch). so in my opinion instead of creating tensors with a dimension for batch, each element in batch should be converted to the tensordict
itself and then all of them stay on another tensordict
.
For example for pyg_batch -> tensordict_batch could be something like:
tensordict_batch = TensorDict(
{f"graph{i}":TensorDict({
'x': pyg_batch[i].x,
'edge_index': pyg_batch[i].edge_index,
# other attributes
}) for i in range(pyg_batch.num_graphs)}
)
Oh i thought your question was about how to port torchrl batches to geometric.
Yeah in the case every batch has a different number of nodes the only solution is to use the original geometric tensors with all the flattened batches in an unbatched tenordict.
The solution above is very slow, in general we never want to iterate on the batch. I would suggest my original solution with the torch_geometric tensors instead
TensorDict{
"graph": {
"pos": Tensor,
"x": Tensor,
"edge_index": Tensor,
"edge_attr": Tensor,
"batch" Tensor
}
}
My point is it would be better if some built-in functions handle these conversions between torch_geometric (Data
and batch
) objects and tensordict
because they both are inside pytorch ecosystem.
I agree it would be nice to have that. I need to wrap my head around how batch-size would play with the dimensionality of those things. As @matteobettini as long as we concatenate things with lazy stacks we're good, but I wonder how that would play in the ecosystem.
Like: does your env return Data
objects? If so:
We would have env.rollout(..., return_contiguous=False)
returns a lazy stack of data and that is fine.
If you want to put it in a replay buffer in contiguous memory, you're gonna get in trouble.
Then: would the loss function read stacks of these objects? If not: how can we make sure that it never happens? If so: how do we stack them?
Let's try to figure this out!
Motivation
I'm trying to create an RL agent which works with graph data structures. but when I pass the torch_geomteric
Data
object toTensorDict
, It saysSolution
I'm not really sure how these two should get compatible, but I think with their current structure, they can't be matched together. something should change in one of these.
Checklist