🌾 JAX/Flax integration for super fast Stable Diffusion on TPUs.
We added JAX support for Stable Diffusion! You can now run Stable Diffusion on Colab TPUs (and GPUs too!) for faster inference.
Check out this TPU-ready colab for a Stable Diffusion pipeline:
And a detailed blog post on Stable Diffusion and parallelism in JAX / Flax 🤗 https://huggingface.co/blog/stable_diffusion_jax
The most used models, schedulers and pipelines have been ported to JAX/Flax, namely:
- Models:
FlaxAutoencoderKL
,FlaxUNet2DConditionModel
- Schedulers:
FlaxDDIMScheduler
,FlaxDDIMScheduler
,FlaxPNDMScheduler
- Pipelines:
FlaxStableDiffusionPipeline
Changelog:
- Implement FlaxModelMixin #493 by @mishig25 , @patil-suraj, @patrickvonplaten , @pcuenca
- Karras VE, DDIM and DDPM flax schedulers #508 by @kashif
- initial flax pndm scheduler #492 by @kashif
- FlaxDiffusionPipeline & FlaxStableDiffusionPipeline #559 by @mishig25 , @patrickvonplaten , @pcuenca
- Flax pipeline pndm #583 by @pcuenca
- Add from_pt argument in .from_pretrained #527 by @younesbelkada
- Make flax from_pretrained work with local subfolder #608 by @mishig25
🔥 DeepSpeed low-memory training
Thanks to the 🤗 accelerate
integration with DeepSpeed, a few of our training examples became even more optimized in terms of VRAM and speed:
- DreamBooth is now trainable on 8GB GPUs thanks to a contribution from @Ttl! Find out how to run it here.
- The Text2Image finetuning example is also fully compatible with DeepSpeed.
✏️ Changelog
- Revert "[v0.4.0] Temporarily remove Flax modules from the public API by @anton-l in #755)"
- Fix push_to_hub for dreambooth and textual_inversion by @YaYaB in #748
- Fix ONNX conversion script opset argument type by @justinchuby in #739
- Add final latent slice checks to SD pipeline intermediate state tests by @jamestiotio in #731
- fix(DDIM scheduler): use correct dtype for noise by @keturn in #742
- [Tests] Fix tests by @patrickvonplaten in #774
- debug an exception by @LowinLi in #638
- Clean up resnet.py file by @natolambert in #780
- add sigmoid betas by @natolambert in #777
- [Low CPU memory] + device map by @patrickvonplaten in #772
- Fix gradient checkpointing test by @patrickvonplaten in #797
- fix typo docstring in unet2d by @natolambert in #798
- DreamBooth DeepSpeed support for under 8 GB VRAM training by @Ttl in #735
- support bf16 for stable diffusion by @patil-suraj in #792
- stable diffusion fine-tuning by @patil-suraj in #356
- Flax: Trickle down
norm_num_groups
by @akash5474 in #789 - Eventually preserve this typo? :) by @spezialspezial in #804
- Fix indentation in the code example by @osanseviero in #802
- [Img2Img] Fix batch size mismatch prompts vs. init images by @patrickvonplaten in #793
- Minor package fixes by @anton-l in #809
- [Dummy imports] Better error message by @patrickvonplaten in #795
- add or fix license formatting in models directory by @natolambert in #808
- [train_text2image] Fix EMA and make it compatible with deepspeed. by @patil-suraj in #813
- Fix fine-tuning compatibility with deepspeed by @pink-red in #816
- Add diffusers version and pipeline class to the Hub UA by @anton-l in #814
- [Flax] Add test by @patrickvonplaten in #824
- update flax scheduler API by @patil-suraj in #822
- Fix dreambooth loss type with prior_preservation and fp16 by @anton-l in #826
- Fix type mismatch error, add tests for negative prompts by @anton-l in #823
- Give more customizable options for safety checker by @patrickvonplaten in #815
- Flax safety checker by @pcuenca in #825
- Align PT and Flax API - allow loading checkpoint from PyTorch configs by @patrickvonplaten in #827