Closed DavidLandup0 closed 1 week ago
@divyashreepathihalli thank you for the comments! Addressed most of them, left WIP:
Backbone
Let me know if any issues arise when adopting the SD3 pattern😃 I'm unsure if the newly introduced tasks are fully compatible with Flux.
@divyashreepathihalli turned into a functional subclassing module - had to wrestle a bit with shapes/autograph, but it should be ready for another review.
Here's the notebook showing numerical equivalence to atol=1e-5 on all modules, as well as the final output of the core model: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k
Adding a preprocessing flow and we can open a PR for integrating with T5 and CLIP.
@divyashreepathihalli could we do another review here?
Thanks David! left a few comments. Do you have a demo colab to verify the outputs?
Yes - here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=_ys5NSkcoQ_O
With converted weights (in the Colab as well), we get identical outputs between the official model and the port, within 1e-3 sensitivity:
This PR ports the core model into a Keras model and includes a weight conversion script. VAE and rest of the pipeline would make sense in a separate PR.
Each layer is numerically compared against the original PyTorch implementation here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k
Modules included:
Output Comparison
The core model's outputs are latents. We plot the PCA of the output from the original implementation and the Keras re-implementation on the same input:
Numerically, equivalent to 1e-3 precision: