LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
446 stars 50 forks source link

Update & Rewrite the DDIM example #661

Closed avik-pal closed 1 month ago

avik-pal commented 1 month ago

Currently, it runs in 34 mins on a V100 for 80 epochs. The Keras example says it runs 30 mins on an A100, so we are in the same ballpark as TF2.

The previous code was somewhat slow; on a 3060Ti, it took 53s per epoch, while the updated code with that configuration took about 20-25s per epoch.

Improvements to Lux

  1. Optimisers.adjust! and Optimisers.adjust can be directly applied to TrainState.
  2. StatefulLuxLayer has pretty printing
  3. StatefulLuxLayer is compatible with Adapt, so gpu_device() / cpu_device() can be directly applied to them.

TODOs

codecov[bot] commented 1 month ago

Codecov Report

Attention: Patch coverage is 17.24138% with 24 lines in your changes are missing coverage. Please review.

Project coverage is 86.07%. Comparing base (ca23485) to head (6a19c4f).

:exclamation: Current head 6a19c4f differs from pull request most recent head 03e2591

Please upload reports for the commit 03e2591 to get more accurate results.

Files Patch % Lines
ext/LuxOptimisersExt.jl 0.00% 16 Missing :warning:
src/helpers/stateful.jl 38.46% 8 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #661 +/- ## ========================================== - Coverage 87.17% 86.07% -1.11% ========================================== Files 50 50 Lines 2527 2520 -7 ========================================== - Hits 2203 2169 -34 - Misses 324 351 +27 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.