openai / consistency_models

Official repo for consistency models.
MIT License
6.02k stars 409 forks source link

Fixing PyTorch Random Number Generator State Issue #59

Open AndreasBergmeister opened 3 months ago

AndreasBergmeister commented 3 months ago

When computing the consistency loss, PyTorch's random number generator state is saved before and restored after the target is computed. This is to ensure the same dropout behavior for the teacher and student networks. PyTorch has different random number generator states for CPU and GPU. The current implementation ignores the GPU state. This commit fixes this issue by saving and restoring the GPU state as well.