Closed emmajane1313 closed 1 year ago
Hi,
unfortunately I don't have a m1 Mac to test it myself, but it seems that the tch
crate supports MPS (https://github.com/LaurentMazare/tch-rs/issues/542).
I think that just setting the device as Device::Mps
in the examples (e.g. here) and running the command you mentioned above (without the --cpu all
command option) might work. I guess you will also need to export these variables (as mentioned in the issue above)
export LIBTORCH=$(python -c 'import torch; from pathlib import Path; print(Path(torch.__file__).parent)')
export DYLD_LIBRARY_PATH=${LIBTORCH}/lib
export LIBTORCH_CXX11_ABI=0
As I said, I can't try it but I hope it works. Alternatively, here you can find a colab notebook to use diffusers-rs with cuda.
I got this working, but it took a few more steps.
Setting those exports is enough to get it to compile and execute, but on the CPU.
Changing let cuda_device = Device::cuda_if_available()
to let cuda_device = Device::Mps
causes it to fail:
Cuda available: false
Cudnn available: false
Running with prompt "A rusty robot holding a fire torch.".
Building the Clip transformer.
Error: Internal torch error: supported devices include CPU, CUDA and HPU, however got MPS
Exception raised from readInstruction at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/serialization/unpickler.cpp:531 (most recent call first):
But there is a workaround suggested upstream (for a bug further upstream), and that works:
diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs
index 98d99e5..477518f 100644
--- a/examples/stable-diffusion/main.rs
+++ b/examples/stable-diffusion/main.rs
@@ -230,7 +230,7 @@ fn run(args: Args) -> anyhow::Result<()> {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size)
}
};
- let cuda_device = Device::cuda_if_available();
+ let cuda_device = Device::Mps;
let cpu_or_cuda = |name: &str| {
if cpu.iter().any(|c| c == "all" || c == name) {
Device::Cpu
diff --git a/src/pipelines/stable_diffusion.rs b/src/pipelines/stable_diffusion.rs
index e5a5813..bb0c65c 100644
--- a/src/pipelines/stable_diffusion.rs
+++ b/src/pipelines/stable_diffusion.rs
@@ -97,10 +97,12 @@ impl StableDiffusionConfig {
vae_weights: &str,
device: Device,
) -> anyhow::Result<vae::AutoEncoderKL> {
- let mut vs_ae = nn::VarStore::new(device);
+ let mut vs_ae = nn::VarStore::new(tch::Device::Mps);
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae.root(), 3, 3, self.autoencoder.clone());
+ vs_ae.set_device(tch::Device::Cpu);
vs_ae.load(vae_weights)?;
+ vs_ae.set_device(tch::Device::Mps);
Ok(autoencoder)
}
@@ -110,10 +112,12 @@ impl StableDiffusionConfig {
device: Device,
in_channels: i64,
) -> anyhow::Result<unet_2d::UNet2DConditionModel> {
- let mut vs_unet = nn::VarStore::new(device);
+ let mut vs_unet = nn::VarStore::new(tch::Device::Mps);
let unet =
unet_2d::UNet2DConditionModel::new(vs_unet.root(), in_channels, 4, self.unet.clone());
+ vs_unet.set_device(tch::Device::Cpu);
vs_unet.load(unet_weights)?;
+ vs_unet.set_device(tch::Device::Mps);
Ok(unet)
}
@@ -126,9 +130,11 @@ impl StableDiffusionConfig {
clip_weights: &str,
device: tch::Device,
) -> anyhow::Result<clip::ClipTextTransformer> {
- let mut vs = tch::nn::VarStore::new(device);
+ let mut vs = tch::nn::VarStore::new(tch::Device::Mps);
let text_model = clip::ClipTextTransformer::new(vs.root(), &self.clip);
+ vs.set_device(tch::Device::Cpu);
vs.load(clip_weights)?;
+ vs.set_device(tch::Device::Mps);
Ok(text_model)
}
}
Obviously hardcoding the device isn't what you'd want to do in the actual project, but it works if you just want to get something working locally.
Looks like tch-rs might set up this workaround in that crate, so you may want to just wait for that to get landed and released.
The mps changes on the tch-rs side have been released (PR-623), I've published a new version of the tch
crate including the fix as well as a new version of the diffusers
crate to use this fixed version.
@LaurentMazare Sweet! Would it be possible to update the logic to default to the MPS device when it's available?
@bakkot sounds like a good idea, could you give a try at the stable-diffusion example using #50 and see if that works well on a device where mps is available? (and that it seems to actually use the device rather than the cpu)
Closing this as the related PR has been merged for a while, feel free to re-open if it's still an issue (I don't have a mac at hand to test).
Is there a way to specify gpu use for m1 macs? It's using my cpu when generating but I have more than 8GB memory, I'm using the command:
cargo run --example stable-diffusion --features clap -- --prompt "A very rusty robot holding a fire torch." --cpu all --sd-version v1-5