echowve / meshGraphNets_pytorch

PyTorch implementations of Learning Mesh-based Simulation With Graph Networks
Apache License 2.0
152 stars 33 forks source link

add prefetch and accumulation steps according to original repo #2

Closed FishWoWater closed 2 years ago

FishWoWater commented 2 years ago

When I was training a cloth simulation network using your implementation, it was much slower than original tensorflow implementation. The bottleneck was the CPU I/O operation. On my task, I/O takes 2s for each batch while forward+backward pass takes only 0.2s, so I add a flag prefetch_factor according to https://github.com/deepmind/deepmind-research/blob/1642ae3499c8d1135ec6fe620a68911091dd25ef/meshgraphnets/dataset.py#L55

Besides, I add a flag for the accumulation of batch statistics, as in https://github.com/deepmind/deepmind-research/blob/1642ae3499c8d1135ec6fe620a68911091dd25ef/meshgraphnets/run_model.py#L72

echowve commented 2 years ago

Thank you!

FishWoWater commented 2 years ago

@echowve Hi! With prefetching the pytorch version is still slower than tensorflow version(4x slower), I have one suggestion: read one sequence into the memory instead of reading each frame separately. You can maintain a dict self.tra_data[tra_index] = {}. Frequently random access the h5 file is very time-consuming.

Actually I have implemented that for my task and after that its speed is comparable to the tensorflow version. I am not sure whether it works with cfd problem.

Thanks for your great repo

echowve commented 2 years ago

@FishWoWater Hi, I tested the dataset class on cfd problem and found that the loading process from h5 file costs about 3ms per sample. Also, maintaining self.tra_data[tra_index] seems no changes. The newer version is uploaded and hoping you continue to make suggestions. Thanks!