oracle / sd4j

Stable diffusion pipeline in Java using ONNX Runtime
Universal Permissive License v1.0
155 stars 14 forks source link

Instructions to enable Core ML with Apple Silicon #6

Open jasonwjones opened 9 months ago

jasonwjones commented 9 months ago

The README indicates that Core ML is ostensibly supported, but it seems like (and I could be wrong here) that a different onnxruntime would be needed (and not the onnxruntime_gpu as that requires CUDA, which won't work on an ARM Mac). I don't see a Maven dependency for something like "onnxruntime_silicon". Are there steps to enable Core ML on an M1+ Mac?

Thank you

Craigacp commented 9 months ago

You need to compile ONNX Runtime from source with ./build.sh --update --config Release --build --build_java --use_coreml --parallel. However at the moment it doesn't make it run any faster with the default ONNX export as the shapes are all undefined.

I think CoreML would work better if the shapes were defined in the ONNX models, and/or if you converted the UNet into fp16, but I've not added fp16 support to sd4j yet as that only arrived in ORT Java 1.16. Fp16 support will likely come when I upgrade sd4j to ORT 1.17 in the new year. Fixing the shapes would mean the model can only generate images of a specific size, which we could either do with an override & reloading the models or by changing the ONNX models themselves, but the UX is kinda bad either way so I didn't settle on a solution for that.

Craigacp commented 9 months ago

I'll fix the readme text to make it clear that the CoreML & DirectML support is more experimental, as I think it needs model modifications to get the performance improvement. Running with CoreML enabled isn't placing a lot of nodes on the CoreML device with 1.16, though it looks like there are improvements which should be in the next release of ONNX Runtime.

Craigacp commented 9 months ago

I've updated the readme to note that CoreML is experimental. I also just merged in support for SD-Turbo which can generate images in a few diffusion steps, so that's a way to get faster images generated.

jasonwjones commented 9 months ago

That helps, thanks. I'll definitely take a look at SD-Turbo. I'm on an M1 Max so basically the idea was to try and speed things up. Since you're on an Intel Mac: do you know if CUDA works via an eGPU?

Craigacp commented 9 months ago

I've not tried eGPU support on macOS, I tested the CUDA parts on Linux.

Craigacp commented 9 months ago

Fixing some of the free dimensions does move more ops to the CoreML execution provider, but not all of them with ORT v1.16.3, so on my regular M1 it's still slower than CPU only. You might get a better answer with an M1 Max as the GPU is much better, but I don't have access to one to test it.

I added these lines to the function which constructs the CoreML session options if you're interested in testing it out, but you might need to use the default opts for the VAE as I don't know if the names collide:

opts.setSymbolicDimensionValue("channels", 4);
opts.setSymbolicDimensionValue("batch", 1);
opts.setSymbolicDimensionValue("height", 64);
opts.setSymbolicDimensionValue("width", 64);
opts.setSymbolicDimensionValue("sequence", 77);

This does limit you to only 512x512 images for a given run of the program (as it operates in a latent space 8x smaller than pixel space so 512px = 64 latent dimensions). If it gives a clear speedup on some hardware then I can look at modifying the way the image size is selected to trigger model reloading and also fixing the way the session options are constructed to make it a little easier to play with, but the UX would be kinda annoying as reloading that UNet takes time.

This has pointed out to me that I've not exposed the dimension names in the input and output fields of a model in ORT Java, I had to go read the protobuf in a terminal to get them, so I'll fix that upstream at some point.

Craigacp commented 4 months ago

I retested it with the new CoreML support in ORT 1.18.0 and unfortunately that still doesn't support many of the nodes used in this. Still needs more work inside ORT.