taichi-dev / taichi

Productive, portable, and performant GPU programming in Python.
https://taichi-lang.org
Apache License 2.0
25.47k stars 2.28k forks source link

A Question about Checkpointing #8325

Open DoTulip opened 1 year ago

DoTulip commented 1 year ago

Dear Taichi Development Team,

I am a Ph.D. student from China and a beginner in Taichi. I am currently studying the paper titled “DIFFTAICHI: DIFFERENTIABLE PROGRAMMING FOR PHYSICAL SIMULATION” published by the Taichi team at ICLR 2020. This paper has opened the door to differentiable physics for me, and I have gained a lot from it! However, as I am still relatively new to this field, I have some doubts regarding certain details, and I am reaching out to seek your guidance.

In the paper, two methods for Checkpointing are proposed: Recomputation within time steps and Segment-Wise recomputation. While I understand the first method, I am a bit confused about the second method. If we divide the entire simulation process into 1000 time steps with 5 Checkpoints, meaning each segment comprises 200 time steps, do we need to recompute these 200 time steps during backpropagation? This question has been puzzling me, and I was wondering if there are any code examples available in Taichi that could help me learn and refer to?

Thank you for your time, and I apologize for any inconvenience caused. I look forward to your response.

ailzhang commented 1 year ago

cc: @erizmr

erizmr commented 1 year ago

Hi @DoTulip , yes, during the backpropagation, the recompuation of the 200 time steps are required as the checkpoints only store a snapshot of one time step. There is an example https://github.com/taichi-dev/difftaichi/blob/master/examples/diffmpm_checkpointing.py show that the grid related variables are recomputed during every time step of the backpropagation.

DoTulip commented 1 year ago

Thank you for your reply! The example you mentioned should only use one method of Checkpointing. Is there an example that both methods are used? @erizmr