pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.92k stars 3.61k forks source link

Learning rewards using graph neural networks #3457

Open akpas001 opened 2 years ago

akpas001 commented 2 years ago

Discussed in https://github.com/pyg-team/pytorch_geometric/discussions/3455

Originally posted by **akpas001** November 8, 2021 I need to train a Gnn based on the data of pre existing next_states and rewards. next states are one hot encoded and rewards are random numbers(can be negative as well as float). I have constructed a gnn using GINConv. According to Markov's decision process in Reinforcement learning, a state_action pair as an input to the network gives the next_state and reward. (state, action, next_states are one-hot encoded). So, i am forming a state action pair inside the network and trying to calculate next state and reward loss. the loss graph for next states is good but the graph for rewards is very bad(rewards are not learning). I donot understand what mistake i am doing here. Can somebody help me? Here I am attaching the snippets of the code, and graphs of next_state loss and reward loss respectively. Can somebody help me with that? ![1](https://user-images.githubusercontent.com/56883247/140769475-d950a306-9e82-4486-a886-d1611d147c80.PNG) ![2](https://user-images.githubusercontent.com/56883247/140769479-79a5fb26-b531-470c-ad9d-275b716722a1.PNG) ![3](https://user-images.githubusercontent.com/56883247/140769481-e98755f5-81ba-4dec-aa06-bb2b11e0e657.PNG) ![4](https://user-images.githubusercontent.com/56883247/140769483-47bcec0e-ec11-4818-9f9e-ba1c5ea3d963.PNG)
rusty1s commented 2 years ago

I'm not sure I understand your action for-loop (the for-loop around i seems to be unnecessary as far as I can tell). Otherwise, the code looks good to me. How does the model perform when you replace the GINConv layers with simple MLPs?

akpas001 commented 2 years ago

So, what i am trying to do with action for loop is,

I am basically trying to append actions to node features to every node, (i have 9 nodes in total), so my x basically is (batch_size9,30). (num_features = 30), and i am reconstructing it to(batch_size9,10), then appending actions to it which is again making it (batch_size*9,40).

Since i have 9 nodes i need to repeat the same action 9 times in order to append it to every node from the same set of nodes.

I haven't checked the performance of simple MLPs yet. I thought this to be a coding mistake and was trying to figure out the reason for it.

rusty1s commented 2 years ago

So the shape of x is [batch_size, num_nodes, num_features] or [batch_size * num_nodes, num_features]? In case of the latter, you will also need to replicate edge_index, as it currently only points to indices to the first 9 nodes.

akpas001 commented 2 years ago

The latter. I processed the data through dataloader, so all of that had been taken care of.

rusty1s commented 2 years ago

How does the final act look like? Is that a one-hot encoded vector?

akpas001 commented 2 years ago

yes, it is a one hot encoded vector. I am attaching an image for your reference. Size is (288,30) because, batch size = 32 and num_nodes = 9. now this gets appended along dim = 1 for every node. 1

rusty1s commented 2 years ago

Thanks. One thing I don't understand in your network is the flatten() call. Shouldn't this be a global_mean_pool?

akpas001 commented 2 years ago

even i'm not sure about that, I thought of trying both.

akpas001 commented 2 years ago

but the network is giving a loss of 0.3 only... i still do not understand the mistakes i am making here

rusty1s commented 2 years ago

How does your new network look like? In the current one, you have a Linear layer with input channels batch_size * 360, which is quite weird as the number of parameters should never depend on the batch size :)

akpas001 commented 2 years ago

I don't think that is an issue because I am using fc1 for getting one of the outputs. and the loss doesn't improve even in that case.

akpas001 commented 2 years ago

Regarding the network i am using almost the same network, except I rewrote the network into a GraphAutoEncoder. 1.txt here is the file for your reference

rusty1s commented 2 years ago

You are right. As far as I can see, there isn't a mistake in computing x1 although the computation of x2 might be still worth to fix. I'm sorry that I cannot be of more help :(

akpas001 commented 2 years ago

Thanks for the help!