pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.24k stars 22.45k forks source link

Unify async_save and sync_save in state_dict_saver from distributed checkpointing #127191

Open teja-rao opened 5 months ago

teja-rao commented 5 months ago

🚀 The feature, motivation and pitch

Most pytorch methods use a async_op param to do asynchronous operation and return a union of Future or return type. EG: def broadcast(tensor, src, group=None, async_op=False):

But for checkpointing, we use separate methods async_save and save. The signatures are exactly same except for return type. Having one method with async_op makes the api consistent with rest of pytorch methods and reduce cognitive overhead for developers, allow easy discovery of async save capability and reduce boiler plate code duplication in pytorch and for users.

Most users now have to write something like -

do_async = is_periodic_checkpoint and have_memory 
if do_async:
   result = async_save(
    state_dict=sd,
    *,
    checkpoint_id= cp_id,
    storage_writer= storage_writer_instance,
    planner= planner_instance,
    process_group= pg,
    async_op = do_async
)
    return result      # returning a future or ignore it
else: 
  result = save(
    state_dict=sd,
    *,
    checkpoint_id= cp_id,
    storage_writer= storage_writer_instance,
    planner= planner_instance,
    process_group= pg,
    async_op = do_async
)
  return result  

Using one method will allow flexible and simpler code -

do_async = is_periodic_checkpoint and have_memory 
result = save(
    state_dict=sd,
    *,
    checkpoint_id= cp_id,
    storage_writer= storage_writer_instance,
    planner= planner_instance,
    process_group= pg,
    async_op = do_async
)
return result # return 

or

do_async = is_periodic_checkpoint and have_memory 

result = save(
    state_dict=sd,
    *,
    checkpoint_id= cp_id,
    storage_writer= storage_writer_instance,
    planner= planner_instance,
    process_group= pg,
    async_op = do_async # or simply use True here
)

return result.wait() if do_async else result # when using async, the user probably returns a future or ignores it or waits on the Future only on shutdown of job.

I believe checkpointing should adopt the goal of supporting training at scale as this is critical for internal consolidation of checkpointing and building community around checkpointing. In large scale training, async checkpointing is more widely used and is probably the main API for checkpointing. so consolidating these methods would simplify integration greatly.

Alternatives

No response

Additional context

No response

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

kwen2501 commented 4 months ago

@LucasLLC any comment?

LucasLLC commented 4 months ago

This was a large contention when we were reviewing the design for async_save. Essentially we went back and forth on this issue, and ended up leaning towards the separate functions as a safegaurd against additional parameters which may be async-relevent only.

Fwiw, I agree with @kirtiteja