NVIDIA / modulus-sym

Framework providing pythonic APIs, algorithms and utilities to be used with Modulus core to physics inform model training as well as higher level abstraction for domain experts
https://developer.nvidia.com/modulus
Apache License 2.0
138 stars 56 forks source link

Add a fix for gradient aggregation #82

Closed ktangsali closed 8 months ago

ktangsali commented 8 months ago

Modulus Pull Request

Description

The Gradient aggregation was not functioning correctly, because it was computing the losses on the same batch as opposed to different batches that is needed for Gradient aggregation. This PR adds a fix that enables use of Gradient aggregation for cases without CUDA Graphs.

Closes #51

Annular ring case now works as expected with Gradient aggregation.

Pink: Baseline, Blue: 0.1x batch size, Red: 0.1x batch size + 10 gradient aggregation steps.

image

Checklist

Dependencies

ktangsali commented 8 months ago

/blossom-ci

ktangsali commented 8 months ago

/blossom-ci

ktangsali commented 8 months ago

/blossom-ci