keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
803 stars 243 forks source link

[Flux] Port Flux Core Model #1864

Closed DavidLandup0 closed 1 week ago

DavidLandup0 commented 2 months ago

This PR ports the core model into a Keras model and includes a weight conversion script. VAE and rest of the pipeline would make sense in a separate PR.

Each layer is numerically compared against the original PyTorch implementation here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k

Modules included:

Output Comparison

The core model's outputs are latents. We plot the PCA of the output from the original implementation and the Keras re-implementation on the same input:

image

Numerically, equivalent to 1e-3 precision:

>>> np.allclose(output_keras.numpy(), output_pt.detach().numpy(), atol=1e-3)
True
DavidLandup0 commented 1 month ago

@divyashreepathihalli thank you for the comments! Addressed most of them, left WIP:

james77777778 commented 1 month ago

Let me know if any issues arise when adopting the SD3 pattern😃 I'm unsure if the newly introduced tasks are fully compatible with Flux.

DavidLandup0 commented 1 month ago

@divyashreepathihalli turned into a functional subclassing module - had to wrestle a bit with shapes/autograph, but it should be ready for another review.

Here's the notebook showing numerical equivalence to atol=1e-5 on all modules, as well as the final output of the core model: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=Bi_WbOjk7C4k

Adding a preprocessing flow and we can open a PR for integrating with T5 and CLIP.

DavidLandup0 commented 4 weeks ago

@divyashreepathihalli could we do another review here?

DavidLandup0 commented 4 weeks ago

Thanks David! left a few comments. Do you have a demo colab to verify the outputs?

Yes - here: https://colab.research.google.com/drive/1Jr5pa9BGAxP6lZPimlpb22rD5DMijN3H#scrollTo=_ys5NSkcoQ_O

With converted weights (in the Colab as well), we get identical outputs between the official model and the port, within 1e-3 sensitivity: image