LaurentMazare / diffusers-rs

An implementation of the diffusers api in Rust
Apache License 2.0
521 stars 54 forks source link

m1 mac gpu #42

Closed emmajane1313 closed 1 year ago

emmajane1313 commented 1 year ago

Is there a way to specify gpu use for m1 macs? It's using my cpu when generating but I have more than 8GB memory, I'm using the command: cargo run --example stable-diffusion --features clap -- --prompt "A very rusty robot holding a fire torch." --cpu all --sd-version v1-5

mspronesti commented 1 year ago

Hi, unfortunately I don't have a m1 Mac to test it myself, but it seems that the tch crate supports MPS (https://github.com/LaurentMazare/tch-rs/issues/542).

I think that just setting the device as Device::Mps in the examples (e.g. here) and running the command you mentioned above (without the --cpu all command option) might work. I guess you will also need to export these variables (as mentioned in the issue above)

export LIBTORCH=$(python -c 'import torch; from pathlib import Path; print(Path(torch.__file__).parent)')

export DYLD_LIBRARY_PATH=${LIBTORCH}/lib

export LIBTORCH_CXX11_ABI=0

As I said, I can't try it but I hope it works. Alternatively, here you can find a colab notebook to use diffusers-rs with cuda.

bakkot commented 1 year ago

I got this working, but it took a few more steps.

Setting those exports is enough to get it to compile and execute, but on the CPU.

Changing let cuda_device = Device::cuda_if_available() to let cuda_device = Device::Mps causes it to fail:

Cuda available: false
Cudnn available: false
Running with prompt "A rusty robot holding a fire torch.".
Building the Clip transformer.
Error: Internal torch error: supported devices include CPU, CUDA and HPU, however got MPS
Exception raised from readInstruction at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/serialization/unpickler.cpp:531 (most recent call first):

But there is a workaround suggested upstream (for a bug further upstream), and that works:

diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs
index 98d99e5..477518f 100644
--- a/examples/stable-diffusion/main.rs
+++ b/examples/stable-diffusion/main.rs
@@ -230,7 +230,7 @@ fn run(args: Args) -> anyhow::Result<()> {
             stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size)
         }
     };
-    let cuda_device = Device::cuda_if_available();
+    let cuda_device = Device::Mps;
     let cpu_or_cuda = |name: &str| {
         if cpu.iter().any(|c| c == "all" || c == name) {
             Device::Cpu
diff --git a/src/pipelines/stable_diffusion.rs b/src/pipelines/stable_diffusion.rs
index e5a5813..bb0c65c 100644
--- a/src/pipelines/stable_diffusion.rs
+++ b/src/pipelines/stable_diffusion.rs
@@ -97,10 +97,12 @@ impl StableDiffusionConfig {
         vae_weights: &str,
         device: Device,
     ) -> anyhow::Result<vae::AutoEncoderKL> {
-        let mut vs_ae = nn::VarStore::new(device);
+        let mut vs_ae = nn::VarStore::new(tch::Device::Mps);
         // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
         let autoencoder = vae::AutoEncoderKL::new(vs_ae.root(), 3, 3, self.autoencoder.clone());
+        vs_ae.set_device(tch::Device::Cpu);
         vs_ae.load(vae_weights)?;
+        vs_ae.set_device(tch::Device::Mps);
         Ok(autoencoder)
     }

@@ -110,10 +112,12 @@ impl StableDiffusionConfig {
         device: Device,
         in_channels: i64,
     ) -> anyhow::Result<unet_2d::UNet2DConditionModel> {
-        let mut vs_unet = nn::VarStore::new(device);
+        let mut vs_unet = nn::VarStore::new(tch::Device::Mps);
         let unet =
             unet_2d::UNet2DConditionModel::new(vs_unet.root(), in_channels, 4, self.unet.clone());
+        vs_unet.set_device(tch::Device::Cpu);
         vs_unet.load(unet_weights)?;
+        vs_unet.set_device(tch::Device::Mps);
         Ok(unet)
     }

@@ -126,9 +130,11 @@ impl StableDiffusionConfig {
         clip_weights: &str,
         device: tch::Device,
     ) -> anyhow::Result<clip::ClipTextTransformer> {
-        let mut vs = tch::nn::VarStore::new(device);
+        let mut vs = tch::nn::VarStore::new(tch::Device::Mps);
         let text_model = clip::ClipTextTransformer::new(vs.root(), &self.clip);
+        vs.set_device(tch::Device::Cpu);
         vs.load(clip_weights)?;
+        vs.set_device(tch::Device::Mps);
         Ok(text_model)
     }
 }

Obviously hardcoding the device isn't what you'd want to do in the actual project, but it works if you just want to get something working locally.

Looks like tch-rs might set up this workaround in that crate, so you may want to just wait for that to get landed and released.

LaurentMazare commented 1 year ago

The mps changes on the tch-rs side have been released (PR-623), I've published a new version of the tch crate including the fix as well as a new version of the diffusers crate to use this fixed version.

bakkot commented 1 year ago

@LaurentMazare Sweet! Would it be possible to update the logic to default to the MPS device when it's available?

LaurentMazare commented 1 year ago

@bakkot sounds like a good idea, could you give a try at the stable-diffusion example using #50 and see if that works well on a device where mps is available? (and that it seems to actually use the device rather than the cpu)

LaurentMazare commented 1 year ago

Closing this as the related PR has been merged for a while, feel free to re-open if it's still an issue (I don't have a mac at hand to test).