Closed eladmw closed 3 years ago
Hmm, this error appears to be a PyTorch error, not a PFRL error (especially given the last line). Could you post more of the stack trace (including the part of the PFRL code used) so we can see how this is called in PFRL?
Yes, looks like the error happens because the default batch_states
calls Pytorch's default_collate
function here which can't handle the dgl.heterograph.DGLHeteroGraph
type, you will need to write your own batch_states function which can then be given to the agents in its constructor
I know I still need to improve my kernel and the NN structure, but I wanted to check dataloading first. This has the full trace: https://www.kaggle.com/eladwar/dgl-football1
Looking at the full trace, I think you should proceed with @marioyc's comment. You would define your own def custom_batch_states()
and pass it into the A2C agent: agent = pfrl.agents.A2C(batch_states=custom_batch_states)
Yes, looks like the error happens because the default
batch_states
calls Pytorch'sdefault_collate
function here which can't handle thedgl.heterograph.DGLHeteroGraph
type, you will need to write your own batch_states function which can then be given to the agents in its constructor
Thank you for your suggestion
I am currently trying to use PFRL with deep graph neural networks, but the dataloading produces errors because the type is not accepted. This looks like a pytorch problem, but it would be nice to get some input. Do you think that this problem is fixable?
The Error is :/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py in(.0)
82 raise RuntimeError('each element in list of batch should be of equal size')
83 transposed = zip(*batch)
---> 84 return [default_collate(samples) for samples in transposed]
85
86 raise TypeError(default_collate_err_msg_format.format(elem_type))
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch) 84 return [default_collate(samples) for samples in transposed] 85 ---> 86 raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'dgl.heterograph.DGLHeteroGraph'>