atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.27k stars 103 forks source link

Fixed `global_step` in `train_cifar10_ddp.py` #144

Closed Xiaoming-Zhao closed 1 week ago

Xiaoming-Zhao commented 1 week ago

What does this PR do?

The current gloabl_step in train_cifar10_ddp.py is not correct. The global step should only increase one at a time instead of accumulating current step number.

Before submitting

Did you have fun?

Make sure you had fun coding 🙃

ImahnShekhzadeh commented 1 week ago

Yes, it's correct that the script train_cifar10_ddp.py currently does not handle the training loop correctly. In this context, I'll point out the following post of mine: https://github.com/atong01/conditional-flow-matching/pull/116#discussion_r1695722539

Can you explain why global_step += 1 is correct? I still think that switching from steps to epochs would be easiest (@kilianFatras).

Xiaoming-Zhao commented 1 week ago

Thanks for the pointer.

I was blindly running the script to check whether I could reproduce the results but realized the saved checkpoints did not have the correct iteration indicator.

Switching from steps to epochs sounds good to me. For the change I gave, I was mainly trying to comply with the tradition current repo had and used step as a training progress indicator.

I also noticed that the use sampler.set_epoch(epoch). Based on my previous experience, this is crucial to ensure randomness across epochs. However, with he current generator provided by infiniteloop, I am not sure whether the set_epoch will actually affect the dataloader , I need to double check.

But I think it is easy to change from datalooper to dataloader to ensure randomness. It merely changes the following https://github.com/atong01/conditional-flow-matching/blob/72ae2fdcd8dfb8e421f0847fee109f7eb0f0c909/examples/images/cifar10/train_cifar10_ddp.py#L165

to for batch in tqdm.tqdm(dataloader, total=len(dataloader):. And in this way, I am sure that the sampler.set_epoch will work as expected.

Xiaoming-Zhao commented 1 week ago

@ImahnShekhzadeh Added a working example in #145

atong01 commented 1 week ago

Closing. Superseded by #145