pfnet / pfrl

PFRL: a PyTorch-based deep reinforcement learning library
MIT License
1.18k stars 158 forks source link

Integration of PFRL with Deep Graph Library #85

Closed eladmw closed 3 years ago

eladmw commented 3 years ago

I am currently trying to use PFRL with deep graph neural networks, but the dataloading produces errors because the type is not accepted. This looks like a pytorch problem, but it would be nice to get some input. Do you think that this problem is fixable?

The Error is :/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py in (.0) 82 raise RuntimeError('each element in list of batch should be of equal size') 83 transposed = zip(*batch) ---> 84 return [default_collate(samples) for samples in transposed] 85 86 raise TypeError(default_collate_err_msg_format.format(elem_type))

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch) 84 return [default_collate(samples) for samples in transposed] 85 ---> 86 raise TypeError(default_collate_err_msg_format.format(elem_type))

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'dgl.heterograph.DGLHeteroGraph'>

prabhatnagarajan commented 3 years ago

Hmm, this error appears to be a PyTorch error, not a PFRL error (especially given the last line). Could you post more of the stack trace (including the part of the PFRL code used) so we can see how this is called in PFRL?

marioyc commented 3 years ago

Yes, looks like the error happens because the default batch_states calls Pytorch's default_collate function here which can't handle the dgl.heterograph.DGLHeteroGraph type, you will need to write your own batch_states function which can then be given to the agents in its constructor

eladmw commented 3 years ago

I know I still need to improve my kernel and the NN structure, but I wanted to check dataloading first. This has the full trace: https://www.kaggle.com/eladwar/dgl-football1

prabhatnagarajan commented 3 years ago

Looking at the full trace, I think you should proceed with @marioyc's comment. You would define your own def custom_batch_states() and pass it into the A2C agent: agent = pfrl.agents.A2C(batch_states=custom_batch_states)

eladmw commented 3 years ago

Yes, looks like the error happens because the default batch_states calls Pytorch's default_collate function here which can't handle the dgl.heterograph.DGLHeteroGraph type, you will need to write your own batch_states function which can then be given to the agents in its constructor

Thank you for your suggestion