Git Product home page Git Product logo

Comments (6)

LaurentMazare avatar LaurentMazare commented on June 26, 2024 1

The mps changes on the tch-rs side have been released (PR-623), I've published a new version of the tch crate including the fix as well as a new version of the diffusers crate to use this fixed version.

from diffusers-rs.

mspronesti avatar mspronesti commented on June 26, 2024

Hi,
unfortunately I don't have a m1 Mac to test it myself, but it seems that the tch crate supports MPS (LaurentMazare/tch-rs#542).

I think that just setting the device as Device::Mps in the examples (e.g. here) and running the command you mentioned above (without the --cpu all command option) might work. I guess you will also need to export these variables (as mentioned in the issue above)

export LIBTORCH=$(python -c 'import torch; from pathlib import Path; print(Path(torch.__file__).parent)')

export DYLD_LIBRARY_PATH=${LIBTORCH}/lib

export LIBTORCH_CXX11_ABI=0

As I said, I can't try it but I hope it works. Alternatively, here you can find a colab notebook to use diffusers-rs with cuda.

from diffusers-rs.

bakkot avatar bakkot commented on June 26, 2024

I got this working, but it took a few more steps.

Setting those exports is enough to get it to compile and execute, but on the CPU.

Changing let cuda_device = Device::cuda_if_available() to let cuda_device = Device::Mps causes it to fail:

Cuda available: false
Cudnn available: false
Running with prompt "A rusty robot holding a fire torch.".
Building the Clip transformer.
Error: Internal torch error: supported devices include CPU, CUDA and HPU, however got MPS
Exception raised from readInstruction at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/serialization/unpickler.cpp:531 (most recent call first):

But there is a workaround suggested upstream (for a bug further upstream), and that works:

diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs
index 98d99e5..477518f 100644
--- a/examples/stable-diffusion/main.rs
+++ b/examples/stable-diffusion/main.rs
@@ -230,7 +230,7 @@ fn run(args: Args) -> anyhow::Result<()> {
             stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size)
         }
     };
-    let cuda_device = Device::cuda_if_available();
+    let cuda_device = Device::Mps;
     let cpu_or_cuda = |name: &str| {
         if cpu.iter().any(|c| c == "all" || c == name) {
             Device::Cpu
diff --git a/src/pipelines/stable_diffusion.rs b/src/pipelines/stable_diffusion.rs
index e5a5813..bb0c65c 100644
--- a/src/pipelines/stable_diffusion.rs
+++ b/src/pipelines/stable_diffusion.rs
@@ -97,10 +97,12 @@ impl StableDiffusionConfig {
         vae_weights: &str,
         device: Device,
     ) -> anyhow::Result<vae::AutoEncoderKL> {
-        let mut vs_ae = nn::VarStore::new(device);
+        let mut vs_ae = nn::VarStore::new(tch::Device::Mps);
         // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
         let autoencoder = vae::AutoEncoderKL::new(vs_ae.root(), 3, 3, self.autoencoder.clone());
+        vs_ae.set_device(tch::Device::Cpu);
         vs_ae.load(vae_weights)?;
+        vs_ae.set_device(tch::Device::Mps);
         Ok(autoencoder)
     }
 
@@ -110,10 +112,12 @@ impl StableDiffusionConfig {
         device: Device,
         in_channels: i64,
     ) -> anyhow::Result<unet_2d::UNet2DConditionModel> {
-        let mut vs_unet = nn::VarStore::new(device);
+        let mut vs_unet = nn::VarStore::new(tch::Device::Mps);
         let unet =
             unet_2d::UNet2DConditionModel::new(vs_unet.root(), in_channels, 4, self.unet.clone());
+        vs_unet.set_device(tch::Device::Cpu);
         vs_unet.load(unet_weights)?;
+        vs_unet.set_device(tch::Device::Mps);
         Ok(unet)
     }
 
@@ -126,9 +130,11 @@ impl StableDiffusionConfig {
         clip_weights: &str,
         device: tch::Device,
     ) -> anyhow::Result<clip::ClipTextTransformer> {
-        let mut vs = tch::nn::VarStore::new(device);
+        let mut vs = tch::nn::VarStore::new(tch::Device::Mps);
         let text_model = clip::ClipTextTransformer::new(vs.root(), &self.clip);
+        vs.set_device(tch::Device::Cpu);
         vs.load(clip_weights)?;
+        vs.set_device(tch::Device::Mps);
         Ok(text_model)
     }
 }

Obviously hardcoding the device isn't what you'd want to do in the actual project, but it works if you just want to get something working locally.

Looks like tch-rs might set up this workaround in that crate, so you may want to just wait for that to get landed and released.

from diffusers-rs.

bakkot avatar bakkot commented on June 26, 2024

@LaurentMazare Sweet! Would it be possible to update the logic to default to the MPS device when it's available?

from diffusers-rs.

LaurentMazare avatar LaurentMazare commented on June 26, 2024

@bakkot sounds like a good idea, could you give a try at the stable-diffusion example using #50 and see if that works well on a device where mps is available? (and that it seems to actually use the device rather than the cpu)

from diffusers-rs.

LaurentMazare avatar LaurentMazare commented on June 26, 2024

Closing this as the related PR has been merged for a while, feel free to re-open if it's still an issue (I don't have a mac at hand to test).

from diffusers-rs.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.