Comments (6)
The mps changes on the tch-rs side have been released (PR-623), I've published a new version of the tch
crate including the fix as well as a new version of the diffusers
crate to use this fixed version.
from diffusers-rs.
Hi,
unfortunately I don't have a m1 Mac to test it myself, but it seems that the tch
crate supports MPS (LaurentMazare/tch-rs#542).
I think that just setting the device as Device::Mps
in the examples (e.g. here) and running the command you mentioned above (without the --cpu all
command option) might work. I guess you will also need to export these variables (as mentioned in the issue above)
export LIBTORCH=$(python -c 'import torch; from pathlib import Path; print(Path(torch.__file__).parent)')
export DYLD_LIBRARY_PATH=${LIBTORCH}/lib
export LIBTORCH_CXX11_ABI=0
As I said, I can't try it but I hope it works. Alternatively, here you can find a colab notebook to use diffusers-rs with cuda.
from diffusers-rs.
I got this working, but it took a few more steps.
Setting those exports is enough to get it to compile and execute, but on the CPU.
Changing let cuda_device = Device::cuda_if_available()
to let cuda_device = Device::Mps
causes it to fail:
Cuda available: false
Cudnn available: false
Running with prompt "A rusty robot holding a fire torch.".
Building the Clip transformer.
Error: Internal torch error: supported devices include CPU, CUDA and HPU, however got MPS
Exception raised from readInstruction at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/serialization/unpickler.cpp:531 (most recent call first):
But there is a workaround suggested upstream (for a bug further upstream), and that works:
diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs
index 98d99e5..477518f 100644
--- a/examples/stable-diffusion/main.rs
+++ b/examples/stable-diffusion/main.rs
@@ -230,7 +230,7 @@ fn run(args: Args) -> anyhow::Result<()> {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size)
}
};
- let cuda_device = Device::cuda_if_available();
+ let cuda_device = Device::Mps;
let cpu_or_cuda = |name: &str| {
if cpu.iter().any(|c| c == "all" || c == name) {
Device::Cpu
diff --git a/src/pipelines/stable_diffusion.rs b/src/pipelines/stable_diffusion.rs
index e5a5813..bb0c65c 100644
--- a/src/pipelines/stable_diffusion.rs
+++ b/src/pipelines/stable_diffusion.rs
@@ -97,10 +97,12 @@ impl StableDiffusionConfig {
vae_weights: &str,
device: Device,
) -> anyhow::Result<vae::AutoEncoderKL> {
- let mut vs_ae = nn::VarStore::new(device);
+ let mut vs_ae = nn::VarStore::new(tch::Device::Mps);
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae.root(), 3, 3, self.autoencoder.clone());
+ vs_ae.set_device(tch::Device::Cpu);
vs_ae.load(vae_weights)?;
+ vs_ae.set_device(tch::Device::Mps);
Ok(autoencoder)
}
@@ -110,10 +112,12 @@ impl StableDiffusionConfig {
device: Device,
in_channels: i64,
) -> anyhow::Result<unet_2d::UNet2DConditionModel> {
- let mut vs_unet = nn::VarStore::new(device);
+ let mut vs_unet = nn::VarStore::new(tch::Device::Mps);
let unet =
unet_2d::UNet2DConditionModel::new(vs_unet.root(), in_channels, 4, self.unet.clone());
+ vs_unet.set_device(tch::Device::Cpu);
vs_unet.load(unet_weights)?;
+ vs_unet.set_device(tch::Device::Mps);
Ok(unet)
}
@@ -126,9 +130,11 @@ impl StableDiffusionConfig {
clip_weights: &str,
device: tch::Device,
) -> anyhow::Result<clip::ClipTextTransformer> {
- let mut vs = tch::nn::VarStore::new(device);
+ let mut vs = tch::nn::VarStore::new(tch::Device::Mps);
let text_model = clip::ClipTextTransformer::new(vs.root(), &self.clip);
+ vs.set_device(tch::Device::Cpu);
vs.load(clip_weights)?;
+ vs.set_device(tch::Device::Mps);
Ok(text_model)
}
}
Obviously hardcoding the device isn't what you'd want to do in the actual project, but it works if you just want to get something working locally.
Looks like tch-rs might set up this workaround in that crate, so you may want to just wait for that to get landed and released.
from diffusers-rs.
@LaurentMazare Sweet! Would it be possible to update the logic to default to the MPS device when it's available?
from diffusers-rs.
@bakkot sounds like a good idea, could you give a try at the stable-diffusion example using #50 and see if that works well on a device where mps is available? (and that it seems to actually use the device rather than the cpu)
from diffusers-rs.
Closing this as the related PR has been merged for a while, feel free to re-open if it's still an issue (I don't have a mac at hand to test).
from diffusers-rs.
Related Issues (20)
- Feature Request: Negative prompts HOT 1
- Add Scheduler trait/enum HOT 2
- Google Colab Notebook to run diffusion experiment on the GPU
- Embed the examples logic into the pipeline HOT 1
- How to load a parameter file in safetensors format? HOT 1
- PytorchStreamReader failed reading zip archive HOT 2
- ControlNet support? HOT 5
- Bad distorted picture using the in-painting example provided HOT 4
- Loading of text embeddings in pt format? HOT 2
- Example of inpaint doesn't work for Stable Diffusion 2.1 HOT 2
- CUDA out of memory on 12GB GPU HOT 2
- Error: The system cannot find the file specified. (os error 2) HOT 2
- Tracking issue for SD ecosystem feature parity HOT 6
- DirectML Support HOT 1
- Cannot link when used together with cxx-qt crate HOT 1
- CUDA/GPU Not Working. HOT 1
- STATUS_DLL_NOT_FOUND HOT 1
- Benchmarks? HOT 1
- Integration with Stable Diffusion XL 1.0 ? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from diffusers-rs.